Skip to content

Commit

Permalink
Implement unstack and fix repeat test
Browse files Browse the repository at this point in the history
  • Loading branch information
cbourjau committed Nov 19, 2024
1 parent 0b16051 commit 38257ac
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 1 deletion.
11 changes: 11 additions & 0 deletions ndonnx/_logic_in_data/_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,17 @@ def __dlpack__(
def __dlpack_device__(self) -> tuple[Enum, int]:
raise ValueError("ONNX provides no control over the used device")

def __iter__(self):
try:
n, *_ = self.shape
except IndexError:
raise ValueError("iteration over 0-d array")
if isinstance(n, int):
return (self[i, ...] for i in range(n))
raise ValueError(
"iteration requires dimension of static length, but dimension 0 is dynamic."
)

def __getitem__(self, key: GetitemIndex, /) -> Array:
idx = normalize_getitem_key(key)
data = self._data[idx]
Expand Down
7 changes: 6 additions & 1 deletion ndonnx/_logic_in_data/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,12 @@ def unique_values(x: Array, /) -> Array:


def unstack(x: Array, /, *, axis: int = 0) -> tuple[Array, ...]:
raise NotImplementedError
# Only possible for statically known dimensions
if not isinstance(x.shape[axis], int):
raise ValueError(
f"'unstack' can only be applied to statically known dimensions, but axis `{axis}` has dynamic length."
)
return tuple(el for el in moveaxis(x, axis, 0))


def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
Expand Down
7 changes: 7 additions & 0 deletions tests/test_logic_in_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,13 @@ def test_indexing_value_prop_tuple_index():
np.testing.assert_equal(el.unwrap_numpy(), np_arr[idx])


def test_iteration():
np_arr = np.asarray([1, 2])
arr = ndx.asarray(np_arr)
for npa, nda in zip(np_arr, arr):
np.testing.assert_array_equal(npa, nda.unwrap_numpy())


@pytest.mark.parametrize("idx", [(0, 1), (-1, ...), (..., 1), (-1, ..., 1)])
@pytest.mark.parametrize(
"np_array",
Expand Down

0 comments on commit 38257ac

Please sign in to comment.