From d2ed67e074aba99cacf9f64c23a6718e9e79e806 Mon Sep 17 00:00:00 2001 From: "Tom C (DLS)" <101418278+coretl@users.noreply.github.com> Date: Mon, 20 May 2024 15:49:08 +0100 Subject: [PATCH] Move timeout handling to functions wrapped with AsyncStatus (#318) * Make callbacks fire on AsyncStatus cancellation * Move timeout handling to AsyncStatus coros * Allow timeouts to be overridden * Make test clearer * Unify timeout handling between signal and motor --- src/ophyd_async/core/__init__.py | 4 + src/ophyd_async/core/async_status.py | 39 ++---- src/ophyd_async/core/signal.py | 40 +++--- src/ophyd_async/core/utils.py | 11 ++ src/ophyd_async/epics/demo/__init__.py | 56 +++++---- src/ophyd_async/epics/motion/motor.py | 64 +++++----- src/ophyd_async/sim/demo/sim_motor.py | 142 +++++++++------------- tests/core/test_async_status.py | 9 +- tests/core/test_device.py | 4 +- tests/core/test_device_collector.py | 2 + tests/core/test_watchable_async_status.py | 8 +- tests/epics/motion/test_motor.py | 28 +++++ tests/sim/demo/test_sim_motor.py | 39 +----- 13 files changed, 211 insertions(+), 235 deletions(-) diff --git a/src/ophyd_async/core/__init__.py b/src/ophyd_async/core/__init__.py index f49e127df2..bc12836223 100644 --- a/src/ophyd_async/core/__init__.py +++ b/src/ophyd_async/core/__init__.py @@ -57,6 +57,8 @@ from .standard_readable import ConfigSignal, HintedSignal, StandardReadable from .utils import ( DEFAULT_TIMEOUT, + CalculatableTimeout, + CalculateTimeout, Callback, NotConnected, ReadingValueCallback, @@ -108,6 +110,8 @@ "TriggerInfo", "TriggerLogic", "HardwareTriggeredFlyable", + "CalculateTimeout", + "CalculatableTimeout", "DEFAULT_TIMEOUT", "Callback", "NotConnected", diff --git a/src/ophyd_async/core/async_status.py b/src/ophyd_async/core/async_status.py index 19e3f1d201..78d0a4ab93 100644 --- a/src/ophyd_async/core/async_status.py +++ b/src/ophyd_async/core/async_status.py @@ -9,7 +9,6 @@ Awaitable, Callable, Generic, - SupportsFloat, Type, TypeVar, cast, @@ -27,15 +26,11 @@ class AsyncStatusBase(Status): """Convert asyncio awaitable to bluesky Status interface""" - def __init__(self, awaitable: Awaitable, timeout: SupportsFloat | None = None): - if isinstance(timeout, SupportsFloat): - timeout = float(timeout) + def __init__(self, awaitable: Awaitable): if isinstance(awaitable, asyncio.Task): self.task = awaitable else: - self.task = asyncio.create_task( - asyncio.wait_for(awaitable, timeout=timeout) - ) + self.task = asyncio.create_task(awaitable) self.task.add_done_callback(self._run_callbacks) self._callbacks: list[Callback[Status]] = [] @@ -49,9 +44,8 @@ def add_callback(self, callback: Callback[Status]): self._callbacks.append(callback) def _run_callbacks(self, task: asyncio.Task): - if not task.cancelled(): - for callback in self._callbacks: - callback(self) + for callback in self._callbacks: + callback(self) def exception(self, timeout: float | None = 0.0) -> BaseException | None: if timeout != 0.0: @@ -93,14 +87,11 @@ def __repr__(self) -> str: class AsyncStatus(AsyncStatusBase): @classmethod def wrap(cls: Type[AS], f: Callable[P, Awaitable]) -> Callable[P, AS]: + """Wrap an async function in an AsyncStatus.""" + @functools.wraps(f) def wrap_f(*args: P.args, **kwargs: P.kwargs) -> AS: - # We can't type this more properly because Concatenate/ParamSpec doesn't - # yet support keywords - # https://peps.python.org/pep-0612/#concatenating-keyword-parameters - timeout = kwargs.get("timeout") - assert isinstance(timeout, SupportsFloat) or timeout is None - return cls(f(*args, **kwargs), timeout=timeout) + return cls(f(*args, **kwargs)) # type is actually functools._Wrapped[P, Awaitable, P, AS] # but functools._Wrapped is not necessarily available @@ -110,15 +101,11 @@ def wrap_f(*args: P.args, **kwargs: P.kwargs) -> AS: class WatchableAsyncStatus(AsyncStatusBase, Generic[T]): """Convert AsyncIterator of WatcherUpdates to bluesky Status interface.""" - def __init__( - self, - iterator: AsyncIterator[WatcherUpdate[T]], - timeout: SupportsFloat | None = None, - ): + def __init__(self, iterator: AsyncIterator[WatcherUpdate[T]]): self._watchers: list[Watcher] = [] self._start = time.monotonic() self._last_update: WatcherUpdate[T] | None = None - super().__init__(self._notify_watchers_from(iterator), timeout) + super().__init__(self._notify_watchers_from(iterator)) async def _notify_watchers_from(self, iterator: AsyncIterator[WatcherUpdate[T]]): async for update in iterator: @@ -146,14 +133,10 @@ def wrap( cls: Type[WAS], f: Callable[P, AsyncIterator[WatcherUpdate[T]]], ) -> Callable[P, WAS]: - """Wrap an AsyncIterator in a WatchableAsyncStatus. If it takes - 'timeout' as an argument, this must be a float or None, and it - will be propagated to the status.""" + """Wrap an AsyncIterator in a WatchableAsyncStatus.""" @functools.wraps(f) def wrap_f(*args: P.args, **kwargs: P.kwargs) -> WAS: - timeout = kwargs.get("timeout") - assert isinstance(timeout, SupportsFloat) or timeout is None - return cls(f(*args, **kwargs), timeout=timeout) + return cls(f(*args, **kwargs)) return cast(Callable[P, WAS], wrap_f) diff --git a/src/ophyd_async/core/signal.py b/src/ophyd_async/core/signal.py index de72252918..4842017aa5 100644 --- a/src/ophyd_async/core/signal.py +++ b/src/ophyd_async/core/signal.py @@ -32,7 +32,7 @@ from .device import Device from .signal_backend import SignalBackend from .soft_signal_backend import SoftSignalBackend -from .utils import DEFAULT_TIMEOUT, Callback, T +from .utils import DEFAULT_TIMEOUT, CalculatableTimeout, CalculateTimeout, Callback, T def _add_timeout(func): @@ -213,15 +213,14 @@ async def unstage(self) -> None: self._del_cache(self._get_cache().set_staged(False)) -USE_DEFAULT_TIMEOUT = "USE_DEFAULT_TIMEOUT" - - class SignalW(Signal[T], Movable): """Signal that can be set""" - def set(self, value: T, wait=True, timeout=USE_DEFAULT_TIMEOUT) -> AsyncStatus: + def set( + self, value: T, wait=True, timeout: CalculatableTimeout = CalculateTimeout + ) -> AsyncStatus: """Set the value and return a status saying when it's done""" - if timeout is USE_DEFAULT_TIMEOUT: + if timeout is CalculateTimeout: timeout = self._timeout async def do_set(): @@ -248,9 +247,11 @@ async def locate(self) -> Location: class SignalX(Signal): """Signal that puts the default value""" - def trigger(self, wait=True, timeout=USE_DEFAULT_TIMEOUT) -> AsyncStatus: + def trigger( + self, wait=True, timeout: CalculatableTimeout = CalculateTimeout + ) -> AsyncStatus: """Trigger the action and return a status saying when it's done""" - if timeout is USE_DEFAULT_TIMEOUT: + if timeout is CalculateTimeout: timeout = self._timeout coro = self._backend.put(None, wait=wait, timeout=timeout) return AsyncStatus(coro) @@ -270,7 +271,7 @@ def soft_signal_r_and_setter( datatype: Optional[Type[T]] = None, initial_value: Optional[T] = None, name: str = "", -) -> Tuple[SignalR[T], Callable[[T]]]: +) -> Tuple[SignalR[T], Callable[[T], None]]: """Returns a tuple of a read-only Signal and a callable through which the signal can be internally modified within the device. Use soft_signal_rw if you want a device that is externally modifiable @@ -394,7 +395,7 @@ def assert_emitted(docs: Mapping[str, list[dict]], **numbers: int): async def observe_value( - signal: SignalR[T], timeout=None, done_status: Status | None = None + signal: SignalR[T], timeout: float | None = None, done_status: Status | None = None ) -> AsyncGenerator[T, None]: """Subscribe to the value of a signal so it can be iterated from. @@ -403,8 +404,12 @@ async def observe_value( signal: Call subscribe_value on this at the start, and clear_sub on it at the end + timeout: + If given, how long to wait for each updated value in seconds. If an update + is not produced in this time then raise asyncio.TimeoutError done_status: If this status is complete, stop observing and make the iterator return. + If it raises an exception then this exception will be raised by the iterator. Notes ----- @@ -414,9 +419,7 @@ async def observe_value( do_something_with(value) """ - class StatusIsDone: ... - - q: asyncio.Queue[T | StatusIsDone] = asyncio.Queue() + q: asyncio.Queue[T | Status] = asyncio.Queue() if timeout is None: get_value = q.get else: @@ -425,16 +428,19 @@ async def get_value(): return await asyncio.wait_for(q.get(), timeout) if done_status is not None: - done_status.add_callback(lambda _: q.put_nowait(StatusIsDone())) + done_status.add_callback(q.put_nowait) signal.subscribe_value(q.put_nowait) try: while True: item = await get_value() - if not isinstance(item, StatusIsDone): - yield item + if done_status and item is done_status: + if exc := done_status.exception(): + raise exc + else: + break else: - break + yield item finally: signal.clear_sub(q.put_nowait) diff --git a/src/ophyd_async/core/utils.py b/src/ophyd_async/core/utils.py index b09b9322a1..f5098ce717 100644 --- a/src/ophyd_async/core/utils.py +++ b/src/ophyd_async/core/utils.py @@ -31,6 +31,17 @@ ErrorText = Union[str, Dict[str, Exception]] +class CalculateTimeout: + """Sentinel class used to implement ``myfunc(timeout=CalculateTimeout)`` + + This signifies that the function should calculate a suitable non-zero + timeout itself + """ + + +CalculatableTimeout = float | None | Type[CalculateTimeout] + + class NotConnected(Exception): """Exception to be raised if a `Device.connect` is cancelled""" diff --git a/src/ophyd_async/epics/demo/__init__.py b/src/ophyd_async/epics/demo/__init__.py index 4e183d8548..13d73bb4b3 100644 --- a/src/ophyd_async/epics/demo/__init__.py +++ b/src/ophyd_async/epics/demo/__init__.py @@ -6,7 +6,6 @@ import string import subprocess import sys -from dataclasses import replace from enum import Enum from pathlib import Path @@ -22,7 +21,13 @@ WatchableAsyncStatus, observe_value, ) -from ophyd_async.core.utils import WatcherUpdate +from ophyd_async.core.async_status import AsyncStatus +from ophyd_async.core.utils import ( + DEFAULT_TIMEOUT, + CalculatableTimeout, + CalculateTimeout, + WatcherUpdate, +) from ..signal.signal import epics_signal_r, epics_signal_rw, epics_signal_x @@ -66,11 +71,9 @@ def __init__(self, prefix: str, name="") -> None: # Define some signals with self.add_children_as_readables(HintedSignal): self.readback = epics_signal_r(float, prefix + "Readback") - with self.add_children_as_readables(ConfigSignal): self.velocity = epics_signal_rw(float, prefix + "Velocity") self.units = epics_signal_r(str, prefix + "Readback.EGU") - self.setpoint = epics_signal_rw(float, prefix + "Setpoint") self.precision = epics_signal_r(int, prefix + "Readback.PREC") # Signals that collide with standard methods should have a trailing underscore @@ -85,41 +88,42 @@ def set_name(self, name: str): # Readback should be named the same as its parent in read() self.readback.set_name(name) - async def _move(self, new_position: float): + @WatchableAsyncStatus.wrap + async def set( + self, new_position: float, timeout: CalculatableTimeout = CalculateTimeout + ): self._set_success = True - # time.monotonic won't go backwards in case of NTP corrections - old_position, units, precision = await asyncio.gather( + old_position, units, precision, velocity = await asyncio.gather( self.setpoint.get_value(), self.units.get_value(), self.precision.get_value(), + self.velocity.get_value(), ) + if timeout is CalculateTimeout: + assert velocity > 0, "Mover has zero velocity" + timeout = abs(new_position - old_position) / velocity + DEFAULT_TIMEOUT + # Make an Event that will be set on completion, and a Status that will + # error if not done in time + done = asyncio.Event() + done_status = AsyncStatus(asyncio.wait_for(done.wait(), timeout)) # Wait for the value to set, but don't wait for put completion callback - move_status = self.setpoint.set(new_position, wait=True) - if not self._set_success: - raise RuntimeError("Motor was stopped") - # return a template to set() which it can use to yield progress updates - return ( - WatcherUpdate( + await self.setpoint.set(new_position, wait=False) + async for current_position in observe_value( + self.readback, done_status=done_status + ): + yield WatcherUpdate( + current=current_position, initial=old_position, - current=old_position, target=new_position, + name=self.name, unit=units, precision=precision, - ), - move_status, - ) - - @WatchableAsyncStatus.wrap # uses the timeout argument from the function it wraps - async def set(self, new_position: float, timeout: float | None = None): - update, _ = await self._move(new_position) - async for current_position in observe_value(self.readback): - yield replace( - update, - name=self.name, - current=current_position, ) if np.isclose(current_position, new_position): + done.set() break + if not self._set_success: + raise RuntimeError("Motor was stopped") async def stop(self, success=True): self._set_success = success diff --git a/src/ophyd_async/epics/motion/motor.py b/src/ophyd_async/epics/motion/motor.py index 4310a58628..c6a20d300c 100644 --- a/src/ophyd_async/epics/motion/motor.py +++ b/src/ophyd_async/epics/motion/motor.py @@ -1,17 +1,20 @@ import asyncio -from dataclasses import replace from bluesky.protocols import Movable, Stoppable from ophyd_async.core import ( - AsyncStatus, ConfigSignal, HintedSignal, StandardReadable, WatchableAsyncStatus, ) from ophyd_async.core.signal import observe_value -from ophyd_async.core.utils import WatcherUpdate +from ophyd_async.core.utils import ( + DEFAULT_TIMEOUT, + CalculatableTimeout, + CalculateTimeout, + WatcherUpdate, +) from ..signal.signal import epics_signal_r, epics_signal_rw, epics_signal_x @@ -47,42 +50,45 @@ def set_name(self, name: str): # Readback should be named the same as its parent in read() self.user_readback.set_name(name) - async def _move( - self, new_position: float - ) -> tuple[WatcherUpdate[float], AsyncStatus]: + @WatchableAsyncStatus.wrap + async def set( + self, new_position: float, timeout: CalculatableTimeout = CalculateTimeout + ): self._set_success = True - old_position, units, precision = await asyncio.gather( + ( + old_position, + units, + precision, + velocity, + acceleration_time, + ) = await asyncio.gather( self.user_setpoint.get_value(), self.motor_egu.get_value(), self.precision.get_value(), + self.velocity.get_value(), + self.acceleration_time.get_value(), ) - move_status = self.user_setpoint.set(new_position, wait=True) - if not self._set_success: - raise RuntimeError("Motor was stopped") - return ( - WatcherUpdate( - initial=old_position, - current=old_position, - target=new_position, - unit=units, - precision=precision, - ), - move_status, - ) - - @WatchableAsyncStatus.wrap - async def set(self, new_position: float, timeout: float | None = None): - update, move_status = await self._move(new_position) + if timeout is CalculateTimeout: + assert velocity > 0, "Motor has zero velocity" + timeout = ( + abs(new_position - old_position) / velocity + + 2 * acceleration_time + + DEFAULT_TIMEOUT + ) + move_status = self.user_setpoint.set(new_position, wait=True, timeout=timeout) async for current_position in observe_value( self.user_readback, done_status=move_status ): - if not self._set_success: - raise RuntimeError("Motor was stopped") - yield replace( - update, - name=self.name, + yield WatcherUpdate( current=current_position, + initial=old_position, + target=new_position, + name=self.name, + unit=units, + precision=precision, ) + if not self._set_success: + raise RuntimeError("Motor was stopped") async def stop(self, success=False): self._set_success = success diff --git a/src/ophyd_async/sim/demo/sim_motor.py b/src/ophyd_async/sim/demo/sim_motor.py index aeb4f38844..db2083ab02 100644 --- a/src/ophyd_async/sim/demo/sim_motor.py +++ b/src/ophyd_async/sim/demo/sim_motor.py @@ -1,6 +1,6 @@ import asyncio +import contextlib import time -from dataclasses import replace from bluesky.protocols import Movable, Stoppable @@ -25,109 +25,79 @@ def __init__(self, name="", instant=True) -> None: - name: str: name of device - instant: bool: whether to move instantly, or with a delay """ + # Define some signals with self.add_children_as_readables(HintedSignal): self.user_readback, self._user_readback_set = soft_signal_r_and_setter( float, 0 ) - with self.add_children_as_readables(ConfigSignal): - self.velocity = soft_signal_rw(float, 1.0) - self.egu = soft_signal_rw(str, "mm") - - self._instant = instant - self._move_status: AsyncStatus | None = None - - # Define some signals + self.velocity = soft_signal_rw(float, 0 if instant else 1.0) + self.units = soft_signal_rw(str, "mm") self.user_setpoint = soft_signal_rw(float, 0) - super().__init__(name=name) - # Whether set() should complete successfully or not self._set_success = True + self._move_status: AsyncStatus | None = None - def stop(self, success=False): - """ - Stop the motor if it is moving - """ - if self._move_status: - self._move_status.task.cancel() - self._move_status = None + super().__init__(name=name) - async def trigger_callbacks(): - await self.user_readback._backend.put( - await self.user_readback._backend.get_value() - ) + async def _move(self, old_position: float, new_position: float, move_time: float): + start = time.monotonic() + distance = abs(new_position - old_position) + while True: + time_elapsed = round(time.monotonic() - start, 2) - asyncio.create_task(trigger_callbacks()) + # update position based on time elapsed + if time_elapsed >= move_time: + # successfully reached our target position + self._user_readback_set(new_position) + break + else: + current_position = old_position + distance * time_elapsed / move_time - self._set_success = success + self._user_readback_set(current_position) + + # 10hz update loop + await asyncio.sleep(0.1) @WatchableAsyncStatus.wrap - async def set(self, new_position: float, timeout: float | None = None): + async def set(self, new_position: float): """ Asynchronously move the motor to a new position. """ - update, move_status = await self._move(new_position, timeout) - async for current_position in observe_value( - self.user_readback, done_status=move_status - ): - if not self._set_success: - raise RuntimeError("Motor was stopped") - yield replace( - update, - name=self.name, - current=current_position, - ) - - async def _move(self, new_position: float, timeout: float | None = None): - """ - Start the motor moving to a new position. - - If the motor is already moving, it will stop first. - If this is an instant motor the move will be instantaneous. - """ - self.stop() - start = time.monotonic() - self._set_success = True - - current_position = await self.user_readback.get_value() - distance = abs(new_position - current_position) - travel_time = 0 if self._instant else distance / await self.velocity.get_value() - - old_position, units = await asyncio.gather( + # Make sure any existing move tasks are stopped + await self.stop() + old_position, units, velocity = await asyncio.gather( self.user_setpoint.get_value(), - self.egu.get_value(), + self.units.get_value(), + self.velocity.get_value(), ) - - async def update_position(): - while True: - time_elapsed = round(time.monotonic() - start, 2) - - # update position based on time elapsed - if time_elapsed >= travel_time: - # successfully reached our target position - self._user_readback_set(new_position) - self._set_success = True - break - else: - current_position = ( - old_position + distance * time_elapsed / travel_time - ) - - self._user_readback_set(current_position) - - # 10hz update loop - await asyncio.sleep(0.1) - - # set up a task that updates the motor position at ~10hz - self._move_status = AsyncStatus(asyncio.wait_for(update_position(), timeout)) - - return ( - WatcherUpdate( - initial=old_position, - current=old_position, - target=new_position, - unit=units, - ), - self._move_status, + # If zero velocity, do instant move + move_time = abs(new_position - old_position) / velocity if velocity else 0 + self._move_status = AsyncStatus( + self._move(old_position, new_position, move_time) ) + # If stop is called then this will raise a CancelledError, ignore it + with contextlib.suppress(asyncio.CancelledError): + async for current_position in observe_value( + self.user_readback, done_status=self._move_status + ): + yield WatcherUpdate( + current=current_position, + initial=old_position, + target=new_position, + name=self.name, + unit=units, + ) + if not self._set_success: + raise RuntimeError("Motor was stopped") + + async def stop(self, success=True): + """ + Stop the motor if it is moving + """ + self._set_success = success + if self._move_status: + self._move_status.task.cancel() + self._move_status = None + await self.user_setpoint.set(await self.user_readback.get_value()) diff --git a/tests/core/test_async_status.py b/tests/core/test_async_status.py index 842df32136..3a0ce189a3 100644 --- a/tests/core/test_async_status.py +++ b/tests/core/test_async_status.py @@ -50,17 +50,18 @@ async def test_async_status_has_no_exception_if_coroutine_successful(normal_coro async def test_async_status_success_if_cancelled(normal_coroutine): + cbs = [] coro = normal_coroutine() status = AsyncStatus(coro) + status.add_callback(cbs.append) assert status.exception() is None status.task.cancel() + assert not cbs with pytest.raises(asyncio.CancelledError): await status + assert cbs == [status] assert status.success is False assert isinstance(status.exception(), asyncio.CancelledError) - # asyncio will RuntimeWarning us about this never being awaited if we don't. - # RunEngine handled this as a special case - await coro async def coroutine_to_wrap(time: float): @@ -126,7 +127,7 @@ def set(self, value) -> AsyncStatus: async def test_status_propogates_traceback_under_RE(RE) -> None: - expected_call_stack = ["wait_for", "_set", "_fail"] + expected_call_stack = ["_set", "_fail"] d = FailingMovable() with pytest.raises(FailedStatus) as ctx: RE(bps.mv(d, 3)) diff --git a/tests/core/test_device.py b/tests/core/test_device.py index ae8af75d9b..4b09aaf764 100644 --- a/tests/core/test_device.py +++ b/tests/core/test_device.py @@ -149,10 +149,10 @@ async def test_device_mock_and_back_again(RE): motor = SimMotor("motor") assert not motor._connect_task await motor.connect(mock=False) - assert isinstance(motor.egu._backend, SoftSignalBackend) + assert isinstance(motor.units._backend, SoftSignalBackend) assert motor._connect_task await motor.connect(mock=True) - assert isinstance(motor.egu._backend, MockSignalBackend) + assert isinstance(motor.units._backend, MockSignalBackend) class MotorBundle(Device): diff --git a/tests/core/test_device_collector.py b/tests/core/test_device_collector.py index 35c71508a9..f2f8d36fce 100644 --- a/tests/core/test_device_collector.py +++ b/tests/core/test_device_collector.py @@ -6,6 +6,7 @@ from super_state_machine.errors import TransitionError from ophyd_async.core import DEFAULT_TIMEOUT, Device, DeviceCollector, NotConnected +from ophyd_async.core.mock_signal_utils import set_mock_value from ophyd_async.epics.motion import motor @@ -86,6 +87,7 @@ def test_async_device_connector_run_engine_same_event_loop(): async def set_up_device(): async with DeviceCollector(mock=True): mock_motor = motor.Motor("BLxxI-MO-TABLE-01:X") + set_mock_value(mock_motor.velocity, 1) return mock_motor loop = asyncio.new_event_loop() diff --git a/tests/core/test_watchable_async_status.py b/tests/core/test_watchable_async_status.py index 721f191467..f7c6c04aeb 100644 --- a/tests/core/test_watchable_async_status.py +++ b/tests/core/test_watchable_async_status.py @@ -86,13 +86,9 @@ async def set(self, val): class ASTestDeviceTimeoutSet(ASTestDevice): @WatchableAsyncStatus.wrap async def set(self, val, timeout=0.01): - assert self._staged - await asyncio.sleep(0.01) - self._sig_setter(val - 1) # type: ignore - await asyncio.sleep(0.01) - yield WatcherUpdate(1, 1, 1) - await asyncio.sleep(0.01) + await asyncio.sleep(timeout) yield WatcherUpdate(1, 1, 1) + raise asyncio.TimeoutError() class ASTestDeviceIteratorSet(ASTestDevice): diff --git a/tests/epics/motion/test_motor.py b/tests/epics/motion/test_motor.py index 4b45ba66d1..dd4fcb0b4a 100644 --- a/tests/epics/motion/test_motor.py +++ b/tests/epics/motion/test_motor.py @@ -11,6 +11,7 @@ set_mock_put_proceeds, set_mock_value, ) +from ophyd_async.core.mock_signal_utils import callback_on_mock_put from ophyd_async.epics.motion import motor # Long enough for multiple asyncio event loop cycles to run so @@ -118,6 +119,33 @@ async def test_motor_moving_well_2(sim_motor: motor.Motor) -> None: done.assert_called_once_with(s) +async def test_motor_move_timeout(sim_motor: motor.Motor): + class MyTimeout(Exception): + pass + + def do_timeout(value, wait=False, timeout=None): + # Check we were given the right timeout of move_time + DEFAULT_TIMEOUT + assert timeout == 10.3 + # Raise custom exception to be clear it bubbles up + raise MyTimeout() + + callback_on_mock_put(sim_motor.user_setpoint, do_timeout) + s = sim_motor.set(0.3) + watcher = Mock() + s.watch(watcher) + with pytest.raises(MyTimeout): + await s + watcher.assert_called_once_with( + name="sim_motor", + current=0.0, + initial=0.0, + target=0.3, + unit="mm", + precision=3, + time_elapsed=pytest.approx(0.0, abs=0.05), + ) + + async def test_motor_moving_stopped(sim_motor: motor.Motor): set_mock_value(sim_motor.motor_done_move, False) set_mock_put_proceeds(sim_motor.user_setpoint, False) diff --git a/tests/sim/demo/test_sim_motor.py b/tests/sim/demo/test_sim_motor.py index 7c6ec074c0..b8953df8c8 100644 --- a/tests/sim/demo/test_sim_motor.py +++ b/tests/sim/demo/test_sim_motor.py @@ -4,7 +4,6 @@ from bluesky.plans import spiral_square from bluesky.run_engine import RunEngine -from ophyd_async.core.async_status import AsyncStatusBase from ophyd_async.core.device import DeviceCollector from ophyd_async.sim.demo.sim_motor import SimMotor @@ -45,45 +44,11 @@ async def test_stop(): # this move should take 10 seconds but we will stop it after 0.2 move_status = m1.set(10) - while not m1._move_status: - # wait to actually get the move started - await asyncio.sleep(0) await asyncio.sleep(0.2) - m1.stop() - await asyncio.sleep(0) + await m1.stop(success=False) new_pos = await m1.user_readback.get_value() - - assert move_status.done - # move should not be successful as we stopped it - assert not move_status.success assert new_pos < 10 assert new_pos >= 0.1 - - -async def test_timeout(): - """ - Verify that timeout happens as expected for SimMotor moves. - - This test also verifies that the two tasks involved in the move are - completed as expected. - """ - async with DeviceCollector(): - m1 = SimMotor("M1", instant=False) - - # do a 10 sec move that will timeout before arriving - move_status = m1.set(10, timeout=0.1) - await asyncio.sleep(0.2) - - # check inner status - assert m1._move_status is not None - assert m1._move_status.done - assert not move_status.success - assert move_status.task.cancelled - - # check outer status - assert isinstance(move_status, AsyncStatusBase) + # move should not be successful as we stopped it assert move_status.done assert not move_status.success - assert move_status.task.cancelled - new_pos = await m1.user_readback.get_value() - assert new_pos < 10