Skip to content

Commit

Permalink
WIP adding multiple snapshots observable
Browse files Browse the repository at this point in the history
  • Loading branch information
Yucheng-Zhang committed Aug 7, 2023
1 parent e830c2f commit 6c23227
Show file tree
Hide file tree
Showing 3 changed files with 193 additions and 14 deletions.
3 changes: 3 additions & 0 deletions pmwd/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ class Configuration:

chunk_size: int = 2**24

# observables
a_snapshots: Optional[Tuple[float]] = None

def __post_init__(self):
if self._is_transforming():
return
Expand Down
97 changes: 83 additions & 14 deletions pmwd/nbody.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
from jax import value_and_grad, jit, vjp, custom_vjp
import jax.numpy as jnp
from jax.tree_util import tree_map
from jax.lax import cond

from pmwd.boltzmann import growth
from pmwd.cosmology import E2, H_deriv
from pmwd.gravity import gravity
from pmwd.obs_util import interptcl, itp_prev_adj, itp_next_adj
from pmwd.particles import Particles


def _G_D(a, cosmo, conf):
Expand Down Expand Up @@ -183,11 +186,69 @@ def coevolve_init(a, ptcl, cosmo, conf):


def observe(a_prev, a_next, ptcl, obsvbl, cosmo, conf):
pass
def itp(a, obsvbl):
snap = interptcl(obsvbl['ptcl_prev'], ptcl, a_prev, a_next, a, cosmo)
obsvbl['snapshots'][a] = snap
return obsvbl

if conf.a_snapshots is not None:
for a in conf.a_snapshots:
obsvbl = cond(jnp.logical_and(a_prev < a, a <= a_next),
partial(itp, a), lambda *args: obsvbl, obsvbl)

obsvbl['ptcl_prev'] = ptcl

return obsvbl


def observe_init(a, ptcl, obsvbl, cosmo, conf):
pass
# a dict to carry all observables and related useful information
obsvbl = {}

# to carry the prev ptcl, starting with lpt ptcl
obsvbl['ptcl_prev'] = ptcl

if conf.a_snapshots is not None:
# all output snapshots
obsvbl['snapshots'] = {
a_snap: Particles(ptcl.conf, ptcl.pmid, jnp.zeros_like(ptcl.disp),
vel=jnp.zeros_like(ptcl.vel))
for a_snap in conf.a_snapshots
}
# the nbody a step of output snapshots, (,]
idx = jnp.searchsorted(conf.a_nbody, jnp.asarray(conf.a_snapshots), side='left')
obsvbl['snap_a_step'] = jnp.array((conf.a_nbody[idx-1], conf.a_nbody[idx])).T

return obsvbl


def observe_adj(a_prev, a_next, ptcl, ptcl_cot, obsvbl, obsvbl_cot, cosmo, cosmo_cot, conf):

if conf.a_snapshots is not None:
for a_snap, a_step in zip(conf.a_snapshots, obsvbl['snap_a_step']):
ptcl_cot, cosmo_cot = cond(a_step[1] == a_next, itp_next_adj,
lambda *args: (ptcl_cot, cosmo_cot),
ptcl_cot, cosmo_cot, obsvbl_cot['snapshots'][a_snap],
ptcl, a_step[0], a_step[1], a_snap, cosmo)
ptcl_cot, cosmo_cot = cond(a_step[1] == a_prev, itp_prev_adj,
lambda *args: (ptcl_cot, cosmo_cot),
ptcl_cot, cosmo_cot, obsvbl_cot['snapshots'][a_snap],
ptcl, a_step[0], a_step[1], a_snap, cosmo)

return ptcl_cot, cosmo_cot


def observe_adj_init(a, ptcl, ptcl_cot, obsvbl, obsvbl_cot, cosmo, cosmo_cot, conf):

if conf.a_snapshots is not None:
# check if the last ptcl is used in interpolation
for a_snap, a_step in zip(conf.a_snapshots, obsvbl['snap_a_step']):
ptcl_cot, cosmo_cot = cond(a_step[1] == a, itp_next_adj,
lambda *args: (ptcl_cot, cosmo_cot),
ptcl_cot, cosmo_cot, obsvbl_cot['snapshots'][a_snap],
ptcl, a_step[0], a_step[1], a_snap, cosmo)

return ptcl_cot, cosmo_cot


@jit
Expand Down Expand Up @@ -224,53 +285,61 @@ def nbody(ptcl, obsvbl, cosmo, conf, reverse=False):


@jit
def nbody_adj_init(a, ptcl, ptcl_cot, obsvbl_cot, cosmo, conf):
#ptcl_cot = observe_adj(a_prev, a_next, ptcl, ptcl_cot, obsvbl_cot, cosmo)
def nbody_adj_init(a, ptcl, ptcl_cot, obsvbl, obsvbl_cot, cosmo, conf):

#ptcl, ptcl_cot = coevolve_adj(a_prev, a_next, ptcl, ptcl_cot, cosmo)

ptcl, ptcl_cot, cosmo_cot_force = force_adj(a, ptcl, ptcl_cot, cosmo, conf)

cosmo_cot = tree_map(jnp.zeros_like, cosmo)

ptcl_cot, cosmo_cot = observe_adj_init(a, ptcl, ptcl_cot, obsvbl, obsvbl_cot,
cosmo, cosmo_cot, conf)

return ptcl, ptcl_cot, cosmo_cot, cosmo_cot_force


@jit
def nbody_adj_step(a_prev, a_next, ptcl, ptcl_cot, obsvbl_cot, cosmo, cosmo_cot, cosmo_cot_force, conf):
#ptcl_cot = observe_adj(a_prev, a_next, ptcl, ptcl_cot, obsvbl_cot, cosmo, conf)
def nbody_adj_step(a_prev, a_next, ptcl, ptcl_cot, obsvbl, obsvbl_cot,
cosmo, cosmo_cot, cosmo_cot_force, conf):

#ptcl, ptcl_cot = coevolve_adj(a_prev, a_next, ptcl, ptcl_cot, cosmo, conf)

ptcl, ptcl_cot, cosmo_cot, cosmo_cot_force = integrate_adj(
a_prev, a_next, ptcl, ptcl_cot, obsvbl_cot, cosmo, cosmo_cot, cosmo_cot_force, conf)

ptcl_cot, cosmo_cot = observe_adj(a_prev, a_next, ptcl, ptcl_cot, obsvbl, obsvbl_cot,
cosmo, cosmo_cot, conf)

return ptcl, ptcl_cot, cosmo_cot, cosmo_cot_force


def nbody_adj(ptcl, ptcl_cot, obsvbl_cot, cosmo, conf, reverse=False):
def nbody_adj(ptcl, ptcl_cot, obsvbl, obsvbl_cot, cosmo, conf, reverse=False):
"""N-body time integration with adjoint equation."""
a_nbody = conf.a_nbody[::-1] if reverse else conf.a_nbody

ptcl, ptcl_cot, cosmo_cot, cosmo_cot_force = nbody_adj_init(
a_nbody[-1], ptcl, ptcl_cot, obsvbl_cot, cosmo, conf)
a_nbody[-1], ptcl, ptcl_cot, obsvbl, obsvbl_cot, cosmo, conf)

for a_prev, a_next in zip(a_nbody[:0:-1], a_nbody[-2::-1]):
ptcl, ptcl_cot, cosmo_cot, cosmo_cot_force = nbody_adj_step(
a_prev, a_next, ptcl, ptcl_cot, obsvbl_cot, cosmo, cosmo_cot, cosmo_cot_force, conf)
a_prev, a_next, ptcl, ptcl_cot, obsvbl, obsvbl_cot,
cosmo, cosmo_cot, cosmo_cot_force, conf)

return ptcl, ptcl_cot, cosmo_cot


def nbody_fwd(ptcl, obsvbl, cosmo, conf, reverse):
ptcl, obsvbl = nbody(ptcl, obsvbl, cosmo, conf, reverse)
return (ptcl, obsvbl), (ptcl, cosmo, conf)
return (ptcl, obsvbl), (ptcl, obsvbl, cosmo, conf)

def nbody_bwd(reverse, res, cotangents):
ptcl, cosmo, conf = res
ptcl, obsvbl, cosmo, conf = res
ptcl_cot, obsvbl_cot = cotangents

ptcl, ptcl_cot, cosmo_cot = nbody_adj(ptcl, ptcl_cot, obsvbl_cot, cosmo, conf,
reverse=reverse)
ptcl, ptcl_cot, cosmo_cot = nbody_adj(
ptcl, ptcl_cot, obsvbl, obsvbl_cot, cosmo, conf, reverse=reverse)

return ptcl_cot, obsvbl_cot, cosmo_cot, None
return ptcl_cot, None, cosmo_cot, None

nbody.defvjp(nbody_fwd, nbody_bwd)
107 changes: 107 additions & 0 deletions pmwd/obs_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from jax import jit, vjp
import jax.numpy as jnp

from pmwd.particles import Particles
from pmwd.cosmology import E2


def itp_prev(ptcl0, a0, a1, a, cosmo):
"""Cubic Hermite interpolation is a linear combination of two ptcls, this
function returns the disp and vel from the first ptcl at a0."""
Da = a1 - a0
t = (a - a0) / Da
a3E0 = a0**3 * jnp.sqrt(E2(a0, cosmo))
# displacement
h00 = 2 * t**3 - 3 * t**2 + 1
h10 = t**3 - 2 * t**2 + t
disp = h00 * ptcl0.disp + h10 * Da / a3E0 * ptcl0.vel
# velocity
# derivatives of the Hermite basis functions
h00 = 6 * t**2 - 6 * t
h10 = 3 * t**2 - 4 * t + 1
vel = h00 / Da * ptcl0.disp + h10 / a3E0 * ptcl0.vel
vel *= a**3 * jnp.sqrt(E2(a, cosmo))

dtype = ptcl0.conf.float_dtype
return disp.astype(dtype), vel.astype(dtype)


def itp_prev_adj(ptcl_cot, cosmo_cot, iptcl_cot, ptcl0, a0, a1, a, cosmo):
# iptcl_cot is the cotangent of the interpolated ptcl
(disp, vel), itp_prev_vjp = vjp(itp_prev, ptcl0, a0, a1, a, cosmo)
ptcl0_cot, a0_cot, a1_cot, a_cot, cosmo_cot_itp = itp_prev_vjp(
(iptcl_cot.disp, iptcl_cot.vel))

disp_cot = ptcl_cot.disp + ptcl0_cot.disp
vel_cot = ptcl_cot.vel + ptcl0_cot.vel
ptcl_cot = ptcl_cot.replace(disp=disp_cot, vel=vel_cot)
cosmo_cot += cosmo_cot_itp
return ptcl_cot, cosmo_cot


def itp_next(ptcl1, a0, a1, a, cosmo):
"""Cubic Hermite interpolation is a linear combination of two ptcls, this
function returns the disp and vel from the second ptcl at a1."""
Da = a1 - a0
t = (a - a0) / Da
a3E1 = a1**3 * jnp.sqrt(E2(a1, cosmo))
# displacement
h01 = - 2 * t**3 + 3 * t**2
h11 = t**3 - t**2
disp = h01 * ptcl1.disp + h11 * Da / a3E1 * ptcl1.vel
# velocity
# derivatives of the Hermite basis functions
h01 = - 6 * t**2 + 6 * t
h11 = 3 * t**2 - 2 * t
vel = h01 / Da * ptcl1.disp + h11 / a3E1 * ptcl1.vel
vel *= a**3 * jnp.sqrt(E2(a, cosmo))

dtype = ptcl1.conf.float_dtype
return disp.astype(dtype), vel.astype(dtype)


def itp_next_adj(ptcl_cot, cosmo_cot, iptcl_cot, ptcl1, a0, a1, a, cosmo):
# iptcl_cot is the cotangent of the interpolated ptcl
(disp, vel), itp_next_vjp = vjp(itp_next, ptcl1, a0, a1, a, cosmo)
ptcl1_cot, a0_cot, a1_cot, a_cot, cosmo_cot_itp = itp_next_vjp(
(iptcl_cot.disp, iptcl_cot.vel))

disp_cot = ptcl_cot.disp + ptcl1_cot.disp
vel_cot = ptcl_cot.vel + ptcl1_cot.vel
ptcl_cot = ptcl_cot.replace(disp=disp_cot, vel=vel_cot)
cosmo_cot += cosmo_cot_itp
return ptcl_cot, cosmo_cot


def interptcl(ptcl0, ptcl1, a0, a1, a, cosmo):
"""Given two ptcl snapshots, get the interpolated one at a given time using
cubic Hermite interpolation."""
Da = a1 - a0
t = (a - a0) / Da
a3E0 = a0**3 * jnp.sqrt(E2(a0, cosmo))
a3E1 = a1**3 * jnp.sqrt(E2(a1, cosmo))
# displacement
h00 = 2 * t**3 - 3 * t**2 + 1
h10 = t**3 - 2 * t**2 + t
h01 = - 2 * t**3 + 3 * t**2
h11 = t**3 - t**2
disp = (h00 * ptcl0.disp + h10 * Da / a3E0 * ptcl0.vel +
h01 * ptcl1.disp + h11 * Da / a3E1 * ptcl1.vel)
# velocity
# derivatives of the Hermite basis functions
h00 = 6 * t**2 - 6 * t
h10 = 3 * t**2 - 4 * t + 1
h01 = - 6 * t**2 + 6 * t
h11 = 3 * t**2 - 2 * t
vel = (h00 / Da * ptcl0.disp + h10 / a3E0 * ptcl0.vel +
h01 / Da * ptcl1.disp + h11 / a3E1 * ptcl1.vel)
vel *= a**3 * jnp.sqrt(E2(a, cosmo))

iptcl = Particles(ptcl0.conf, ptcl0.pmid, disp, vel=vel)
return iptcl


def interptcl_adj(iptcl_cot, ptcl0, ptcl1, a0, a1, a, cosmo):
iptcl, interptcl_vjp = vjp(interptcl, ptcl0, ptcl1, a0, a1, a, cosmo)
ptcl0_cot, ptcl1_cot, a0_cot, a1_cot, a_cot, cosmo_cot_itp = interptcl_vjp(iptcl_cot)
return ptcl0_cot, ptcl1_cot, cosmo_cot_itp

0 comments on commit 6c23227

Please sign in to comment.