From b839c404e1a7d3f126ef014f0ad5d5b9ed22bf61 Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad <120411540+vtavana@users.noreply.github.com> Date: Thu, 21 Nov 2024 09:19:59 -0600 Subject: [PATCH] simplify `dpnp.average` implementation (#2189) * simplify dpnp.average implementation * address comments --- dpnp/dpnp_iface_statistics.py | 64 ++++++++++++++++------------------- 1 file changed, 29 insertions(+), 35 deletions(-) diff --git a/dpnp/dpnp_iface_statistics.py b/dpnp/dpnp_iface_statistics.py index 83473e11a5b..c266f7c397e 100644 --- a/dpnp/dpnp_iface_statistics.py +++ b/dpnp/dpnp_iface_statistics.py @@ -47,18 +47,11 @@ import dpnp # pylint: disable=no-name-in-module -from .dpnp_algo import ( - dpnp_correlate, -) +from .dpnp_algo import dpnp_correlate from .dpnp_array import dpnp_array -from .dpnp_utils import ( - call_origin, - get_usm_allocations, -) +from .dpnp_utils import call_origin, get_usm_allocations from .dpnp_utils.dpnp_utils_reduction import dpnp_wrap_reduction_call -from .dpnp_utils.dpnp_utils_statistics import ( - dpnp_cov, -) +from .dpnp_utils.dpnp_utils_statistics import dpnp_cov __all__ = [ "amax", @@ -276,60 +269,61 @@ def average(a, axis=None, weights=None, returned=False, *, keepdims=False): """ dpnp.check_supported_arrays_type(a) + usm_type, exec_q = get_usm_allocations([a, weights]) + if weights is None: avg = dpnp.mean(a, axis=axis, keepdims=keepdims) scl = dpnp.asanyarray( avg.dtype.type(a.size / avg.size), - usm_type=a.usm_type, - sycl_queue=a.sycl_queue, + usm_type=usm_type, + sycl_queue=exec_q, ) else: - if not isinstance(weights, (dpnp_array, dpt.usm_ndarray)): - wgt = dpnp.asanyarray( - weights, usm_type=a.usm_type, sycl_queue=a.sycl_queue + if not dpnp.is_supported_array_type(weights): + weights = dpnp.asarray( + weights, usm_type=usm_type, sycl_queue=exec_q ) - else: - get_usm_allocations([a, weights]) - wgt = weights - if not dpnp.issubdtype(a.dtype, dpnp.inexact): + a_dtype = a.dtype + if not dpnp.issubdtype(a_dtype, dpnp.inexact): default_dtype = dpnp.default_float_type(a.device) - result_dtype = dpnp.result_type(a.dtype, wgt.dtype, default_dtype) + res_dtype = dpnp.result_type(a_dtype, weights.dtype, default_dtype) else: - result_dtype = dpnp.result_type(a.dtype, wgt.dtype) + res_dtype = dpnp.result_type(a_dtype, weights.dtype) # Sanity checks - if a.shape != wgt.shape: + wgt_shape = weights.shape + a_shape = a.shape + if a_shape != wgt_shape: if axis is None: raise TypeError( "Axis must be specified when shapes of input array and " "weights differ." ) - if wgt.ndim != 1: + if weights.ndim != 1: raise TypeError( "1D weights expected when shapes of input array and " "weights differ." ) - if wgt.shape[0] != a.shape[axis]: + if wgt_shape[0] != a_shape[axis]: raise ValueError( "Length of weights not compatible with specified axis." ) - # setup wgt to broadcast along axis - wgt = dpnp.broadcast_to(wgt, (a.ndim - 1) * (1,) + wgt.shape) - wgt = wgt.swapaxes(-1, axis) + # setup weights to broadcast along axis + weights = dpnp.broadcast_to( + weights, (a.ndim - 1) * (1,) + wgt_shape + ) + weights = weights.swapaxes(-1, axis) - scl = wgt.sum(axis=axis, dtype=result_dtype, keepdims=keepdims) + scl = weights.sum(axis=axis, dtype=res_dtype, keepdims=keepdims) if dpnp.any(scl == 0.0): raise ZeroDivisionError("Weights sum to zero, can't be normalized") - # result_datatype - avg = ( - dpnp.multiply(a, wgt).sum( - axis=axis, dtype=result_dtype, keepdims=keepdims - ) - / scl + avg = dpnp.multiply(a, weights).sum( + axis=axis, dtype=res_dtype, keepdims=keepdims ) + avg /= scl if returned: if scl.shape != avg.shape: @@ -556,7 +550,7 @@ def cov( """ - if not isinstance(m, (dpnp_array, dpt.usm_ndarray)): + if not dpnp.is_supported_array_type(m): pass elif m.ndim > 2: pass