Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Grid bound handling in interpolate #13

Open
ChoiJangho opened this issue May 13, 2024 · 0 comments
Open

Grid bound handling in interpolate #13

ChoiJangho opened this issue May 13, 2024 · 0 comments

Comments

@ChoiJangho
Copy link

Hi Ed,

I think the function interpolate in grid.py is not safe against out of grid bound error.
For instance, domain_lo or domain_hi can be out of the grid shape range.
I made the following changes and it fixed the bugs, but not sure if this is the ideal change.

    def interpolate(self, values, state):
        """Interpolates `values` (possibly multidimensional per node) defined over the grid at the given `state`."""
        # check wehter state is in the domain
        out_of_range_status = jnp.logical_or(state < self.domain.lo, state > self.domain.hi)
        # check if any of out_of_range_status is True and is_periodic_dim is False
        if jnp.any(jnp.logical_and(out_of_range_status, jnp.logical_not(self._is_periodic_dim))):
            raise ValueError("state is out of the domain")
                
        position = (state - self.domain.lo) / jnp.array(self.spacings)
        index_lo = jnp.floor(position).astype(jnp.int32)
        # for index_lo that is same with grid shape-1, set it to the last index-1
        index_lo = jnp.where(index_lo >= np.array(self.shape)-1, np.array(self.shape) - 2, index_lo)
        index_hi = index_lo + 1
        
        
        weight_hi = position - index_lo
        weight_lo = 1 - weight_hi
        index_lo, index_hi = tuple(
            jnp.where(self._is_periodic_dim, index % np.array(self.shape), jnp.clip(index, 0, np.array(self.shape)))
            for index in (index_lo, index_hi))
        weight = functools.reduce(lambda x, y: x * y, jnp.ix_(*jnp.stack([weight_lo, weight_hi], -1)))
        # TODO: Double-check numerical stability here and/or switch to `tuple`s and `itertools.product` for clarity.
        return jnp.sum(
            weight[(...,) + (np.newaxis,) * (values.ndim - self.ndim)] *
            values[jnp.ix_(*jnp.stack([index_lo, index_hi], -1))], list(range(self.ndim)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant