Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using a lax.scan to run the solver #11

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

EiffL
Copy link

@EiffL EiffL commented Nov 22, 2022

This draft PR is in response to #9 and presents a prototype implementation of the leap-frog solver that uses a lax.scan instead of a for loop in the nbody function.

Here are the results on the baseline default configuration.

Current master

  • No extra jitting (so, normal pmwd)
First result obtained in  16.21842312812805
Second result obtained in  9.312845468521118
  • jitted pmwd (what I'm told not to do ^^)
First result obtained in  184.43911576271057
Second result obtained in  9.710259437561035

This PR (using scan, and I actually removed all lower level jit)

  • No jitting at all (I removed all the jit in pmwd.nbody)
First result obtained in  19.26646661758423
Second result obtained in  9.970827102661133
  • jitted pmwd
First result obtained in  13.62941026687622
Second result obtained in  9.098160743713379

And here the notebook to reproduce this test (working off my fork):
https://gist.github.com/EiffL/aa6a651141f694ca257fb5ff83e829d6

So I would advocate using lax.scan.

In this draft implementation, I chose not to output intermediate ptcl and obsvl, exactly like what is done on master, but if you want to export intermediate snapshots, it's easy you can export them as the output of the scan fn :-)

If you look at the implementation of odeint in jax, you can also have a slightly more complicated logic that exports the state of the system only at some desired pre-defined steps, and not necessarily at all time steps:
https://github.com/google/jax/blob/518fe6656ca2aab66dcfc8cd7866c10f476a17b1/jax/experimental/ode.py#L189

And finally, if you want to save the sims to disk, then nothing prevents you from using the nbody step function directly/manually in a for loop.

@EiffL
Copy link
Author

EiffL commented Nov 22, 2022

And actually ^^ it's generally not a good idea ^^' but if we want to, we can definitely write a custom CPU op that will dump the simulation in hdf5 from within jitted code, and from within the lax.scan.

In this particular instance, I think it would be pretty cool

@EiffL
Copy link
Author

EiffL commented Nov 22, 2022

So yeah I don't see any drawbacks of using a scan :-)

@eelregit
Copy link
Owner

Thanks! This was very much how it was done here. Also here for the adjoint.

So it's good to know that XLA or JAX has gotten better on this.

but if you want to export intermediate snapshots, it's easy you can export them as the output of the scan fn :-)

I guess you meant nested scan's. We want interpolation between two steps. It looks like odeint is extrapolating from the last step? But interpolation should also be okay with nested scan's.

@EiffL
Copy link
Author

EiffL commented Nov 23, 2022

In odeint they use a while inside the scan function yes.

Would you be ok with an API with an argument which would be the array a to use in the solver, and maybe another optional array save_at which would contain the indices of the snapshots to export. By default it would be [-1]. If so I'm happy to implement it :-)

And then, I think it would be very cool to have the ability to do IO directly from jitted code :-) And I think I know how to do it, but probably that's for a different PR.

@eelregit eelregit force-pushed the master branch 2 times, most recently from 418337c to a5329ae Compare November 26, 2022 06:26
@eelregit
Copy link
Owner

eelregit commented Nov 29, 2022

Let's try switching to scan following the odeint way, once the checkpoint (exactly at a time step, directly copying disp and vel) and snapshot (interpolation between 2 steps) observables are implemented. @Yucheng-Zhang is working on those observables.

Yes, it'd be super cool to have a custom IO op ^^

@eelregit
Copy link
Owner

id_tap seems to be useful in writing snapshots

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants