diff --git a/src/ophyd_async/core/device.py b/src/ophyd_async/core/device.py index ce7beaeecb..6036d2e4f0 100644 --- a/src/ophyd_async/core/device.py +++ b/src/ophyd_async/core/device.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio import sys from functools import cached_property from logging import LoggerAdapter, getLogger @@ -32,6 +33,8 @@ class Device(HasName): _name: str = "" #: The parent Device if it exists parent: Optional[Device] = None + # None if connect hasn't started, an Event if it has, a set Event if it's done + _connect_task: Optional[asyncio.Task] = None def __init__(self, name: str = "") -> None: self.set_name(name) @@ -71,7 +74,12 @@ def set_name(self, name: str): child.set_name(child_name) child.parent = self - async def connect(self, sim: bool = False, timeout: float = DEFAULT_TIMEOUT): + async def connect( + self, + sim: bool = False, + timeout: float = DEFAULT_TIMEOUT, + force_reconnect=False, + ): """Connect self and all child Devices. Contains a timeout that gets propagated to child.connect methods. @@ -83,12 +91,21 @@ async def connect(self, sim: bool = False, timeout: float = DEFAULT_TIMEOUT): timeout: Time to wait before failing with a TimeoutError. """ - coros = { - name: child_device.connect(sim, timeout=timeout) - for name, child_device in self.children() - } - if coros: - await wait_for_connection(**coros) + previous_connect_ok = ( + self._connect_task + and self._connect_task.done() + and not self._connect_task.exception() + ) + if force_reconnect or not previous_connect_ok: + # Kick off a connection + coros = { + name: child_device.connect(sim, timeout=timeout) + for name, child_device in self.children() + } + self._connect_task = asyncio.create_task(wait_for_connection(**coros)) + # Wait for it to complete + assert self._connect_task + await self._connect_task VT = TypeVar("VT", bound=Device) diff --git a/src/ophyd_async/epics/pvi/pvi.py b/src/ophyd_async/epics/pvi/pvi.py index e68cf3b132..d953a19f6f 100644 --- a/src/ophyd_async/epics/pvi/pvi.py +++ b/src/ophyd_async/epics/pvi/pvi.py @@ -91,7 +91,7 @@ def _verify_common_blocks(entry: PVIEntry, common_device: Type[Device]): return common_sub_devices = get_type_hints(common_device) for sub_name, sub_device in common_sub_devices.items(): - if sub_name in ("_name", "parent"): + if sub_name.startswith("_") or sub_name == "parent": continue assert entry.sub_entries device_t, is_optional = _strip_union(sub_device) @@ -161,7 +161,7 @@ def _sim_common_blocks(device: Device, stripped_type: Optional[Type] = None): sub_devices = ( (field, field_type) for field, field_type in get_type_hints(device_t).items() - if field not in ("_name", "parent") + if not field.startswith("_") and field != "parent" ) for device_name, device_cls in sub_devices: diff --git a/src/ophyd_async/planstubs/__init__.py b/src/ophyd_async/planstubs/__init__.py index cc409ce3a1..d97ec112e2 100644 --- a/src/ophyd_async/planstubs/__init__.py +++ b/src/ophyd_async/planstubs/__init__.py @@ -1,5 +1,9 @@ +from .ensure_connected import ensure_connected from .prepare_trigger_and_dets import ( prepare_static_seq_table_flyer_and_detectors_with_same_trigger, ) -__all__ = ["prepare_static_seq_table_flyer_and_detectors_with_same_trigger"] +__all__ = [ + "prepare_static_seq_table_flyer_and_detectors_with_same_trigger", + "ensure_connected", +] diff --git a/src/ophyd_async/planstubs/ensure_connected.py b/src/ophyd_async/planstubs/ensure_connected.py new file mode 100644 index 0000000000..cb4b5caa5f --- /dev/null +++ b/src/ophyd_async/planstubs/ensure_connected.py @@ -0,0 +1,17 @@ +import bluesky.plan_stubs as bps + +from ophyd_async.core.device import Device +from ophyd_async.core.utils import DEFAULT_TIMEOUT, wait_for_connection + + +def ensure_connected(*devices: Device, + sim: bool = False, + timeout: float = DEFAULT_TIMEOUT, + force_reconnect=False): + yield from bps.wait_for([lambda : + wait_for_connection(**{ + device.name: + device.connect(sim, timeout, force_reconnect) + for device in devices + }) + ]) diff --git a/tests/core/test_device.py b/tests/core/test_device.py index 97184aed54..eebf1b265b 100644 --- a/tests/core/test_device.py +++ b/tests/core/test_device.py @@ -11,6 +11,10 @@ NotConnected, wait_for_connection, ) +from ophyd_async.epics.motion import motor +from ophyd_async.planstubs.ensure_connected import ( + ensure_connected, +) class DummyBaseDevice(Device): @@ -117,3 +121,57 @@ async def test_device_log_has_correct_name(): assert device.log.extra["ophyd_async_device_name"] == "" device.set_name("device") assert device.log.extra["ophyd_async_device_name"] == "device" + + +async def test_device_lazily_connects(RE): + async with DeviceCollector(sim=True, connect=False): + sim_motor = motor.Motor("BLxxI-MO-TABLE-01:X") + + assert sim_motor._connect_task is None + + # When ready to connect + RE(ensure_connected(sim_motor, sim=True)) + + assert ( + sim_motor._connect_task + and sim_motor._connect_task.done() + and not sim_motor._connect_task.exception() + ) +class MotorBundle(Device): + def __init__(self, name: str) -> None: + self.X = motor.Motor("BLxxI-MO-TABLE-01:X") + self.Y = motor.Motor("BLxxI-MO-TABLE-01:Y") + self.V : DeviceVector[motor.Motor] = DeviceVector( + {0: motor.Motor("BLxxI-MO-TABLE-21:X"),1: motor.Motor("BLxxI-MO-TABLE-21:Y"),2: motor.Motor("BLxxI-MO-TABLE-21:Z")} + ) + + +async def test_device_with_children_lazily_connects(RE): + parentMotor = MotorBundle("parentMotor") + + assert parentMotor._connect_task is None and parentMotor.X._connect_task is None and parentMotor.Y._connect_task is None + + + RE(ensure_connected(parentMotor, sim=True)) + + assert ( + parentMotor.X._connect_task + and parentMotor.X._connect_task.done() + and not parentMotor.X._connect_task.exception() + ) + assert ( + parentMotor.Y._connect_task + and parentMotor.Y._connect_task.done() + and not parentMotor.Y._connect_task.exception() + ) + assert ( + parentMotor._connect_task + and parentMotor._connect_task.done() + and not parentMotor._connect_task.exception() + ) + for motor in parentMotor.V.values(): + assert ( + motor._connect_task + and motor._connect_task.done() + and not motor._connect_task.exception() + ) \ No newline at end of file