-
Notifications
You must be signed in to change notification settings - Fork 241
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
mean_accept_prob significantly different after warmup #1786
Comments
In the early phase, I guess the sampler tends to reject many samples. Hence you can see the smaller accept_prob than in the sampling phase. We use dual averaging to adapt step size and update the step size at the end of the warm-up phase, numpyro/numpyro/infer/hmc_util.py Line 663 in 2f1bccd
|
I see that they won't be the same, but the eventual accept rate is almost 100% suggesting the learned step size is too small. Note that I am targeting the default accept rate of 80%. Could this be the same issue discussed by the stan guys here? stan-dev/stan#3105. |
You're right - the step size seems to be small. I'll look into the adaptation dynamic later this week. If you are interested, you can extract more information from |
@jonny-so This turns out to be the issue of the dual averaging algorithm that we used import jax.numpy as jnp
from jax.lax import scan
from numpyro.infer.hmc import hmc
def potential(x):
return 0.5 * jnp.sum(x**2)
d = 10
nwarmup = 10000
nsamples = 10000
init_kernel, sample_kernel = hmc(potential, algo='HMC')
hmc_state = init_kernel(init_params=jnp.zeros(d), num_warmup=nwarmup, adapt_step_size=True, adapt_mass_matrix=False)
hmc_state_warmup, step_sizes = scan(lambda s, _: (sample_kernel(s), s.adapt_state.step_size), hmc_state, None, length=nwarmup)
print("post warmup", hmc_state_warmup.mean_accept_prob)
hmc_state = scan(lambda s, _: (sample_kernel(s), None), hmc_state_warmup, None, length=nsamples)[0]
print("post samples", hmc_state.mean_accept_prob)
print("exp(mean(log(last_50_step_sizes)))", jnp.exp(jnp.log(step_sizes[-50:]).mean()))
print("mean(last_50_step_sizes)", step_sizes[-50:].mean())
We use dual averaging over the last window buffer (50 steps) of the warmup phase. With that, the estimation for cc @martinjankowiak do you have any suggestions dealing with this issue? |
i'm not sure but if you wanted to reduce that specific bias i guess you could use the formula for the mean of a log normal distribution....
|
It looks like the implementation agrees with Algorithm 5 in https://jmlr.org/papers/volume15/hoffman14a/hoffman14a.pdf#page=18.62 I guess it is better to let users control the last window size. |
Sorry for the delay, I've been flat out for the neurips deadline. I need to think about this a bit, but I'm taking a week off to recover... I'll come back to you soon. |
Increasing the fixed window size to > 50 does indeed seem to resolve the issue; some comments on the stan ticket I linked suggests they have observed the same. Exposing it as an option would be fine, although I am curious to know why fixing it to such a small value hasn't been a problem before. |
I spoke with some of the stan developers recently, and they said this has come up a number of times before, and that the default should probably just be bigger. It would make a little more sense to me to have it proportional to the length of the warmup period, or something like that. |
Hi @jonny-so, I think we can expose this configuration in HMC/NUTS constructor. Do you want to make a PR for this? |
I notice that after warmup, the
mean_accept_prob
significantly higher than bothtarget_accept_prob
and themean_accept_prob
observed during warmup, even on a trivial isotropic gaussian example. Minimum working example:outputs:
am I misusing something here?
The text was updated successfully, but these errors were encountered: