Skip to content

Commit

Permalink
Add inplace matrix multiplication (#2147)
Browse files Browse the repository at this point in the history
* Add support for inplace matrix multiplication

* Raise ValueError exception per axes keyword

* Add rendering documentation for reflected and inplace operations

* Exclude __reduce__ method from rendering documentation

* Align expection text for inplace matrix multiplication

* Split too long line
  • Loading branch information
antonwolfy authored Nov 5, 2024
1 parent 0517413 commit 687b8ea
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 61 deletions.
30 changes: 26 additions & 4 deletions doc/reference/ndarray.rst
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ Comparison operators:
dpnp.ndarray.__eq__
dpnp.ndarray.__ne__

Truth value of an array (:func:`bool()`):
Truth value of an array (:class:`bool() <bool>`):

.. autosummary::
:toctree: generated/
Expand All @@ -260,11 +260,11 @@ Truth value of an array (:func:`bool()`):

Truth-value testing of an array invokes
:meth:`dpnp.ndarray.__bool__`, which raises an error if the number of
elements in the array is larger than 1, because the truth value
elements in the array is not 1, because the truth value
of such arrays is ambiguous. Use :meth:`.any() <dpnp.ndarray.any>` and
:meth:`.all() <dpnp.ndarray.all>` instead to be clear about what is meant
in such cases. (If the number of elements is 0, the array evaluates
to ``False``.)
in such cases. (If you wish to check for whether an array is empty,
use for example ``.size > 0``.)


Unary operations:
Expand Down Expand Up @@ -300,6 +300,26 @@ Arithmetic:
dpnp.ndarray.__xor__


Arithmetic, reflected:

.. autosummary::
:toctree: generated/
:nosignatures:

dpnp.ndarray.__radd__
dpnp.ndarray.__rsub__
dpnp.ndarray.__rmul__
dpnp.ndarray.__rtruediv__
dpnp.ndarray.__rfloordiv__
dpnp.ndarray.__rmod__
dpnp.ndarray.__rpow__
dpnp.ndarray.__rlshift__
dpnp.ndarray.__rrshift__
dpnp.ndarray.__rand__
dpnp.ndarray.__ror__
dpnp.ndarray.__rxor__


Arithmetic, in-place:

.. autosummary::
Expand All @@ -326,6 +346,8 @@ Matrix Multiplication:
:toctree: generated/

dpnp.ndarray.__matmul__
dpnp.ndarray.__rmatmul__
dpnp.ndarray.__imatmul__


Special methods
Expand Down
43 changes: 42 additions & 1 deletion dpnp/dpnp_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
# *****************************************************************************

import dpctl.tensor as dpt
from dpctl.tensor._numpy_helper import AxisError

import dpnp

Expand Down Expand Up @@ -205,6 +206,7 @@ def __bool__(self):
return self._array_obj.__bool__()

# '__class__',
# `__class_getitem__`,

def __complex__(self):
return self._array_obj.__complex__()
Expand Down Expand Up @@ -335,6 +337,8 @@ def __getitem__(self, key):
res._array_obj = item
return res

# '__getstate__',

def __gt__(self, other):
"""Return ``self>value``."""
return dpnp.greater(self, other)
Expand All @@ -361,7 +365,31 @@ def __ilshift__(self, other):
dpnp.left_shift(self, other, out=self)
return self

# '__imatmul__',
def __imatmul__(self, other):
"""Return ``self@=value``."""

"""
Unlike `matmul(a, b, out=a)` we ensure that the result is not broadcast
if the result without `out` would have less dimensions than `a`.
Since the signature of matmul is '(n?,k),(k,m?)->(n?,m?)' this is the
case exactly when the second operand has both core dimensions.
We have to enforce this check by passing the correct `axes=`.
"""
if self.ndim == 1:
axes = [(-1,), (-2, -1), (-1,)]
else:
axes = [(-2, -1), (-2, -1), (-2, -1)]

try:
dpnp.matmul(self, other, out=self, axes=axes)
except AxisError:
# AxisError should indicate that the axes argument didn't work out
# which should mean the second operand not being 2 dimensional.
raise ValueError(
"inplace matrix multiplication requires the first operand to "
"have at least one and the second at least two dimensions."
)
return self

def __imod__(self, other):
"""Return ``self%=value``."""
Expand Down Expand Up @@ -469,9 +497,11 @@ def __pow__(self, other):
return dpnp.power(self, other)

def __radd__(self, other):
"""Return ``value+self``."""
return dpnp.add(other, self)

def __rand__(self, other):
"""Return ``value&self``."""
return dpnp.bitwise_and(other, self)

# '__rdivmod__',
Expand All @@ -483,40 +513,51 @@ def __repr__(self):
return dpt.usm_ndarray_repr(self._array_obj, prefix="array")

def __rfloordiv__(self, other):
"""Return ``value//self``."""
return dpnp.floor_divide(self, other)

def __rlshift__(self, other):
"""Return ``value<<self``."""
return dpnp.left_shift(other, self)

def __rmatmul__(self, other):
"""Return ``value@self``."""
return dpnp.matmul(other, self)

def __rmod__(self, other):
"""Return ``value%self``."""
return dpnp.remainder(other, self)

def __rmul__(self, other):
"""Return ``value*self``."""
return dpnp.multiply(other, self)

def __ror__(self, other):
"""Return ``value|self``."""
return dpnp.bitwise_or(other, self)

def __rpow__(self, other):
"""Return ``value**self``."""
return dpnp.power(other, self)

def __rrshift__(self, other):
"""Return ``value>>self``."""
return dpnp.right_shift(other, self)

def __rshift__(self, other):
"""Return ``self>>value``."""
return dpnp.right_shift(self, other)

def __rsub__(self, other):
"""Return ``value-self``."""
return dpnp.subtract(other, self)

def __rtruediv__(self, other):
"""Return ``value/self``."""
return dpnp.true_divide(other, self)

def __rxor__(self, other):
"""Return ``value^self``."""
return dpnp.bitwise_xor(other, self)

# '__setattr__',
Expand Down
8 changes: 4 additions & 4 deletions dpnp/dpnp_utils/dpnp_utils_linearalgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import dpctl.tensor._tensor_impl as ti
import dpctl.utils as dpu
import numpy
from dpctl.tensor._numpy_helper import normalize_axis_tuple
from dpctl.tensor._numpy_helper import AxisError, normalize_axis_tuple
from dpctl.utils import ExecutionPlacementError

import dpnp
Expand Down Expand Up @@ -525,15 +525,15 @@ def _validate_internal(axes, i, ndim):
)

if len(axes) != 1:
raise ValueError(
raise AxisError(
f"Axes item {i} should be a tuple with a single element, or an integer."
)
else:
iter = 2
if not isinstance(axes, tuple):
raise TypeError(f"Axes item {i} should be a tuple.")
if len(axes) != 2:
raise ValueError(
raise AxisError(
f"Axes item {i} should be a tuple with 2 elements."
)

Expand Down Expand Up @@ -563,7 +563,7 @@ def _validate_internal(axes, i, ndim):

if x1_ndim == 1 and x2_ndim == 1:
if axes[2] != ():
raise TypeError("Axes item 2 should be an empty tuple.")
raise AxisError("Axes item 2 should be an empty tuple.")
elif x1_ndim == 1 or x2_ndim == 1:
axes[2] = _validate_internal(axes[2], 2, 1)
else:
Expand Down
Loading

0 comments on commit 687b8ea

Please sign in to comment.