Skip to content

Commit

Permalink
Add dpnp.broadcast_shapes implementation (#2153)
Browse files Browse the repository at this point in the history
  • Loading branch information
ekomarova authored Nov 7, 2024
1 parent eb9a6c0 commit 9086f45
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 2 deletions.
36 changes: 36 additions & 0 deletions dpnp/dpnp_iface_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
"atleast_2d",
"atleast_3d",
"broadcast_arrays",
"broadcast_shapes",
"broadcast_to",
"can_cast",
"column_stack",
Expand Down Expand Up @@ -967,6 +968,41 @@ def broadcast_arrays(*args, subok=False):
return [dpnp_array._create_from_usm_ndarray(a) for a in usm_arrays]


def broadcast_shapes(*args):
"""
Broadcast the input shapes into a single shape.
For full documentation refer to :obj:`numpy.broadcast_shapes`.
Parameters
----------
*args : tuples of ints, or ints
The shapes to be broadcast against each other.
Returns
-------
tuple
Broadcasted shape.
See Also
--------
:obj:`dpnp.broadcast_arrays` : Broadcast any number of arrays against
each other.
:obj:`dpnp.broadcast_to` : Broadcast an array to a new shape.
Examples
--------
>>> import dpnp as np
>>> np.broadcast_shapes((1, 2), (3, 1), (3, 2))
(3, 2)
>>> np.broadcast_shapes((6, 7), (5, 6, 1), (7,), (5, 1, 7))
(5, 6, 7)
"""

return numpy.broadcast_shapes(*args)


# pylint: disable=redefined-outer-name
def broadcast_to(array, /, shape, subok=False):
"""
Expand Down
3 changes: 1 addition & 2 deletions dpnp/dpnp_iface_mathematical.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,8 +994,7 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
a_shape = a.shape
b_shape = b.shape

# TODO: replace with dpnp.broadcast_shapes once implemented
res_shape = numpy.broadcast_shapes(a_shape[:-1], b_shape[:-1])
res_shape = dpnp.broadcast_shapes(a_shape[:-1], b_shape[:-1])
if a_shape[:-1] != res_shape:
a = dpnp.broadcast_to(a, res_shape + (a_shape[-1],))
a_shape = a.shape
Expand Down
25 changes: 25 additions & 0 deletions tests/test_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,31 @@ def test_no_copy(self):
assert_array_equal(b, a)


class TestBroadcast:
@pytest.mark.parametrize(
"shape",
[
[(1,), (3,)],
[(1, 3), (3, 3)],
[(3, 1), (3, 3)],
[(1, 3), (3, 1)],
[(1, 1), (3, 3)],
[(1, 1), (1, 3)],
[(1, 1), (3, 1)],
[(1, 0), (0, 0)],
[(0, 1), (0, 0)],
[(1, 0), (0, 1)],
[(1, 1), (0, 0)],
[(1, 1), (1, 0)],
[(1, 1), (0, 1)],
],
)
def test_broadcast_shapes(self, shape):
expected = numpy.broadcast_shapes(*shape)
result = dpnp.broadcast_shapes(*shape)
assert_equal(result, expected)


class TestDelete:
@pytest.mark.parametrize(
"obj", [slice(0, 4, 2), 3, [2, 3]], ids=["slice", "int", "list"]
Expand Down

0 comments on commit 9086f45

Please sign in to comment.