Skip to content

Commit

Permalink
WIP minor
Browse files Browse the repository at this point in the history
  • Loading branch information
Yucheng-Zhang committed Aug 8, 2023
1 parent 6c23227 commit 10b4de5
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 13 deletions.
22 changes: 11 additions & 11 deletions pmwd/nbody.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def itp(a, obsvbl):
if conf.a_snapshots is not None:
for a in conf.a_snapshots:
obsvbl = cond(jnp.logical_and(a_prev < a, a <= a_next),
partial(itp, a), lambda *args: obsvbl, obsvbl)
partial(itp, a), lambda *args: obsvbl, obsvbl)

obsvbl['ptcl_prev'] = ptcl

Expand All @@ -212,7 +212,7 @@ def observe_init(a, ptcl, obsvbl, cosmo, conf):
# all output snapshots
obsvbl['snapshots'] = {
a_snap: Particles(ptcl.conf, ptcl.pmid, jnp.zeros_like(ptcl.disp),
vel=jnp.zeros_like(ptcl.vel))
vel=jnp.zeros_like(ptcl.vel))
for a_snap in conf.a_snapshots
}
# the nbody a step of output snapshots, (,]
Expand All @@ -227,13 +227,13 @@ def observe_adj(a_prev, a_next, ptcl, ptcl_cot, obsvbl, obsvbl_cot, cosmo, cosmo
if conf.a_snapshots is not None:
for a_snap, a_step in zip(conf.a_snapshots, obsvbl['snap_a_step']):
ptcl_cot, cosmo_cot = cond(a_step[1] == a_next, itp_next_adj,
lambda *args: (ptcl_cot, cosmo_cot),
ptcl_cot, cosmo_cot, obsvbl_cot['snapshots'][a_snap],
ptcl, a_step[0], a_step[1], a_snap, cosmo)
lambda *args: (ptcl_cot, cosmo_cot),
ptcl_cot, cosmo_cot, obsvbl_cot['snapshots'][a_snap],
ptcl, a_step[0], a_step[1], a_snap, cosmo)
ptcl_cot, cosmo_cot = cond(a_step[1] == a_prev, itp_prev_adj,
lambda *args: (ptcl_cot, cosmo_cot),
ptcl_cot, cosmo_cot, obsvbl_cot['snapshots'][a_snap],
ptcl, a_step[0], a_step[1], a_snap, cosmo)
lambda *args: (ptcl_cot, cosmo_cot),
ptcl_cot, cosmo_cot, obsvbl_cot['snapshots'][a_snap],
ptcl, a_step[0], a_step[1], a_snap, cosmo)

return ptcl_cot, cosmo_cot

Expand All @@ -244,9 +244,9 @@ def observe_adj_init(a, ptcl, ptcl_cot, obsvbl, obsvbl_cot, cosmo, cosmo_cot, co
# check if the last ptcl is used in interpolation
for a_snap, a_step in zip(conf.a_snapshots, obsvbl['snap_a_step']):
ptcl_cot, cosmo_cot = cond(a_step[1] == a, itp_next_adj,
lambda *args: (ptcl_cot, cosmo_cot),
ptcl_cot, cosmo_cot, obsvbl_cot['snapshots'][a_snap],
ptcl, a_step[0], a_step[1], a_snap, cosmo)
lambda *args: (ptcl_cot, cosmo_cot),
ptcl_cot, cosmo_cot, obsvbl_cot['snapshots'][a_snap],
ptcl, a_step[0], a_step[1], a_snap, cosmo)

return ptcl_cot, cosmo_cot

Expand Down
6 changes: 4 additions & 2 deletions pmwd/obs_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

def itp_prev(ptcl0, a0, a1, a, cosmo):
"""Cubic Hermite interpolation is a linear combination of two ptcls, this
function returns the disp and vel from the first ptcl at a0."""
function returns the disp and vel from the first ptcl at a0."""
Da = a1 - a0
t = (a - a0) / Da
a3E0 = a0**3 * jnp.sqrt(E2(a0, cosmo))
Expand All @@ -27,6 +27,7 @@ def itp_prev(ptcl0, a0, a1, a, cosmo):


def itp_prev_adj(ptcl_cot, cosmo_cot, iptcl_cot, ptcl0, a0, a1, a, cosmo):
"""Update ptcl_cot and cosmo_cot given the iptcl_cot and the vjp with itp_prev."""
# iptcl_cot is the cotangent of the interpolated ptcl
(disp, vel), itp_prev_vjp = vjp(itp_prev, ptcl0, a0, a1, a, cosmo)
ptcl0_cot, a0_cot, a1_cot, a_cot, cosmo_cot_itp = itp_prev_vjp(
Expand All @@ -41,7 +42,7 @@ def itp_prev_adj(ptcl_cot, cosmo_cot, iptcl_cot, ptcl0, a0, a1, a, cosmo):

def itp_next(ptcl1, a0, a1, a, cosmo):
"""Cubic Hermite interpolation is a linear combination of two ptcls, this
function returns the disp and vel from the second ptcl at a1."""
function returns the disp and vel from the second ptcl at a1."""
Da = a1 - a0
t = (a - a0) / Da
a3E1 = a1**3 * jnp.sqrt(E2(a1, cosmo))
Expand All @@ -61,6 +62,7 @@ def itp_next(ptcl1, a0, a1, a, cosmo):


def itp_next_adj(ptcl_cot, cosmo_cot, iptcl_cot, ptcl1, a0, a1, a, cosmo):
"""Update ptcl_cot and cosmo_cot given the iptcl_cot and the vjp with itp_next."""
# iptcl_cot is the cotangent of the interpolated ptcl
(disp, vel), itp_next_vjp = vjp(itp_next, ptcl1, a0, a1, a, cosmo)
ptcl1_cot, a0_cot, a1_cot, a_cot, cosmo_cot_itp = itp_next_vjp(
Expand Down

0 comments on commit 10b4de5

Please sign in to comment.