Skip to content

Commit

Permalink
Drop pyproject.toml entries, start windows fix
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 committed Jun 10, 2024
1 parent c4fa1fd commit 311649c
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 21 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/array-api.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ concurrency:

jobs:
array-api-tests:
# Run if the commit message contains 'run array-api tests' or if the job is triggered on schedule
if: >-
contains(github.event.head_commit.message, 'run array-api tests') ||
github.event_name == 'schedule'
name: Array API test
timeout-minutes: 90
runs-on: ubuntu-latest-8core
Expand Down
7 changes: 6 additions & 1 deletion ndonnx/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,12 @@ def roll(x, shift, axis=None):
shift_single = opx.add(opx.const(-sh), len_single)
# Find the needed element index and then gather from it
range = opx.cast(
opx.range(opx.const(0), len_single, opx.const(1)), to=dtypes.int64
opx.range(
opx.const(0, dtype=len_single.dtype),
len_single,
opx.const(1, dtype=len_single.dtype),
),
to=dtypes.int64,
)
new_indices = opx.mod(opx.add(range, shift_single), len_single)
x = take(x, _from_corearray(new_indices), axis=ax)
Expand Down
24 changes: 17 additions & 7 deletions ndonnx/_opset_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,8 @@ def concat(inputs: list[_CoreArray], axis: int) -> _CoreArray:

@eager_propagate
def unsqueeze(data: _CoreArray, axes: _CoreArray) -> _CoreArray:
if axes.dtype != dtypes.int64:
raise TypeError(f"axes must be int64, got {axes.dtype}")
return _CoreArray(op.unsqueeze(data.var, axes.var))


Expand Down Expand Up @@ -583,7 +585,7 @@ def getitem(
index_filtered = [x for x in index if isinstance(x, (type(None), slice))]
axis_new_axes = [ind for ind, x in enumerate(index_filtered) if x is None]
if len(axis_new_axes) != 0:
var = op.unsqueeze(var, axes=op.const(axis_new_axes))
var = op.unsqueeze(var, axes=op.const(axis_new_axes, dtype=np.int64))

return _CoreArray(var)

Expand Down Expand Up @@ -640,14 +642,19 @@ def ndindex(shape: _CoreArray, to_reverse=None, axes_permutation=None) -> _CoreA
axes_indices = [axes_permutation.index(i) for i in builtins.range(rank)]

shape_var = shape.var
dtype = shape_var.unwrap_tensor().dtype
ranges = [
(
op.range(op.const(0), op.gather(shape_var, op.const(i)), op.const(1))
op.range(
op.const(0, dtype=dtype),
op.gather(shape_var, op.const(i)),
op.const(1, dtype=dtype),
)
if i not in to_reverse
else op.range(
op.sub(op.gather(shape_var, op.const(i)), op.const(1)),
op.const(-1),
op.const(-1),
op.sub(op.gather(shape_var, op.const(i)), op.const(1, dtype=dtype)),
op.const(-1, dtype=dtype),
op.const(-1, dtype=dtype),
)
)
for i in builtins.range(rank)
Expand All @@ -657,7 +664,8 @@ def ndindex(shape: _CoreArray, to_reverse=None, axes_permutation=None) -> _CoreA
op.unsqueeze(
r,
op.const(
[j for j in builtins.range(rank) if axes_indices[i] != j], dtype=np.int_
[j for j in builtins.range(rank) if axes_indices[i] != j],
dtype=np.int64,
),
)
for i, r in enumerate(ranges)
Expand All @@ -669,7 +677,7 @@ def ndindex(shape: _CoreArray, to_reverse=None, axes_permutation=None) -> _CoreA
expanded_ranges = [op.expand(r, shape_var) for r in fit_ranges]

ret = op.concat(
[op.unsqueeze(r, op.const([-1])) for r in expanded_ranges],
[op.unsqueeze(r, op.const([-1], dtype=np.int64)) for r in expanded_ranges],
axis=-1,
)

Expand All @@ -692,6 +700,8 @@ def static_map(
input: _CoreArray, mapping: Mapping[KeyType, ValueType], default: ValueType | None
) -> _CoreArray:
keys = np.array(tuple(mapping.keys()))
if keys.dtype == np.int32:
keys = keys.astype(np.int64)
values = np.array(tuple(mapping.values()))
value_dtype = values.dtype
if default is None:
Expand Down
9 changes: 0 additions & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,3 @@ exclude = ["docs/"]

[tool.typos.default]
extend-ignore-identifiers-re = ["scatter_nd", "arange"]

[tool.pixi.project]
channels = ["conda-forge"]
platforms = ["osx-arm64"]

[tool.pixi.pypi-dependencies]
ndonnx = { path = ".", editable = true }

[tool.pixi.tasks]
4 changes: 2 additions & 2 deletions tests/ndonnx/test_constant_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,13 @@ def dynamic_masking_model(mode: Literal["lazy", "constant"]):

def constant_indexing_model(mode: Literal["lazy", "constant"]):
if mode == "constant":
a = ndx.asarray([0, 1, 2, 3])
a = ndx.asarray([0, 1, 2, 3], dtype=ndx.int64)
else:
a = ndx.array(
shape=("N",),
dtype=ndx.int64,
)
b = ndx.asarray([5, 7, 8, 8, 9, 9, 234])
b = ndx.asarray([5, 7, 8, 8, 9, 9, 234], dtype=ndx.int64)
idx = ndx.asarray([1, 3, 5, 0])
result = a * b[idx]
return ndx.build({"a": a} if mode == "lazy" else {}, {"y": result})
Expand Down
6 changes: 4 additions & 2 deletions tests/ndonnx/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,11 @@ def test_null_promotion():


def test_asarray():
a = ndx.asarray([1, 2, 3])
a = ndx.asarray([1, 2, 3], dtype=ndx.int64)
assert a.dtype == ndx.int64
np.testing.assert_array_equal(np.array([1, 2, 3]), a.to_numpy(), strict=True)
np.testing.assert_array_equal(
np.array([1, 2, 3], np.int64), a.to_numpy(), strict=True
)


def test_asarray_masked():
Expand Down

0 comments on commit 311649c

Please sign in to comment.