Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update core and tango tests to match structure of epics tests #723

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
Draft
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ forbidden_modules = ["ophyd_async.testing", "ophyd_async.sim"]
source_modules = [
"ophyd_async.plan_stubs",
"ophyd_async.fast.*",
"ophyd_async.epics",
"ophyd_async.tango",
"ophyd_async.epics.*",
"ophyd_async.tango.*",
]
ignore_imports = ["ophyd_async.tango.testing.* -> ophyd_async.testing"]
70 changes: 66 additions & 4 deletions src/ophyd_async/tango/core/_tango_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from abc import abstractmethod
from collections.abc import Callable, Coroutine
from enum import Enum
from typing import Any, TypeVar, cast
from typing import Any, Generic, TypeVar, cast

import numpy as np
from bluesky.protocols import Descriptor, Reading
Expand All @@ -15,6 +15,7 @@
NotConnected,
SignalBackend,
SignalDatatypeT,
StrictEnum,
get_dtype,
get_unique,
wait_for_connection,
Expand Down Expand Up @@ -617,6 +618,60 @@ async def get_tango_trl(
raise RuntimeError(f"{trl_name} cannot be found in {device_proxy.name()}")


class TangoConverter(Generic[SignalDatatypeT]):
def write_value(self, value: Any) -> Any:
return value

def value(self, value: Any) -> Any:
return value


class TangoEnumConverter(TangoConverter):
def __init__(self, labels: list[str]):
self._labels = labels

def write_value(self, value: str):
if not isinstance(value, str):
raise TypeError("TangoEnumConverter expects str value")
return self._labels.index(value)

def value(self, value: int):
return self._labels[value]


class TangoEnumSpectrumConverter(TangoEnumConverter):
def write_value(self, value: np.ndarray[Any, str | StrictEnum]):
# should return array of ints
return np.array([self._labels.index(v) for v in value])

def value(self, value: np.ndarray[Any, int]):
# should return array of strs
return np.array([self._labels[v] for v in value])


class TangoEnumImageConverter(TangoEnumConverter):
def write_value(self, value: np.ndarray[Any, str | StrictEnum]):
# should return array of ints
return np.vstack([[self._labels.index(v) for v in row] for row in value])

def value(self, value: np.ndarray[Any, int]):
# should return array of strs
return np.vstack([[self._labels[v] for v in row] for row in value])


def make_converter(info: AttributeInfoEx | CommandInfo) -> TangoConverter:
if isinstance(info, AttributeInfoEx):
if info.enum_labels: # enum_labels should be discarded for non enum types
if info.data_format == AttrDataFormat.SCALAR:
return TangoEnumConverter(list(info.enum_labels))
elif info.data_format == AttrDataFormat.SPECTRUM:
return TangoEnumSpectrumConverter(list(info.enum_labels))
elif info.data_format == AttrDataFormat.IMAGE:
return TangoEnumImageConverter(list(info.enum_labels))
# default case return trivial converter
return TangoConverter()


class TangoSignalBackend(SignalBackend[SignalDatatypeT]):
def __init__(
self,
Expand All @@ -642,6 +697,7 @@ def __init__(
)
self.support_events: bool = True
self.status: AsyncStatus | None = None
self.converter = TangoConverter() # gets replaced at connect
super().__init__(datatype)

@classmethod
Expand Down Expand Up @@ -687,11 +743,13 @@ async def connect(self, timeout: float) -> None:
# The same, so only need to connect one
await self._connect_and_store_config(self.read_trl, timeout)
self.proxies[self.read_trl].set_polling(*self._polling) # type: ignore
self.converter = make_converter(self.trl_configs[self.read_trl])
self.descriptor = get_trl_descriptor(
self.datatype, self.read_trl, self.trl_configs
)

async def put(self, value: SignalDatatypeT | None, wait=True, timeout=None) -> None:
value = self.converter.write_value(value)
if self.proxies[self.write_trl] is None:
raise NotConnected(f"Not connected to {self.write_trl}")
self.status = None
Expand All @@ -704,23 +762,27 @@ async def get_datakey(self, source: str) -> Descriptor:
async def get_reading(self) -> Reading[SignalDatatypeT]:
if self.proxies[self.read_trl] is None:
raise NotConnected(f"Not connected to {self.read_trl}")
return await self.proxies[self.read_trl].get_reading() # type: ignore
reading = await self.proxies[self.read_trl].get_reading() # type: ignore
reading["value"] = self.converter.value(reading["value"])
return reading

async def get_value(self) -> SignalDatatypeT:
if self.proxies[self.read_trl] is None:
raise NotConnected(f"Not connected to {self.read_trl}")
proxy = self.proxies[self.read_trl]
if proxy is None:
raise NotConnected(f"Not connected to {self.read_trl}")
return cast(SignalDatatypeT, await proxy.get())
value = await proxy.get()
return cast(SignalDatatypeT, self.converter.value(value))

async def get_setpoint(self) -> SignalDatatypeT:
if self.proxies[self.write_trl] is None:
raise NotConnected(f"Not connected to {self.write_trl}")
proxy = self.proxies[self.write_trl]
if proxy is None:
raise NotConnected(f"Not connected to {self.write_trl}")
return cast(SignalDatatypeT, await proxy.get_w_value())
w_value = await proxy.get_w_value()
return cast(SignalDatatypeT, self.converter.value(w_value))

def set_callback(self, callback: Callback | None) -> None:
if self.proxies[self.read_trl] is None:
Expand Down
3 changes: 3 additions & 0 deletions src/ophyd_async/tango/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from ._one_of_everything import ExampleStrEnum, OneOfEverythingTangoDevice

__all__ = ["ExampleStrEnum", "OneOfEverythingTangoDevice"]
141 changes: 141 additions & 0 deletions src/ophyd_async/tango/testing/_one_of_everything.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import numpy as np

from ophyd_async.core import (
DTypeScalar_co,
StrictEnum,
)
from ophyd_async.testing import float_array_value, int_array_value
from tango import AttrDataFormat, AttrWriteType
from tango.server import Device, attribute


class ExampleStrEnum(StrictEnum):
A = "AAA"
B = "BBB"
C = "CCC"


def int_image_value(
dtype: type[DTypeScalar_co],
):
# how do we type this?
array_1d = int_array_value(dtype)
return np.vstack((array_1d, array_1d))


def float_image_value(
dtype: type[DTypeScalar_co],
):
# how do we type this?
array_1d = float_array_value(dtype)
return np.vstack((array_1d, array_1d))


_dtypes = {
"str": "DevString",
"bool": "DevBoolean",
"enum": "DevEnum",
"strenum": "DevEnum",
"int8": "DevShort",
"uint8": "DevUChar",
"int16": "DevShort",
"uint16": "DevUShort",
"int32": "DevLong",
"uint32": "DevULong",
"int64": "DevLong64",
"uint64": "DevULong64",
"float32": "DevFloat",
"float64": "DevDouble",
}

_initial_values = {
AttrDataFormat.SCALAR: {
"str": "test_string",
"bool": True,
"strenum": 1, # Tango devices must use ints for enums
"int8": 1,
"uint8": 1,
"int16": 1,
"uint16": 1,
"int32": 1,
"uint32": 1,
"int64": 1,
"uint64": 1,
"float32": 1.234,
"float64": 1.234,
},
AttrDataFormat.SPECTRUM: {
"str": ["one", "two", "three"],
"bool": [False, True],
"strenum": [0, 1, 2], # Tango devices must use ints for enums
"int8": int_array_value(np.int8),
"uint8": int_array_value(np.uint8),
"int16": int_array_value(np.int16),
"uint16": int_array_value(np.uint16),
"int32": int_array_value(np.int32),
"uint32": int_array_value(np.uint32),
"int64": int_array_value(np.int64),
"uint64": int_array_value(np.uint64),
"float32": float_array_value(np.float32),
"float64": float_array_value(np.float64),
},
AttrDataFormat.IMAGE: {
"str": np.array([["one", "two", "three"], ["one", "two", "three"]]),
"bool": np.array([[False, True], [False, True]]),
"strenum": np.array(
[[0, 1, 2], [0, 1, 2]]
), # Tango devices must use ints for enums
"int8": int_image_value(np.int8),
"uint8": int_image_value(np.uint8),
"int16": int_image_value(np.int16),
"uint16": int_image_value(np.uint16),
"int32": int_image_value(np.int32),
"uint32": int_image_value(np.uint32),
"int64": int_image_value(np.int64),
"uint64": int_image_value(np.uint64),
"float32": float_image_value(np.float32),
"float64": float_image_value(np.float64),
},
}


class OneOfEverythingTangoDevice(Device):
attr_values = {}

def initialize_dynamic_attributes(self):
for dformat, initial_values in _initial_values.items():
if dformat == AttrDataFormat.SPECTRUM:
suffix = "_spectrum"
elif dformat == AttrDataFormat.IMAGE:
suffix = "_image"
else:
suffix = "" # scalar
for prefix, value in initial_values.items():
name = prefix + suffix
self.attr_values[name] = value
if prefix == "strenum":
labels = [e.value for e in ExampleStrEnum]
else:
labels = []
attr = attribute(
name=name,
dtype=_dtypes[prefix],
dformat=dformat,
access=AttrWriteType.READ_WRITE,
fget=self.read,
fset=self.write,
max_dim_x=100,
max_dim_y=2,
enum_labels=labels,
)
self.add_attribute(attr)
self.set_change_event(name, True, False)

def read(self, attr):
value = self.attr_values[attr.get_name()]
attr.set_value(value) # fails with enums

def write(self, attr):
new_value = attr.get_write_value()
self.attr_values[attr.get_name()] = new_value
self.push_change_event(attr.get_name(), new_value)
4 changes: 4 additions & 0 deletions src/ophyd_async/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
ExampleTable,
OneOfEverythingDevice,
ParentOfEverythingDevice,
float_array_value,
int_array_value,
)
from ._wait_for_pending import wait_for_pending_wakeups

Expand All @@ -49,4 +51,6 @@
"ParentOfEverythingDevice",
"MonitorQueue",
"ApproxTable",
"int_array_value",
"float_array_value",
]
45 changes: 21 additions & 24 deletions src/ophyd_async/testing/_assert.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import asyncio
import time
from contextlib import AbstractContextManager
from typing import Any
from unittest.mock import ANY

import pytest
from bluesky.protocols import Reading
from event_model import DataKey

from ophyd_async.core import (
Expand Down Expand Up @@ -41,7 +40,8 @@ async def assert_value(signal: SignalR[SignalDatatypeT], value: Any) -> None:


async def assert_reading(
readable: AsyncReadable, expected_reading: dict[str, Reading]
readable: AsyncReadable,
expected_reading: dict[str, dict[str, Any]],
) -> None:
"""Assert readings from readable.

Expand All @@ -60,16 +60,25 @@ async def assert_reading(

"""
actual_reading = await readable.read()
_assert_readings_approx_equal(expected_reading, actual_reading)


def _assert_readings_approx_equal(expected, actual):
approx_expected_reading = {
k: dict(v, value=approx_value(expected_reading[k]["value"]))
for k, v in expected_reading.items()
k: dict(
v,
value=approx_value(expected[k]["value"]),
timestamp=pytest.approx(expected[k].get("timestamp", ANY), rel=0.1),
alarm_severity=pytest.approx(expected[k].get("alarm_severity", ANY)),
)
for k, v in expected.items()
}
assert actual_reading == approx_expected_reading
assert actual == approx_expected_reading


async def assert_configuration(
configurable: AsyncConfigurable,
configuration: dict[str, Reading],
configuration: dict[str, dict[str, Any]],
) -> None:
"""Assert readings from Configurable.

Expand All @@ -88,11 +97,7 @@ async def assert_configuration(

"""
actual_configuration = await configurable.read_configuration()
approx_expected_configuration = {
k: dict(v, value=approx_value(configuration[k]["value"]))
for k, v in configuration.items()
}
assert actual_configuration == approx_expected_configuration
_assert_readings_approx_equal(configuration, actual_configuration)


async def assert_describe_signal(signal: SignalR, /, **metadata):
Expand Down Expand Up @@ -146,27 +151,19 @@ def __eq__(self, value):
class MonitorQueue(AbstractContextManager):
def __init__(self, signal: SignalR):
self.signal = signal
self.updates: asyncio.Queue[dict[str, Reading]] = asyncio.Queue()
self.signal.subscribe(self.updates.put_nowait)
self.updates: asyncio.Queue[dict[str, dict[str, Any]]] = asyncio.Queue()

async def assert_updates(self, expected_value):
# Get an update, value and reading
expected_type = type(expected_value)
expected_value = approx_value(expected_value)
update = await self.updates.get()
value = await self.signal.get_value()
reading = await self.signal.read()
# Check they match what we expected
assert value == expected_value
assert type(value) is expected_type
await assert_value(self.signal, expected_value)
expected_reading = {
self.signal.name: {
"value": expected_value,
"timestamp": pytest.approx(time.time(), rel=0.1),
"alarm_severity": 0,
}
}
assert reading == update == expected_reading
await assert_reading(self.signal, expected_reading)
_assert_readings_approx_equal(expected_reading, update)

def __enter__(self):
self.signal.subscribe(self.updates.put_nowait)
Expand Down
Loading
Loading