You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The current Chainsail prototype is really about providing an automatically tuning and scalable Replica Exchange (RE) implementation. This is why currently only two very simple "local" samplers (that sample between the RE swaps) are implemented, namely a Metropolis algorithm with a uniform proposal density and a barebone HMC sampler, both to be found in /lib/common/chainsail/common/samplers/. While even with these basic samplers, RE can give great advantages when sampling multimodal distributions, better sampling for a wider range of problems would be achieved if Chainsail used a state-of-the-art method such as auto-tuning NUTS.
One could reimplement NUTS for Chainsail, but there are great implementations with various improvements such as mass matrix adaptations etc. out there, for example in the BlackJAX library. But BlackJAX requires a log-density function compatible with JAX. So while we could use BlackJAX to sample between RE swaps, to perform the exchanges, we need to convert between Python lists or numpy arrays and be able to update the state of the BlackJAX sampler "from the outside".
Once we have either a proprietary NUTS implementation or some kind of wrapper around BlackJAX or another NUTS implementation, adding a new sampler is as easy as implementing it, adding it to lib/common/chainsail/common/samplers/ and extending the job specification, the front-end and the MPI runner script with code that permits this sampler / chooses this sampler based on user input.
The text was updated successfully, but these errors were encountered:
The current Chainsail prototype is really about providing an automatically tuning and scalable Replica Exchange (RE) implementation. This is why currently only two very simple "local" samplers (that sample between the RE swaps) are implemented, namely a Metropolis algorithm with a uniform proposal density and a barebone HMC sampler, both to be found in
/lib/common/chainsail/common/samplers/
. While even with these basic samplers, RE can give great advantages when sampling multimodal distributions, better sampling for a wider range of problems would be achieved if Chainsail used a state-of-the-art method such as auto-tuning NUTS.One could reimplement NUTS for Chainsail, but there are great implementations with various improvements such as mass matrix adaptations etc. out there, for example in the BlackJAX library. But BlackJAX requires a log-density function compatible with JAX. So while we could use BlackJAX to sample between RE swaps, to perform the exchanges, we need to convert between Python lists or numpy arrays and be able to update the state of the BlackJAX sampler "from the outside".
Once we have either a proprietary NUTS implementation or some kind of wrapper around BlackJAX or another NUTS implementation, adding a new sampler is as easy as implementing it, adding it to
lib/common/chainsail/common/samplers/
and extending the job specification, the front-end and the MPI runner script with code that permits this sampler / chooses this sampler based on user input.The text was updated successfully, but these errors were encountered: