Skip to content

Commit

Permalink
Implement where for nullable data types
Browse files Browse the repository at this point in the history
  • Loading branch information
cbourjau committed Sep 2, 2024
1 parent e4232c0 commit bd23b01
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 32 deletions.
50 changes: 41 additions & 9 deletions ndonnx/_logic_in_data/_typed_array/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

from __future__ import annotations

from collections.abc import Sequence
from collections.abc import Callable, Sequence
from types import NotImplementedType
from typing import TYPE_CHECKING, Any, TypeGuard, TypeVar

import numpy as np
Expand Down Expand Up @@ -95,7 +96,9 @@ def as_core_dtype(self, dtype: CoreDTypes) -> _ArrayCoreType:
def _astype(self, dtype: DType) -> _TypedArray:
return NotImplemented

def where(self, cond: BoolData, y: _TypedArray) -> _TypedArray:
def _where(
self, cond: BoolData, y: _TypedArray
) -> _TypedArray | NotImplementedType:
if isinstance(y, _ArrayCoreType):
x, y = promote(self, y)
var = op.where(cond.var, x.var, y.var)
Expand Down Expand Up @@ -136,18 +139,29 @@ class BoolData(_ArrayCoreType[dtypes.Bool]):
dtype = dtypes.bool_

def __or__(self, rhs: _TypedArray) -> _TypedArray:
from .utils import promote
if self.dtype != rhs.dtype:
a, b = promote(self, rhs)
return a | b

if isinstance(rhs, _ArrayCoreType):
if self.dtype != rhs.dtype:
a, b = promote(self, rhs)
return a | b

# Data is core & bool
if isinstance(rhs, BoolData):
var = op.or_(self.var, rhs.var)
return ascoredata(var)
return NotImplemented

def __and__(self, rhs: _TypedArray) -> _TypedArray:
if self.dtype != rhs.dtype:
a, b = promote(self, rhs)
return a & b

if isinstance(rhs, BoolData):
var = op.and_(self.var, rhs.var)
return ascoredata(var)
return NotImplemented

def __invert__(self) -> BoolData:
var = op.not_(self.var)
return type(self)(var)


class Int8Data(_ArrayCoreInteger[dtypes.Int8]):
dtype = dtypes.int8
Expand Down Expand Up @@ -203,3 +217,21 @@ def is_sequence_of_core_data(
seq: Sequence[_TypedArray],
) -> TypeGuard[Sequence[_ArrayCoreType]]:
return all(isinstance(d, _ArrayCoreType) for d in seq)


def _promote_and_apply_op(
lhs: _ArrayCoreType,
rhs: _TypedArray,
arr_op: Callable[[_ArrayCoreType, _ArrayCoreType], _ArrayCoreType],
spox_op: Callable[[Var, Var], Var],
) -> _ArrayCoreType:
"""Promote and apply an operation by passing it through to the data member."""
if isinstance(rhs, _ArrayCoreType):
if lhs.dtype != rhs.dtype:
a, b = promote(lhs, rhs)
return arr_op(a, b)

# Data is core & integer
var = spox_op(lhs.var, rhs.var)
return ascoredata(var)
return NotImplemented
17 changes: 17 additions & 0 deletions ndonnx/_logic_in_data/_typed_array/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,23 @@
from .typed_array import _TypedArray


def typed_where(cond: _TypedArray, x: _TypedArray, y: _TypedArray) -> _TypedArray:
from .core import BoolData

# TODO: Masked condition
if not isinstance(cond, BoolData):
raise TypeError("'cond' must be a boolean data type.")

ret = x._where(cond, y)
if ret == NotImplemented:
ret = y._rwhere(cond, x)
if ret == NotImplemented:
raise TypeError(
f"Unsuppoerted operand data types for 'where': `{x.dtype}` and `{y.dtype}`"
)
return ret


def astypedarray(
val: int | float | np.ndarray | _TypedArray | Var,
dtype: None | DType = None,
Expand Down
42 changes: 37 additions & 5 deletions ndonnx/_logic_in_data/_typed_array/masked.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import operator
from collections.abc import Callable
from dataclasses import dataclass
from types import NotImplementedType
from typing import TYPE_CHECKING, Any, TypeVar

import spox.opset.ai.onnx.v21 as op
Expand Down Expand Up @@ -111,9 +112,40 @@ def _astype(self, dtype: DType) -> _TypedArray:
dtype._tyarr_class(data=new_data, mask=self.mask)
return NotImplemented

def where(self, cond: BoolData, y: _TypedArray) -> _TypedArray:
# TODO
raise NotImplementedError
def _where(
self, cond: BoolData, y: _TypedArray
) -> _TypedArray | NotImplementedType:
if isinstance(y, _ArrayCoreType):
return self._where(cond, asncoredata(y, None))
if isinstance(y, _ArrayMaCoreType):
x_ = unmask_core(self)
y_ = unmask_core(y)
new_data = x_._where(cond, y_)
if self.mask is not None and y.mask is not None:
new_mask = cond & self.mask | ~cond & y.mask
elif self.mask is not None:
new_mask = cond & self.mask
elif y.mask is not None:
new_mask = ~cond & y.mask
else:
new_mask = None

if new_mask is not None and not isinstance(new_mask, BoolData):
# Should never happen. Might be worth while adding
# overloads to the BoolData dunder methods to
# propagate the types more precisely.
raise TypeError(f"expected boolean mask, found `{new_mask.dtype}`")

return asncoredata(new_data, new_mask)

return NotImplemented

def _rwhere(
self, cond: BoolData, x: _TypedArray
) -> _TypedArray | NotImplementedType:
if isinstance(x, _ArrayCoreType):
return asncoredata(x, None)._where(cond, self)
return NotImplemented


class NBoolData(_ArrayMaCoreType[dtypes.NBool]):
Expand Down Expand Up @@ -229,10 +261,10 @@ def _apply_op(
) -> _ArrayMaCoreType:
"""Apply an operation by passing it through to the data member."""
if isinstance(rhs, _ArrayCoreType):
data = lhs.data + rhs
data = op(lhs.data, rhs)
mask = lhs.mask
elif isinstance(rhs, _ArrayMaCoreType):
data = lhs.data + rhs.data
data = op(lhs.data, rhs.data)
mask = _merge_masks(lhs.mask, rhs.mask)
else:
return NotImplemented
Expand Down
17 changes: 15 additions & 2 deletions ndonnx/_logic_in_data/_typed_array/typed_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,24 @@ def _astype(self, dtype: DType) -> _TypedArray | NotImplementedType:
"""
return NotImplemented

@abstractmethod
def where(self, cond: BoolData, y: _TypedArray) -> _TypedArray: ...
def _where(
self, cond: BoolData, y: _TypedArray
) -> _TypedArray | NotImplementedType:
return NotImplemented

def _rwhere(
self, cond: BoolData, y: _TypedArray
) -> _TypedArray | NotImplementedType:
return NotImplemented

def __add__(self, other: _TypedArray) -> _TypedArray:
return NotImplemented

def __and__(self, rhs: _TypedArray) -> _TypedArray:
return NotImplemented

def __invert__(self) -> _TypedArray:
return NotImplemented

def __or__(self, rhs: _TypedArray) -> _TypedArray:
return NotImplemented
16 changes: 4 additions & 12 deletions ndonnx/_logic_in_data/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def __init__(self, shape=None, dtype=None, value=None, var=None):

@classmethod
def _from_data(cls, data: _TypedArray) -> Array:
if not isinstance(data, _TypedArray):
raise TypeError(f"expected '_TypedArray', found `{type(data)}`")
inst = cls.__new__(cls)
inst._data = data
return inst
Expand Down Expand Up @@ -117,19 +119,9 @@ def asarray(obj: int | float | bool | str | Array) -> Array:


def where(cond: Array, a: Array, b: Array) -> Array:
from ._typed_array import BoolData
from .dtypes import bool_, nbool
from ._typed_array.funcs import typed_where

if cond.dtype not in [bool_, nbool]:
raise ValueError

# TODO: NBoolData
if not isinstance(cond._data, BoolData):
raise ValueError(
f"condition must be of a boolean data type; found `{cond.dtype}`"
)

data = a._data.where(cond._data, b._data)
data = typed_where(cond._data, a._data, b._data)
return Array._from_data(data)


Expand Down
16 changes: 12 additions & 4 deletions tests/test_logic_in_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,22 @@ def test__getitem__():
assert arr[0].shape == (None,)


def test_where():
@pytest.mark.parametrize(
"x_ty, y_ty, res_ty",
[
(dtypes.int16, dtypes.int32, dtypes.int32),
(dtypes.nint16, dtypes.int32, dtypes.nint32),
(dtypes.int32, dtypes.nint16, dtypes.nint32),
],
)
def test_where(x_ty, y_ty, res_ty):
shape = ("N", "M")
cond = Array(shape, dtypes.bool_)
x = Array(shape, dtypes.int16)
y = Array(shape, dtypes.int32)
x = Array(shape, x_ty)
y = Array(shape, y_ty)

res = where(cond, x, y)

assert res.dtype == dtypes.int32
assert res.dtype == res_ty
assert res._data.shape == shape
assert res.shape == (None, None)

0 comments on commit bd23b01

Please sign in to comment.