Skip to content

Commit

Permalink
Fix rolling center (#406)
Browse files Browse the repository at this point in the history
* fix np.roll in center for numpy engine

* fix center for step!=None

* squeeze the patch as a condition for Pandas engine

* fix test_dimension_order
  • Loading branch information
ahmadtourei authored Jul 3, 2024
1 parent a59dc35 commit 99f4468
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 12 deletions.
11 changes: 5 additions & 6 deletions dascore/proc/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def _pad_roll_array(self, data):
assert padded.shape == self.patch.data.shape
if self.center:
# roll array along axis to center
padded = np.roll(padded, -self.window // 2, axis=self.axis)
padded = np.roll(padded, -(num_nans // 2), axis=self.axis)
return padded

def apply(self, function):
Expand All @@ -105,9 +105,8 @@ def apply(self, function):
step_slice.append(slice(None, None))
# this accounts for NaNs that pad the start of the array.
start = self.get_start_index()
# start = (self.window - 1) % self.step
step_slice[self.axis] = slice(start, None, self.step)
# apply function, then pad with zeros and roll
# apply function, then pad with NaNs and roll
kwargs = self.func_kwargs
trimmed_slide_view = slide_view[tuple(step_slice)]
raw = function(trimmed_slide_view, axis=-1, **kwargs).astype(np.float64)
Expand Down Expand Up @@ -231,10 +230,10 @@ def rolling(
step
The window is evaluated at every step result, equivalent to slicing
at every step. If the step argument is not None, the result will
have a different shape than the input.
have a different shape than the input. Default None.
center
If False, set the window labels as the right edge of the window index.
If True, set the window labels as the center of the window index.
If True, set the window labels as the center of the window index. Default False.
engine
Determines how the rolling operations are applied. If None, try to
determine which will be fastest for a given step. Options are:
Expand Down Expand Up @@ -301,7 +300,7 @@ def _get_engine(step, engine, patch):
engines = {"numpy": _NumpyPatchRoller, "pandas": _PandasPatchRoller}
if cls := engines.get(engine):
return cls
if step < 10 and len(patch.dims) < 2:
if step < 10 and len(patch.squeeze().dims) < 2:
return _PandasPatchRoller
return _NumpyPatchRoller

Expand Down
2 changes: 2 additions & 0 deletions dasdasdasd
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
* fix_rolling_center
master
40 changes: 34 additions & 6 deletions tests/test_proc/test_rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,11 +238,30 @@ def test_dist_dim(self, data: tuple[int, int], range_patch):
def test_center_same(self, range_patch):
"""Ensure center values are handled the same."""
dt = range_patch.get_coord("time").step
numpy_out = range_patch.rolling(time=13 * dt, center=True).sum()
pandas_out = range_patch.rolling(time=13 * dt, center=True).sum()
numpy_out = range_patch.rolling(time=13 * dt, center=True, engine="numpy").sum()
pandas_out = range_patch.rolling(
time=13 * dt, center=True, engine="pandas"
).sum()
numpy_isnan = np.isnan(numpy_out.data)
pandas_isnan = np.isnan(pandas_out.data)
assert np.all(np.equal(numpy_isnan, pandas_isnan))
assert np.all(
np.equal(numpy_isnan, pandas_isnan)
), "The NaN indices do not match"

def test_center_same_stepped(self, range_patch):
"""Ensure center values are handled the same."""
dt = range_patch.get_coord("time").step
numpy_out = range_patch.rolling(
time=13 * dt, step=3 * dt, center=True, engine="numpy"
).sum()
pandas_out = range_patch.rolling(
time=13 * dt, step=3 * dt, center=True, engine="pandas"
).sum()
numpy_isnan = np.isnan(numpy_out.data)
pandas_isnan = np.isnan(pandas_out.data)
assert np.all(
np.equal(numpy_isnan, pandas_isnan)
), "The NaN indices do not match"

def test_dimension_order(self, range_patch):
"""Ensure the dimension order doesn't matter."""
Expand All @@ -251,7 +270,16 @@ def test_dimension_order(self, range_patch):
coord = patch.get_coord(dim)
step = coord.step
total_len = len(coord) - 2
kwargs = {dim: step * total_len, "step": total_len * step}
pandas_out = patch.rolling(**kwargs).mean().dropna(dim)
numpy_out = patch.rolling(**kwargs).mean().dropna(dim)
kwargs_pandas = {
dim: step * total_len,
"step": total_len * step,
"engine": "pandas",
}
kwargs_numpy = {
dim: step * total_len,
"step": total_len * step,
"engine": "numpy",
}
pandas_out = patch.rolling(**kwargs_pandas).mean().dropna(dim)
numpy_out = patch.rolling(**kwargs_numpy).mean().dropna(dim)
assert pandas_out == numpy_out

0 comments on commit 99f4468

Please sign in to comment.