Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GMM notebook example: MCMC/NUTS simulation is not reproducible #1616

Open
inarighas opened this issue Jun 27, 2023 · 2 comments
Open

GMM notebook example: MCMC/NUTS simulation is not reproducible #1616

inarighas opened this issue Jun 27, 2023 · 2 comments
Labels
help wanted Extra attention is needed Tutorials/Examples

Comments

@inarighas
Copy link

inarighas commented Jun 27, 2023

source: https://num.pyro.ai/en/stable/tutorials/gmm.html#MCMC
numpyro.__version__: 0.12.1
jax.__version__: 0.4.13

--

When running the collapsed NUTS to explore the full posterior, the obtained results did not match the presented ones.

from numpyro.infer import MCMC, NUTS


kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=50, num_samples=250)
mcmc.run(random.PRNGKey(2), data)
mcmc.print_summary()
posterior_samples = mcmc.get_samples()

Obtained posterior density:
image

But, with longer num_warmup (to 150 or more), we get roughly the expected behaviour:

image

With more samples (~2500), the pattern is better:
image

I would like to clarify that in my previous attempts, I used the same values and parameters for reproduction. When I ran the code on Google Drive, the results matched the ones mentioned in the documentation. However, when I ran the code on my laptop, there was a significant difference. I can provide you with more details regarding this issue. Considering the specified random seeds and the simplicity of the example, I find the difference to be quite substantial.

I would like also to thank all the contributors for this library ! I am impressed and excited by the remarkable work done by its developers.

@ordabayevy
Copy link
Member

ordabayevy commented Jun 27, 2023

@fehiepsi can this be due to a newer (different) version of jax?

@fehiepsi
Copy link
Member

Yes, it's likely caused by numerical changes. It makes sense to use higher warmup and larger num_samples. Currently, rhat, n_eff in the checked-in version is pretty poor.

@fehiepsi fehiepsi added the help wanted Extra attention is needed label Aug 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed Tutorials/Examples
Projects
None yet
Development

No branches or pull requests

3 participants