Skip to content

Commit

Permalink
implement dpnp.apply_over_axes (#2174)
Browse files Browse the repository at this point in the history
* implement dpnp.apply_over_axes

* fix issue with a test
  • Loading branch information
vtavana authored Nov 16, 2024
1 parent 264c6d8 commit 2987585
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 3 deletions.
87 changes: 85 additions & 2 deletions dpnp/dpnp_iface_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,14 @@


import numpy
from dpctl.tensor._numpy_helper import normalize_axis_index
from dpctl.tensor._numpy_helper import (
normalize_axis_index,
normalize_axis_tuple,
)

import dpnp

__all__ = ["apply_along_axis"]
__all__ = ["apply_along_axis", "apply_over_axes"]


def apply_along_axis(func1d, axis, arr, *args, **kwargs):
Expand Down Expand Up @@ -185,3 +188,83 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs):
buff = dpnp.moveaxis(buff, -1, axis)

return buff


def apply_over_axes(func, a, axes):
"""
Apply a function repeatedly over multiple axes.
`func` is called as ``res = func(a, axis)``, where `axis` is the first
element of `axes`. The result `res` of the function call must have
either the same dimensions as `a` or one less dimension. If `res`
has one less dimension than `a`, a dimension is inserted before
`axis`. The call to `func` is then repeated for each axis in `axes`,
with `res` as the first argument.
For full documentation refer to :obj:`numpy.apply_over_axes`.
Parameters
----------
func : function
This function must take two arguments, ``func(a, axis)``.
a : {dpnp.ndarray, usm_ndarray}
Input array.
axes : {int, sequence of ints}
Axes over which `func` is applied.
Returns
-------
out : dpnp.ndarray
The output array. The number of dimensions is the same as `a`,
but the shape can be different. This depends on whether `func`
changes the shape of its output with respect to its input.
See Also
--------
:obj:`dpnp.apply_along_axis` : Apply a function to 1-D slices of an array
along the given axis.
Examples
--------
>>> import dpnp as np
>>> a = np.arange(24).reshape(2, 3, 4)
>>> a
array([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
Sum over axes 0 and 2. The result has same number of dimensions
as the original array:
>>> np.apply_over_axes(np.sum, a, [0, 2])
array([[[ 60],
[ 92],
[124]]])
Tuple axis arguments to ufuncs are equivalent:
>>> np.sum(a, axis=(0, 2), keepdims=True)
array([[[ 60],
[ 92],
[124]]])
"""

dpnp.check_supported_arrays_type(a)
if isinstance(axes, int):
axes = (axes,)
axes = normalize_axis_tuple(axes, a.ndim)

for axis in axes:
res = func(a, axis)
if res.ndim != a.ndim:
res = dpnp.expand_dims(res, axis)
if res.ndim != a.ndim:
raise ValueError(
"function is not returning an array of the correct shape"
)
a = res
return res
21 changes: 20 additions & 1 deletion dpnp/tests/test_functional.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy
import pytest
from numpy.testing import assert_array_equal
from numpy.testing import assert_array_equal, assert_raises

import dpnp

Expand Down Expand Up @@ -46,3 +46,22 @@ def test_args(self, dtype):
# positional args: axis, dtype, out, keepdims
result = dpnp.apply_along_axis(dpnp.mean, 0, ia, 0, dtype, None, True)
assert_array_equal(result, expected)


class TestApplyOverAxes:
@pytest.mark.parametrize("func", ["sum", "cumsum"])
@pytest.mark.parametrize("axes", [1, [0, 2], (-1, -2)])
def test_basic(self, func, axes):
a = numpy.arange(24).reshape(2, 3, 4)
ia = dpnp.array(a)

expected = numpy.apply_over_axes(getattr(numpy, func), a, axes)
result = dpnp.apply_over_axes(getattr(dpnp, func), ia, axes)
assert_array_equal(result, expected)

def test_custom_func(self):
def custom_func(x, axis):
return dpnp.expand_dims(dpnp.expand_dims(x, axis), axis)

ia = dpnp.arange(24).reshape(2, 3, 4)
assert_raises(ValueError, dpnp.apply_over_axes, custom_func, ia, 1)
12 changes: 12 additions & 0 deletions dpnp/tests/test_sycl_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -2215,6 +2215,18 @@ def test_apply_along_axis(device):
assert_sycl_queue_equal(result.sycl_queue, x.sycl_queue)


@pytest.mark.parametrize(
"device",
valid_devices,
ids=[device.filter_string for device in valid_devices],
)
def test_apply_over_axes(device):
x = dpnp.arange(18, device=device).reshape(2, 3, 3)
result = dpnp.apply_over_axes(dpnp.sum, x, [0, 1])

assert_sycl_queue_equal(result.sycl_queue, x.sycl_queue)


@pytest.mark.parametrize(
"device_x",
valid_devices,
Expand Down
8 changes: 8 additions & 0 deletions dpnp/tests/test_usm_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,6 +784,14 @@ def test_apply_along_axis(usm_type):
assert x.usm_type == y.usm_type


@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types)
def test_apply_over_axes(usm_type):
x = dp.arange(18, usm_type=usm_type).reshape(2, 3, 3)
y = dp.apply_over_axes(dp.sum, x, [0, 1])

assert x.usm_type == y.usm_type


@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types)
def test_broadcast_to(usm_type):
x = dp.ones(7, usm_type=usm_type)
Expand Down

0 comments on commit 2987585

Please sign in to comment.