You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The Pareto Smoothed Importance Sampling $\hat{k}$ diagnostic (as described by Yao et al. (2018)) can be used to determine if a surrogate posterior determined with variational inference is a good approximation of the true posterior. It would be great if there was a built-in method to calculate $\hat{k}$ in NumPyro (as there is in Pyro).
I have created a preliminary implementation for my work (seen here) that calculates the log importance ratios and uses the arviz.stats.stats._psislw() method to calculate $\hat{k}$. It can be easily generalized for use with any model and guide.
Any thoughts would be much appreciated!
The text was updated successfully, but these errors were encountered:
That's very cool! I didn't know there was a nice diagnostic for how "good" variational approximations are. Just wondering if this could be a nice contribution to arviz itself so the method can also be used for other inference frameworks, e.g., Stan?
There is a method in arviz (arviz.stats.stats._psislw()) which calculates the test statistic given the log importance ratios. But to calculate the log importance ratios you need the numpyro model, guide, best fit parameters, and the samples that you're trying to evaluate, thus my suggestion for a new method that would do this all in one go.
The Pareto Smoothed Importance Sampling$\hat{k}$ diagnostic (as described by Yao et al. (2018)) can be used to determine if a surrogate posterior determined with variational inference is a good approximation of the true posterior. It would be great if there was a built-in method to calculate $\hat{k}$ in NumPyro (as there is in Pyro).
I have created a preliminary implementation for my work (seen here) that calculates the log importance ratios and uses the$\hat{k}$ . It can be easily generalized for use with any model and guide.
arviz.stats.stats._psislw()
method to calculateAny thoughts would be much appreciated!
The text was updated successfully, but these errors were encountered: