Skip to content

Commit

Permalink
Lazy rolling_window
Browse files Browse the repository at this point in the history
  • Loading branch information
bouweandela committed Feb 22, 2024
1 parent 2b024aa commit d924d1a
Showing 1 changed file with 15 additions and 20 deletions.
35 changes: 15 additions & 20 deletions lib/iris/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,34 +322,29 @@ def rolling_window(a, window=1, step=1, axis=-1):
See more at :doc:`/userguide/real_and_lazy_data`.
"""
# NOTE: The implementation of this function originates from
# https://github.com/numpy/numpy/pull/31#issuecomment-1304851 04/08/2011
if window < 1:
raise ValueError("`window` must be at least 1.")
if window > a.shape[axis]:
raise ValueError("`window` is too long.")
if step < 1:
raise ValueError("`step` must be at least 1.")
axis = axis % a.ndim
num_windows = (a.shape[axis] - window + step) // step
shape = a.shape[:axis] + (num_windows, window) + a.shape[axis + 1 :]
strides = (
a.strides[:axis]
+ (step * a.strides[axis], a.strides[axis])
+ a.strides[axis + 1 :]
array_module = da if isinstance(a, da.Array) else np
steps = tuple(
slice(None, None, step) if i == axis else slice(None) for i in range(a.ndim)
)
rw = np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)
if ma.isMaskedArray(a):
mask = ma.getmaskarray(a)
strides = (
mask.strides[:axis]
+ (step * mask.strides[axis], mask.strides[axis])
+ mask.strides[axis + 1 :]
)
rw = ma.array(
rw,
mask=np.lib.stride_tricks.as_strided(mask, shape=shape, strides=strides),
)
rw = array_module.lib.stride_tricks.sliding_window_view(
a,
window_shape=window,
axis=axis,
)[steps]
if isinstance(da.utils.meta_from_array(a), np.ma.MaskedArray):
mask = array_module.lib.stride_tricks.sliding_window_view(
array_module.ma.getmaskarray(a),
window_shape=window,
axis=axis,
)[steps]
rw = array_module.ma.masked_array(rw, mask)
return rw


Expand Down

0 comments on commit d924d1a

Please sign in to comment.