Skip to content

Commit

Permalink
simplify dpnp.average implementation (#2189)
Browse files Browse the repository at this point in the history
* simplify dpnp.average implementation

* address comments
  • Loading branch information
vtavana authored Nov 21, 2024
1 parent ffd3829 commit b839c40
Showing 1 changed file with 29 additions and 35 deletions.
64 changes: 29 additions & 35 deletions dpnp/dpnp_iface_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,11 @@
import dpnp

# pylint: disable=no-name-in-module
from .dpnp_algo import (
dpnp_correlate,
)
from .dpnp_algo import dpnp_correlate
from .dpnp_array import dpnp_array
from .dpnp_utils import (
call_origin,
get_usm_allocations,
)
from .dpnp_utils import call_origin, get_usm_allocations
from .dpnp_utils.dpnp_utils_reduction import dpnp_wrap_reduction_call
from .dpnp_utils.dpnp_utils_statistics import (
dpnp_cov,
)
from .dpnp_utils.dpnp_utils_statistics import dpnp_cov

__all__ = [
"amax",
Expand Down Expand Up @@ -276,60 +269,61 @@ def average(a, axis=None, weights=None, returned=False, *, keepdims=False):
"""

dpnp.check_supported_arrays_type(a)
usm_type, exec_q = get_usm_allocations([a, weights])

if weights is None:
avg = dpnp.mean(a, axis=axis, keepdims=keepdims)
scl = dpnp.asanyarray(
avg.dtype.type(a.size / avg.size),
usm_type=a.usm_type,
sycl_queue=a.sycl_queue,
usm_type=usm_type,
sycl_queue=exec_q,
)
else:
if not isinstance(weights, (dpnp_array, dpt.usm_ndarray)):
wgt = dpnp.asanyarray(
weights, usm_type=a.usm_type, sycl_queue=a.sycl_queue
if not dpnp.is_supported_array_type(weights):
weights = dpnp.asarray(
weights, usm_type=usm_type, sycl_queue=exec_q
)
else:
get_usm_allocations([a, weights])
wgt = weights

if not dpnp.issubdtype(a.dtype, dpnp.inexact):
a_dtype = a.dtype
if not dpnp.issubdtype(a_dtype, dpnp.inexact):
default_dtype = dpnp.default_float_type(a.device)
result_dtype = dpnp.result_type(a.dtype, wgt.dtype, default_dtype)
res_dtype = dpnp.result_type(a_dtype, weights.dtype, default_dtype)
else:
result_dtype = dpnp.result_type(a.dtype, wgt.dtype)
res_dtype = dpnp.result_type(a_dtype, weights.dtype)

# Sanity checks
if a.shape != wgt.shape:
wgt_shape = weights.shape
a_shape = a.shape
if a_shape != wgt_shape:
if axis is None:
raise TypeError(
"Axis must be specified when shapes of input array and "
"weights differ."
)
if wgt.ndim != 1:
if weights.ndim != 1:
raise TypeError(
"1D weights expected when shapes of input array and "
"weights differ."
)
if wgt.shape[0] != a.shape[axis]:
if wgt_shape[0] != a_shape[axis]:
raise ValueError(
"Length of weights not compatible with specified axis."
)

# setup wgt to broadcast along axis
wgt = dpnp.broadcast_to(wgt, (a.ndim - 1) * (1,) + wgt.shape)
wgt = wgt.swapaxes(-1, axis)
# setup weights to broadcast along axis
weights = dpnp.broadcast_to(
weights, (a.ndim - 1) * (1,) + wgt_shape
)
weights = weights.swapaxes(-1, axis)

scl = wgt.sum(axis=axis, dtype=result_dtype, keepdims=keepdims)
scl = weights.sum(axis=axis, dtype=res_dtype, keepdims=keepdims)
if dpnp.any(scl == 0.0):
raise ZeroDivisionError("Weights sum to zero, can't be normalized")

# result_datatype
avg = (
dpnp.multiply(a, wgt).sum(
axis=axis, dtype=result_dtype, keepdims=keepdims
)
/ scl
avg = dpnp.multiply(a, weights).sum(
axis=axis, dtype=res_dtype, keepdims=keepdims
)
avg /= scl

if returned:
if scl.shape != avg.shape:
Expand Down Expand Up @@ -556,7 +550,7 @@ def cov(
"""

if not isinstance(m, (dpnp_array, dpt.usm_ndarray)):
if not dpnp.is_supported_array_type(m):
pass
elif m.ndim > 2:
pass
Expand Down

0 comments on commit b839c40

Please sign in to comment.