diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 20bcd5a..757a647 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -41,8 +41,8 @@ jobs: name: artifact path: dist - name: Publish package on TestPyPi - uses: pypa/gh-action-pypi-publish@ec4db0b4ddc65acdf4bff5fa45ac92d78b56bdf0 + uses: pypa/gh-action-pypi-publish@8a08d616893759ef8e1aa1f2785787c0b97e20d6 with: repository-url: https://test.pypi.org/legacy/ - name: Publish package on PyPi - uses: pypa/gh-action-pypi-publish@ec4db0b4ddc65acdf4bff5fa45ac92d78b56bdf0 + uses: pypa/gh-action-pypi-publish@8a08d616893759ef8e1aa1f2785787c0b97e20d6 diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 7f208b4..e12e1a2 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -7,9 +7,23 @@ Changelog ========= -0.9.0 (unreleased) + +0.9.0 (2024-08-30) ------------------ +**New features** + +- User defined data types can now define how arrays with that dtype are constructed by implementing the ``make_array`` function. +- User defined data types can now define how they are indexed (via ``__getitem__``) by implementing the ``getitem`` function. +- :class:`ndonnx.NullableCore` is now public, encapsulating nullable variants of `CoreType`s exported by ndonnx. + +**Bug fixes** + +- Various operations that depend on the array's shape have been updated to work correctly with lazy arrays. +- :func:`ndonnx.cumulative_sum` now correctly applies the ``include_initial`` parameter and works around missing onnxruntime kernels for unsigned integral types. +- :func:`ndonnx.additional.make_nullable` applies broadcasting to the provided null array (instead of reshape like it did previously). This allows writing ``make_nullable(x, False)`` to turn an array into nullable. +- User-defined data types that implement :class:`ndonnx._core.UniformShapeOperations` may now implement :func:`ndonnx.where` without requiring both data types be promotable. + **Breaking change** - Iterating over dynamic dimensions of :class:`~ndonnx.Array` is no longer allowed since it commonly lead to infinite loops when used without an explicit break condition. diff --git a/ndonnx/__init__.py b/ndonnx/__init__.py index 358b4b3..9587f99 100644 --- a/ndonnx/__init__.py +++ b/ndonnx/__init__.py @@ -14,6 +14,7 @@ Floating, Integral, Nullable, + NullableCore, NullableFloating, NullableIntegral, NullableNumerical, @@ -323,6 +324,7 @@ "Floating", "NullableIntegral", "Nullable", + "NullableCore", "Integral", "CoreType", "CastError", diff --git a/ndonnx/_array.py b/ndonnx/_array.py index 1a521dd..c06bf66 100644 --- a/ndonnx/_array.py +++ b/ndonnx/_array.py @@ -15,6 +15,7 @@ import ndonnx as ndx import ndonnx._data_types as dtypes from ndonnx.additional import shape +from ndonnx.additional._additional import _getitem as getitem from ndonnx.additional._additional import _static_shape as static_shape from ._corearray import _CoreArray @@ -47,7 +48,11 @@ def array( out : Array The new array. This represents an ONNX model input. """ - return Array._construct(shape=shape, dtype=dtype) + if (out := dtype._ops.make_array(shape, dtype)) is not NotImplemented: + return out + raise ndx.UnsupportedOperationError( + f"No implementation of `make_array` for {dtype}" + ) def from_spox_var( @@ -154,17 +159,7 @@ def astype(self, to: CoreType | StructType) -> Array: return ndx.astype(self, to) def __getitem__(self, index: IndexType) -> Array: - if isinstance(index, Array) and not ( - isinstance(index.dtype, dtypes.Integral) or index.dtype == dtypes.bool - ): - raise TypeError( - f"Index must be an integral or boolean 'Array', not `{index.dtype}`" - ) - - if isinstance(index, Array): - index = index._core() - - return self._transmute(lambda corearray: corearray[index]) + return getitem(self, index) def __setitem__( self, index: IndexType | Self, updates: int | bool | float | Array @@ -517,7 +512,7 @@ def size(self) -> ndx.Array: out: Array Scalar ``Array`` instance whose value is the number of elements in the original array. """ - return ndx.prod(self.shape) + return ndx.prod(shape(self)) @property def T(self) -> ndx.Array: # noqa: N802 diff --git a/ndonnx/_core/_boolimpl.py b/ndonnx/_core/_boolimpl.py index 9d7b946..fe56305 100644 --- a/ndonnx/_core/_boolimpl.py +++ b/ndonnx/_core/_boolimpl.py @@ -13,15 +13,16 @@ import ndonnx._data_types as dtypes import ndonnx._opset_extensions as opx +from ._coreimpl import CoreOperationsImpl +from ._interface import OperationsBlock from ._nullableimpl import NullableOperationsImpl -from ._shapeimpl import UniformShapeOperations from ._utils import binary_op, unary_op, validate_core if TYPE_CHECKING: from ndonnx import Array -class BooleanOperationsImpl(UniformShapeOperations): +class _BooleanOperationsImpl(OperationsBlock): @validate_core def equal(self, x, y) -> Array: return binary_op(x, y, opx.equal) @@ -99,7 +100,7 @@ def can_cast(self, from_, to) -> bool: @validate_core def all(self, x, *, axis=None, keepdims: bool = False): - if isinstance(x.dtype, dtypes._NullableCore): + if isinstance(x.dtype, dtypes.NullableCore): x = ndx.where(x.null, True, x.values) if functools.reduce(operator.mul, x._static_shape, 1) == 0: return ndx.asarray(True, dtype=ndx.bool) @@ -109,7 +110,7 @@ def all(self, x, *, axis=None, keepdims: bool = False): @validate_core def any(self, x, *, axis=None, keepdims: bool = False): - if isinstance(x.dtype, dtypes._NullableCore): + if isinstance(x.dtype, dtypes.NullableCore): x = ndx.where(x.null, False, x.values) if functools.reduce(operator.mul, x._static_shape, 1) == 0: return ndx.asarray(False, dtype=ndx.bool) @@ -162,17 +163,8 @@ def empty_like(self, x, dtype=None, device=None) -> ndx.Array: def nonzero(self, x) -> tuple[Array, ...]: return ndx.nonzero(x.astype(ndx.int8)) - @validate_core - def make_nullable(self, x, null): - if null.dtype != dtypes.bool: - raise TypeError("'null' must be a boolean array") - return ndx.Array._from_fields( - dtypes.into_nullable(x.dtype), - values=x.copy(), - null=ndx.reshape(null, x.shape), - ) +class BooleanOperationsImpl(CoreOperationsImpl, _BooleanOperationsImpl): ... -class NullableBooleanOperationsImpl(BooleanOperationsImpl, NullableOperationsImpl): - def make_nullable(self, x, null): - return NotImplemented + +class NullableBooleanOperationsImpl(NullableOperationsImpl, _BooleanOperationsImpl): ... diff --git a/ndonnx/_core/_coreimpl.py b/ndonnx/_core/_coreimpl.py new file mode 100644 index 0000000..5515425 --- /dev/null +++ b/ndonnx/_core/_coreimpl.py @@ -0,0 +1,58 @@ +# Copyright (c) QuantCo 2023-2024 +# SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +from spox import Tensor, argument + +import ndonnx as ndx +import ndonnx._data_types as dtypes +import ndonnx.additional as nda +from ndonnx._corearray import _CoreArray + +from ._shapeimpl import UniformShapeOperations +from ._utils import validate_core + +if TYPE_CHECKING: + from ndonnx._array import Array + from ndonnx._data_types import Dtype + + +class CoreOperationsImpl(UniformShapeOperations): + def make_array( + self, + shape: tuple[int | None | str, ...], + dtype: Dtype, + eager_value: np.ndarray | None = None, + ) -> Array: + if not isinstance(dtype, dtypes.CoreType): + return NotImplemented + return ndx.Array._from_fields( + dtype, + data=_CoreArray( + dtype._parse_input(eager_value)["data"] + if eager_value is not None + else argument(Tensor(dtype.to_numpy_dtype(), shape)) + ), + ) + + @validate_core + def make_nullable(self, x: Array, null: Array) -> Array: + if null.dtype != ndx.bool: + raise TypeError("'null' must be a boolean array") + + return ndx.Array._from_fields( + dtypes.into_nullable(x.dtype), + values=x.copy(), + null=ndx.broadcast_to(null, nda.shape(x)), + ) + + @validate_core + def where(self, condition, x, y): + if x.dtype != y.dtype: + target_dtype = ndx.result_type(x, y) + x = ndx.astype(x, target_dtype) + y = ndx.astype(y, target_dtype) + return super().where(condition, x, y) diff --git a/ndonnx/_core/_interface.py b/ndonnx/_core/_interface.py index 37fca9a..5340f4f 100644 --- a/ndonnx/_core/_interface.py +++ b/ndonnx/_core/_interface.py @@ -3,11 +3,17 @@ from __future__ import annotations -from typing import Literal +from typing import TYPE_CHECKING, Literal + +import numpy as np import ndonnx as ndx import ndonnx._data_types as dtypes +if TYPE_CHECKING: + from ndonnx._array import IndexType + from ndonnx._data_types import Dtype + class OperationsBlock: """Interface for data types to implement top-level functions exported by ndonnx.""" @@ -251,7 +257,7 @@ def cumulative_sum( x, *, axis: int | None = None, - dtype: ndx.CoreType | dtypes.StructType | None = None, + dtype: Dtype | None = None, include_initial: bool = False, ): return NotImplemented @@ -270,7 +276,7 @@ def prod( x, *, axis=None, - dtype: dtypes.CoreType | dtypes.StructType | None = None, + dtype: Dtype | None = None, keepdims: bool = False, ) -> ndx.Array: return NotImplemented @@ -293,7 +299,7 @@ def sum( x, *, axis=None, - dtype: dtypes.CoreType | dtypes.StructType | None = None, + dtype: Dtype | None = None, keepdims: bool = False, ) -> ndx.Array: return NotImplemented @@ -305,7 +311,7 @@ def var( axis=None, keepdims: bool = False, correction=0.0, - dtype: dtypes.CoreType | dtypes.StructType | None = None, + dtype: Dtype | None = None, ) -> ndx.Array: return NotImplemented @@ -352,7 +358,7 @@ def full_like(self, x, fill_value, dtype=None, device=None) -> ndx.Array: def ones( self, shape, - dtype: dtypes.CoreType | dtypes.StructType | None = None, + dtype: Dtype | None = None, device=None, ): return NotImplemented @@ -365,14 +371,12 @@ def ones_like( def zeros( self, shape, - dtype: dtypes.CoreType | dtypes.StructType | None = None, + dtype: Dtype | None = None, device=None, ): return NotImplemented - def zeros_like( - self, x, dtype: dtypes.CoreType | dtypes.StructType | None = None, device=None - ): + def zeros_like(self, x, dtype: Dtype | None = None, device=None): return NotImplemented def empty(self, shape, dtype=None, device=None) -> ndx.Array: @@ -413,3 +417,18 @@ def can_cast(self, from_, to) -> bool: def static_shape(self, x) -> tuple[int | None, ...]: return NotImplemented + + def make_array( + self, + shape: tuple[int | None | str, ...], + dtype: Dtype, + eager_value: np.ndarray | None = None, + ) -> ndx.Array: + return NotImplemented + + def getitem( + self, + x: ndx.Array, + index: IndexType, + ) -> ndx.Array: + return NotImplemented diff --git a/ndonnx/_core/_nullableimpl.py b/ndonnx/_core/_nullableimpl.py index 71115fc..835f4fd 100644 --- a/ndonnx/_core/_nullableimpl.py +++ b/ndonnx/_core/_nullableimpl.py @@ -1,16 +1,37 @@ # Copyright (c) QuantCo 2023-2024 # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + +from typing import TYPE_CHECKING, Union import ndonnx as ndx -from ._interface import OperationsBlock +from ._shapeimpl import UniformShapeOperations from ._utils import validate_core +if TYPE_CHECKING: + from ndonnx._array import Array + from ndonnx._data_types import CoreType, StructType + + Dtype = Union[CoreType, StructType] -class NullableOperationsImpl(OperationsBlock): + +class NullableOperationsImpl(UniformShapeOperations): @validate_core - def fill_null(self, x, value): + def fill_null(self, x: Array, value) -> Array: value = ndx.asarray(value) if value.dtype != x.values.dtype: value = value.astype(x.values.dtype) return ndx.where(x.null, value, x.values) + + @validate_core + def make_nullable(self, x: Array, null: Array) -> Array: + return NotImplemented + + @validate_core + def where(self, condition, x, y): + if x.dtype != y.dtype: + target_dtype = ndx.result_type(x, y) + x = ndx.astype(x, target_dtype) + y = ndx.astype(y, target_dtype) + return super().where(condition, x, y) diff --git a/ndonnx/_core/_numericimpl.py b/ndonnx/_core/_numericimpl.py index 0b8641f..a3c92e7 100644 --- a/ndonnx/_core/_numericimpl.py +++ b/ndonnx/_core/_numericimpl.py @@ -16,10 +16,12 @@ import ndonnx as ndx import ndonnx._data_types as dtypes import ndonnx._opset_extensions as opx +import ndonnx.additional as nda from ndonnx._utility import promote +from ._coreimpl import CoreOperationsImpl +from ._interface import OperationsBlock from ._nullableimpl import NullableOperationsImpl -from ._shapeimpl import UniformShapeOperations from ._utils import ( binary_op, from_corearray, @@ -35,7 +37,7 @@ from ndonnx._corearray import _CoreArray -class NumericOperationsImpl(UniformShapeOperations): +class _NumericOperationsImpl(OperationsBlock): # elementwise.py @validate_core @@ -198,7 +200,7 @@ def isfinite(self, x): def isinf(self, x): if isinstance(x.dtype, (dtypes.Floating, dtypes.NullableFloating)): return unary_op(x, opx.isinf) - return ndx.full(x.shape, fill_value=False) + return ndx.full(nda.shape(x), fill_value=False) @validate_core def isnan(self, x): @@ -400,9 +402,7 @@ def nonzero(self, x) -> tuple[Array, ...]: return (ndx.arange(0, x != 0, dtype=dtypes.int64),) ret_full_flattened = ndx.reshape( - from_corearray( - opx.ndindex(ndx.asarray(x.shape, dtype=dtypes.int64)._core()) - )[x != 0], + from_corearray(opx.ndindex(nda.shape(x)._core()))[x != 0], [-1], ) @@ -413,7 +413,7 @@ def nonzero(self, x) -> tuple[Array, ...]: ret_full_flattened._core(), ndx.arange( i, - ret_full_flattened.shape[0], + nda.shape(ret_full_flattened)[0], x.ndim, dtype=dtypes.int64, )._core(), @@ -454,20 +454,21 @@ def searchsorted( from_corearray, opx.get_indices(x1._core(), x2._core(), positions._core()) ) - how_many = ndx.zeros(ndx.asarray(combined.shape) + 1, dtype=dtypes.int64) + combined_shape = nda.shape(combined) + how_many = ndx.zeros(combined_shape + 1, dtype=dtypes.int64) how_many[ - ndx.where(indices_x1 + 1 <= combined.shape[0], indices_x1 + 1, indices_x1) + ndx.where(indices_x1 + 1 <= combined_shape[0], indices_x1 + 1, indices_x1) ] = counts - how_many = ndx.cumulative_sum(how_many, include_initial=True) + how_many = ndx.cumulative_sum(how_many, include_initial=False, axis=None) - ret = ndx.zeros(x2.shape, dtype=dtypes.int64) + ret = ndx.zeros(nda.shape(x2), dtype=dtypes.int64) if side == "left": ret = how_many[indices_x2] - ret[nan_mask] = ndx.asarray(x1.shape, dtype=dtypes.int64) - 1 + ret[nan_mask] = nda.shape(x1) - 1 else: ret = how_many[indices_x2 + 1] - ret[nan_mask] = ndx.asarray(x1.shape, dtype=dtypes.int64) + ret[nan_mask] = nda.shape(x1) return ret @@ -494,7 +495,7 @@ def unique_all(self, x): # FIXME: I think we can simply use arange/ones+cumsum or something for the indices # maybe: indices = opx.cumsum(ones_like(flattened, dtype=dtypes.i64), axis=ndx.asarray(0)) indices = opx.squeeze( - opx.ndindex(ndx.asarray(flattened.shape, dtype=dtypes.int64)._core()), + opx.ndindex(nda.shape(flattened)._core()), opx.const([1], dtype=dtypes.int64), ) @@ -502,7 +503,7 @@ def unique_all(self, x): values = from_corearray(ret_opd[0]) indices = from_corearray(indices[ret_opd[1]]) - inverse_indices = ndx.reshape(from_corearray(ret_opd[2]), x.shape) + inverse_indices = ndx.reshape(from_corearray(ret_opd[2]), nda.shape(x)) counts = from_corearray(ret_opd[3]) return ret( @@ -535,7 +536,7 @@ def argsort(self, x, *, axis=-1, descending=False, stable=True): if axis < 0: axis += x.ndim - _len = ndx.asarray(x.shape[axis : axis + 1], dtype=dtypes.int64)._core() + _len = ndx.asarray(nda.shape(x)[axis : axis + 1], dtype=dtypes.int64)._core() return _via_i64_f64( lambda x: opx.top_k(x, _len, largest=descending, axis=axis)[1], [x] ) @@ -544,7 +545,7 @@ def argsort(self, x, *, axis=-1, descending=False, stable=True): def sort(self, x, *, axis=-1, descending=False, stable=True): if axis < 0: axis += x.ndim - _len = ndx.asarray(x.shape[axis : axis + 1], dtype=dtypes.int64)._core() + _len = ndx.asarray(nda.shape(x)[axis : axis + 1], dtype=dtypes.int64)._core() return _via_i64_f64( lambda x: opx.top_k(x, _len, largest=descending, axis=axis)[0], [x] ) @@ -566,13 +567,47 @@ def cumulative_sum( axis = 0 else: raise ValueError("axis must be specified for multi-dimensional arrays") + + if isinstance(x.dtype, (dtypes.Unsigned, dtypes.NullableUnsigned)): + if ndx.iinfo(x.dtype).bits < 64: + out = x.astype(dtypes.int64) + else: + return NotImplemented + elif dtype == dtypes.uint64 or dtype == dtypes.nuint64: + raise ndx.UnsupportedOperationError( + f"Unsupported dtype parameter for cumulative_sum {dtype} due to missing kernel support" + ) + else: + out = x.astype(_determine_reduce_op_dtype(x, None, dtypes.uint64)) + out = from_corearray( opx.cumsum( - x._core(), axis=opx.const(axis), exclusive=int(not include_initial) + out._core(), + axis=opx.const(axis), + exclusive=0, ) ) - if dtype is not None: + + if dtype is None: + if isinstance(x.dtype, dtypes.Unsigned): + out = out.astype(ndx.uint64) + elif isinstance(x.dtype, dtypes.NullableUnsigned): + out = out.astype(ndx.nuint64) + else: out = out.astype(dtype) + + # Exclude axis and create zeros of that shape + if include_initial: + out_shape = nda.shape(out) + out_shape[axis] = 1 + out = ndx.concat( + [ + ndx.zeros(out_shape, dtype=out.dtype), + out, + ], + axis=axis, + ) + return out @validate_core @@ -707,7 +742,7 @@ def clip( and isinstance(x.dtype, dtypes.Numerical) ): x, min, max = promote(x, min, max) - if isinstance(x.dtype, dtypes._NullableCore): + if isinstance(x.dtype, dtypes.NullableCore): out_null = x.null x_values = x.values._core() clipped = from_corearray(opx.clip(x_values, min._core(), max._core())) @@ -805,17 +840,6 @@ def var( - correction ) - @validate_core - def make_nullable(self, x, null): - if null.dtype != dtypes.bool: - raise TypeError("'null' must be a boolean array") - - return ndx.Array._from_fields( - dtypes.into_nullable(x.dtype), - values=x.copy(), - null=ndx.reshape(null, x.shape), - ) - @validate_core def can_cast(self, from_, to) -> bool: if isinstance(from_, dtypes.CoreType) and isinstance(to, ndx.CoreType): @@ -824,7 +848,7 @@ def can_cast(self, from_, to) -> bool: @validate_core def all(self, x, *, axis=None, keepdims: bool = False): - if isinstance(x.dtype, dtypes._NullableCore): + if isinstance(x.dtype, dtypes.NullableCore): x = ndx.where(x.null, True, x.values) if functools.reduce(operator.mul, x._static_shape, 1) == 0: return ndx.asarray(True, dtype=ndx.bool) @@ -834,7 +858,7 @@ def all(self, x, *, axis=None, keepdims: bool = False): @validate_core def any(self, x, *, axis=None, keepdims: bool = False): - if isinstance(x.dtype, dtypes._NullableCore): + if isinstance(x.dtype, dtypes.NullableCore): x = ndx.where(x.null, False, x.values) if functools.reduce(operator.mul, x._static_shape, 1) == 0: return ndx.asarray(False, dtype=ndx.bool) @@ -866,7 +890,7 @@ def arange(self, start, stop=None, step=None, dtype=None, device=None) -> ndx.Ar @validate_core def tril(self, x, k=0) -> ndx.Array: - if isinstance(x.dtype, dtypes._NullableCore): + if isinstance(x.dtype, dtypes.NullableCore): # NumPy appears to just ignore the mask so we do the same x = x.values return x._transmute( @@ -877,7 +901,7 @@ def tril(self, x, k=0) -> ndx.Array: @validate_core def triu(self, x, k=0) -> ndx.Array: - if isinstance(x.dtype, dtypes._NullableCore): + if isinstance(x.dtype, dtypes.NullableCore): # NumPy appears to just ignore the mask so we do the same x = x.values return x._transmute( @@ -948,9 +972,10 @@ def empty_like(self, x, dtype=None, device=None) -> ndx.Array: return ndx.full_like(x, 0, dtype=dtype) -class NullableNumericOperationsImpl(NumericOperationsImpl, NullableOperationsImpl): - def make_nullable(self, x, null): - return NotImplemented +class NumericOperationsImpl(CoreOperationsImpl, _NumericOperationsImpl): ... + + +class NullableNumericOperationsImpl(NullableOperationsImpl, _NumericOperationsImpl): ... def _via_i64_f64( diff --git a/ndonnx/_core/_shapeimpl.py b/ndonnx/_core/_shapeimpl.py index 8674d6d..67da73e 100644 --- a/ndonnx/_core/_shapeimpl.py +++ b/ndonnx/_core/_shapeimpl.py @@ -1,7 +1,10 @@ # Copyright (c) QuantCo 2023-2024 # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + from collections.abc import Iterable, Sequence +from typing import TYPE_CHECKING import numpy as np @@ -13,6 +16,10 @@ from ._interface import OperationsBlock from ._utils import from_corearray +if TYPE_CHECKING: + from ndonnx._array import Array, IndexType + from ndonnx._data_types import Dtype + class UniformShapeOperations(OperationsBlock): """Provides implementation for shape/indexing operations that are generic across all @@ -99,7 +106,7 @@ def roll(self, x, shift, axis): if not isinstance(shift, Sequence): shift = [shift] - old_shape = x.shape + old_shape = nda.shape(x) if axis is None: x = ndx.reshape(x, [-1]) @@ -112,9 +119,7 @@ def roll(self, x, shift, axis): raise ValueError("shift and axis must have the same length") for sh, ax in zip(shift, axis): - len_single = opx.gather( - ndx.asarray(x.shape, dtype=dtypes.int64)._core(), opx.const(ax) - ) + len_single = opx.gather(nda.shape(x)._core(), opx.const(ax)) shift_single = opx.add(opx.const(-sh, dtype=dtypes.int64), len_single) # Find the needed element index and then gather from it range = opx.cast( @@ -157,9 +162,7 @@ def full_like(self, x, fill_value, dtype=None, device=None): def where(self, condition, x, y): if x.dtype != y.dtype: - target_dtype = ndx.result_type(x, y) - x = ndx.astype(x, target_dtype) - y = ndx.astype(y, target_dtype) + return NotImplemented if isinstance(condition.dtype, dtypes.Nullable) and not isinstance( x.dtype, (dtypes.Nullable, dtypes.CoreType) ): @@ -246,7 +249,58 @@ def where_dtype_agnostic(a: ndx.Array, b: ndx.Array): return output def zeros_like(self, x, dtype=None, device=None): - return ndx.zeros(x.shape, dtype=dtype or x.dtype, device=device) + return ndx.zeros(nda.shape(x), dtype=dtype or x.dtype, device=device) def ones_like(self, x, dtype=None, device=None): return ndx.ones(x.shape, dtype=dtype or x.dtype, device=device) + + def make_array( + self, + shape: tuple[int | None | str, ...], + dtype: Dtype, + eager_value: np.ndarray | None = None, + ) -> Array: + if isinstance(dtype, dtypes.CoreType): + return NotImplemented + + fields: dict[str, ndx.Array] = {} + + eager_values = None if eager_value is None else dtype._parse_input(eager_value) + for name, field_dtype in dtype._fields().items(): + if eager_values is None: + field_value = None + else: + field_value = _assemble_output_recurse(field_dtype, eager_values[name]) + fields[name] = field_dtype._ops.make_array( + shape, + field_dtype, + field_value, + ) + return ndx.Array._from_fields( + dtype, + **fields, + ) + + def getitem(self, x: Array, index: IndexType) -> Array: + if isinstance(index, ndx.Array) and not ( + isinstance(index.dtype, dtypes.Integral) or index.dtype == dtypes.bool + ): + raise TypeError( + f"Index must be an integral or boolean 'Array', not `{index.dtype}`" + ) + + if isinstance(index, ndx.Array): + index = index._core() + + return x._transmute(lambda corearray: corearray[index]) + + +def _assemble_output_recurse(dtype: Dtype, values: dict) -> np.ndarray: + if isinstance(dtype, dtypes.CoreType): + return dtype._assemble_output(values) + else: + fields = { + name: _assemble_output_recurse(field_dtype, values[name]) + for name, field_dtype in dtype._fields().items() + } + return dtype._assemble_output(fields) diff --git a/ndonnx/_core/_stringimpl.py b/ndonnx/_core/_stringimpl.py index 0218a4a..d414782 100644 --- a/ndonnx/_core/_stringimpl.py +++ b/ndonnx/_core/_stringimpl.py @@ -11,15 +11,16 @@ import ndonnx._data_types as dtypes import ndonnx._opset_extensions as opx +from ._coreimpl import CoreOperationsImpl +from ._interface import OperationsBlock from ._nullableimpl import NullableOperationsImpl -from ._shapeimpl import UniformShapeOperations from ._utils import binary_op, validate_core if TYPE_CHECKING: from ndonnx import Array -class StringOperationsImpl(UniformShapeOperations): +class _StringOperationsImpl(OperationsBlock): @validate_core def add(self, x, y) -> Array: return binary_op(x, y, opx.string_concat) @@ -52,7 +53,7 @@ def zeros_like( self, x, dtype: dtypes.CoreType | dtypes.StructType | None = None, device=None ): if dtype is not None and not isinstance( - dtype, (dtypes.CoreType, dtypes._NullableCore) + dtype, (dtypes.CoreType, dtypes.NullableCore) ): raise TypeError("'dtype' must be a CoreType or NullableCoreType") if dtype in (None, dtypes.utf8, dtypes.nutf8): @@ -68,18 +69,8 @@ def empty(self, shape, dtype=None, device=None) -> ndx.Array: def empty_like(self, x, dtype=None, device=None) -> ndx.Array: return ndx.zeros_like(x, dtype=dtype, device=device) - @validate_core - def make_nullable(self, x, null): - if null.dtype != dtypes.bool: - raise TypeError("'null' must be a boolean array") - return ndx.Array._from_fields( - dtypes.into_nullable(x.dtype), - values=x.copy(), - null=ndx.reshape(null, x.shape), - ) +class StringOperationsImpl(CoreOperationsImpl, _StringOperationsImpl): ... -class NullableStringOperationsImpl(StringOperationsImpl, NullableOperationsImpl): - def make_nullable(self, x, null): - return NotImplemented +class NullableStringOperationsImpl(NullableOperationsImpl, _StringOperationsImpl): ... diff --git a/ndonnx/_core/_utils.py b/ndonnx/_core/_utils.py index ec84943..29e1b74 100644 --- a/ndonnx/_core/_utils.py +++ b/ndonnx/_core/_utils.py @@ -38,7 +38,7 @@ def variadic_op( ): args = promote(*args) out_dtype = args[0].dtype - if not isinstance(out_dtype, (dtypes.CoreType, dtypes._NullableCore)): + if not isinstance(out_dtype, (dtypes.CoreType, dtypes.NullableCore)): raise TypeError( f"Expected ndx.Array with CoreType or NullableCoreType, got {args[0].dtype}" ) @@ -100,7 +100,7 @@ def _via_dtype( promoted = promote(*arrays) out_dtype = promoted[0].dtype - if isinstance(out_dtype, dtypes._NullableCore) and out_dtype.values == dtype: + if isinstance(out_dtype, dtypes.NullableCore) and out_dtype.values == dtype: dtype = out_dtype values, nulls = split_nulls_and_values( @@ -203,7 +203,7 @@ def validate_core(func): def wrapper(*args, **kwargs): for arg in itertools.chain(args, kwargs.values()): if isinstance(arg, ndx.Array) and not isinstance( - arg.dtype, (dtypes.CoreType, dtypes._NullableCore) + arg.dtype, (dtypes.CoreType, dtypes.NullableCore) ): return NotImplemented return func(*args, **kwargs) diff --git a/ndonnx/_data_types/__init__.py b/ndonnx/_data_types/__init__.py index 12d580d..392abe0 100644 --- a/ndonnx/_data_types/__init__.py +++ b/ndonnx/_data_types/__init__.py @@ -3,7 +3,7 @@ from __future__ import annotations from ndonnx._utility import deprecated - +from typing import Union from .aliases import ( bool, float32, @@ -40,7 +40,7 @@ NullableUnsigned, Numerical, Unsigned, - _NullableCore, + NullableCore, from_numpy_dtype, get_finfo, get_iinfo, @@ -51,7 +51,7 @@ from .structtype import StructType -def into_nullable(dtype: StructType | CoreType) -> _NullableCore: +def into_nullable(dtype: StructType | CoreType) -> NullableCore: """Return nullable counterpart, if present. Parameters @@ -61,7 +61,7 @@ def into_nullable(dtype: StructType | CoreType) -> _NullableCore: Returns ------- - out : _NullableCore + out : NullableCore The nullable counterpart of the input type. Raises @@ -93,24 +93,27 @@ def into_nullable(dtype: StructType | CoreType) -> _NullableCore: return nuint64 elif dtype == utf8: return nutf8 - elif isinstance(dtype, _NullableCore): + elif isinstance(dtype, NullableCore): return dtype else: raise ValueError(f"Cannot promote {dtype} to nullable") +Dtype = Union[CoreType, StructType] + + @deprecated( "Function 'ndonnx.promote_nullable' will be deprecated in ndonnx 0.7. " "To create nullable array, use 'ndonnx.additional.make_nullable' instead." ) -def promote_nullable(dtype: StructType | CoreType) -> _NullableCore: +def promote_nullable(dtype: StructType | CoreType) -> NullableCore: return into_nullable(dtype) __all__ = [ "CoreType", "StructType", - "_NullableCore", + "NullableCore", "NullableFloating", "NullableIntegral", "NullableUnsigned", @@ -151,4 +154,5 @@ def promote_nullable(dtype: StructType | CoreType) -> _NullableCore: "Schema", "CastMixin", "CastError", + "Dtype", ] diff --git a/ndonnx/_data_types/classes.py b/ndonnx/_data_types/classes.py index 661ef20..c8acd7e 100644 --- a/ndonnx/_data_types/classes.py +++ b/ndonnx/_data_types/classes.py @@ -189,7 +189,7 @@ def _fields(self) -> dict[str, StructType | CoreType]: } -class _NullableCore(Nullable[CoreType], CastMixin): +class NullableCore(Nullable[CoreType], CastMixin): def copy(self) -> Self: return self @@ -213,7 +213,7 @@ def _schema(self) -> Schema: return Schema(type_name=type(self).__name__, author="ndonnx") def _cast_to(self, array: Array, dtype: CoreType | StructType) -> Array: - if isinstance(dtype, _NullableCore): + if isinstance(dtype, NullableCore): return ndx.Array._from_fields( dtype, values=self.values._cast_to(array.values, dtype.values), @@ -230,7 +230,7 @@ def _cast_from(self, array: Array) -> Array: values=self.values._cast_from(array), null=ndx.zeros_like(array, dtype=Boolean()), ) - elif isinstance(array.dtype, _NullableCore): + elif isinstance(array.dtype, NullableCore): return ndx.Array._from_fields( self, values=self.values._cast_from(array.values), @@ -240,7 +240,7 @@ def _cast_from(self, array: Array) -> Array: raise CastError(f"Cannot cast from {array.dtype} to {self}") -class NullableNumerical(_NullableCore): +class NullableNumerical(NullableCore): """Base class for nullable numerical data types.""" _ops: OperationsBlock = NullableNumericOperationsImpl() @@ -312,14 +312,14 @@ class NFloat64(NullableFloating): null = Boolean() -class NBoolean(_NullableCore): +class NBoolean(NullableCore): values = Boolean() null = Boolean() _ops: OperationsBlock = NullableBooleanOperationsImpl() -class NUtf8(_NullableCore): +class NUtf8(NullableCore): values = Utf8() null = Boolean() @@ -405,18 +405,18 @@ def _from_dtype(cls, dtype: CoreType) -> Finfo: ) -def get_finfo(dtype: _NullableCore | CoreType) -> Finfo: +def get_finfo(dtype: NullableCore | CoreType) -> Finfo: try: - if isinstance(dtype, _NullableCore): + if isinstance(dtype, NullableCore): dtype = dtype.values return Finfo._from_dtype(dtype) except KeyError: raise TypeError(f"'{dtype}' is not a floating point data type.") -def get_iinfo(dtype: _NullableCore | CoreType) -> Iinfo: +def get_iinfo(dtype: NullableCore | CoreType) -> Iinfo: try: - if isinstance(dtype, _NullableCore): + if isinstance(dtype, NullableCore): dtype = dtype.values return Iinfo._from_dtype(dtype) except KeyError: diff --git a/ndonnx/_data_types/coretype.py b/ndonnx/_data_types/coretype.py index c82ee72..34fcdf5 100644 --- a/ndonnx/_data_types/coretype.py +++ b/ndonnx/_data_types/coretype.py @@ -69,7 +69,11 @@ def _parse_input(self, data: np.ndarray) -> dict[str, np.ndarray]: def _cast_from(self, array: Array) -> Array: if isinstance(array.dtype, CoreType): - return ndx.Array._from_fields(self, data=opx.cast(array._core(), to=self)) + return ( + ndx.Array._from_fields(self, data=opx.cast(array._core(), to=self)) + if array.dtype != self + else array.copy() + ) else: raise CastError(f"Cannot cast from {array.dtype} to {self}") diff --git a/ndonnx/_funcs.py b/ndonnx/_funcs.py index 1088a28..d15dd16 100644 --- a/ndonnx/_funcs.py +++ b/ndonnx/_funcs.py @@ -11,8 +11,9 @@ import numpy.typing as npt import ndonnx._data_types as dtypes -from ndonnx._data_types import CastError, CastMixin, CoreType, _NullableCore +from ndonnx._data_types import CastError, CastMixin, CoreType, NullableCore from ndonnx._data_types.structtype import StructType +from ndonnx.additional import shape from . import _opset_extensions as opx from ._array import Array, _from_corearray @@ -60,20 +61,26 @@ def asarray( device=None, ) -> Array: if not isinstance(x, Array): - arr = np.asanyarray( + eager_value = np.asanyarray( x, dtype=( dtype.to_numpy_dtype() if isinstance(dtype, dtypes.CoreType) else None ), ) if dtype is None: - dtype = dtypes.from_numpy_dtype(arr.dtype) - if isinstance(arr, np.ma.masked_array): + dtype = dtypes.from_numpy_dtype(eager_value.dtype) + if isinstance(eager_value, np.ma.masked_array): dtype = dtypes.into_nullable(dtype) - ret = Array._construct( - shape=arr.shape, dtype=dtype, eager_values=dtype._parse_input(arr) + ret = dtype._ops.make_array( + shape=eager_value.shape, + dtype=dtype, + eager_value=eager_value, ) + if ret is NotImplemented: + raise UnsupportedOperationError( + f"Unsupported operand type for asarray: '{dtype}'" + ) else: ret = x.copy() if copy is True else x @@ -290,7 +297,7 @@ def result_type( np_dtypes = [] for dtype in observed_dtypes: if isinstance(dtype, dtypes.StructType): - if isinstance(dtype, _NullableCore): + if isinstance(dtype, NullableCore): nullable = True np_dtypes.append(dtype.values.to_numpy_dtype()) else: @@ -574,18 +581,22 @@ def numeric_like(x): ret = numeric_like(next(it)) while (x := next(it, None)) is not None: ret = ret + numeric_like(x) - target_shape = ret.shape + target_shape = shape(ret) return [broadcast_to(x, target_shape) for x in arrays] def broadcast_to(x, shape): - return x.dtype._ops.broadcast_to(x, shape) + if (out := x.dtype._ops.broadcast_to(x, shape)) is not NotImplemented: + return out + raise UnsupportedOperationError( + f"Unsupported operand type for broadcast_to: '{x.dtype}'" + ) # TODO: onnxruntime doesn't work for 2 empty arrays of integer type # TODO: what is the appropriate strategy to dispatch? (iterate over the inputs and keep trying is reasonable but it can # change the outcome based on order if poorly implemented) -def concat(arrays, axis=None): +def concat(arrays, /, *, axis: int | None = 0): if axis is None: arrays = [reshape(x, [-1]) for x in arrays] axis = 0 @@ -598,27 +609,47 @@ def concat(arrays, axis=None): def expand_dims(x, axis=0): - return x.dtype._ops.expand_dims(x, axis) + if (out := x.dtype._ops.expand_dims(x, axis)) is not NotImplemented: + return out + raise UnsupportedOperationError( + f"Unsupported operand type for expand_dims: '{x.dtype}'" + ) def flip(x, axis=None): - return x.dtype._ops.flip(x, axis=axis) + if (out := x.dtype._ops.flip(x, axis=axis)) is not NotImplemented: + return out + raise UnsupportedOperationError(f"Unsupported operand type for flip: '{x.dtype}'") def permute_dims(x, axes): - return x.dtype._ops.permute_dims(x, axes) + if (out := x.dtype._ops.permute_dims(x, axes)) is not NotImplemented: + return out + raise UnsupportedOperationError( + f"Unsupported operand type for permute_dims: '{x.dtype}'" + ) def reshape(x, shape, *, copy=None): - return x.dtype._ops.reshape(x, shape, copy=copy) + if (out := x.dtype._ops.reshape(x, shape, copy=copy)) is not NotImplemented: + return out + raise UnsupportedOperationError( + f"Unsupported operand type for reshape: '{x.dtype}'" + ) def roll(x, shift, axis=None): - return x.dtype._ops.roll(x, shift, axis) + if (out := x.dtype._ops.roll(x, shift, axis)) is not NotImplemented: + return out + raise UnsupportedOperationError(f"Unsupported operand type for roll: '{x.dtype}'") def squeeze(x, axis): - return x.dtype._ops.squeeze(x, axis) + if (out := x.dtype._ops.squeeze(x, axis)) is not NotImplemented: + return out + raise UnsupportedOperationError( + f"Unsupported operand type for squeeze: '{x.dtype}'" + ) def stack(arrays, axis=0): diff --git a/ndonnx/_opset_extensions.py b/ndonnx/_opset_extensions.py index 2cb8660..1ac10df 100644 --- a/ndonnx/_opset_extensions.py +++ b/ndonnx/_opset_extensions.py @@ -707,24 +707,31 @@ def reshape_like(x: _CoreArray, y: _CoreArray) -> _CoreArray: def static_map( input: _CoreArray, mapping: Mapping[KeyType, ValueType], default: ValueType | None ) -> _CoreArray: - keys = np.array(tuple(mapping.keys())) - if keys.dtype == np.int32: - keys = keys.astype(np.int64) - values = np.array(tuple(mapping.values())) + keys = np.asarray(tuple(mapping.keys())) + values = np.asarray(tuple(mapping.values())) + if isinstance(input.dtype, dtypes.Integral): + input = input.astype(dtypes.int64) + # Should only be a relevant path in Windows NumPy 1.x + if keys.dtype != np.dtype("int64") and keys.dtype.kind == "i": + keys = keys.astype(input.dtype.to_numpy_dtype()) + elif isinstance(input.dtype, dtypes.Floating): + input = input.astype(dtypes.float64) + value_dtype = values.dtype if default is None: if value_dtype.kind == "U" or ( value_dtype.kind == "O" and all(isinstance(x, str) for x in values.flat) ): - default_tensor = np.array(["MISSING"]) + default_tensor = np.asarray(["MISSING"]) else: - default_tensor = np.array([0], dtype=value_dtype) + default_tensor = np.asarray([0], dtype=value_dtype) + elif value_dtype.kind == "U": + default_tensor = np.asarray([default], dtype=np.str_) else: - default_tensor = np.array([default], dtype=value_dtype) - - if keys.dtype == np.float64 and isinstance(input.dtype, dtypes.Integral): + default_tensor = np.asarray([default], dtype=value_dtype) + if keys.dtype.kind == "f" and isinstance(input.dtype, dtypes.Integral): input = cast(input, dtypes.from_numpy_dtype(keys.dtype)) - elif keys.dtype == np.int64 and isinstance(input.dtype, dtypes.Floating): + elif keys.dtype.kind == "i" and isinstance(input.dtype, dtypes.Floating): keys = keys.astype(input.dtype.to_numpy_dtype()) return _CoreArray( ml.label_encoder( diff --git a/ndonnx/additional/_additional.py b/ndonnx/additional/_additional.py index 5ec6442..a9f1fc4 100644 --- a/ndonnx/additional/_additional.py +++ b/ndonnx/additional/_additional.py @@ -13,6 +13,7 @@ if TYPE_CHECKING: from ndonnx import Array + from ndonnx._array import IndexType Scalar = TypeVar("Scalar", int, float, str) @@ -88,7 +89,9 @@ def static_map( A new Array with the values mapped according to the mapping. """ if not isinstance(x.dtype, ndx.CoreType): - raise TypeError("static_map accepts only non-nullable arrays") + raise ndx.UnsupportedOperationError( + "'static_map' accepts only non-nullable arrays" + ) data = opx.static_map(x._core(), mapping, default) return ndx.Array._from_fields(data.dtype, data=data) @@ -147,6 +150,15 @@ def make_nullable(x: Array, null: Array) -> Array: return out +def _getitem(x: Array, index: IndexType) -> ndx.Array: + out = x.dtype._ops.getitem(x, index) + if out is NotImplemented: + raise ndx.UnsupportedOperationError( + f"'getitem' not implemented for `{x.dtype}`" + ) + return out + + def _static_shape(x: Array) -> tuple[int | None, ...]: """Return shape of the array as a tuple. Typical implementations will make use of ONNX shape inference, with `None` entries denoting unknown or symbolic dimensions. diff --git a/tests/test_additional.py b/tests/test_additional.py index d314dbc..869eabe 100644 --- a/tests/test_additional.py +++ b/tests/test_additional.py @@ -67,13 +67,12 @@ def test_searchsorted_raises(): @pytest.mark.skipif( - sys.platform.startswith("win"), + sys.platform.startswith("win") and np.__version__ < "2", reason="ORT 1.18 not registering LabelEncoder(4) only on Windows.", ) -def test_static_map(): +def test_static_map_lazy(): a = ndx.array(shape=(3,), dtype=ndx.int64) b = nda.static_map(a, {1: 2, 2: 3}) - model = ndx.build({"a": a}, {"b": b}) assert_array_equal([0, 2, 3], run(model, {"a": np.array([0, 1, 2])})["b"]) @@ -87,9 +86,72 @@ def test_static_map(): run(model, {"a": np.array([0.0, 2.0, 3.0, np.nan])})["b"], ) - a = ndx.asarray(["hello", "world", "!"]) - b = nda.static_map(a, {"hello": "hi", "world": "earth"}) - np.testing.assert_equal(["hi", "earth", "MISSING"], b.to_numpy()) + +@pytest.mark.skipif( + sys.platform.startswith("win") and np.__version__ < "2", + reason="ORT 1.18 not registering LabelEncoder(4) only on Windows.", +) +@pytest.mark.parametrize( + "x, mapping, default, expected", + [ + ( + ndx.asarray(["hello", "world", "!"]), + {"hello": "hi", "world": "earth"}, + None, + ["hi", "earth", "MISSING"], + ), + ( + ndx.asarray(["hello", "world", "!"]), + {"hello": "hi", "world": "earth"}, + "DIFFERENT", + ["hi", "earth", "DIFFERENT"], + ), + (ndx.asarray([0, 1, 2], dtype=ndx.int64), {0: -1, 1: -2}, None, [-1, -2, 0]), + (ndx.asarray([0, 1, 2], dtype=ndx.int64), {0: -1, 1: -2}, 42, [-1, -2, 42]), + ( + ndx.asarray([[0], [1], [2]], dtype=ndx.int64), + {0: -1, 1: -2}, + 42, + [[-1], [-2], [42]], + ), + ( + ndx.asarray([[0], [1], [2]], dtype=ndx.int32), + {0: -1, 1: -2}, + 42, + [[-1], [-2], [42]], + ), + ( + ndx.asarray([[0], [1], [2]], dtype=ndx.int8), + {0: -1, 1: -2}, + 42, + [[-1], [-2], [42]], + ), + ( + ndx.asarray([[0], [1], [2]], dtype=ndx.uint8), + {0: -1, 1: -2}, + 42, + [[-1], [-2], [42]], + ), + ( + ndx.asarray([[0], [1], [np.nan]], dtype=ndx.float32), + {0: -1, 1: -2, np.nan: 3.142}, + 42, + [[-1], [-2], [3.142]], + ), + ], +) +def test_static_map(x, mapping, default, expected): + actual = nda.static_map(x, mapping, default=default) + assert_array_equal(actual.to_numpy(), expected) + + +def test_static_map_unimplemented_for_nullable(): + a = ndx.asarray([1, 2, 3], dtype=ndx.int64) + m = ndx.asarray([True, False, True]) + a = nda.make_nullable(a, m) + + with pytest.raises(ndx.UnsupportedOperationError): + nda.static_map(a, {1: 2, 2: 3}) @pytest.mark.skipif( @@ -117,3 +179,29 @@ def test_isin(): a = ndx.asarray(["hello", "world"]) assert_array_equal([True, False], nda.isin(a, ["hello"]).to_numpy()) + + +@pytest.mark.parametrize( + "dtype", + [ + ndx.int64, + ndx.utf8, + ndx.bool, + ], +) +@pytest.mark.parametrize( + "mask", + [ + [True, False, True], + True, + False, + [True], + ], +) +def test_make_nullable(dtype, mask): + a = ndx.asarray([1, 2, 3], dtype=dtype) + m = ndx.asarray(mask) + + result = nda.make_nullable(a, m) + expected = np.ma.masked_array([1, 2, 3], mask, dtype.to_numpy_dtype()) + assert_array_equal(result.to_numpy(), expected) diff --git a/tests/test_core.py b/tests/test_core.py index 7ec974f..b0eaae3 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -926,3 +926,59 @@ def test_lazy_array_shape(x, expected_shape): def test_dynamic_reshape_has_no_static_shape(x, shape): with pytest.raises(ValueError, match="Could not determine static shape"): ndx.reshape(x, shape).shape + + +@pytest.mark.skipif( + not np.__version__.startswith("2"), reason="NumPy >= 2 used for test assertions" +) +@pytest.mark.parametrize("include_initial", [True, False]) +@pytest.mark.parametrize( + "array_dtype", + [ndx.int32, ndx.int64, ndx.float32, ndx.float64, ndx.uint8, ndx.uint16, ndx.uint32], +) +@pytest.mark.parametrize( + "array, axis", + [ + ([1, 2, 3], None), + ([100, 100], None), + ([1, 2, 3], 0), + ([[1, 2], [3, 4]], 0), + ([[1, 2], [3, 4]], 1), + ([[1, 2, 50], [3, 4, 5]], 1), + ([[[[1]]], [[[3]]]], 0), + ([[[[1]]], [[[3]]]], 1), + ], +) +@pytest.mark.parametrize( + "cumsum_dtype", + [None, ndx.int32, ndx.float32, ndx.float64, ndx.uint8, ndx.int8], +) +def test_cumulative_sum(array, axis, include_initial, array_dtype, cumsum_dtype): + a = ndx.asarray(array, dtype=array_dtype) + assert_array_equal( + ndx.cumulative_sum( + a, include_initial=include_initial, axis=axis, dtype=cumsum_dtype + ).to_numpy(), + np.cumulative_sum( + np.asarray(array, a.dtype.to_numpy_dtype()), + include_initial=include_initial, + axis=axis, + dtype=cumsum_dtype.to_numpy_dtype() if cumsum_dtype is not None else None, + ), + ) + + +def test_no_unsafe_cumulative_sum_cast(): + with pytest.raises( + ndx.UnsupportedOperationError, + match="Unsupported operand type for cumulative_sum", + ): + a = ndx.asarray([1, 2, 3], ndx.uint64) + ndx.cumulative_sum(a) + + with pytest.raises( + ndx.UnsupportedOperationError, + match="Unsupported dtype parameter for cumulative_sum", + ): + a = ndx.asarray([1, 2, 3], ndx.int32) + ndx.cumulative_sum(a, dtype=ndx.uint64) diff --git a/tests/test_dtypes.py b/tests/test_dtypes.py index 048978d..cddc3b0 100644 --- a/tests/test_dtypes.py +++ b/tests/test_dtypes.py @@ -3,6 +3,7 @@ from __future__ import annotations +import functools import re import numpy as np @@ -10,12 +11,19 @@ from typing_extensions import Self import ndonnx as ndx +import ndonnx.additional as nda from ndonnx import ( Array, CastError, CoreType, ) -from ndonnx._experimental import CastMixin, Schema, StructType, UniformShapeOperations +from ndonnx._experimental import ( + CastMixin, + OperationsBlock, + Schema, + StructType, + UniformShapeOperations, +) from .utils import assert_array_equal @@ -87,6 +95,11 @@ def add(self, x, y) -> Array: return x + y.astype(Unsigned96()) return NotImplemented + def where(self, condition, x, y): + x = x.astype(Unsigned96()) + y = y.astype(Unsigned96()) + return super().where(condition, x, y) + class Unsigned96(StructType, CastMixin): def _fields(self) -> dict[str, StructType | CoreType]: @@ -132,6 +145,94 @@ def _cast_from(self, array: Array) -> Array: _ops = Unsigned96Impl() +class ListImpl(OperationsBlock): + def make_array( + self, + shape: tuple[int | str | None, ...], + dtype: CoreType | StructType, + eager_value: np.ndarray | None = None, + ) -> Array: + if eager_value is None: + return Array._from_fields( + dtype, + endpoints=ndx.array(shape=shape + (2,), dtype=ndx.int64), + items=ndx.array(shape=(None,), dtype=ndx.utf8), + ) + else: + fields = dtype._parse_input(eager_value) + return Array._from_fields( + dtype, **{name: ndx.asarray(field) for name, field in fields.items()} + ) + + def getitem( + self, + x: Array, + index, + ) -> Array: + if isinstance(index, int): + index = slice(index, index + 1), ... + + return Array._from_fields( + dtype=x.dtype, + endpoints=x.endpoints[index], + items=x.items.copy(), + ) + + def shape(self, x) -> Array: + return nda.shape(x.endpoints)[:-1] + + def static_shape(self, x) -> tuple[int | None, ...]: + return x.endpoints.shape[:-1] + + +class List(StructType): + # The fields here have different shapes + def _fields(self) -> dict[str, StructType | CoreType]: + return { + "endpoints": ndx.int64, + "items": ndx.utf8, + } + + def _parse_input(self, x: np.ndarray) -> dict: + assert x.dtype == object + assert all(isinstance(x, list) for x in x.flat) + + endpoints = np.empty(x.shape + (2,), dtype=np.int64) + items = np.empty( + functools.reduce(lambda acc, elem: acc + len(elem), x.flat, 0), dtype=object + ) + + cur_items_idx = 0 + for idx in np.ndindex(x.shape): + endpoints[idx, :] = [cur_items_idx, cur_items_idx + len(x[idx])] + for elem in x[idx]: + items[cur_items_idx] = elem + cur_items_idx += 1 + + return { + "endpoints": endpoints, + "items": items.astype(np.str_), + } + + def _assemble_output(self, fields: dict[str, np.ndarray]) -> np.ndarray: + endpoints = fields["endpoints"] + items = fields["items"] + + out = np.empty(endpoints.shape[:-1], dtype=object) + for idx in np.ndindex(endpoints.shape[:-1]): + start, end = endpoints[idx] + out[idx] = items[start:end].tolist() + return out + + def copy(self) -> Self: + return self + + def _schema(self) -> Schema: + return Schema(type_name="List", author="value from data!") + + _ops = ListImpl() + + def custom_equal(x: Array, y: Array) -> Array: if x.dtype != Unsigned96() or y.dtype != Unsigned96(): raise ValueError("Can only compare Unsigned96 arrays") @@ -275,3 +376,61 @@ def test_custom_dtype_capable_creation_functions(): assert_array_equal( ndx.ones_like(x, dtype=ndx.int32).to_numpy(), np.ones_like(arr, dtype=np.int32) ) + + +def test_custom_where(u96): + x = ndx.asarray([1, 2, 3], u96) + y = ndx.asarray([4, 5, 6], ndx.uint32) + cond = ndx.asarray([True, False, True]) + + result1 = ndx.where(cond, x, y) + assert_array_equal(result1, ndx.asarray([1, 5, 3], u96)) + + result2 = ndx.where(cond, y, x) + assert_array_equal(result2, ndx.asarray([4, 2, 6], u96)) + + result3 = ndx.where(cond, x, ndx.asarray(0, ndx.uint32)) + assert_array_equal(result3, ndx.asarray([1, 0, 3], u96)) + + +def test_create_dtype_mismatched_shape_fields_eager(): + array = np.empty(shape=(2,), dtype=object) + array[0] = ["a", "bcd", "e"] + array[1] = ["f", "gh"] + x = ndx.asarray(array, dtype=List()) + assert_array_equal(x.to_numpy(), array) + assert x[0].to_numpy().item() == ["a", "bcd", "e"] + assert_array_equal(nda.shape(x).to_numpy(), np.array([2], dtype=np.int64)) + assert x.shape == (2,) + + +def test_create_dtype_mismatched_shape_fields_lazy(): + x = ndx.array(shape=("N", "M", 2), dtype=List()) + assert x.shape == (None, None, 2) + out = x[1:2, 0, ...] + + ndx.build({"x": x}, {"out": out}) + + +def test_recursive_construction(): + class MyNInt64(StructType): + def _fields(self) -> dict[str, StructType | CoreType]: + return {"x": ndx.nint64} + + def _parse_input(self, x: np.ndarray) -> dict: + return {"x": ndx.nint64._parse_input(x)} + + def _assemble_output(self, fields: dict[str, np.ndarray]) -> np.ndarray: + return fields["x"] + + def copy(self) -> Self: + return self + + def _schema(self) -> Schema: + return Schema(type_name="my_nint64", author="me") + + _ops = UniformShapeOperations() + + my_nint64 = MyNInt64() + a = ndx.asarray(np.ma.masked_array([1, 2, 3], [1, 0, 1], np.int64), my_nint64) + assert_array_equal(a.to_numpy(), np.ma.masked_array([1, 2, 3], [1, 0, 1], np.int64))