Skip to content

Commit

Permalink
fix typing errors with numpy 2.1
Browse files Browse the repository at this point in the history
  • Loading branch information
jorenham committed Aug 19, 2024
1 parent f31465c commit 04f4725
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 19 deletions.
5 changes: 4 additions & 1 deletion optype/numpy/_any_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,10 @@ def __getitem__(self, i: int, /) -> _T_co | _PyArrray[_T_co]: ...
AnyObjectArray: Alias = _Any2[np.object_, _ct.Object]

if _x.NP2 and not _x.NP20: # `numpy>=2.1`
AnyStringArray: Alias = _a.CanArray[tuple[int, ...], np.dtypes.StringDType]
AnyStringArray: Alias = _a.CanArray[ # type: ignore[type-var]
tuple[int, ...],
np.dtypes.StringDType, # pyright: ignore[reportInvalidTypeArguments]
]
elif _x.NP2: # `numpy>=2,<2.1`
AnyStringArray: Alias = _a.CanArray[tuple[int, ...], np.dtype[Never]]
else: # `numpy<2`
Expand Down
10 changes: 4 additions & 6 deletions optype/numpy/_any_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,13 +456,11 @@
# I (@jorenham) added them, see https://github.com/numpy/numpy/pull/27008).
if not _x.NP20:
# `numpy>=2.1`
AnyStringDType: Alias = _dt.HasDType[np.dtypes.StringDType] | _StringCode
_HasStringDType: Alias = _dt.HasDType[np.dtypes.StringDType] # type: ignore[type-var] # pyright: ignore[reportInvalidTypeArguments]

AnyDType: Alias = (
_Any2[np.generic, object]
| _dt.HasDType[np.dtypes.StringDType]
| LiteralString
)
AnyStringDType: Alias = _HasStringDType | _StringCode # type: ignore[type-var]

AnyDType: Alias = _Any2[np.generic, object] | _HasStringDType | LiteralString
else:
AnyStringDType: Alias = np.dtype[Never] | _StringCode

Expand Down
2 changes: 1 addition & 1 deletion optype/numpy/_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
@runtime_checkable
class CanArray(Protocol[_ShapeT_co, _DT_co]):
@overload
def __array__(self, dtype: None = ..., /) -> Array[_ShapeT_co, _DT_co]: ...
def __array__(self, dtype: None = ..., /) -> np.ndarray[_ShapeT_co, _DT_co]: ...
@overload
def __array__(self, dtype: _DT, /) -> np.ndarray[_ShapeT_co, _DT]: ...
else:
Expand Down
17 changes: 11 additions & 6 deletions optype/numpy/_ufunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,20 @@ def ntypes(self, /) -> int: ...
def types(self, /) -> list[LiteralString]: ...

# raises `ValueError` i.f.f. `nout != 1 or bool(signature)`
def at(self, /, *args: object, **kw: object) -> None: ...
@property
def at(self, /) -> CanCall[..., None]: ...
# raises `ValueError` i.f.f. `nin != 2 or nout != 1 or bool(signature)`
def reduce(self, /, *args: object, **kw: object) -> object: ...
@property
def reduce(self, /) -> CanCall[..., object]: ...
# raises `ValueError` i.f.f. `nin != 2 or nout != 1 or bool(signature)`
def reduceat(self, /, *args: object, **kw: object) -> _AnyArray: ...
@property
def reduceat(self, /) -> CanCall[..., _AnyArray]: ...
# raises `ValueError` i.f.f. `nin != 2 or nout != 1 or bool(signature)`
def accumulate(self, /, *args: object, **kw: object) -> _AnyArray: ...
@property
def accumulate(self, /) -> CanCall[..., _AnyArray]: ...
# raises `ValueError` i.f.f. `nin != 2 or nout != 1 or bool(signature)`
def outer(self, /, *args: object, **kw: object) -> object: ...
@property
def outer(self, /) -> CanCall[..., object]: ...

else:
# `numpy<2.1`
Expand Down Expand Up @@ -177,7 +182,7 @@ def outer(self, /) -> CanCall[..., object] | None: ...


_MethodCommon: Alias = L['__call__', 'reduce', 'reduceat', 'accumulate', 'outer']
if _x.NP2:
if _x.NP2: # type: ignore[redundant-expr]
_Method: Alias = L[_MethodCommon, 'at']
else:
_Method: Alias = L[_MethodCommon, 'inner']
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ venv = ".venv"
pythonVersion = "3.10"
pythonPlatform = "All"
typeCheckingMode = "all"
# defineConstant = {NP2 = true, NP20 = false}
defineConstant = {NP2 = true, NP20 = false}

reportUnusedCallResult = false # https://github.com/microsoft/pyright/issues/8650
Expand Down
8 changes: 4 additions & 4 deletions tests/numpy/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,17 @@ def test_can_array() -> None:
assert isinstance(scalar, onp.CanArray)
assert not isinstance(42, onp.CanArray)

x_0d: onp.CanArray[_Shape0D, np.dtypes.UInt8DType] = np.array(42, sct)
x_0d: onp.CanArray[_Shape0D, np.dtype[np.uint8]] = np.array(42, sct)
assert isinstance(x_0d, onp.CanArray)

x_1d: onp.CanArray[_Shape1D, np.dtypes.UInt8DType] = np.array([42], sct)
x_1d: onp.CanArray[_Shape1D, np.dtype[np.uint8]] = np.array([42], sct)
assert isinstance(x_1d, onp.CanArray)
assert not isinstance([42], onp.CanArray)

x_2d: onp.CanArray[_Shape2D, np.dtypes.UInt8DType] = np.array([[42]], sct)
x_2d: onp.CanArray[_Shape2D, np.dtype[np.uint8]] = np.array([[42]], sct)
assert isinstance(x_2d, onp.CanArray)

mat: onp.CanArray[_Shape2D, np.dtypes.UInt8DType] = np.asmatrix(42, sct)
mat: onp.CanArray[_Shape2D, np.dtype[np.uint8]] = np.asmatrix(42, sct)
assert isinstance(mat, onp.CanArray)


Expand Down

0 comments on commit 04f4725

Please sign in to comment.