From 6c23227daea4f66f9f39be6437cabcfea297c1c4 Mon Sep 17 00:00:00 2001 From: Yucheng Zhang Date: Mon, 7 Aug 2023 07:01:58 -0400 Subject: [PATCH] WIP adding multiple snapshots observable --- pmwd/configuration.py | 3 ++ pmwd/nbody.py | 97 ++++++++++++++++++++++++++++++++------ pmwd/obs_util.py | 107 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 193 insertions(+), 14 deletions(-) create mode 100644 pmwd/obs_util.py diff --git a/pmwd/configuration.py b/pmwd/configuration.py index 5509fd27..6128a761 100644 --- a/pmwd/configuration.py +++ b/pmwd/configuration.py @@ -145,6 +145,9 @@ class Configuration: chunk_size: int = 2**24 + # observables + a_snapshots: Optional[Tuple[float]] = None + def __post_init__(self): if self._is_transforming(): return diff --git a/pmwd/nbody.py b/pmwd/nbody.py index a0bc2f9d..143477cb 100644 --- a/pmwd/nbody.py +++ b/pmwd/nbody.py @@ -3,10 +3,13 @@ from jax import value_and_grad, jit, vjp, custom_vjp import jax.numpy as jnp from jax.tree_util import tree_map +from jax.lax import cond from pmwd.boltzmann import growth from pmwd.cosmology import E2, H_deriv from pmwd.gravity import gravity +from pmwd.obs_util import interptcl, itp_prev_adj, itp_next_adj +from pmwd.particles import Particles def _G_D(a, cosmo, conf): @@ -183,11 +186,69 @@ def coevolve_init(a, ptcl, cosmo, conf): def observe(a_prev, a_next, ptcl, obsvbl, cosmo, conf): - pass + def itp(a, obsvbl): + snap = interptcl(obsvbl['ptcl_prev'], ptcl, a_prev, a_next, a, cosmo) + obsvbl['snapshots'][a] = snap + return obsvbl + + if conf.a_snapshots is not None: + for a in conf.a_snapshots: + obsvbl = cond(jnp.logical_and(a_prev < a, a <= a_next), + partial(itp, a), lambda *args: obsvbl, obsvbl) + + obsvbl['ptcl_prev'] = ptcl + + return obsvbl def observe_init(a, ptcl, obsvbl, cosmo, conf): - pass + # a dict to carry all observables and related useful information + obsvbl = {} + + # to carry the prev ptcl, starting with lpt ptcl + obsvbl['ptcl_prev'] = ptcl + + if conf.a_snapshots is not None: + # all output snapshots + obsvbl['snapshots'] = { + a_snap: Particles(ptcl.conf, ptcl.pmid, jnp.zeros_like(ptcl.disp), + vel=jnp.zeros_like(ptcl.vel)) + for a_snap in conf.a_snapshots + } + # the nbody a step of output snapshots, (,] + idx = jnp.searchsorted(conf.a_nbody, jnp.asarray(conf.a_snapshots), side='left') + obsvbl['snap_a_step'] = jnp.array((conf.a_nbody[idx-1], conf.a_nbody[idx])).T + + return obsvbl + + +def observe_adj(a_prev, a_next, ptcl, ptcl_cot, obsvbl, obsvbl_cot, cosmo, cosmo_cot, conf): + + if conf.a_snapshots is not None: + for a_snap, a_step in zip(conf.a_snapshots, obsvbl['snap_a_step']): + ptcl_cot, cosmo_cot = cond(a_step[1] == a_next, itp_next_adj, + lambda *args: (ptcl_cot, cosmo_cot), + ptcl_cot, cosmo_cot, obsvbl_cot['snapshots'][a_snap], + ptcl, a_step[0], a_step[1], a_snap, cosmo) + ptcl_cot, cosmo_cot = cond(a_step[1] == a_prev, itp_prev_adj, + lambda *args: (ptcl_cot, cosmo_cot), + ptcl_cot, cosmo_cot, obsvbl_cot['snapshots'][a_snap], + ptcl, a_step[0], a_step[1], a_snap, cosmo) + + return ptcl_cot, cosmo_cot + + +def observe_adj_init(a, ptcl, ptcl_cot, obsvbl, obsvbl_cot, cosmo, cosmo_cot, conf): + + if conf.a_snapshots is not None: + # check if the last ptcl is used in interpolation + for a_snap, a_step in zip(conf.a_snapshots, obsvbl['snap_a_step']): + ptcl_cot, cosmo_cot = cond(a_step[1] == a, itp_next_adj, + lambda *args: (ptcl_cot, cosmo_cot), + ptcl_cot, cosmo_cot, obsvbl_cot['snapshots'][a_snap], + ptcl, a_step[0], a_step[1], a_snap, cosmo) + + return ptcl_cot, cosmo_cot @jit @@ -224,8 +285,7 @@ def nbody(ptcl, obsvbl, cosmo, conf, reverse=False): @jit -def nbody_adj_init(a, ptcl, ptcl_cot, obsvbl_cot, cosmo, conf): - #ptcl_cot = observe_adj(a_prev, a_next, ptcl, ptcl_cot, obsvbl_cot, cosmo) +def nbody_adj_init(a, ptcl, ptcl_cot, obsvbl, obsvbl_cot, cosmo, conf): #ptcl, ptcl_cot = coevolve_adj(a_prev, a_next, ptcl, ptcl_cot, cosmo) @@ -233,44 +293,53 @@ def nbody_adj_init(a, ptcl, ptcl_cot, obsvbl_cot, cosmo, conf): cosmo_cot = tree_map(jnp.zeros_like, cosmo) + ptcl_cot, cosmo_cot = observe_adj_init(a, ptcl, ptcl_cot, obsvbl, obsvbl_cot, + cosmo, cosmo_cot, conf) + return ptcl, ptcl_cot, cosmo_cot, cosmo_cot_force @jit -def nbody_adj_step(a_prev, a_next, ptcl, ptcl_cot, obsvbl_cot, cosmo, cosmo_cot, cosmo_cot_force, conf): - #ptcl_cot = observe_adj(a_prev, a_next, ptcl, ptcl_cot, obsvbl_cot, cosmo, conf) +def nbody_adj_step(a_prev, a_next, ptcl, ptcl_cot, obsvbl, obsvbl_cot, + cosmo, cosmo_cot, cosmo_cot_force, conf): #ptcl, ptcl_cot = coevolve_adj(a_prev, a_next, ptcl, ptcl_cot, cosmo, conf) ptcl, ptcl_cot, cosmo_cot, cosmo_cot_force = integrate_adj( a_prev, a_next, ptcl, ptcl_cot, obsvbl_cot, cosmo, cosmo_cot, cosmo_cot_force, conf) + ptcl_cot, cosmo_cot = observe_adj(a_prev, a_next, ptcl, ptcl_cot, obsvbl, obsvbl_cot, + cosmo, cosmo_cot, conf) + return ptcl, ptcl_cot, cosmo_cot, cosmo_cot_force -def nbody_adj(ptcl, ptcl_cot, obsvbl_cot, cosmo, conf, reverse=False): +def nbody_adj(ptcl, ptcl_cot, obsvbl, obsvbl_cot, cosmo, conf, reverse=False): """N-body time integration with adjoint equation.""" a_nbody = conf.a_nbody[::-1] if reverse else conf.a_nbody ptcl, ptcl_cot, cosmo_cot, cosmo_cot_force = nbody_adj_init( - a_nbody[-1], ptcl, ptcl_cot, obsvbl_cot, cosmo, conf) + a_nbody[-1], ptcl, ptcl_cot, obsvbl, obsvbl_cot, cosmo, conf) + for a_prev, a_next in zip(a_nbody[:0:-1], a_nbody[-2::-1]): ptcl, ptcl_cot, cosmo_cot, cosmo_cot_force = nbody_adj_step( - a_prev, a_next, ptcl, ptcl_cot, obsvbl_cot, cosmo, cosmo_cot, cosmo_cot_force, conf) + a_prev, a_next, ptcl, ptcl_cot, obsvbl, obsvbl_cot, + cosmo, cosmo_cot, cosmo_cot_force, conf) + return ptcl, ptcl_cot, cosmo_cot def nbody_fwd(ptcl, obsvbl, cosmo, conf, reverse): ptcl, obsvbl = nbody(ptcl, obsvbl, cosmo, conf, reverse) - return (ptcl, obsvbl), (ptcl, cosmo, conf) + return (ptcl, obsvbl), (ptcl, obsvbl, cosmo, conf) def nbody_bwd(reverse, res, cotangents): - ptcl, cosmo, conf = res + ptcl, obsvbl, cosmo, conf = res ptcl_cot, obsvbl_cot = cotangents - ptcl, ptcl_cot, cosmo_cot = nbody_adj(ptcl, ptcl_cot, obsvbl_cot, cosmo, conf, - reverse=reverse) + ptcl, ptcl_cot, cosmo_cot = nbody_adj( + ptcl, ptcl_cot, obsvbl, obsvbl_cot, cosmo, conf, reverse=reverse) - return ptcl_cot, obsvbl_cot, cosmo_cot, None + return ptcl_cot, None, cosmo_cot, None nbody.defvjp(nbody_fwd, nbody_bwd) diff --git a/pmwd/obs_util.py b/pmwd/obs_util.py new file mode 100644 index 00000000..a74d87e3 --- /dev/null +++ b/pmwd/obs_util.py @@ -0,0 +1,107 @@ +from jax import jit, vjp +import jax.numpy as jnp + +from pmwd.particles import Particles +from pmwd.cosmology import E2 + + +def itp_prev(ptcl0, a0, a1, a, cosmo): + """Cubic Hermite interpolation is a linear combination of two ptcls, this + function returns the disp and vel from the first ptcl at a0.""" + Da = a1 - a0 + t = (a - a0) / Da + a3E0 = a0**3 * jnp.sqrt(E2(a0, cosmo)) + # displacement + h00 = 2 * t**3 - 3 * t**2 + 1 + h10 = t**3 - 2 * t**2 + t + disp = h00 * ptcl0.disp + h10 * Da / a3E0 * ptcl0.vel + # velocity + # derivatives of the Hermite basis functions + h00 = 6 * t**2 - 6 * t + h10 = 3 * t**2 - 4 * t + 1 + vel = h00 / Da * ptcl0.disp + h10 / a3E0 * ptcl0.vel + vel *= a**3 * jnp.sqrt(E2(a, cosmo)) + + dtype = ptcl0.conf.float_dtype + return disp.astype(dtype), vel.astype(dtype) + + +def itp_prev_adj(ptcl_cot, cosmo_cot, iptcl_cot, ptcl0, a0, a1, a, cosmo): + # iptcl_cot is the cotangent of the interpolated ptcl + (disp, vel), itp_prev_vjp = vjp(itp_prev, ptcl0, a0, a1, a, cosmo) + ptcl0_cot, a0_cot, a1_cot, a_cot, cosmo_cot_itp = itp_prev_vjp( + (iptcl_cot.disp, iptcl_cot.vel)) + + disp_cot = ptcl_cot.disp + ptcl0_cot.disp + vel_cot = ptcl_cot.vel + ptcl0_cot.vel + ptcl_cot = ptcl_cot.replace(disp=disp_cot, vel=vel_cot) + cosmo_cot += cosmo_cot_itp + return ptcl_cot, cosmo_cot + + +def itp_next(ptcl1, a0, a1, a, cosmo): + """Cubic Hermite interpolation is a linear combination of two ptcls, this + function returns the disp and vel from the second ptcl at a1.""" + Da = a1 - a0 + t = (a - a0) / Da + a3E1 = a1**3 * jnp.sqrt(E2(a1, cosmo)) + # displacement + h01 = - 2 * t**3 + 3 * t**2 + h11 = t**3 - t**2 + disp = h01 * ptcl1.disp + h11 * Da / a3E1 * ptcl1.vel + # velocity + # derivatives of the Hermite basis functions + h01 = - 6 * t**2 + 6 * t + h11 = 3 * t**2 - 2 * t + vel = h01 / Da * ptcl1.disp + h11 / a3E1 * ptcl1.vel + vel *= a**3 * jnp.sqrt(E2(a, cosmo)) + + dtype = ptcl1.conf.float_dtype + return disp.astype(dtype), vel.astype(dtype) + + +def itp_next_adj(ptcl_cot, cosmo_cot, iptcl_cot, ptcl1, a0, a1, a, cosmo): + # iptcl_cot is the cotangent of the interpolated ptcl + (disp, vel), itp_next_vjp = vjp(itp_next, ptcl1, a0, a1, a, cosmo) + ptcl1_cot, a0_cot, a1_cot, a_cot, cosmo_cot_itp = itp_next_vjp( + (iptcl_cot.disp, iptcl_cot.vel)) + + disp_cot = ptcl_cot.disp + ptcl1_cot.disp + vel_cot = ptcl_cot.vel + ptcl1_cot.vel + ptcl_cot = ptcl_cot.replace(disp=disp_cot, vel=vel_cot) + cosmo_cot += cosmo_cot_itp + return ptcl_cot, cosmo_cot + + +def interptcl(ptcl0, ptcl1, a0, a1, a, cosmo): + """Given two ptcl snapshots, get the interpolated one at a given time using + cubic Hermite interpolation.""" + Da = a1 - a0 + t = (a - a0) / Da + a3E0 = a0**3 * jnp.sqrt(E2(a0, cosmo)) + a3E1 = a1**3 * jnp.sqrt(E2(a1, cosmo)) + # displacement + h00 = 2 * t**3 - 3 * t**2 + 1 + h10 = t**3 - 2 * t**2 + t + h01 = - 2 * t**3 + 3 * t**2 + h11 = t**3 - t**2 + disp = (h00 * ptcl0.disp + h10 * Da / a3E0 * ptcl0.vel + + h01 * ptcl1.disp + h11 * Da / a3E1 * ptcl1.vel) + # velocity + # derivatives of the Hermite basis functions + h00 = 6 * t**2 - 6 * t + h10 = 3 * t**2 - 4 * t + 1 + h01 = - 6 * t**2 + 6 * t + h11 = 3 * t**2 - 2 * t + vel = (h00 / Da * ptcl0.disp + h10 / a3E0 * ptcl0.vel + + h01 / Da * ptcl1.disp + h11 / a3E1 * ptcl1.vel) + vel *= a**3 * jnp.sqrt(E2(a, cosmo)) + + iptcl = Particles(ptcl0.conf, ptcl0.pmid, disp, vel=vel) + return iptcl + + +def interptcl_adj(iptcl_cot, ptcl0, ptcl1, a0, a1, a, cosmo): + iptcl, interptcl_vjp = vjp(interptcl, ptcl0, ptcl1, a0, a1, a, cosmo) + ptcl0_cot, ptcl1_cot, a0_cot, a1_cot, a_cot, cosmo_cot_itp = interptcl_vjp(iptcl_cot) + return ptcl0_cot, ptcl1_cot, cosmo_cot_itp