Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Widen Array1D to fix changes to numpy shape #736

Merged
merged 2 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/ophyd_async/core/_signal_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from ._utils import Callback, StrictEnum, T

DTypeScalar_co = TypeVar("DTypeScalar_co", covariant=True, bound=np.generic)
Array1D = np.ndarray[tuple[int], np.dtype[DTypeScalar_co]]
# To be a 1D array shape should really be tuple[int], but np.array()
# currently produces tuple[int, ...] even when it has 1D input args
# https://github.com/numpy/numpy/issues/28077#issuecomment-2566485178
Array1D = np.ndarray[tuple[int, ...], np.dtype[DTypeScalar_co]]
Primitive = bool | int | float | str
# NOTE: if you change this union then update the docs to match
SignalDatatype = (
Expand Down
2 changes: 1 addition & 1 deletion src/ophyd_async/core/_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def numpy_table(self, selection: slice | None = None) -> np.ndarray:
v = v[selection]
if array is None:
array = np.empty(v.shape, dtype=self.numpy_dtype())
array[k] = v
array[k] = v # type: ignore
if array is None:
msg = "No arrays found in table"
raise ValueError(msg)
Expand Down
2 changes: 1 addition & 1 deletion src/ophyd_async/epics/core/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def get_supported_values(


def format_datatype(datatype: Any) -> str:
if get_origin(datatype) is np.ndarray and get_args(datatype)[0] == tuple[int]:
if get_origin(datatype) is np.ndarray and get_args(datatype):
dtype = get_dtype(datatype)
return f"Array1D[np.{dtype.name}]"
elif get_origin(datatype) is Sequence:
Expand Down
4 changes: 3 additions & 1 deletion tests/core/test_soft_signal_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,10 @@ async def test_soft_signal_backend_set_callback():
async def test_soft_signal_backend_with_numpy_typing():
soft_backend = SoftSignalBackend(Array1D[np.float64])
await soft_backend.connect(timeout=1)
await soft_backend.put(np.array([1, 2]), wait=True)
array = await soft_backend.get_value()
assert array.shape == (0,)
assert array.shape == (2,)
assert array[0] == 1


async def test_soft_signal_descriptor_fails_for_invalid_class():
Expand Down
4 changes: 4 additions & 0 deletions tests/core/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ class MyTable(Table):
{"bool": [0, 1], "uint": [3, 4], "str": [44, ""]},
"Input should be a valid string [type=string_type, input_value=44,",
),
(
{"bool": [0, 1], "uint": [[3], [4]], "str": ["", ""]},
"Array 2-dimensional; the target dimensions is 1",
),
],
)
def test_table_wrong_types(kwargs, error_msg):
Expand Down
14 changes: 14 additions & 0 deletions tests/epics/signal/test_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import bluesky.plan_stubs as bps
import numpy as np
import numpy.typing as npt
import pytest
import yaml
from aioca import purge_channel_caches
Expand Down Expand Up @@ -38,6 +39,7 @@
epics_signal_w,
epics_signal_x,
)
from ophyd_async.epics.core._util import format_datatype # noqa: PLC2701
from ophyd_async.epics.testing import (
EpicsTestEnum,
EpicsTestIocAndDevices,
Expand Down Expand Up @@ -567,6 +569,18 @@ async def test_non_existent_errors(
await signal.connect(timeout=0.1)


@pytest.mark.parametrize(
"dt,expected",
[
(Array1D[np.int32], "Array1D[np.int32]"),
(np.ndarray, "ndarray"),
(npt.NDArray[np.float64], "Array1D[np.float64]"),
],
)
def test_format_error_message(dt, expected):
assert format_datatype(dt) == expected


def test_make_backend_fails_for_different_transports():
read_pv = "test"
write_pv = "pva://test"
Expand Down
Loading