Skip to content

Commit

Permalink
Blackjax sampler fix for breaking change / enable progress bar under …
Browse files Browse the repository at this point in the history
…parallel chain_method (pymc-devs#7453)

* remove blackjax pmap warning

* use gen_scan_fn

* remove labels

* retrigger checks

* retrigger checks
  • Loading branch information
andrewdipper authored Aug 12, 2024
1 parent f3cff73 commit 8cdc9ee
Showing 1 changed file with 2 additions and 15 deletions.
17 changes: 2 additions & 15 deletions pymc/sampling/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,15 +278,10 @@ def _one_step(state, xs):
return state, (position, stats)

progress_bar = adaptation_kwargs.pop("progress_bar", False)
if progress_bar:
from blackjax.progress_bar import progress_bar_scan

one_step = jax.jit(progress_bar_scan(draws)(_one_step))
else:
one_step = jax.jit(_one_step)

keys = jax.random.split(seed, draws)
_, (samples, stats) = jax.lax.scan(one_step, last_state, (jnp.arange(draws), keys))
scan_fn = blackjax.progress_bar.gen_scan_fn(draws, progress_bar)
_, (samples, stats) = scan_fn(_one_step, last_state, (jnp.arange(draws), keys))

return samples, stats

Expand Down Expand Up @@ -365,14 +360,6 @@ def _sample_blackjax_nuts(
# Adapted from numpyro
if chain_method == "parallel":
map_fn = jax.pmap
if progressbar:
import warnings

warnings.warn(
"BlackJax currently only display progress bar correctly under "
"`chain_method == 'vectorized'`. Setting `progressbar=False`."
)
progressbar = False
elif chain_method == "vectorized":
map_fn = jax.vmap
else:
Expand Down

0 comments on commit 8cdc9ee

Please sign in to comment.