diff --git a/neurobayes/utils/utils.py b/neurobayes/utils/utils.py index 093896d..dddc0d9 100644 --- a/neurobayes/utils/utils.py +++ b/neurobayes/utils/utils.py @@ -283,10 +283,15 @@ def set_fn(func: Callable) -> Callable: return local_namespace[func.__name__] -def plot_rhats(samples): +def get_rhats(model): + samples = model.get_samples(1) sgr = numpyro.diagnostics.split_gelman_rubin rhats = [sgr(v).flatten() for (k, v) in samples.items() if k.endswith('kernel')] - rhats = np.concatenate(rhats) + return np.concatenate(rhats) + + +def plot_rhats(model): + rhats = get_rhats(model) plt.hist(rhats, bins=20, color='green', alpha=0.6) plt.xlabel('r_hat', fontsize=14) plt.ylabel('Count', fontsize=14)