From 04f472511929a1b28b9225164e5c63aade0cbd4d Mon Sep 17 00:00:00 2001 From: jorenham Date: Mon, 19 Aug 2024 02:33:16 +0200 Subject: [PATCH] fix typing errors with numpy 2.1 --- optype/numpy/_any_array.py | 5 ++++- optype/numpy/_any_dtype.py | 10 ++++------ optype/numpy/_array.py | 2 +- optype/numpy/_ufunc.py | 17 +++++++++++------ pyproject.toml | 1 - tests/numpy/test_array.py | 8 ++++---- 6 files changed, 24 insertions(+), 19 deletions(-) diff --git a/optype/numpy/_any_array.py b/optype/numpy/_any_array.py index 387e338..4f993b4 100644 --- a/optype/numpy/_any_array.py +++ b/optype/numpy/_any_array.py @@ -148,7 +148,10 @@ def __getitem__(self, i: int, /) -> _T_co | _PyArrray[_T_co]: ... AnyObjectArray: Alias = _Any2[np.object_, _ct.Object] if _x.NP2 and not _x.NP20: # `numpy>=2.1` - AnyStringArray: Alias = _a.CanArray[tuple[int, ...], np.dtypes.StringDType] + AnyStringArray: Alias = _a.CanArray[ # type: ignore[type-var] + tuple[int, ...], + np.dtypes.StringDType, # pyright: ignore[reportInvalidTypeArguments] + ] elif _x.NP2: # `numpy>=2,<2.1` AnyStringArray: Alias = _a.CanArray[tuple[int, ...], np.dtype[Never]] else: # `numpy<2` diff --git a/optype/numpy/_any_dtype.py b/optype/numpy/_any_dtype.py index b0dc6ee..f0de48f 100644 --- a/optype/numpy/_any_dtype.py +++ b/optype/numpy/_any_dtype.py @@ -456,13 +456,11 @@ # I (@jorenham) added them, see https://github.com/numpy/numpy/pull/27008). if not _x.NP20: # `numpy>=2.1` - AnyStringDType: Alias = _dt.HasDType[np.dtypes.StringDType] | _StringCode + _HasStringDType: Alias = _dt.HasDType[np.dtypes.StringDType] # type: ignore[type-var] # pyright: ignore[reportInvalidTypeArguments] - AnyDType: Alias = ( - _Any2[np.generic, object] - | _dt.HasDType[np.dtypes.StringDType] - | LiteralString - ) + AnyStringDType: Alias = _HasStringDType | _StringCode # type: ignore[type-var] + + AnyDType: Alias = _Any2[np.generic, object] | _HasStringDType | LiteralString else: AnyStringDType: Alias = np.dtype[Never] | _StringCode diff --git a/optype/numpy/_array.py b/optype/numpy/_array.py index 49af197..4c869b5 100644 --- a/optype/numpy/_array.py +++ b/optype/numpy/_array.py @@ -50,7 +50,7 @@ @runtime_checkable class CanArray(Protocol[_ShapeT_co, _DT_co]): @overload - def __array__(self, dtype: None = ..., /) -> Array[_ShapeT_co, _DT_co]: ... + def __array__(self, dtype: None = ..., /) -> np.ndarray[_ShapeT_co, _DT_co]: ... @overload def __array__(self, dtype: _DT, /) -> np.ndarray[_ShapeT_co, _DT]: ... else: diff --git a/optype/numpy/_ufunc.py b/optype/numpy/_ufunc.py index 6abb56b..498a0c3 100644 --- a/optype/numpy/_ufunc.py +++ b/optype/numpy/_ufunc.py @@ -119,15 +119,20 @@ def ntypes(self, /) -> int: ... def types(self, /) -> list[LiteralString]: ... # raises `ValueError` i.f.f. `nout != 1 or bool(signature)` - def at(self, /, *args: object, **kw: object) -> None: ... + @property + def at(self, /) -> CanCall[..., None]: ... # raises `ValueError` i.f.f. `nin != 2 or nout != 1 or bool(signature)` - def reduce(self, /, *args: object, **kw: object) -> object: ... + @property + def reduce(self, /) -> CanCall[..., object]: ... # raises `ValueError` i.f.f. `nin != 2 or nout != 1 or bool(signature)` - def reduceat(self, /, *args: object, **kw: object) -> _AnyArray: ... + @property + def reduceat(self, /) -> CanCall[..., _AnyArray]: ... # raises `ValueError` i.f.f. `nin != 2 or nout != 1 or bool(signature)` - def accumulate(self, /, *args: object, **kw: object) -> _AnyArray: ... + @property + def accumulate(self, /) -> CanCall[..., _AnyArray]: ... # raises `ValueError` i.f.f. `nin != 2 or nout != 1 or bool(signature)` - def outer(self, /, *args: object, **kw: object) -> object: ... + @property + def outer(self, /) -> CanCall[..., object]: ... else: # `numpy<2.1` @@ -177,7 +182,7 @@ def outer(self, /) -> CanCall[..., object] | None: ... _MethodCommon: Alias = L['__call__', 'reduce', 'reduceat', 'accumulate', 'outer'] -if _x.NP2: +if _x.NP2: # type: ignore[redundant-expr] _Method: Alias = L[_MethodCommon, 'at'] else: _Method: Alias = L[_MethodCommon, 'inner'] diff --git a/pyproject.toml b/pyproject.toml index bacce53..4aa2a3a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,7 +78,6 @@ venv = ".venv" pythonVersion = "3.10" pythonPlatform = "All" typeCheckingMode = "all" -# defineConstant = {NP2 = true, NP20 = false} defineConstant = {NP2 = true, NP20 = false} reportUnusedCallResult = false # https://github.com/microsoft/pyright/issues/8650 diff --git a/tests/numpy/test_array.py b/tests/numpy/test_array.py index 830b33d..1d6600d 100644 --- a/tests/numpy/test_array.py +++ b/tests/numpy/test_array.py @@ -24,17 +24,17 @@ def test_can_array() -> None: assert isinstance(scalar, onp.CanArray) assert not isinstance(42, onp.CanArray) - x_0d: onp.CanArray[_Shape0D, np.dtypes.UInt8DType] = np.array(42, sct) + x_0d: onp.CanArray[_Shape0D, np.dtype[np.uint8]] = np.array(42, sct) assert isinstance(x_0d, onp.CanArray) - x_1d: onp.CanArray[_Shape1D, np.dtypes.UInt8DType] = np.array([42], sct) + x_1d: onp.CanArray[_Shape1D, np.dtype[np.uint8]] = np.array([42], sct) assert isinstance(x_1d, onp.CanArray) assert not isinstance([42], onp.CanArray) - x_2d: onp.CanArray[_Shape2D, np.dtypes.UInt8DType] = np.array([[42]], sct) + x_2d: onp.CanArray[_Shape2D, np.dtype[np.uint8]] = np.array([[42]], sct) assert isinstance(x_2d, onp.CanArray) - mat: onp.CanArray[_Shape2D, np.dtypes.UInt8DType] = np.asmatrix(42, sct) + mat: onp.CanArray[_Shape2D, np.dtype[np.uint8]] = np.asmatrix(42, sct) assert isinstance(mat, onp.CanArray)