diff --git a/tests/test_interpolation.py b/tests/test_interpolation.py index 684de18d991..44ea9d3024f 100644 --- a/tests/test_interpolation.py +++ b/tests/test_interpolation.py @@ -1,4 +1,4 @@ -from math import sin, floor +from math import sin, floor, prod import numpy as np import pytest @@ -53,6 +53,18 @@ def time_points(grid, ranges, npoints, name='points', nt=10): return points +def time_grid_points(grid, name='points', nt=10): + """Create a SparseTimeFunction field with coordinates + filled in by all grid points""" + npoints = prod(grid.shape) + a = SparseTimeFunction(name=name, grid=grid, npoint=npoints, nt=nt) + dims = tuple([np.linspace(0., 1., d) for d in grid.shape]) + for i in range(len(grid.shape)): + a.coordinates.data[:,i] = np.meshgrid(*dims)[i].flatten() + + return a + + def a(shape=(11, 11)): grid = Grid(shape=shape) a = Function(name='a', grid=grid) @@ -417,6 +429,31 @@ def test_inject_time_shift(shape, coords, result, npoints=19): assert np.allclose(a.data[indices], result, rtol=1.e-5) +@pytest.mark.parametrize('shape, result, increment', [ + ((10, 10), 1., False), + ((10, 10), 5., True), + ((10, 10, 10), 1., False), + ((10, 10, 10), 5., True) +]) +def test_inject_time_increment(shape, result, increment): + """Test the increment option in the SparseTimeFunction's + injection method. The increment=False option is + expected to work only at points located on the grid, + where no interpolation needed. + """ + a = unit_box_time(shape=shape) + a.data[:] = 0. + p = time_grid_points(a.grid, name='points', nt=10) + + expr = p.inject(a, Float(1.), increment=increment) + + Operator(expr)(a=a) + + assert np.allclose(a.data, result*np.ones(a.grid.shape), rtol=1.e-5) + + + + @pytest.mark.parametrize('shape, coords, result', [ ((11, 11), [(.05, .95), (.45, .45)], 1.), ((11, 11, 11), [(.05, .95), (.45, .45), (.45, .45)], 0.5)