diff --git a/lib/iris/util.py b/lib/iris/util.py index 020b67783a2..f121cbcfee6 100644 --- a/lib/iris/util.py +++ b/lib/iris/util.py @@ -322,8 +322,6 @@ def rolling_window(a, window=1, step=1, axis=-1): See more at :doc:`/userguide/real_and_lazy_data`. """ - # NOTE: The implementation of this function originates from - # https://github.com/numpy/numpy/pull/31#issuecomment-1304851 04/08/2011 if window < 1: raise ValueError("`window` must be at least 1.") if window > a.shape[axis]: @@ -331,25 +329,22 @@ def rolling_window(a, window=1, step=1, axis=-1): if step < 1: raise ValueError("`step` must be at least 1.") axis = axis % a.ndim - num_windows = (a.shape[axis] - window + step) // step - shape = a.shape[:axis] + (num_windows, window) + a.shape[axis + 1 :] - strides = ( - a.strides[:axis] - + (step * a.strides[axis], a.strides[axis]) - + a.strides[axis + 1 :] + array_module = da if isinstance(a, da.Array) else np + steps = tuple( + slice(None, None, step) if i == axis else slice(None) for i in range(a.ndim) ) - rw = np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides) - if ma.isMaskedArray(a): - mask = ma.getmaskarray(a) - strides = ( - mask.strides[:axis] - + (step * mask.strides[axis], mask.strides[axis]) - + mask.strides[axis + 1 :] - ) - rw = ma.array( - rw, - mask=np.lib.stride_tricks.as_strided(mask, shape=shape, strides=strides), - ) + rw = array_module.lib.stride_tricks.sliding_window_view( + a, + window_shape=window, + axis=axis, + )[steps] + if isinstance(da.utils.meta_from_array(a), np.ma.MaskedArray): + mask = array_module.lib.stride_tricks.sliding_window_view( + array_module.ma.getmaskarray(a), + window_shape=window, + axis=axis, + )[steps] + rw = array_module.ma.masked_array(rw, mask) return rw