Skip to content

Commit

Permalink
Make sure n_steps_default >= 3
Browse files Browse the repository at this point in the history
  • Loading branch information
NoraLoose committed Nov 12, 2024
1 parent d3d52c9 commit ab564d3
Showing 1 changed file with 24 additions and 7 deletions.
31 changes: 24 additions & 7 deletions gcm_filters/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,24 @@ def _taper_target(target_spec: TargetSpec):
}


def _compute_n_steps_default(
ndim, filter_shape, filter_scale, dx_min, transition_width
):
"""Compute the default number of steps for 1D or 2D filters based on provided parameters."""

n_steps_factor = filter_params[filter_shape][ndim]["offset"] + filter_params[
filter_shape
][ndim]["factor"] * (
(np.pi / transition_width) ** filter_params[filter_shape][ndim]["exponent"]
)

filter_factor = filter_scale / dx_min

n_steps_default = max(np.ceil(n_steps_factor * filter_factor).astype(int), 3)

return n_steps_default


class FilterSpec(NamedTuple):
n_steps: int
s_max: float
Expand Down Expand Up @@ -332,20 +350,19 @@ def __post_init__(self):
raise ValueError(f"Transition width must be > 1.")

# Get default number of steps
filter_factor = self.filter_scale / self.dx_min
if self.ndim > 2:
if self.n_steps < 3:
raise ValueError(f"When ndim > 2, you must set n_steps manually")
else:
n_steps_default = self.n_steps # For ndim>2 we don't have a default
else:
n_steps_factor = filter_params[self.filter_shape][self.ndim][
"offset"
] + filter_params[self.filter_shape][self.ndim]["factor"] * (
(np.pi / self.transition_width)
** filter_params[self.filter_shape][self.ndim]["exponent"]
n_steps_default = _compute_n_steps_default(
self.ndim,
self.filter_shape,
self.filter_scale,
self.dx_min,
self.transition_width,
)
n_steps_default = np.ceil(n_steps_factor * filter_factor).astype(int)

# Set n_steps if needed and issue n_step warning, if needed
if self.n_steps < 3:
Expand Down

0 comments on commit ab564d3

Please sign in to comment.