generated from CDCgov/template
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'origin/main' into 18-function-to-cast-i…
…nferencedata-into-tidy_draws-format
- Loading branch information
Showing
6 changed files
with
718 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,253 @@ | ||
import arviz as az | ||
import jax.numpy as jnp | ||
import numpyro | ||
from jax import random | ||
from numpyro.infer import MCMC | ||
from numpyro.infer.mcmc import MCMCKernel | ||
from tqdm import tqdm | ||
|
||
from forecasttools.sbc_plots import plot_results | ||
|
||
|
||
class SBC: | ||
def __init__( | ||
self, | ||
mcmc_kernel: MCMCKernel, | ||
*args, | ||
observed_vars: dict[str, str], | ||
num_simulations=10, | ||
sample_kwargs=None, | ||
seed=None, | ||
inspection_mode=False, | ||
**kwargs, | ||
) -> None: | ||
""" | ||
Set up class for doing SBC. | ||
Based on simulation based calibration (Talts et. al. 2018) in PyMC. | ||
Parameters | ||
---------- | ||
mcmc_kernel : numpyro.infer.mcmc.MCMCKernel | ||
An instance of a numpyo MCMC kernel object. | ||
observed_vars : dict[str, str] | ||
A dictionary mapping observed/response variable name as a kwarg to | ||
the numpyro model to the corresponding variable name sampled using | ||
`numpyro.sample`. | ||
args : tuple | ||
Positional arguments passed to `numpyro.sample`. | ||
num_simulations : int | ||
How many simulations to run for SBC. | ||
sample_kwargs : dict[str, Any] | ||
Arguments passed to `numpyro.sample`. Defaults to | ||
`dict(num_warmup=500, num_samples=100, progress_bar = False)`. | ||
Which assumes a MCMC sampler e.g. NUTS. | ||
seed : random.PRNGKey | ||
Random seed. | ||
kwargs : dict[str, Any] | ||
Keyword arguments passed to `numpyro` models. | ||
""" | ||
if sample_kwargs is None: | ||
sample_kwargs = dict( | ||
num_warmup=500, num_samples=100, progress_bar=False | ||
) | ||
if seed is None: | ||
seed = random.PRNGKey(1234) | ||
self.mcmc_kernel = mcmc_kernel | ||
if not hasattr(mcmc_kernel, "model"): | ||
raise ValueError( | ||
"The `mcmc_kernel` must have a 'model' attribute." | ||
) | ||
|
||
self.model = mcmc_kernel.model | ||
self.args = args | ||
self.kwargs = kwargs | ||
self.observed_vars = observed_vars | ||
|
||
for key in self.observed_vars: | ||
if key in self.kwargs and self.kwargs[key] is not None: | ||
raise ValueError( | ||
f"The value for '{key}' in kwargs must be None for this to" | ||
" be a prior predictive check." | ||
) | ||
|
||
self.num_simulations = num_simulations | ||
self.sample_kwargs = sample_kwargs | ||
# Initialize the simulations and random seeds | ||
self.simulations = {} | ||
self._simulations_complete = 0 | ||
prior_pred_rng, sampler_rng = random.split(seed) | ||
self._prior_pred_rng = prior_pred_rng | ||
self._sampler_rng = sampler_rng | ||
self.num_samples = None | ||
# Set the inspection mode | ||
# if in inspection mode, store all idata objects from fitting | ||
self.inspection_mode = inspection_mode | ||
if inspection_mode: | ||
self.idatas = [] | ||
|
||
def _get_prior_predictive_samples( | ||
self, | ||
) -> tuple[dict[str, any], dict[str, any]]: | ||
""" | ||
Generate samples to use for the simulations by prior predictive | ||
sampling. Then splits between observed and unobserved variables based | ||
on the `observed_vars` attribute. | ||
Returns | ||
------- | ||
tuple[dict[str, any], dict[str, any]] | ||
The prior and prior predictive samples. | ||
""" | ||
prior_predictive_fn = numpyro.infer.Predictive( | ||
self.mcmc_kernel.model, num_samples=self.num_simulations | ||
) | ||
prior_predictions = prior_predictive_fn( | ||
self._prior_pred_rng, *self.args, **self.kwargs | ||
) | ||
prior_pred = { | ||
k: prior_predictions[v] for k, v in self.observed_vars.items() | ||
} | ||
prior = { | ||
k: v | ||
for k, v in prior_predictions.items() | ||
if k not in self.observed_vars.values() | ||
} | ||
return prior, prior_pred | ||
|
||
def _get_posterior_samples( | ||
self, seed: random.PRNGKey, prior_predictive_draw: dict[str, any] | ||
) -> tuple[az.InferenceData, int]: | ||
""" | ||
Generate posterior samples conditioned to a prior predictive sample. | ||
This returns the posterior samples and the number of samples. The | ||
number of samples are used in scaling plotting and checking that each | ||
inference draw has the same number of samples. | ||
Parameters | ||
---------- | ||
seed : random.PRNGKey | ||
Random seed for MCMC sampling. | ||
prior_predictive_draw : dict[str, any] | ||
Prior predictive samples. | ||
Returns | ||
------- | ||
tuple[az.InferenceData, int] | ||
Posterior samples as an arviz InferenceData object, with the count | ||
of posterior samples. | ||
""" | ||
mcmc = MCMC(self.mcmc_kernel, **self.sample_kwargs) | ||
obs_vars = {**self.kwargs, **prior_predictive_draw} | ||
mcmc.run(seed, *self.args, **obs_vars) | ||
num_samples = mcmc.num_samples | ||
# Check that the number of samples is consistent | ||
if self.num_samples is None: | ||
self.num_samples = num_samples | ||
if self.num_samples != num_samples: | ||
raise ValueError( | ||
"The number of samples from the posterior is not consistent." | ||
) | ||
idata = az.from_numpyro(mcmc) | ||
return idata | ||
|
||
def _increment_rank_statistics(self, prior_draw, posterior) -> None: | ||
""" | ||
Increment the rank statistics for each parameter in the prior draw. | ||
This method updates the `self.simulations` dictionary with the rank | ||
statistics for each parameter in the `prior_draw` compared to the | ||
`posterior`. | ||
Returns: | ||
None | ||
""" | ||
for name in prior_draw: | ||
num_dims = jnp.ndim(prior_draw[name]) | ||
if num_dims == 0: | ||
rank_statistics = ( | ||
(posterior[name].sel(chain=0) < prior_draw[name]) | ||
.sum() | ||
.values | ||
) | ||
self.simulations[name].append(rank_statistics) | ||
else: | ||
rank_statistics = ( | ||
(posterior[name].sel(chain=0) < prior_draw[name]) | ||
.sum(axis=0) | ||
.values | ||
) | ||
self.simulations[name].append(rank_statistics) | ||
|
||
def run_simulations(self) -> None: | ||
""" | ||
The main method of `SBC` class that runs the simulations for | ||
simulation based calibration and fills the `simulations` attribute | ||
with the results. | ||
""" | ||
prior, prior_pred = self._get_prior_predictive_samples() | ||
sampler_seeds = random.split(self._sampler_rng, self.num_simulations) | ||
self.simulations = {name: [] for name in prior} | ||
progress = tqdm( | ||
initial=self._simulations_complete, | ||
total=self.num_simulations, | ||
) | ||
if self.inspection_mode: | ||
self.prior = prior | ||
self.prior_pred = prior_pred | ||
try: | ||
while self._simulations_complete < self.num_simulations: | ||
idx = self._simulations_complete | ||
prior_draw = {k: v[idx] for k, v in prior.items()} | ||
prior_predictive_draw = { | ||
k: v[idx] for k, v in prior_pred.items() | ||
} | ||
idata = self._get_posterior_samples( | ||
sampler_seeds[idx], prior_predictive_draw | ||
) | ||
if self.inspection_mode: | ||
self.idatas.append(idata) | ||
self._increment_rank_statistics(prior_draw, idata["posterior"]) | ||
self._simulations_complete += 1 | ||
progress.update() | ||
finally: | ||
self.simulations = { | ||
k: v[: self._simulations_complete] | ||
for k, v in self.simulations.items() | ||
} | ||
progress.close() | ||
|
||
def plot_results(self, kind="ecdf", var_names=None, color="C0"): | ||
""" | ||
Visual diagnostic for SBC. | ||
Currently it support two options: `ecdf` for the empirical CDF plots | ||
of the difference between prior and posterior. `hist` for the rank | ||
histogram. | ||
Parameters | ||
---------- | ||
simulations | ||
The SBC.simulations dictionary. | ||
kind : str | ||
What kind of plot to make. Supported values are 'ecdf' (default) | ||
and 'hist' | ||
var_names : list[str] | ||
Variables to plot (defaults to all) | ||
figsize : tuple | ||
Figure size for the plot. If None, it will be defined | ||
automatically. | ||
color : str | ||
Color to use for the eCDF or histogram | ||
Returns | ||
------- | ||
fig, axes | ||
matplotlib figure and axes | ||
""" | ||
return plot_results( | ||
self.simulations, | ||
self.num_samples, | ||
kind=kind, | ||
var_names=var_names, | ||
color=color, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
""" | ||
Plots for the simulation based calibration | ||
""" | ||
|
||
import itertools | ||
|
||
import arviz as az | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import numpyro.distributions as dist | ||
from scipy.special import bdtrik | ||
|
||
|
||
def plot_results( | ||
simulations, ndraws, kind="ecdf", var_names=None, figsize=None, color="C0" | ||
): | ||
""" | ||
Visual diagnostic for SBC. | ||
Currently it support two options: `ecdf` for the empirical CDF plots | ||
of the difference between prior and posterior. `hist` for the rank | ||
histogram. | ||
Parameters | ||
---------- | ||
simulations : dict[str, Any] | ||
The SBC.simulations dictionary. | ||
ndraws : int | ||
Number of draws in each posterior predictive sample | ||
kind : str | ||
What kind of plot to make. Supported values are 'ecdf' (default) | ||
and 'hist' | ||
var_names : list[str] | ||
Variables to plot (defaults to all) | ||
figsize : tuple | ||
Figure size for the plot. If None, it will be defined automatically. | ||
color : str | ||
Color to use for the eCDF or histogram | ||
Returns | ||
------- | ||
fig, axes | ||
matplotlib figure and axes | ||
""" | ||
|
||
if kind not in ["ecdf", "hist"]: | ||
raise ValueError(f"kind must be 'ecdf' or 'hist', not {kind}") | ||
|
||
if var_names is None: | ||
var_names = list(simulations.keys()) | ||
|
||
sims = {} | ||
for k in var_names: | ||
ary = np.array(simulations[k]) | ||
while ary.ndim < 2: | ||
ary = np.expand_dims(ary, -1) | ||
sims[k] = ary | ||
|
||
n_plots = sum(np.prod(v.shape[1:]) for v in sims.values()) | ||
|
||
if n_plots > 1: | ||
if figsize is None: | ||
figsize = (8, n_plots * 1.0) | ||
|
||
fig, axes = plt.subplots( | ||
nrows=(n_plots + 1) // 2, ncols=2, figsize=figsize, sharex=True | ||
) | ||
axes = axes.flatten() | ||
else: | ||
if figsize is None: | ||
figsize = (8, 1.5) | ||
|
||
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=figsize) | ||
axes = [axes] | ||
|
||
if kind == "ecdf": | ||
cdf = dist.DiscreteUniform(high=ndraws).cdf | ||
|
||
idx = 0 | ||
for var_name, var_data in sims.items(): | ||
plot_idxs = list( | ||
itertools.product(*(np.arange(s) for s in var_data.shape[1:])) | ||
) | ||
|
||
for indices in plot_idxs: | ||
if len(plot_idxs) > 1: # has dims | ||
dim_label = f"{var_name}[{']['.join(map(str, indices))}]" | ||
else: | ||
dim_label = var_name | ||
ax = axes[idx] | ||
ary = var_data[(...,) + indices] | ||
if kind == "ecdf": | ||
az.plot_ecdf( | ||
ary, | ||
cdf=cdf, | ||
difference=True, | ||
pit=True, | ||
confidence_bands="auto", | ||
plot_kwargs={"color": color}, | ||
fill_kwargs={"color": color}, | ||
ax=ax, | ||
) | ||
else: | ||
hist(ary, color=color, ax=ax) | ||
ax.set_title(dim_label) | ||
ax.set_yticks([]) | ||
idx += 1 | ||
|
||
for extra_ax in range(n_plots, len(axes)): | ||
fig.delaxes(axes[extra_ax]) | ||
|
||
return fig, axes | ||
|
||
|
||
def hist(ary, color, ax): | ||
hist, bins = np.histogram(ary, bins="auto") | ||
bin_centers = 0.5 * (bins[:-1] + bins[1:]) | ||
max_rank = np.ceil(bins[-1]) | ||
len_bins = len(bins) | ||
n_sims = len(ary) | ||
|
||
band = np.ceil(bdtrik([0.025, 0.5, 0.975], n_sims, 1 / len_bins)) | ||
ax.bar( | ||
bin_centers, | ||
hist, | ||
width=bins[1] - bins[0], | ||
color=color, | ||
edgecolor="black", | ||
) | ||
ax.axhline(band[1], color="0.5", ls="--") | ||
ax.fill_between( | ||
np.linspace(0, max_rank, len_bins), | ||
band[0], | ||
band[2], | ||
color="0.5", | ||
alpha=0.5, | ||
) |
Oops, something went wrong.