Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

matmul raises IndexError exception with input shapes (2, 1, 2) and (2,) #2264

Closed
antonwolfy opened this issue Jan 15, 2025 · 0 comments · Fixed by #2278
Closed

matmul raises IndexError exception with input shapes (2, 1, 2) and (2,) #2264

antonwolfy opened this issue Jan 15, 2025 · 0 comments · Fixed by #2278
Assignees

Comments

@antonwolfy
Copy link
Contributor

The below example causes an issue:

import dpnp, numpy

dpnp.__version__
# Out: '0.17.0dev4+3.g498e705d848.dirty'

a = dpnp.ones((2, 1, 2))
b = dpnp.ones((2,))

a @ b
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[10], line 1
----> 1 a @ b

File ~/code/dpnp/dpnp/dpnp_array.py:489, in dpnp_array.__matmul__(self, other)
    487 def __matmul__(self, other):
    488     """Return ``self@value``."""
--> 489     return dpnp.matmul(self, other)

File ~/code/dpnp/dpnp/dpnp_iface_linearalgebra.py:851, in matmul(x1, x2, out, casting, order, dtype, subok, signature, axes, axis)
    846 if axis is not None:
    847     raise NotImplementedError(
    848         "axis keyword argument is only supported by its default value."
    849     )
--> 851 return dpnp_matmul(
    852     x1,
    853     x2,
    854     out=out,
    855     casting=casting,
    856     order=order,
    857     dtype=dtype,
    858     axes=axes,
    859 )

File ~/code/dpnp/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py:951, in dpnp_matmul(x1, x2, out, casting, order, dtype, axes)
    949         else:  # call_flag == "gemm_batch"
    950             assert call_flag == "gemm_batch"
--> 951             result = _gemm_batch_matmul(
    952                 exec_q,
    953                 x1,
    954                 x2,
    955                 result,
    956             )
    958 if NumPy_special_behavior:
    959     result = dpnp.tile(result, out.shape)

File ~/code/dpnp/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py:373, in _gemm_batch_matmul(exec_q, x1, x2, res)
    371 x2_shape = x2.shape
    372 x1 = dpnp.reshape(x1, (-1, x1_shape[-2], x1_shape[-1]))
--> 373 x2 = dpnp.reshape(x2, (-1, x2_shape[-2], x2_shape[-1]))
    374 orig_shape = res.shape
    375 res = dpnp.reshape(res, (-1, orig_shape[-2], orig_shape[-1]))

IndexError: tuple index out of range

# works with numpy:
na, nb = a.asnumpy(), b.asnumpy()
na @ nb
# Out:
# array([[2.],
#        [2.]])
@antonwolfy antonwolfy changed the title matmul raises unexpected IndexError exception with input shapes (2, 1, 2) and (2,) matmul raises IndexError exception with input shapes (2, 1, 2) and (2,) Jan 15, 2025
vtavana added a commit that referenced this issue Jan 24, 2025
@vtavana vtavana mentioned this issue Jan 24, 2025
6 tasks
github-actions bot added a commit that referenced this issue Jan 24, 2025
github-actions bot added a commit that referenced this issue Jan 26, 2025
github-actions bot added a commit that referenced this issue Jan 28, 2025
github-actions bot added a commit that referenced this issue Jan 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants