You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I think the function interpolate in grid.py is not safe against out of grid bound error.
For instance, domain_lo or domain_hi can be out of the grid shape range.
I made the following changes and it fixed the bugs, but not sure if this is the ideal change.
def interpolate(self, values, state):
"""Interpolates `values` (possibly multidimensional per node) defined over the grid at the given `state`."""
# check wehter state is in the domain
out_of_range_status = jnp.logical_or(state < self.domain.lo, state > self.domain.hi)
# check if any of out_of_range_status is True and is_periodic_dim is False
if jnp.any(jnp.logical_and(out_of_range_status, jnp.logical_not(self._is_periodic_dim))):
raise ValueError("state is out of the domain")
position = (state - self.domain.lo) / jnp.array(self.spacings)
index_lo = jnp.floor(position).astype(jnp.int32)
# for index_lo that is same with grid shape-1, set it to the last index-1
index_lo = jnp.where(index_lo >= np.array(self.shape)-1, np.array(self.shape) - 2, index_lo)
index_hi = index_lo + 1
weight_hi = position - index_lo
weight_lo = 1 - weight_hi
index_lo, index_hi = tuple(
jnp.where(self._is_periodic_dim, index % np.array(self.shape), jnp.clip(index, 0, np.array(self.shape)))
for index in (index_lo, index_hi))
weight = functools.reduce(lambda x, y: x * y, jnp.ix_(*jnp.stack([weight_lo, weight_hi], -1)))
# TODO: Double-check numerical stability here and/or switch to `tuple`s and `itertools.product` for clarity.
return jnp.sum(
weight[(...,) + (np.newaxis,) * (values.ndim - self.ndim)] *
values[jnp.ix_(*jnp.stack([index_lo, index_hi], -1))], list(range(self.ndim)))
The text was updated successfully, but these errors were encountered:
Hi Ed,
I think the function
interpolate
ingrid.py
is not safe against out of grid bound error.For instance,
domain_lo
ordomain_hi
can be out of the grid shape range.I made the following changes and it fixed the bugs, but not sure if this is the ideal change.
The text was updated successfully, but these errors were encountered: