Skip to content

Commit

Permalink
Add coverage metric
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Dec 30, 2024
1 parent 833c9c7 commit c23839b
Showing 1 changed file with 20 additions and 3 deletions.
23 changes: 20 additions & 3 deletions neurobayes/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,10 +283,27 @@ def set_fn(func: Callable) -> Callable:
return local_namespace[func.__name__]


def plot_rhats(samples):
def get_rhats(model):
samples = model.get_samples(1)
sgr = numpyro.diagnostics.split_gelman_rubin
rhats = [sgr(v).flatten() for (k, v) in samples.items() if k.endswith('kernel')]
rhats = np.concatenate(rhats)
return np.concatenate(rhats)


def plot_rhats(model):
rhats = get_rhats(model)
plt.hist(rhats, bins=20, color='green', alpha=0.6)
plt.xlabel('r_hat', fontsize=14)
plt.ylabel('Count', fontsize=14)
plt.ylabel('Count', fontsize=14)


def confidence_interval(mean,var):
std = np.sqrt(var)
return (mean-1.96*std, mean+1.96*std)

def coverage(y_true, y_pred_mean, y_pred_var):
ci = confidence_interval(y_pred_mean, y_pred_var)
lower_bounds = ci[0]
upper_bounds = ci[1]
in_ci = (y_true >= lower_bounds) & (y_true <= upper_bounds)
return np.mean(in_ci)

0 comments on commit c23839b

Please sign in to comment.