Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

User driven getitem and construction #60

Merged
merged 8 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ Changelog
0.9.0 (unreleased)
------------------

**New feature**
**New features**

- User defined data types can now define how arrays with that dtype are constructed by implementing the ``make_array`` function.
- User defined data types can now define how they are indexed (via ``__getitem__``) by implementing the ``getitem`` function.
- :class:`ndonnx.NullableCore` is now public, encapsulating nullable variants of `CoreType`s exported by ndonnx.

**Bug fixes**
Expand Down
19 changes: 7 additions & 12 deletions ndonnx/_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import ndonnx as ndx
import ndonnx._data_types as dtypes
from ndonnx.additional import shape
from ndonnx.additional._additional import _getitem as getitem
from ndonnx.additional._additional import _static_shape as static_shape

from ._corearray import _CoreArray
Expand Down Expand Up @@ -47,7 +48,11 @@ def array(
out : Array
The new array. This represents an ONNX model input.
"""
return Array._construct(shape=shape, dtype=dtype)
if (out := dtype._ops.make_array(shape, dtype)) is not NotImplemented:
return out
raise ndx.UnsupportedOperationError(
f"No implementation of `make_array` for {dtype}"
)


def from_spox_var(
Expand Down Expand Up @@ -154,17 +159,7 @@ def astype(self, to: CoreType | StructType) -> Array:
return ndx.astype(self, to)

def __getitem__(self, index: IndexType) -> Array:
if isinstance(index, Array) and not (
isinstance(index.dtype, dtypes.Integral) or index.dtype == dtypes.bool
):
raise TypeError(
f"Index must be an integral or boolean 'Array', not `{index.dtype}`"
)

if isinstance(index, Array):
index = index._core()

return self._transmute(lambda corearray: corearray[index])
return getitem(self, index)

def __setitem__(
self, index: IndexType | Self, updates: int | bool | float | Array
Expand Down
24 changes: 10 additions & 14 deletions ndonnx/_core/_boolimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
import ndonnx as ndx
import ndonnx._data_types as dtypes
import ndonnx._opset_extensions as opx
import ndonnx.additional as nda

from ._coreimpl import CoreOperationsImpl
from ._interface import OperationsBlock
from ._nullableimpl import NullableOperationsImpl
from ._shapeimpl import UniformShapeOperations
from ._utils import binary_op, unary_op, validate_core
Expand All @@ -22,7 +23,7 @@
from ndonnx import Array


class BooleanOperationsImpl(UniformShapeOperations):
class _BooleanOperationsImpl(OperationsBlock):
@validate_core
def equal(self, x, y) -> Array:
return binary_op(x, y, opx.equal)
Expand Down Expand Up @@ -163,17 +164,12 @@ def empty_like(self, x, dtype=None, device=None) -> ndx.Array:
def nonzero(self, x) -> tuple[Array, ...]:
return ndx.nonzero(x.astype(ndx.int8))

@validate_core
def make_nullable(self, x, null):
if null.dtype != dtypes.bool:
raise TypeError("'null' must be a boolean array")
return ndx.Array._from_fields(
dtypes.into_nullable(x.dtype),
values=x.copy(),
null=ndx.broadcast_to(null, nda.shape(x)),
)

class BooleanOperationsImpl(
CoreOperationsImpl, _BooleanOperationsImpl, UniformShapeOperations
): ...

class NullableBooleanOperationsImpl(BooleanOperationsImpl, NullableOperationsImpl):
def make_nullable(self, x, null):
return NotImplemented

class NullableBooleanOperationsImpl(
NullableOperationsImpl, _BooleanOperationsImpl, UniformShapeOperations
): ...
50 changes: 50 additions & 0 deletions ndonnx/_core/_coreimpl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) QuantCo 2023-2024
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np
from spox import Tensor, argument

import ndonnx as ndx
import ndonnx._data_types as dtypes
import ndonnx.additional as nda
from ndonnx._corearray import _CoreArray

from ._interface import OperationsBlock
from ._utils import validate_core

if TYPE_CHECKING:
from ndonnx._array import Array
from ndonnx._data_types import Dtype


class CoreOperationsImpl(OperationsBlock):
def make_array(
self,
shape: tuple[int | None | str, ...],
dtype: Dtype,
eager_value: np.ndarray | None = None,
) -> Array:
if not isinstance(dtype, dtypes.CoreType):
return NotImplemented
return ndx.Array._from_fields(
dtype,
data=_CoreArray(
dtype._parse_input(eager_value)["data"]
if eager_value is not None
else argument(Tensor(dtype.to_numpy_dtype(), shape))
),
)

@validate_core
def make_nullable(self, x: Array, null: Array) -> Array:
if null.dtype != ndx.bool:
raise TypeError("'null' must be a boolean array")

return ndx.Array._from_fields(
dtypes.into_nullable(x.dtype),
values=x.copy(),
null=ndx.broadcast_to(null, nda.shape(x)),
)
39 changes: 29 additions & 10 deletions ndonnx/_core/_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,17 @@

from __future__ import annotations

from typing import Literal
from typing import TYPE_CHECKING, Literal

import numpy as np

import ndonnx as ndx
import ndonnx._data_types as dtypes

if TYPE_CHECKING:
from ndonnx._array import IndexType
from ndonnx._data_types import Dtype


class OperationsBlock:
"""Interface for data types to implement top-level functions exported by ndonnx."""
Expand Down Expand Up @@ -251,7 +257,7 @@ def cumulative_sum(
x,
*,
axis: int | None = None,
dtype: ndx.CoreType | dtypes.StructType | None = None,
dtype: Dtype | None = None,
include_initial: bool = False,
):
return NotImplemented
Expand All @@ -270,7 +276,7 @@ def prod(
x,
*,
axis=None,
dtype: dtypes.CoreType | dtypes.StructType | None = None,
dtype: Dtype | None = None,
keepdims: bool = False,
) -> ndx.Array:
return NotImplemented
Expand All @@ -293,7 +299,7 @@ def sum(
x,
*,
axis=None,
dtype: dtypes.CoreType | dtypes.StructType | None = None,
dtype: Dtype | None = None,
keepdims: bool = False,
) -> ndx.Array:
return NotImplemented
Expand All @@ -305,7 +311,7 @@ def var(
axis=None,
keepdims: bool = False,
correction=0.0,
dtype: dtypes.CoreType | dtypes.StructType | None = None,
dtype: Dtype | None = None,
) -> ndx.Array:
return NotImplemented

Expand Down Expand Up @@ -352,7 +358,7 @@ def full_like(self, x, fill_value, dtype=None, device=None) -> ndx.Array:
def ones(
self,
shape,
dtype: dtypes.CoreType | dtypes.StructType | None = None,
dtype: Dtype | None = None,
device=None,
):
return NotImplemented
Expand All @@ -365,14 +371,12 @@ def ones_like(
def zeros(
self,
shape,
dtype: dtypes.CoreType | dtypes.StructType | None = None,
dtype: Dtype | None = None,
device=None,
):
return NotImplemented

def zeros_like(
self, x, dtype: dtypes.CoreType | dtypes.StructType | None = None, device=None
):
def zeros_like(self, x, dtype: Dtype | None = None, device=None):
return NotImplemented

def empty(self, shape, dtype=None, device=None) -> ndx.Array:
Expand Down Expand Up @@ -413,3 +417,18 @@ def can_cast(self, from_, to) -> bool:

def static_shape(self, x) -> tuple[int | None, ...]:
return NotImplemented

def make_array(
self,
shape: tuple[int | None | str, ...],
dtype: Dtype,
eager_value: np.ndarray | None = None,
) -> ndx.Array:
return NotImplemented

def getitem(
self,
x: ndx.Array,
index: IndexType,
) -> ndx.Array:
return NotImplemented
15 changes: 14 additions & 1 deletion ndonnx/_core/_nullableimpl.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,29 @@
# Copyright (c) QuantCo 2023-2024
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations

from typing import TYPE_CHECKING, Union

import ndonnx as ndx

from ._interface import OperationsBlock
from ._utils import validate_core

if TYPE_CHECKING:
from ndonnx._array import Array
from ndonnx._data_types import CoreType, StructType

Dtype = Union[CoreType, StructType]


class NullableOperationsImpl(OperationsBlock):
@validate_core
def fill_null(self, x, value):
def fill_null(self, x: Array, value) -> Array:
value = ndx.asarray(value)
if value.dtype != x.values.dtype:
value = value.astype(x.values.dtype)
return ndx.where(x.null, value, x.values)

@validate_core
def make_nullable(self, x: Array, null: Array) -> Array:
return NotImplemented
26 changes: 11 additions & 15 deletions ndonnx/_core/_numericimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import ndonnx.additional as nda
from ndonnx._utility import promote

from ._coreimpl import CoreOperationsImpl
from ._interface import OperationsBlock
from ._nullableimpl import NullableOperationsImpl
from ._shapeimpl import UniformShapeOperations
from ._utils import (
Expand All @@ -36,7 +38,7 @@
from ndonnx._corearray import _CoreArray


class NumericOperationsImpl(UniformShapeOperations):
class _NumericOperationsImpl(OperationsBlock):
# elementwise.py

@validate_core
Expand Down Expand Up @@ -837,17 +839,6 @@ def var(
- correction
)

@validate_core
def make_nullable(self, x, null):
if null.dtype != dtypes.bool:
raise TypeError("'null' must be a boolean array")

return ndx.Array._from_fields(
dtypes.into_nullable(x.dtype),
values=x.copy(),
null=ndx.broadcast_to(null, nda.shape(x)),
)

@validate_core
def can_cast(self, from_, to) -> bool:
if isinstance(from_, dtypes.CoreType) and isinstance(to, ndx.CoreType):
Expand Down Expand Up @@ -980,9 +971,14 @@ def empty_like(self, x, dtype=None, device=None) -> ndx.Array:
return ndx.full_like(x, 0, dtype=dtype)


class NullableNumericOperationsImpl(NumericOperationsImpl, NullableOperationsImpl):
def make_nullable(self, x, null):
return NotImplemented
class NumericOperationsImpl(
CoreOperationsImpl, _NumericOperationsImpl, UniformShapeOperations
): ...


class NullableNumericOperationsImpl(
NullableOperationsImpl, _NumericOperationsImpl, UniformShapeOperations
): ...


def _via_i64_f64(
Expand Down
Loading