diff --git a/.github/workflows/array-api.yml b/.github/workflows/array-api.yml index 7d0a0da..1cd6665 100644 --- a/.github/workflows/array-api.yml +++ b/.github/workflows/array-api.yml @@ -13,6 +13,10 @@ concurrency: jobs: array-api-tests: + # Run if the commit message contains 'run array-api tests' or if the job is triggered on schedule + if: >- + contains(github.event.head_commit.message, 'run array-api tests') || + github.event_name == 'schedule' name: Array API test timeout-minutes: 90 runs-on: ubuntu-latest-8core diff --git a/ndonnx/_funcs.py b/ndonnx/_funcs.py index f69c9a5..1378cbd 100644 --- a/ndonnx/_funcs.py +++ b/ndonnx/_funcs.py @@ -694,7 +694,12 @@ def roll(x, shift, axis=None): shift_single = opx.add(opx.const(-sh), len_single) # Find the needed element index and then gather from it range = opx.cast( - opx.range(opx.const(0), len_single, opx.const(1)), to=dtypes.int64 + opx.range( + opx.const(0, dtype=len_single.dtype), + len_single, + opx.const(1, dtype=len_single.dtype), + ), + to=dtypes.int64, ) new_indices = opx.mod(opx.add(range, shift_single), len_single) x = take(x, _from_corearray(new_indices), axis=ax) diff --git a/ndonnx/_opset_extensions.py b/ndonnx/_opset_extensions.py index 815f28d..a2b072b 100644 --- a/ndonnx/_opset_extensions.py +++ b/ndonnx/_opset_extensions.py @@ -364,6 +364,8 @@ def concat(inputs: list[_CoreArray], axis: int) -> _CoreArray: @eager_propagate def unsqueeze(data: _CoreArray, axes: _CoreArray) -> _CoreArray: + if axes.dtype != dtypes.int64: + raise TypeError(f"axes must be int64, got {axes.dtype}") return _CoreArray(op.unsqueeze(data.var, axes.var)) @@ -583,7 +585,7 @@ def getitem( index_filtered = [x for x in index if isinstance(x, (type(None), slice))] axis_new_axes = [ind for ind, x in enumerate(index_filtered) if x is None] if len(axis_new_axes) != 0: - var = op.unsqueeze(var, axes=op.const(axis_new_axes)) + var = op.unsqueeze(var, axes=op.const(axis_new_axes, dtype=np.int64)) return _CoreArray(var) @@ -640,14 +642,19 @@ def ndindex(shape: _CoreArray, to_reverse=None, axes_permutation=None) -> _CoreA axes_indices = [axes_permutation.index(i) for i in builtins.range(rank)] shape_var = shape.var + dtype = shape_var.unwrap_tensor().dtype ranges = [ ( - op.range(op.const(0), op.gather(shape_var, op.const(i)), op.const(1)) + op.range( + op.const(0, dtype=dtype), + op.gather(shape_var, op.const(i)), + op.const(1, dtype=dtype), + ) if i not in to_reverse else op.range( - op.sub(op.gather(shape_var, op.const(i)), op.const(1)), - op.const(-1), - op.const(-1), + op.sub(op.gather(shape_var, op.const(i)), op.const(1, dtype=dtype)), + op.const(-1, dtype=dtype), + op.const(-1, dtype=dtype), ) ) for i in builtins.range(rank) @@ -657,7 +664,8 @@ def ndindex(shape: _CoreArray, to_reverse=None, axes_permutation=None) -> _CoreA op.unsqueeze( r, op.const( - [j for j in builtins.range(rank) if axes_indices[i] != j], dtype=np.int_ + [j for j in builtins.range(rank) if axes_indices[i] != j], + dtype=np.int64, ), ) for i, r in enumerate(ranges) @@ -669,7 +677,7 @@ def ndindex(shape: _CoreArray, to_reverse=None, axes_permutation=None) -> _CoreA expanded_ranges = [op.expand(r, shape_var) for r in fit_ranges] ret = op.concat( - [op.unsqueeze(r, op.const([-1])) for r in expanded_ranges], + [op.unsqueeze(r, op.const([-1], dtype=np.int64)) for r in expanded_ranges], axis=-1, ) @@ -692,6 +700,8 @@ def static_map( input: _CoreArray, mapping: Mapping[KeyType, ValueType], default: ValueType | None ) -> _CoreArray: keys = np.array(tuple(mapping.keys())) + if keys.dtype == np.int32: + keys = keys.astype(np.int64) values = np.array(tuple(mapping.values())) value_dtype = values.dtype if default is None: diff --git a/pyproject.toml b/pyproject.toml index e9c9680..0962b75 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,12 +82,3 @@ exclude = ["docs/"] [tool.typos.default] extend-ignore-identifiers-re = ["scatter_nd", "arange"] - -[tool.pixi.project] -channels = ["conda-forge"] -platforms = ["osx-arm64"] - -[tool.pixi.pypi-dependencies] -ndonnx = { path = ".", editable = true } - -[tool.pixi.tasks] diff --git a/tests/ndonnx/test_constant_propagation.py b/tests/ndonnx/test_constant_propagation.py index 5be3ed7..f98dd31 100644 --- a/tests/ndonnx/test_constant_propagation.py +++ b/tests/ndonnx/test_constant_propagation.py @@ -67,13 +67,13 @@ def dynamic_masking_model(mode: Literal["lazy", "constant"]): def constant_indexing_model(mode: Literal["lazy", "constant"]): if mode == "constant": - a = ndx.asarray([0, 1, 2, 3]) + a = ndx.asarray([0, 1, 2, 3], dtype=ndx.int64) else: a = ndx.array( shape=("N",), dtype=ndx.int64, ) - b = ndx.asarray([5, 7, 8, 8, 9, 9, 234]) + b = ndx.asarray([5, 7, 8, 8, 9, 9, 234], dtype=ndx.int64) idx = ndx.asarray([1, 3, 5, 0]) result = a * b[idx] return ndx.build({"a": a} if mode == "lazy" else {}, {"y": result}) diff --git a/tests/ndonnx/test_core.py b/tests/ndonnx/test_core.py index 6845ec7..55edb03 100644 --- a/tests/ndonnx/test_core.py +++ b/tests/ndonnx/test_core.py @@ -81,9 +81,11 @@ def test_null_promotion(): def test_asarray(): - a = ndx.asarray([1, 2, 3]) + a = ndx.asarray([1, 2, 3], dtype=ndx.int64) assert a.dtype == ndx.int64 - np.testing.assert_array_equal(np.array([1, 2, 3]), a.to_numpy(), strict=True) + np.testing.assert_array_equal( + np.array([1, 2, 3], np.int64), a.to_numpy(), strict=True + ) def test_asarray_masked():