Skip to content

Commit

Permalink
Merge pull request #138 from jorenham/feature/numpy-string-dtype
Browse files Browse the repository at this point in the history
Support for `numpy.StringDType`
  • Loading branch information
jorenham authored Aug 11, 2024
2 parents 85c4ac7 + 3543edf commit de361da
Show file tree
Hide file tree
Showing 9 changed files with 58 additions and 41 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ jobs:

- name: basedpyright --verifytypes
run: poetry run basedpyright --ignoreexternal --verifytypes optype
continue-on-error: true # TODO: remove after NumPy 2.1 is released

- name: markdownlint
uses: DavidAnson/markdownlint-cli2-action@v16
Expand Down
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2856,7 +2856,7 @@ that depends on the specific `np.dtype` instantiation.
</tr>
<tr>
<th><code>bool_</code></th>
<td rowspan="4"><code>generic</code></td>
<td rowspan="5"><code>generic</code></td>
<td><code>AnyBoolArray</code></td>
<td><code>AnyBoolDType</code></td>
</tr>
Expand All @@ -2875,6 +2875,11 @@ that depends on the specific `np.dtype` instantiation.
<td><code>AnyObjectArray</code></td>
<td><code>AnyObjectDType</code></td>
</tr>
<tr>
<td>missing</td>
<td><code>AnyStringArray</code></td>
<td><code>AnyStringDType</code></td>
</tr>
</table>
#### Low-level interfaces
Expand Down
12 changes: 5 additions & 7 deletions optype/numpy/_any_array.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import sys
from typing import Final, TypeAlias as _Type
from typing import TypeAlias as _Type

import numpy as np

Expand Down Expand Up @@ -84,9 +84,7 @@
'AnyTimeDelta64Array',
'AnyObjectArray',
]


_NP_V2: Final[bool] = np.__version__.startswith('2.')
__all__ += ['AnyStringArray']


T_co = TypeVar('T_co', covariant=True, bound=object)
Expand Down Expand Up @@ -166,14 +164,14 @@ def __getitem__(self, i: int, /) -> T_co | _PyArray[T_co]: ...
AnyInt16Array: _Type = _AnyNP[ND, np.int16] | _ct.Int16
AnyInt32Array: _Type = _AnyNP[ND, np.int32] | _ct.Int32
AnyInt64Array: _Type = _AnyNP[ND, np.int64] | _ct.Int64
if _NP_V2:
if _x.NP2:
AnyIntPArray: _Type = _Any4[ND, np.int64, int, _ct.IntP]
else:
AnyIntPArray: _Type = _AnyNP[ND, np.int64] | _ct.IntP
AnyByteArray: _Type = _AnyNP[ND, np.byte] | _ct.Byte
AnyShortArray: _Type = _AnyNP[ND, np.short] | _ct.Short
AnyIntCArray: _Type = _AnyNP[ND, np.intc] | _ct.IntC
if _NP_V2:
if _x.NP2:
AnyLongArray: _Type = _AnyNP[ND, _x.Long] | _ct.Long
else:
AnyLongArray: _Type = _Any4[ND, _x.Long, int, _ct.Long]
Expand Down Expand Up @@ -244,4 +242,4 @@ def __getitem__(self, i: int, /) -> T_co | _PyArray[T_co]: ...
AnyObjectArray: _Type = _Any4[ND, np.object_, object, _ct.Object]

# generic :> {StringDType.type}
# TODO
AnyStringArray: _Type = np.ndarray[ND, _x.StringDType] # type: ignore[type-var]
17 changes: 9 additions & 8 deletions optype/numpy/_any_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from __future__ import annotations

import sys
from typing import Any, Final, Literal as _Lit, TypeAlias as _Type
from typing import Any, Literal as _Lit, TypeAlias as _Type

import numpy as np
import numpy.typing as npt
Expand All @@ -15,6 +15,8 @@
import optype.numpy._ctype as _ct
import optype.numpy._dtype as _dt

from ._compat import StringDType as AnyStringDType


if sys.version_info >= (3, 13):
from typing import TypeVar
Expand Down Expand Up @@ -90,8 +92,7 @@
'AnyTimeDelta64DType',
'AnyObjectDType',
]

_NP_V2: Final = np.__version__.startswith('2.')
__all__ += ['AnyStringDType']


# helper aliases
Expand Down Expand Up @@ -134,7 +135,7 @@
AnyUInt64DType: _Type = _Any2[np.uint64, _s.AnyUInt64] | _UInt64Code

# uintp (assuming that `uint_ptr_t == size_t`, as done in `numpy.typing`)
if _NP_V2:
if _x.NP2:
_UIntPName: _Type = _Lit['uintp', 'uint']
_UIntPChar: _Type = _Lit['N', '|N', '=N', '<N', '>N']
_UIntPCode: _Type = _UIntPName | _UIntPChar
Expand Down Expand Up @@ -171,7 +172,7 @@

# ulong (uint if numpy<2)
_ULongChar: _Type = _Lit['L', '|L', '=L', '<L', '>L']
if _NP_V2:
if _x.NP2:
_ULongName: _Type = _Lit['ulong']
_ULongCode: _Type = _ULongName | _ULongChar
AnyULongDType: _Type = _Any2[_x.ULong, _s.AnyULong] | _ULongCode
Expand Down Expand Up @@ -220,7 +221,7 @@

# intp
# (`AnyIntPDType` must be inside each block, for valid typing)
if _NP_V2:
if _x.NP2:
_IntPName: _Type = _Lit['intp', 'int', 'int_']
_IntPChar: _Type = _Lit['n', '|n', '=n', '<n', '>n']
_IntPCode: _Type = _IntPName | _IntPChar
Expand Down Expand Up @@ -251,7 +252,7 @@

# long (or int_ if numpy<2)
_LongChar: _Type = _Lit['l', '|l', '=l', '<l', '>l']
if _NP_V2:
if _x.NP2:
_LongName: _Type = _Lit['long']
_LongCode: _Type = _LongName | _LongChar
AnyLongDType: _Type = _Any2[_x.Long, _s.AnyLong] | _LongCode
Expand Down Expand Up @@ -552,7 +553,7 @@
) # fmt: skip

# this duplicated mess is needed for valid types and numpy 1/2 compat
if _NP_V2:
if _x.NP2:
_UnsignedIntegerChar: _Type = (
_UnsignedIntegerCharCommon
| _Lit['N', '|N', '=N', '<N', '>N'] # numpy>=2 only
Expand Down
9 changes: 3 additions & 6 deletions optype/numpy/_any_scalar.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import datetime as dt
from typing import Any, Final, TypeAlias as _Type
from typing import Any, TypeAlias as _Type

import numpy as np

Expand Down Expand Up @@ -62,9 +62,6 @@
)


_NP_V2: Final[bool] = np.__version__.startswith('2.')


# integer - unsigned
AnyUInt8: _Type = np.uint8 | _ct.UInt8
AnyUInt16: _Type = np.uint16 | _ct.UInt16
Expand All @@ -84,14 +81,14 @@
AnyInt16: _Type = np.int16 | _ct.Int16
AnyInt32: _Type = np.int32 | _ct.Int32
AnyInt64: _Type = np.int64 | _ct.Int64
if _NP_V2:
if _x.NP2:
AnyIntP: _Type = int | np.intp | _ct.IntP
else:
AnyIntP: _Type = np.intp | _ct.IntP
AnyByte: _Type = np.byte | _ct.Byte
AnyShort: _Type = np.short | _ct.Short
AnyIntC: _Type = np.intc | _ct.IntC
if _NP_V2:
if _x.NP2:
AnyLong: _Type = _x.Long | _ct.Long
else:
AnyLong: _Type = int | _x.Long | _ct.Long
Expand Down
13 changes: 4 additions & 9 deletions optype/numpy/_array.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import sys
from typing import TYPE_CHECKING, Any, Final, TypeAlias
from typing import TYPE_CHECKING, Any, TypeAlias

import numpy as np

import optype.numpy._compat as _x


if sys.version_info >= (3, 13):
from typing import (
Expand Down Expand Up @@ -37,13 +39,6 @@
]


_NP_VERSION: Final = np.__version__
_NP_V2: Final = _NP_VERSION.startswith('2.')

if not _NP_V2:
assert _NP_VERSION.startswith('1.'), f'numpy {_NP_VERSION} is unsupported'


_AnyShape: TypeAlias = tuple[int, ...]

_ShapeT = TypeVar('_ShapeT', bound=_AnyShape, default=_AnyShape)
Expand Down Expand Up @@ -93,7 +88,7 @@ def __array_finalize__(self, obj: _T_contra, /) -> None: ...

@runtime_checkable
class CanArrayWrap(Protocol):
if _NP_V2:
if _x.NP2:
def __array_wrap__(
self,
array: np.ndarray[_ShapeT, _DT],
Expand Down
26 changes: 23 additions & 3 deletions optype/numpy/_compat.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,41 @@
from __future__ import annotations

import sys
from typing import Final, TypeAlias

import numpy as np


__all__ = ['Bool', 'Long', 'ULong']
__all__ = [
'NP2',
'NP20',
'NP21',
'Bool',
'Long',
'StringDType',
'ULong',
]


_NP_V2: Final[bool] = np.__version__.startswith('2.')
NP2: Final[bool] = np.__version__.startswith('2.')
NP20: Final[bool] = np.__version__.startswith('2.0')
NP21: Final[bool] = np.__version__.startswith('2.1')


if _NP_V2:
if NP2:
Bool: TypeAlias = np.bool
ULong: TypeAlias = np.ulong
Long: TypeAlias = np.long
else:
Bool: TypeAlias = np.bool_
ULong: TypeAlias = np.uint
Long: TypeAlias = np.int_

if NP21:
StringDType: TypeAlias = np.dtypes.StringDType
elif NP2:
StringDType: TypeAlias = np.dtype[str] # type: ignore[type-var] # pyright: ignore[reportInvalidTypeArguments]
elif sys.version_info >= (3, 13):
from typing import Never as StringDType
else:
from typing_extensions import Never as StringDType
9 changes: 4 additions & 5 deletions optype/numpy/_ufunc.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import sys
from typing import TYPE_CHECKING, Any, Final, Literal, TypeAlias as _Type
from typing import TYPE_CHECKING, Any, Literal, TypeAlias as _Type

import numpy as np

import optype.numpy._compat as _x


if sys.version_info >= (3, 13):
from typing import Protocol, TypeVar, runtime_checkable
Expand All @@ -22,9 +24,6 @@
__all__ = ['CanArrayFunction', 'CanArrayUFunc', 'UFunc']


_NP_V2: Final[bool] = np.__version__.startswith('2.')


_FT_co = TypeVar(
'_FT_co',
bound='CanCall[..., Any]',
Expand Down Expand Up @@ -139,7 +138,7 @@ def outer(self, /) -> CanCall[..., Any] | None: ...
'accumulate',
'outer',
]
if _NP_V2:
if _x.NP2:
_UFuncMethod: _Type = _UFuncMethodCommon | Literal['at']
else:
_UFuncMethod: _Type = _UFuncMethodCommon | Literal['inner']
Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ venv = ".venv"
pythonVersion = "3.10"
pythonPlatform = "All"
typeCheckingMode = "all"
defineConstant = {"_NP_V2" = true, "byteorder" = "little"}
defineConstant = {NP2 = true, NP20 = true, NP21 = false}

reportAny = false # blame typeshed
reportUnusedCallResult = false # https://github.com/microsoft/pyright/issues/8650
Expand All @@ -89,7 +89,8 @@ reportUnusedVariable = false # dupe of F841

[tool.mypy]
python_version = "3.10"
always_true = "_NP_V2"
always_true = "NP2,NP20"
always_false = "NP21"
modules = ["optype"]
exclude = ["^.venv/.*", "^examples/.*", "^tests/.*"]
allow_redefinition = true
Expand Down

0 comments on commit de361da

Please sign in to comment.