Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Permit "same_kind" casting for element-wise in-place operators #2170

Merged
merged 12 commits into from
Jan 11, 2025
Merged
18 changes: 14 additions & 4 deletions dpnp/dpnp_algo/dpnp_elementwise_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,20 @@ def __call__(
"as an argument, but both were provided."
)

x1_usm = dpnp.get_usm_ndarray_or_scalar(x1)
x2_usm = dpnp.get_usm_ndarray_or_scalar(x2)
out_usm = None if out is None else dpnp.get_usm_ndarray(out)

if (
isinstance(x1, dpnp_array)
and x1 is out
and order == "K"
and dtype is None
):
# in-place operation
super()._inplace_op(x1_usm, x2_usm)
return x1

if order is None:
order = "K"
elif order in "afkcAFKC":
Expand All @@ -344,9 +358,6 @@ def __call__(
"order must be one of 'C', 'F', 'A', or 'K' (got '{order}')"
)

x1_usm = dpnp.get_usm_ndarray_or_scalar(x1)
x2_usm = dpnp.get_usm_ndarray_or_scalar(x2)

if dtype is not None:
if dpnp.isscalar(x1):
x1_usm = dpt.asarray(
Expand All @@ -368,7 +379,6 @@ def __call__(
x1_usm = dpt.astype(x1_usm, dtype, copy=False)
x2_usm = dpt.astype(x2_usm, dtype, copy=False)

out_usm = None if out is None else dpnp.get_usm_ndarray(out)
res_usm = super().__call__(x1_usm, x2_usm, out=out_usm, order=order)

if out is not None and isinstance(out, dpnp_array):
Expand Down
2 changes: 1 addition & 1 deletion dpnp/dpnp_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def __imatmul__(self, other):
axes = [(-2, -1), (-2, -1), (-2, -1)]

try:
dpnp.matmul(self, other, out=self, axes=axes)
dpnp.matmul(self, other, out=self, dtype=self.dtype, axes=axes)
except AxisError:
# AxisError should indicate that the axes argument didn't work out
# which should mean the second operand not being 2 dimensional.
Expand Down
5 changes: 5 additions & 0 deletions dpnp/dpnp_iface_bitwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def binary_repr(num, width=None):
ti._bitwise_and_result_type,
ti._bitwise_and,
_BITWISE_AND_DOCSTRING,
binary_inplace_fn=ti._bitwise_and_inplace,
)


Expand Down Expand Up @@ -285,6 +286,7 @@ def binary_repr(num, width=None):
ti._bitwise_or_result_type,
ti._bitwise_or,
_BITWISE_OR_DOCSTRING,
binary_inplace_fn=ti._bitwise_or_inplace,
)


Expand Down Expand Up @@ -366,6 +368,7 @@ def binary_repr(num, width=None):
ti._bitwise_xor_result_type,
ti._bitwise_xor,
_BITWISE_XOR_DOCSTRING,
binary_inplace_fn=ti._bitwise_xor_inplace,
)


Expand Down Expand Up @@ -518,6 +521,7 @@ def binary_repr(num, width=None):
ti._bitwise_left_shift_result_type,
ti._bitwise_left_shift,
_LEFT_SHIFT_DOCSTRING,
binary_inplace_fn=ti._bitwise_left_shift_inplace,
)

bitwise_left_shift = left_shift # bitwise_left_shift is an alias for left_shift
Expand Down Expand Up @@ -595,6 +599,7 @@ def binary_repr(num, width=None):
ti._bitwise_right_shift_result_type,
ti._bitwise_right_shift,
_RIGHT_SHIFT_DOCSTRING,
binary_inplace_fn=ti._bitwise_right_shift_inplace,
)

# bitwise_right_shift is an alias for right_shift
Expand Down
4 changes: 1 addition & 3 deletions dpnp/dpnp_iface_trigonometric.py
Original file line number Diff line number Diff line change
Expand Up @@ -2450,7 +2450,5 @@ def unwrap(p, discont=None, axis=-1, *, period=2 * dpnp.pi):

up = dpnp.astype(p, dtype=dt, copy=True)
up[slice1] = p[slice1]
# TODO: replace, once dpctl-1757 resolved
# up[slice1] += ph_correct.cumsum(axis=axis)
up[slice1] += ph_correct.cumsum(axis=axis, dtype=dt)
up[slice1] += ph_correct.cumsum(axis=axis)
return up
Loading
Loading