Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add benchmarks for spergel profile #125

Merged
merged 6 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .github/dependabot.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
version: 2
updates:
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "monthly"
groups:
github-actions:
patterns:
- '*'
26 changes: 13 additions & 13 deletions jax_galsim/spergel.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def reducedfluxfractionFunc(z, nu, norm):

@jax.jit
def calculateFluxRadius(alpha, nu, zmin=0.0, zmax=30.0):
"""Return radius R enclosing flux fraction alpha in unit of the scale radius r0
"""Return radius R enclosing flux fraction alpha in unit of the scale radius r0

Method: Solve F(R/r0=z)/Flux - alpha = 0 using bisection algorithm

Expand Down Expand Up @@ -267,34 +267,34 @@ def scale_radius(self):
def _r0(self):
return self.scale_radius

@property
@lazy_property
def _inv_r0(self):
return 1.0 / self._r0

@property
@lazy_property
def _r0_sq(self):
return self._r0 * self._r0

@property
@lazy_property
def _inv_r0_sq(self):
return self._inv_r0 * self._inv_r0

@property
@lazy_property
@implements(_galsim.spergel.Spergel.half_light_radius)
def half_light_radius(self):
return self._r0 * calculateFluxRadius(0.5, self.nu)

@property
@lazy_property
def _shootxnorm(self):
"""Normalization for photon shooting"""
return 1.0 / (2.0 * jnp.pi * jnp.power(2.0, self.nu) * _gammap1(self.nu))

@property
@lazy_property
def _xnorm(self):
"""Normalization of xValue"""
return self._shootxnorm * self.flux * self._inv_r0_sq

@property
@lazy_property
def _xnorm0(self):
"""return z^nu K_nu(z) for z=0"""
return jax.lax.select(
Expand All @@ -303,11 +303,11 @@ def _xnorm0(self):

@implements(_galsim.spergel.Spergel.calculateFluxRadius)
def calculateFluxRadius(self, f):
return calculateFluxRadius(f, self.nu)
return self._r0 * calculateFluxRadius(f, self.nu)

@implements(_galsim.spergel.Spergel.calculateIntegratedFlux)
def calculateIntegratedFlux(self, r):
return fluxfractionFunc(r, self.nu, 0.0)
return fluxfractionFunc(r / self._r0, self.nu, 0.0)

def __hash__(self):
return hash(
Expand Down Expand Up @@ -338,21 +338,21 @@ def __str__(self):
s += ")"
return s

@property
@lazy_property
def _maxk(self):
"""(1+ (k r0)^2)^(-1-nu) = maxk_threshold"""
res = jnp.power(self.gsparams.maxk_threshold, -1.0 / (1.0 + self.nu)) - 1.0
return jnp.sqrt(res) / self._r0

@property
@lazy_property
def _stepk(self):
R = calculateFluxRadius(1.0 - self.gsparams.folding_threshold, self.nu)
R *= self._r0
# Go to at least 5*hlr
R = jnp.maximum(R, self.gsparams.stepk_minimum_hlr * self.half_light_radius)
return jnp.pi / R

@property
@lazy_property
def _max_sb(self):
# from SBSpergelImpl.h
return jnp.abs(self._xnorm) * self._xnorm0
Expand Down
53 changes: 53 additions & 0 deletions tests/jax/test_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,56 @@ def _run():

dt = _run_benchmarks(benchmark, kind, _run)
print(f"time: {dt:0.4g} ms", end=" ")


def _run_spergel_bench_conv(gsmod):
obj = gsmod.Spergel(nu=-0.6, scale_radius=5)
psf = gsmod.Gaussian(fwhm=0.9)
obj = gsmod.Convolve(
[obj, psf],
gsparams=gsmod.GSParams(minimum_fft_size=2048, maximum_fft_size=2048),
)
return obj.drawImage(nx=50, ny=50, scale=0.2).array


_run_spergel_bench_conv_jit = jax.jit(partial(_run_spergel_bench_conv, jgs))


@pytest.mark.parametrize("kind", ["compile", "run"])
def test_benchmark_spergel_conv(benchmark, kind):
dt = _run_benchmarks(
benchmark, kind, lambda: _run_spergel_bench_conv_jit().block_until_ready()
)
print(f"time: {dt:0.4g} ms", end=" ")


def _run_spergel_bench_xvalue(gsmod):
obj = gsmod.Spergel(nu=-0.6, scale_radius=5)
return obj.drawImage(nx=1024, ny=1204, scale=0.05, method="no_pixel").array


_run_spergel_bench_xvalue_jit = jax.jit(partial(_run_spergel_bench_xvalue, jgs))


@pytest.mark.parametrize("kind", ["compile", "run"])
def test_benchmark_spergel_xvalue(benchmark, kind):
dt = _run_benchmarks(
benchmark, kind, lambda: _run_spergel_bench_xvalue_jit().block_until_ready()
)
print(f"time: {dt:0.4g} ms", end=" ")


def _run_spergel_bench_kvalue(gsmod):
obj = gsmod.Spergel(nu=-0.6, scale_radius=5)
return obj.drawKImage(nx=1024, ny=1204, scale=0.05).array


_run_spergel_bench_kvalue_jit = jax.jit(partial(_run_spergel_bench_kvalue, jgs))


@pytest.mark.parametrize("kind", ["compile", "run"])
def test_benchmark_spergel_kvalue(benchmark, kind):
dt = _run_benchmarks(
benchmark, kind, lambda: _run_spergel_bench_kvalue_jit().block_until_ready()
)
print(f"time: {dt:0.4g} ms", end=" ")
157 changes: 157 additions & 0 deletions tests/jax/test_spergel_comp_galsim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import galsim as _galsim
import galsim as gs
import jax
import numpy as np
import pytest
from test_benchmarks import (
_run_spergel_bench_conv,
_run_spergel_bench_conv_jit,
_run_spergel_bench_kvalue,
_run_spergel_bench_kvalue_jit,
_run_spergel_bench_xvalue,
_run_spergel_bench_xvalue_jit,
)

import jax_galsim as jgs
from jax_galsim.core.testing import time_code_block


@pytest.mark.parametrize(
"attr",
[
"nu",
"scale_radius",
"maxk",
pytest.param(
"stepk",
marks=pytest.mark.xfail(
reason="GalSim has a bug in its stepk routine. See https://github.com/GalSim-developers/GalSim/issues/1324"
),
),
"half_light_radius",
],
)
@pytest.mark.parametrize("nu", [-0.6, -0.5, 0.0, 0.1, 0.5, 1.0, 1.1, 1.5, 2, 2.7])
@pytest.mark.parametrize("scale_radius", [0.5, 1.0, 1.5, 2.0, 2.5])
def test_spergel_comp_galsim_properties(nu, scale_radius, attr):
s_jgs = jgs.Spergel(nu=nu, scale_radius=scale_radius)
s_gs = gs.Spergel(nu=nu, scale_radius=scale_radius)

assert s_jgs.gsparams.folding_threshold == s_gs.gsparams.folding_threshold
assert s_jgs.gsparams.stepk_minimum_hlr == s_gs.gsparams.stepk_minimum_hlr

np.testing.assert_allclose(getattr(s_jgs, attr), getattr(s_gs, attr), rtol=1e-5)


@pytest.mark.parametrize("nu", [-0.6, -0.5, 0.0, 0.1, 0.5, 1.0, 1.1, 1.5, 2, 2.7])
@pytest.mark.parametrize("scale_radius", [0.5, 1.0, 1.5, 2.0, 2.5])
def test_spergel_comp_galsim_flux_radius(nu, scale_radius):
s_jgs = jgs.Spergel(nu=nu, scale_radius=scale_radius)
s_gs = gs.Spergel(nu=nu, scale_radius=scale_radius)

np.testing.assert_allclose(
s_jgs.calculateFluxRadius(0.8),
s_gs.calculateFluxRadius(0.8),
rtol=1e-5,
)


@pytest.mark.parametrize("nu", [-0.6, -0.5, 0.0, 0.1, 0.5, 1.0, 1.1, 1.5, 2, 2.7])
@pytest.mark.parametrize("scale_radius", [0.5, 1.0, 1.5, 2.0, 2.5])
def test_spergel_comp_galsim_integ_flux(nu, scale_radius):
s_jgs = jgs.Spergel(nu=nu, scale_radius=scale_radius)
s_gs = gs.Spergel(nu=nu, scale_radius=scale_radius)

np.testing.assert_allclose(
s_jgs.calculateIntegratedFlux(0.8),
s_gs.calculateIntegratedFlux(0.8),
rtol=1e-5,
)


@pytest.mark.parametrize("kx", np.linspace(-1, 1, 13).tolist())
@pytest.mark.parametrize("ky", np.linspace(-1, 1, 13).tolist())
@pytest.mark.parametrize("nu", [-0.6, -0.5, 0.0, 0.1, 0.5, 1.0, 1.1, 1.5, 2, 2.7])
@pytest.mark.parametrize("scale_radius", [0.5, 1.0, 1.5, 2.0, 2.5])
def test_spergel_comp_galsim_kvalue(nu, scale_radius, kx, ky):
s_jgs = jgs.Spergel(nu=nu, scale_radius=scale_radius)
s_gs = gs.Spergel(nu=nu, scale_radius=scale_radius)

np.testing.assert_allclose(s_jgs.kValue(kx, ky), s_gs.kValue(kx, ky), rtol=1e-5)


@pytest.mark.parametrize("x", np.linspace(-1, 1, 13).tolist())
@pytest.mark.parametrize("y", np.linspace(-1, 1, 13).tolist())
@pytest.mark.parametrize("nu", [-0.6, -0.5, 0.0, 0.1, 0.5, 1.0, 1.1, 1.5, 2, 2.7])
@pytest.mark.parametrize("scale_radius", [0.5, 1.0, 1.5, 2.0, 2.5])
def test_spergel_comp_galsim_xvalue(nu, scale_radius, x, y):
s_jgs = jgs.Spergel(nu=nu, scale_radius=scale_radius)
s_gs = gs.Spergel(nu=nu, scale_radius=scale_radius)

np.testing.assert_allclose(s_jgs.xValue(x, y), s_gs.xValue(x, y), rtol=1e-5)


def _run_time_test(kind, func):
if kind == "compile":

def _run():
jax.clear_caches()
func()

elif kind == "run":
# run once to compile
func()

def _run():
func()

else:
raise ValueError(f"kind={kind} not recognized")

tot_time = 0
for _ in range(3):
with time_code_block(quiet=True) as tr:
_run()
tot_time += tr.dt

return tot_time / 3


@pytest.mark.parametrize("kind", ["compile", "run"])
def test_spergel_comp_galsim_perf_conv(benchmark, kind):
dt = _run_time_test(kind, lambda: _run_spergel_bench_conv_jit().block_until_ready())
print(f"\njax-galsim time: {dt:0.4g} ms")

dt = _run_time_test(
kind,
lambda: _run_spergel_bench_conv(_galsim),
)
print(f" galsim time: {dt:0.4g} ms")


@pytest.mark.parametrize("kind", ["compile", "run"])
def test_spergel_comp_galsim_perf_kvalue(benchmark, kind):
dt = _run_time_test(
kind, lambda: _run_spergel_bench_kvalue_jit().block_until_ready()
)
print(f"\njax-galsim time: {dt:0.4g} ms")

dt = _run_time_test(
kind,
lambda: _run_spergel_bench_kvalue(_galsim),
)
print(f" galsim time: {dt:0.4g} ms")


@pytest.mark.parametrize("kind", ["compile", "run"])
def test_spergel_comp_galsim_perf_xvalue(benchmark, kind):
dt = _run_time_test(
kind, lambda: _run_spergel_bench_xvalue_jit().block_until_ready()
)
print(f"\njax-galsim time: {dt:0.4g} ms")

dt = _run_time_test(
kind,
lambda: _run_spergel_bench_xvalue(_galsim),
)
print(f" galsim time: {dt:0.4g} ms")
Loading