diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000..5f454fdf --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,10 @@ +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "monthly" + groups: + github-actions: + patterns: + - '*' diff --git a/jax_galsim/spergel.py b/jax_galsim/spergel.py index 5c02a741..a5a7389b 100644 --- a/jax_galsim/spergel.py +++ b/jax_galsim/spergel.py @@ -168,7 +168,7 @@ def reducedfluxfractionFunc(z, nu, norm): @jax.jit def calculateFluxRadius(alpha, nu, zmin=0.0, zmax=30.0): - """Return radius R enclosing flux fraction alpha in unit of the scale radius r0 + """Return radius R enclosing flux fraction alpha in unit of the scale radius r0 Method: Solve F(R/r0=z)/Flux - alpha = 0 using bisection algorithm @@ -267,34 +267,34 @@ def scale_radius(self): def _r0(self): return self.scale_radius - @property + @lazy_property def _inv_r0(self): return 1.0 / self._r0 - @property + @lazy_property def _r0_sq(self): return self._r0 * self._r0 - @property + @lazy_property def _inv_r0_sq(self): return self._inv_r0 * self._inv_r0 - @property + @lazy_property @implements(_galsim.spergel.Spergel.half_light_radius) def half_light_radius(self): return self._r0 * calculateFluxRadius(0.5, self.nu) - @property + @lazy_property def _shootxnorm(self): """Normalization for photon shooting""" return 1.0 / (2.0 * jnp.pi * jnp.power(2.0, self.nu) * _gammap1(self.nu)) - @property + @lazy_property def _xnorm(self): """Normalization of xValue""" return self._shootxnorm * self.flux * self._inv_r0_sq - @property + @lazy_property def _xnorm0(self): """return z^nu K_nu(z) for z=0""" return jax.lax.select( @@ -303,11 +303,11 @@ def _xnorm0(self): @implements(_galsim.spergel.Spergel.calculateFluxRadius) def calculateFluxRadius(self, f): - return calculateFluxRadius(f, self.nu) + return self._r0 * calculateFluxRadius(f, self.nu) @implements(_galsim.spergel.Spergel.calculateIntegratedFlux) def calculateIntegratedFlux(self, r): - return fluxfractionFunc(r, self.nu, 0.0) + return fluxfractionFunc(r / self._r0, self.nu, 0.0) def __hash__(self): return hash( @@ -338,13 +338,13 @@ def __str__(self): s += ")" return s - @property + @lazy_property def _maxk(self): """(1+ (k r0)^2)^(-1-nu) = maxk_threshold""" res = jnp.power(self.gsparams.maxk_threshold, -1.0 / (1.0 + self.nu)) - 1.0 return jnp.sqrt(res) / self._r0 - @property + @lazy_property def _stepk(self): R = calculateFluxRadius(1.0 - self.gsparams.folding_threshold, self.nu) R *= self._r0 @@ -352,7 +352,7 @@ def _stepk(self): R = jnp.maximum(R, self.gsparams.stepk_minimum_hlr * self.half_light_radius) return jnp.pi / R - @property + @lazy_property def _max_sb(self): # from SBSpergelImpl.h return jnp.abs(self._xnorm) * self._xnorm0 diff --git a/tests/jax/test_benchmarks.py b/tests/jax/test_benchmarks.py index 7fe51487..7ef21332 100644 --- a/tests/jax/test_benchmarks.py +++ b/tests/jax/test_benchmarks.py @@ -159,3 +159,56 @@ def _run(): dt = _run_benchmarks(benchmark, kind, _run) print(f"time: {dt:0.4g} ms", end=" ") + + +def _run_spergel_bench_conv(gsmod): + obj = gsmod.Spergel(nu=-0.6, scale_radius=5) + psf = gsmod.Gaussian(fwhm=0.9) + obj = gsmod.Convolve( + [obj, psf], + gsparams=gsmod.GSParams(minimum_fft_size=2048, maximum_fft_size=2048), + ) + return obj.drawImage(nx=50, ny=50, scale=0.2).array + + +_run_spergel_bench_conv_jit = jax.jit(partial(_run_spergel_bench_conv, jgs)) + + +@pytest.mark.parametrize("kind", ["compile", "run"]) +def test_benchmark_spergel_conv(benchmark, kind): + dt = _run_benchmarks( + benchmark, kind, lambda: _run_spergel_bench_conv_jit().block_until_ready() + ) + print(f"time: {dt:0.4g} ms", end=" ") + + +def _run_spergel_bench_xvalue(gsmod): + obj = gsmod.Spergel(nu=-0.6, scale_radius=5) + return obj.drawImage(nx=1024, ny=1204, scale=0.05, method="no_pixel").array + + +_run_spergel_bench_xvalue_jit = jax.jit(partial(_run_spergel_bench_xvalue, jgs)) + + +@pytest.mark.parametrize("kind", ["compile", "run"]) +def test_benchmark_spergel_xvalue(benchmark, kind): + dt = _run_benchmarks( + benchmark, kind, lambda: _run_spergel_bench_xvalue_jit().block_until_ready() + ) + print(f"time: {dt:0.4g} ms", end=" ") + + +def _run_spergel_bench_kvalue(gsmod): + obj = gsmod.Spergel(nu=-0.6, scale_radius=5) + return obj.drawKImage(nx=1024, ny=1204, scale=0.05).array + + +_run_spergel_bench_kvalue_jit = jax.jit(partial(_run_spergel_bench_kvalue, jgs)) + + +@pytest.mark.parametrize("kind", ["compile", "run"]) +def test_benchmark_spergel_kvalue(benchmark, kind): + dt = _run_benchmarks( + benchmark, kind, lambda: _run_spergel_bench_kvalue_jit().block_until_ready() + ) + print(f"time: {dt:0.4g} ms", end=" ") diff --git a/tests/jax/test_spergel_comp_galsim.py b/tests/jax/test_spergel_comp_galsim.py new file mode 100644 index 00000000..0c0d4f63 --- /dev/null +++ b/tests/jax/test_spergel_comp_galsim.py @@ -0,0 +1,157 @@ +import galsim as _galsim +import galsim as gs +import jax +import numpy as np +import pytest +from test_benchmarks import ( + _run_spergel_bench_conv, + _run_spergel_bench_conv_jit, + _run_spergel_bench_kvalue, + _run_spergel_bench_kvalue_jit, + _run_spergel_bench_xvalue, + _run_spergel_bench_xvalue_jit, +) + +import jax_galsim as jgs +from jax_galsim.core.testing import time_code_block + + +@pytest.mark.parametrize( + "attr", + [ + "nu", + "scale_radius", + "maxk", + pytest.param( + "stepk", + marks=pytest.mark.xfail( + reason="GalSim has a bug in its stepk routine. See https://github.com/GalSim-developers/GalSim/issues/1324" + ), + ), + "half_light_radius", + ], +) +@pytest.mark.parametrize("nu", [-0.6, -0.5, 0.0, 0.1, 0.5, 1.0, 1.1, 1.5, 2, 2.7]) +@pytest.mark.parametrize("scale_radius", [0.5, 1.0, 1.5, 2.0, 2.5]) +def test_spergel_comp_galsim_properties(nu, scale_radius, attr): + s_jgs = jgs.Spergel(nu=nu, scale_radius=scale_radius) + s_gs = gs.Spergel(nu=nu, scale_radius=scale_radius) + + assert s_jgs.gsparams.folding_threshold == s_gs.gsparams.folding_threshold + assert s_jgs.gsparams.stepk_minimum_hlr == s_gs.gsparams.stepk_minimum_hlr + + np.testing.assert_allclose(getattr(s_jgs, attr), getattr(s_gs, attr), rtol=1e-5) + + +@pytest.mark.parametrize("nu", [-0.6, -0.5, 0.0, 0.1, 0.5, 1.0, 1.1, 1.5, 2, 2.7]) +@pytest.mark.parametrize("scale_radius", [0.5, 1.0, 1.5, 2.0, 2.5]) +def test_spergel_comp_galsim_flux_radius(nu, scale_radius): + s_jgs = jgs.Spergel(nu=nu, scale_radius=scale_radius) + s_gs = gs.Spergel(nu=nu, scale_radius=scale_radius) + + np.testing.assert_allclose( + s_jgs.calculateFluxRadius(0.8), + s_gs.calculateFluxRadius(0.8), + rtol=1e-5, + ) + + +@pytest.mark.parametrize("nu", [-0.6, -0.5, 0.0, 0.1, 0.5, 1.0, 1.1, 1.5, 2, 2.7]) +@pytest.mark.parametrize("scale_radius", [0.5, 1.0, 1.5, 2.0, 2.5]) +def test_spergel_comp_galsim_integ_flux(nu, scale_radius): + s_jgs = jgs.Spergel(nu=nu, scale_radius=scale_radius) + s_gs = gs.Spergel(nu=nu, scale_radius=scale_radius) + + np.testing.assert_allclose( + s_jgs.calculateIntegratedFlux(0.8), + s_gs.calculateIntegratedFlux(0.8), + rtol=1e-5, + ) + + +@pytest.mark.parametrize("kx", np.linspace(-1, 1, 13).tolist()) +@pytest.mark.parametrize("ky", np.linspace(-1, 1, 13).tolist()) +@pytest.mark.parametrize("nu", [-0.6, -0.5, 0.0, 0.1, 0.5, 1.0, 1.1, 1.5, 2, 2.7]) +@pytest.mark.parametrize("scale_radius", [0.5, 1.0, 1.5, 2.0, 2.5]) +def test_spergel_comp_galsim_kvalue(nu, scale_radius, kx, ky): + s_jgs = jgs.Spergel(nu=nu, scale_radius=scale_radius) + s_gs = gs.Spergel(nu=nu, scale_radius=scale_radius) + + np.testing.assert_allclose(s_jgs.kValue(kx, ky), s_gs.kValue(kx, ky), rtol=1e-5) + + +@pytest.mark.parametrize("x", np.linspace(-1, 1, 13).tolist()) +@pytest.mark.parametrize("y", np.linspace(-1, 1, 13).tolist()) +@pytest.mark.parametrize("nu", [-0.6, -0.5, 0.0, 0.1, 0.5, 1.0, 1.1, 1.5, 2, 2.7]) +@pytest.mark.parametrize("scale_radius", [0.5, 1.0, 1.5, 2.0, 2.5]) +def test_spergel_comp_galsim_xvalue(nu, scale_radius, x, y): + s_jgs = jgs.Spergel(nu=nu, scale_radius=scale_radius) + s_gs = gs.Spergel(nu=nu, scale_radius=scale_radius) + + np.testing.assert_allclose(s_jgs.xValue(x, y), s_gs.xValue(x, y), rtol=1e-5) + + +def _run_time_test(kind, func): + if kind == "compile": + + def _run(): + jax.clear_caches() + func() + + elif kind == "run": + # run once to compile + func() + + def _run(): + func() + + else: + raise ValueError(f"kind={kind} not recognized") + + tot_time = 0 + for _ in range(3): + with time_code_block(quiet=True) as tr: + _run() + tot_time += tr.dt + + return tot_time / 3 + + +@pytest.mark.parametrize("kind", ["compile", "run"]) +def test_spergel_comp_galsim_perf_conv(benchmark, kind): + dt = _run_time_test(kind, lambda: _run_spergel_bench_conv_jit().block_until_ready()) + print(f"\njax-galsim time: {dt:0.4g} ms") + + dt = _run_time_test( + kind, + lambda: _run_spergel_bench_conv(_galsim), + ) + print(f" galsim time: {dt:0.4g} ms") + + +@pytest.mark.parametrize("kind", ["compile", "run"]) +def test_spergel_comp_galsim_perf_kvalue(benchmark, kind): + dt = _run_time_test( + kind, lambda: _run_spergel_bench_kvalue_jit().block_until_ready() + ) + print(f"\njax-galsim time: {dt:0.4g} ms") + + dt = _run_time_test( + kind, + lambda: _run_spergel_bench_kvalue(_galsim), + ) + print(f" galsim time: {dt:0.4g} ms") + + +@pytest.mark.parametrize("kind", ["compile", "run"]) +def test_spergel_comp_galsim_perf_xvalue(benchmark, kind): + dt = _run_time_test( + kind, lambda: _run_spergel_bench_xvalue_jit().block_until_ready() + ) + print(f"\njax-galsim time: {dt:0.4g} ms") + + dt = _run_time_test( + kind, + lambda: _run_spergel_bench_xvalue(_galsim), + ) + print(f" galsim time: {dt:0.4g} ms")