diff --git a/src/ophyd_async/core/device.py b/src/ophyd_async/core/device.py index ce7beaeecb..9b3631bc40 100644 --- a/src/ophyd_async/core/device.py +++ b/src/ophyd_async/core/device.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio import sys from functools import cached_property from logging import LoggerAdapter, getLogger @@ -33,6 +34,10 @@ class Device(HasName): #: The parent Device if it exists parent: Optional[Device] = None + # Previous connect was successful, False on initialization + # since there was no previous connect + _previous_connect_success: bool = False + def __init__(self, name: str = "") -> None: self.set_name(name) @@ -71,7 +76,12 @@ def set_name(self, name: str): child.set_name(child_name) child.parent = self - async def connect(self, sim: bool = False, timeout: float = DEFAULT_TIMEOUT): + async def connect( + self, + sim: bool = False, + timeout: float = DEFAULT_TIMEOUT, + force_reconnect=False, + ): """Connect self and all child Devices. Contains a timeout that gets propagated to child.connect methods. @@ -83,12 +93,18 @@ async def connect(self, sim: bool = False, timeout: float = DEFAULT_TIMEOUT): timeout: Time to wait before failing with a TimeoutError. """ - coros = { - name: child_device.connect(sim, timeout=timeout) - for name, child_device in self.children() - } - if coros: - await wait_for_connection(**coros) + + if force_reconnect or not self._previous_connect_success: + # Kick off a connection + coros = { + name: child_device.connect(sim, timeout=timeout) + for name, child_device in self.children() + } + connect_task = asyncio.create_task(wait_for_connection(**coros)) + + # Wait for it to complete + await connect_task + self._previous_connect_success = not connect_task.exception() VT = TypeVar("VT", bound=Device) diff --git a/src/ophyd_async/epics/pvi/pvi.py b/src/ophyd_async/epics/pvi/pvi.py index e68cf3b132..d953a19f6f 100644 --- a/src/ophyd_async/epics/pvi/pvi.py +++ b/src/ophyd_async/epics/pvi/pvi.py @@ -91,7 +91,7 @@ def _verify_common_blocks(entry: PVIEntry, common_device: Type[Device]): return common_sub_devices = get_type_hints(common_device) for sub_name, sub_device in common_sub_devices.items(): - if sub_name in ("_name", "parent"): + if sub_name.startswith("_") or sub_name == "parent": continue assert entry.sub_entries device_t, is_optional = _strip_union(sub_device) @@ -161,7 +161,7 @@ def _sim_common_blocks(device: Device, stripped_type: Optional[Type] = None): sub_devices = ( (field, field_type) for field, field_type in get_type_hints(device_t).items() - if field not in ("_name", "parent") + if not field.startswith("_") and field != "parent" ) for device_name, device_cls in sub_devices: diff --git a/src/ophyd_async/planstubs/__init__.py b/src/ophyd_async/planstubs/__init__.py index cc409ce3a1..d97ec112e2 100644 --- a/src/ophyd_async/planstubs/__init__.py +++ b/src/ophyd_async/planstubs/__init__.py @@ -1,5 +1,9 @@ +from .ensure_connected import ensure_connected from .prepare_trigger_and_dets import ( prepare_static_seq_table_flyer_and_detectors_with_same_trigger, ) -__all__ = ["prepare_static_seq_table_flyer_and_detectors_with_same_trigger"] +__all__ = [ + "prepare_static_seq_table_flyer_and_detectors_with_same_trigger", + "ensure_connected", +] diff --git a/src/ophyd_async/planstubs/ensure_connected.py b/src/ophyd_async/planstubs/ensure_connected.py new file mode 100644 index 0000000000..78607be36c --- /dev/null +++ b/src/ophyd_async/planstubs/ensure_connected.py @@ -0,0 +1,22 @@ +import bluesky.plan_stubs as bps + +from ophyd_async.core.device import Device +from ophyd_async.core.utils import DEFAULT_TIMEOUT, wait_for_connection + + +def ensure_connected( + *devices: Device, + sim: bool = False, + timeout: float = DEFAULT_TIMEOUT, + force_reconnect=False, +): + yield from bps.wait_for( + [ + lambda: wait_for_connection( + **{ + device.name: device.connect(sim, timeout, force_reconnect) + for device in devices + } + ) + ] + ) diff --git a/tests/core/test_device.py b/tests/core/test_device.py index 97184aed54..4b8d0dd010 100644 --- a/tests/core/test_device.py +++ b/tests/core/test_device.py @@ -11,10 +11,14 @@ NotConnected, wait_for_connection, ) +from ophyd_async.epics.motion import motor +from ophyd_async.planstubs.ensure_connected import ( + ensure_connected, +) class DummyBaseDevice(Device): - def __init__(self) -> None: + def __init__(self): self.connected = False async def connect(self, sim=False, timeout=DEFAULT_TIMEOUT): @@ -117,3 +121,52 @@ async def test_device_log_has_correct_name(): assert device.log.extra["ophyd_async_device_name"] == "" device.set_name("device") assert device.log.extra["ophyd_async_device_name"] == "device" + + +async def test_device_lazily_connects(RE): + async with DeviceCollector(sim=True, connect=False): + sim_motor = motor.Motor("BLxxI-MO-TABLE-01:X") + + assert not sim_motor._previous_connect_success + + # When ready to connect + RE(ensure_connected(sim_motor, sim=True)) + + assert sim_motor._previous_connect_success + + +class MotorBundle(Device): + def __init__(self, name: str) -> None: + self.X = motor.Motor("BLxxI-MO-TABLE-01:X") + self.Y = motor.Motor("BLxxI-MO-TABLE-01:Y") + self.V: DeviceVector[motor.Motor] = DeviceVector( + { + 0: motor.Motor("BLxxI-MO-TABLE-21:X"), + 1: motor.Motor("BLxxI-MO-TABLE-21:Y"), + 2: motor.Motor("BLxxI-MO-TABLE-21:Z"), + } + ) + + +async def test_device_with_children_lazily_connects(RE): + parentMotor = MotorBundle("parentMotor") + + for device in [parentMotor, parentMotor.X, parentMotor.Y] + list( + parentMotor.V.values() + ): + assert not device._previous_connect_success + RE(ensure_connected(parentMotor, sim=True)) + + for device in [parentMotor, parentMotor.X, parentMotor.Y] + list( + parentMotor.V.values() + ): + assert device._previous_connect_success + + +async def test_device_with_device_collector_lazily_connects(): + sim_motor = motor.Motor("SOME_SIGNAL_WHICH_DOESN'T_EXIST:X") + with pytest.raises(NotConnected): + await sim_motor.connect(sim=False, timeout=0.01) + assert not sim_motor._previous_connect_success + await sim_motor.connect(sim=True, timeout=0.01) + assert sim_motor._previous_connect_success