diff --git a/forecasttools/sbc.py b/forecasttools/sbc.py new file mode 100644 index 0000000..1e9d4e1 --- /dev/null +++ b/forecasttools/sbc.py @@ -0,0 +1,253 @@ +import arviz as az +import jax.numpy as jnp +import numpyro +from jax import random +from numpyro.infer import MCMC +from numpyro.infer.mcmc import MCMCKernel +from tqdm import tqdm + +from forecasttools.sbc_plots import plot_results + + +class SBC: + def __init__( + self, + mcmc_kernel: MCMCKernel, + *args, + observed_vars: dict[str, str], + num_simulations=10, + sample_kwargs=None, + seed=None, + inspection_mode=False, + **kwargs, + ) -> None: + """ + Set up class for doing SBC. + Based on simulation based calibration (Talts et. al. 2018) in PyMC. + + Parameters + ---------- + mcmc_kernel : numpyro.infer.mcmc.MCMCKernel + An instance of a numpyo MCMC kernel object. + observed_vars : dict[str, str] + A dictionary mapping observed/response variable name as a kwarg to + the numpyro model to the corresponding variable name sampled using + `numpyro.sample`. + args : tuple + Positional arguments passed to `numpyro.sample`. + num_simulations : int + How many simulations to run for SBC. + sample_kwargs : dict[str, Any] + Arguments passed to `numpyro.sample`. Defaults to + `dict(num_warmup=500, num_samples=100, progress_bar = False)`. + Which assumes a MCMC sampler e.g. NUTS. + seed : random.PRNGKey + Random seed. + kwargs : dict[str, Any] + Keyword arguments passed to `numpyro` models. + """ + if sample_kwargs is None: + sample_kwargs = dict( + num_warmup=500, num_samples=100, progress_bar=False + ) + if seed is None: + seed = random.PRNGKey(1234) + self.mcmc_kernel = mcmc_kernel + if not hasattr(mcmc_kernel, "model"): + raise ValueError( + "The `mcmc_kernel` must have a 'model' attribute." + ) + + self.model = mcmc_kernel.model + self.args = args + self.kwargs = kwargs + self.observed_vars = observed_vars + + for key in self.observed_vars: + if key in self.kwargs and self.kwargs[key] is not None: + raise ValueError( + f"The value for '{key}' in kwargs must be None for this to" + " be a prior predictive check." + ) + + self.num_simulations = num_simulations + self.sample_kwargs = sample_kwargs + # Initialize the simulations and random seeds + self.simulations = {} + self._simulations_complete = 0 + prior_pred_rng, sampler_rng = random.split(seed) + self._prior_pred_rng = prior_pred_rng + self._sampler_rng = sampler_rng + self.num_samples = None + # Set the inspection mode + # if in inspection mode, store all idata objects from fitting + self.inspection_mode = inspection_mode + if inspection_mode: + self.idatas = [] + + def _get_prior_predictive_samples( + self, + ) -> tuple[dict[str, any], dict[str, any]]: + """ + Generate samples to use for the simulations by prior predictive + sampling. Then splits between observed and unobserved variables based + on the `observed_vars` attribute. + + Returns + ------- + tuple[dict[str, any], dict[str, any]] + The prior and prior predictive samples. + """ + prior_predictive_fn = numpyro.infer.Predictive( + self.mcmc_kernel.model, num_samples=self.num_simulations + ) + prior_predictions = prior_predictive_fn( + self._prior_pred_rng, *self.args, **self.kwargs + ) + prior_pred = { + k: prior_predictions[v] for k, v in self.observed_vars.items() + } + prior = { + k: v + for k, v in prior_predictions.items() + if k not in self.observed_vars.values() + } + return prior, prior_pred + + def _get_posterior_samples( + self, seed: random.PRNGKey, prior_predictive_draw: dict[str, any] + ) -> tuple[az.InferenceData, int]: + """ + Generate posterior samples conditioned to a prior predictive sample. + This returns the posterior samples and the number of samples. The + number of samples are used in scaling plotting and checking that each + inference draw has the same number of samples. + + Parameters + ---------- + seed : random.PRNGKey + Random seed for MCMC sampling. + prior_predictive_draw : dict[str, any] + Prior predictive samples. + + Returns + ------- + tuple[az.InferenceData, int] + Posterior samples as an arviz InferenceData object, with the count + of posterior samples. + """ + mcmc = MCMC(self.mcmc_kernel, **self.sample_kwargs) + obs_vars = {**self.kwargs, **prior_predictive_draw} + mcmc.run(seed, *self.args, **obs_vars) + num_samples = mcmc.num_samples + # Check that the number of samples is consistent + if self.num_samples is None: + self.num_samples = num_samples + if self.num_samples != num_samples: + raise ValueError( + "The number of samples from the posterior is not consistent." + ) + idata = az.from_numpyro(mcmc) + return idata + + def _increment_rank_statistics(self, prior_draw, posterior) -> None: + """ + Increment the rank statistics for each parameter in the prior draw. + + This method updates the `self.simulations` dictionary with the rank + statistics for each parameter in the `prior_draw` compared to the + `posterior`. + + Returns: + None + """ + for name in prior_draw: + num_dims = jnp.ndim(prior_draw[name]) + if num_dims == 0: + rank_statistics = ( + (posterior[name].sel(chain=0) < prior_draw[name]) + .sum() + .values + ) + self.simulations[name].append(rank_statistics) + else: + rank_statistics = ( + (posterior[name].sel(chain=0) < prior_draw[name]) + .sum(axis=0) + .values + ) + self.simulations[name].append(rank_statistics) + + def run_simulations(self) -> None: + """ + The main method of `SBC` class that runs the simulations for + simulation based calibration and fills the `simulations` attribute + with the results. + """ + prior, prior_pred = self._get_prior_predictive_samples() + sampler_seeds = random.split(self._sampler_rng, self.num_simulations) + self.simulations = {name: [] for name in prior} + progress = tqdm( + initial=self._simulations_complete, + total=self.num_simulations, + ) + if self.inspection_mode: + self.prior = prior + self.prior_pred = prior_pred + try: + while self._simulations_complete < self.num_simulations: + idx = self._simulations_complete + prior_draw = {k: v[idx] for k, v in prior.items()} + prior_predictive_draw = { + k: v[idx] for k, v in prior_pred.items() + } + idata = self._get_posterior_samples( + sampler_seeds[idx], prior_predictive_draw + ) + if self.inspection_mode: + self.idatas.append(idata) + self._increment_rank_statistics(prior_draw, idata["posterior"]) + self._simulations_complete += 1 + progress.update() + finally: + self.simulations = { + k: v[: self._simulations_complete] + for k, v in self.simulations.items() + } + progress.close() + + def plot_results(self, kind="ecdf", var_names=None, color="C0"): + """ + Visual diagnostic for SBC. + + Currently it support two options: `ecdf` for the empirical CDF plots + of the difference between prior and posterior. `hist` for the rank + histogram. + + Parameters + ---------- + simulations + The SBC.simulations dictionary. + kind : str + What kind of plot to make. Supported values are 'ecdf' (default) + and 'hist' + var_names : list[str] + Variables to plot (defaults to all) + figsize : tuple + Figure size for the plot. If None, it will be defined + automatically. + color : str + Color to use for the eCDF or histogram + + Returns + ------- + fig, axes + matplotlib figure and axes + """ + return plot_results( + self.simulations, + self.num_samples, + kind=kind, + var_names=var_names, + color=color, + ) diff --git a/forecasttools/sbc_plots.py b/forecasttools/sbc_plots.py new file mode 100644 index 0000000..43f68e2 --- /dev/null +++ b/forecasttools/sbc_plots.py @@ -0,0 +1,137 @@ +""" +Plots for the simulation based calibration +""" + +import itertools + +import arviz as az +import matplotlib.pyplot as plt +import numpy as np +import numpyro.distributions as dist +from scipy.special import bdtrik + + +def plot_results( + simulations, ndraws, kind="ecdf", var_names=None, figsize=None, color="C0" +): + """ + Visual diagnostic for SBC. + + Currently it support two options: `ecdf` for the empirical CDF plots + of the difference between prior and posterior. `hist` for the rank + histogram. + + Parameters + ---------- + simulations : dict[str, Any] + The SBC.simulations dictionary. + ndraws : int + Number of draws in each posterior predictive sample + kind : str + What kind of plot to make. Supported values are 'ecdf' (default) + and 'hist' + var_names : list[str] + Variables to plot (defaults to all) + figsize : tuple + Figure size for the plot. If None, it will be defined automatically. + color : str + Color to use for the eCDF or histogram + + Returns + ------- + fig, axes + matplotlib figure and axes + """ + + if kind not in ["ecdf", "hist"]: + raise ValueError(f"kind must be 'ecdf' or 'hist', not {kind}") + + if var_names is None: + var_names = list(simulations.keys()) + + sims = {} + for k in var_names: + ary = np.array(simulations[k]) + while ary.ndim < 2: + ary = np.expand_dims(ary, -1) + sims[k] = ary + + n_plots = sum(np.prod(v.shape[1:]) for v in sims.values()) + + if n_plots > 1: + if figsize is None: + figsize = (8, n_plots * 1.0) + + fig, axes = plt.subplots( + nrows=(n_plots + 1) // 2, ncols=2, figsize=figsize, sharex=True + ) + axes = axes.flatten() + else: + if figsize is None: + figsize = (8, 1.5) + + fig, axes = plt.subplots(nrows=1, ncols=1, figsize=figsize) + axes = [axes] + + if kind == "ecdf": + cdf = dist.DiscreteUniform(high=ndraws).cdf + + idx = 0 + for var_name, var_data in sims.items(): + plot_idxs = list( + itertools.product(*(np.arange(s) for s in var_data.shape[1:])) + ) + + for indices in plot_idxs: + if len(plot_idxs) > 1: # has dims + dim_label = f"{var_name}[{']['.join(map(str, indices))}]" + else: + dim_label = var_name + ax = axes[idx] + ary = var_data[(...,) + indices] + if kind == "ecdf": + az.plot_ecdf( + ary, + cdf=cdf, + difference=True, + pit=True, + confidence_bands="auto", + plot_kwargs={"color": color}, + fill_kwargs={"color": color}, + ax=ax, + ) + else: + hist(ary, color=color, ax=ax) + ax.set_title(dim_label) + ax.set_yticks([]) + idx += 1 + + for extra_ax in range(n_plots, len(axes)): + fig.delaxes(axes[extra_ax]) + + return fig, axes + + +def hist(ary, color, ax): + hist, bins = np.histogram(ary, bins="auto") + bin_centers = 0.5 * (bins[:-1] + bins[1:]) + max_rank = np.ceil(bins[-1]) + len_bins = len(bins) + n_sims = len(ary) + + band = np.ceil(bdtrik([0.025, 0.5, 0.975], n_sims, 1 / len_bins)) + ax.bar( + bin_centers, + hist, + width=bins[1] - bins[0], + color=color, + edgecolor="black", + ) + ax.axhline(band[1], color="0.5", ls="--") + ax.fill_between( + np.linspace(0, max_rank, len_bins), + band[0], + band[2], + color="0.5", + alpha=0.5, + ) diff --git a/notebooks/references.bib b/notebooks/references.bib new file mode 100644 index 0000000..f347b61 --- /dev/null +++ b/notebooks/references.bib @@ -0,0 +1,16 @@ + +@book{gelman2013, + title = {Bayesian Data Analysis}, + author = {Gelman, Andrew and Carlin, John B and Stern, Hal S and Dunson, David B and Vehtari, Aki and Rubin, Donald B}, + year = {2013}, + date = {2013}, + publisher = {CRC Press}, + edition = {third}, + langid = {en} +} + +@article{talts, + title = {Validating Bayesian Inference Algorithms with Simulation-Based Calibration}, + author = {Talts, Sean and Betancourt, Michael and Simpson, Daniel and Vehtari, Aki and Gelman, Andrew}, + langid = {en} +} diff --git a/notebooks/sbc_model_checking.qmd b/notebooks/sbc_model_checking.qmd new file mode 100644 index 0000000..3e133a4 --- /dev/null +++ b/notebooks/sbc_model_checking.qmd @@ -0,0 +1,184 @@ +--- +title: "Using the `SBC` class to do sim-based calibration for `numpyro` models" +format: gfm +engine: jupyter +execute: + warning: false + output: false + cache: true +bibliography: references.bib +--- + +This notebook covers: + +- A brief introduction to simulation-based calibration (SBC) as a method for testing the self-consistency of a Bayesian inference method. +- Tuning and running SBC for a `numpyro` model using the `SBC` from `forecasttools`. +- Plotting results and interpreting them. + +The example we will use to illustrate our ideas is the classic "eight schools" inference problem from section 5.5 of Gelman *et al* [@gelman2013] with posterior sampling done using the NUTS algorithm implemented by `numpyro` PPL. + +## Simulation-based calibration + +"Unit" testing the correctness of Bayesian probabilistic programs is challenging because the target quantity is a distribution (the posterior distribution) typically sampled from via a random process, e.g. Markov chain Monte Carlo sampling. The range of models where the correct posterior distribution is known analytically, and therefore, we can test sampling against distributional checks (e.g. Kolmogorov-Smirnov tests, chi-squared tests, etc) is limited. + +Simulation-based calibration (SBC) [@talts] aims to provide a correctness check for a Bayesian model and sampler using the self-consistency property of Bayesian inference: + +$$ +p(\theta) = \int p(y,\theta) dy = \int p(\theta | y) p(y) dy = \int p(\theta|y)\int p(y | \theta')p(\theta')d\theta' dy. +$$ + +In words, this states that the distribution of posterior samples from Bayesian inference gathered over many datasets $(y^{(i)})_{i=1, 2, \dots}$ is the prior distribution $p(\theta)$ *if* the datasets are generated using the prior distribution $y^{(i)} \sim p(y | \theta') p(\theta')$. + +This self-consistency property of Bayesian inference provides a convenient method for checking if the model inference process $p(\theta | y)$ is performing as desired, since generating datasets $(y^{(i)})_{i=1, 2, \dots}$ from a generative model is typically quite easy. Since for any individual $k$th parameter $\theta[k] \in \theta$ we know the prior distribution $p(\theta[k])$ we can check correctness against: + +$$ +F^{(-1)}(\theta[k]) \sim \mathcal{U}(0,1). +$$ + +Where $F^{(-1)}$ is the (pseudo) inverse distribution function of the known distribution $p(\theta[k])$. This is often called the probability integral transform (PIT). Significant (in a classical statistics sense) deviations in the PIT of parameters from the uniform distribution form evidence that either the model is incorrectly coded, its priors are cover hard to sample regimes, and/or the Bayesian inference method is approximate/invalid. See Talts *et al* [@talts] for more details on understanding the returned tests. + +More precisely, in SBC: + +1. For $i = 1,\dots,n$: sample the observable and parameters from the model using the prior distribution, $(y^{(i)}, \theta^{(i)})$. +2. For $i = 1,\dots,n$: sample posterior parameters using the Bayesian method and model under evaluation, $\theta_p^{(i)} \sim p(\theta_p | y^{(i)})$. +3. Generate the probability integral transform (PIT) of each of $k = 1,\dots,P$ parameters $\theta[k]$ with respect to the known prior distribution. +4. Assess the PIT distributions for deviation against an assumption of being $\mathcal{U}([0,1])$. + +In `SBC` we form the PIT by looking at the distribution of proportion of posterior samples that are less than the "true" parameter values cached along with the generated data $y^{(i)}$: $P(\theta_p^{(i)}[k] < \theta^{(i)}[k])$. These are commonly called the rank statistics of the sampling process. Using rank statistics is a more convenient approach in `numpyro` than trying to solve the inverse distribution function directly. + +## Example: Eight schools + +The eight schools example is a classic example of using partial pooling to share inferential strength between groups, cf Gelman *et al* [@gelman2013]: + +> *A study was performed for the Educational Testing Service to analyze the effects of special coaching programs for SAT-V (Scholastic Aptitude Test-Verbal) in each of eight high schools. The outcome variable in each study was the score on a special administration of the SAT-V, a standardized multiple choice test administered by the Educational Testing Service and used to help colleges make admissions decisions; the scores can vary between 200 and 800, with mean about 500 and standard deviation about 100. The SAT examinations are designed to be resistant to short-term efforts directed specifically toward improving performance on the test; instead they are designed to reflect knowledge acquired and abilities developed over many years of education. Nevertheless, each of the eight schools in this study considered its short-term coaching program to be very successful at increasing SAT scores. Also, there was no prior reason to believe that any of the eight programs was more effective than any other or that some were more similar in effect to each other than to any other.* + +The statistical model for the SAT scores in each of the $J=8$ schools $y_j$ is: + +$$ +\begin{aligned} +\mu & \sim \mathcal{N}(0, 5), \\ +\tau & \sim \text{HalfCauchy}(5),\\ +\theta_j & \sim \mathcal{N}(\mu, \tau),~ j = 1,\dots,J, \\ +y_j & \sim \mathcal{N}(\theta_j,\sigma_j),~ j = 1,\dots,J. +\end{aligned} +$$ + +Where the the SAT standard deviations per high school $\sigma_j$ are treated as known along with the scores. + +Gelman *et al* use the eight schools example to illustrate partial pooling, and to demonstrate the importance of choosing the variance priors carefully. + +We start by setting the dependencies and the basic data and random seed. + +```{python} +# Dependencies +import arviz as az +import numpyro +import jax.numpy as jnp +import numpyro.distributions as dist +from jax import random +from numpyro.infer import NUTS + +import forecasttools.sbc as sbc + +``` + +```{python} +rng_key = random.PRNGKey(0) +J = 8 +y = jnp.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]) +sigma = jnp.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]) +``` + +### Running SBC for example with HalfCauchy priors + +Next we define the model above as a `numpyro` model. + +```{python} +def eight_schools_cauchy_prior(J, sigma, y=None): + mu = numpyro.sample('mu', dist.Normal(0, 5)) + tau = numpyro.sample('tau', dist.HalfCauchy(5)) + with numpyro.plate('J', J): + theta = numpyro.sample('theta', dist.Normal(mu, tau)) + numpyro.sample('obs', dist.Normal(theta, sigma), obs=y) + +nuts_kernel_cauchy_prior = NUTS(eight_schools_cauchy_prior) +``` + +Note that we define a `numpyro` `MCMCKernel` object to pass to `SBC`. This wraps the *model* with the *sampling approach*, which are both checked by SBC. + +Next, we create an instance of the `SBC` class. There are a few things to note here: + +- The structure here is that we must pass the `MCMCKernel` object as well as the positional and keyword arguments that define a specific model, for example the `J` and `sigma` arguments. This follows a standard pattern in `numpyro`. +- Note that we don't pass the data argument `y` (and indeed this will throw an error). The reason is that we will be generating datasets for SBC rather than using any particular observed dataset. +- `observed_vars` is an important argument which lets `SBC` know what variables to treat as observed in the SBC process. It is a dictionary rather than a list because it maps the data name *as it is passed to the* `numpyro` *model* to the random variable name inside the `numpyro` model (which can be distinct). For example, in `eight_schools_cauchy_prior` the data value of the observed SAT scores is called `y` but the random sample is traced as `obs` . Hence, we must pass `observed_vars = dict(y = "obs")` . +- `sample_kwargs` gets passed through to the sampler. Note that `progress_bar = False` is important if we don't want to see alot of progress bars. +- `num_simulations` sets the number of SBC trials to run. + +```{python} +seed1, seed2 = random.split(rng_key) +S = sbc.SBC(nuts_kernel_cauchy_prior, J, sigma, + observed_vars = dict(y = "obs"), + sample_kwargs=dict(num_warmup=500, num_samples=1000, progress_bar = False), + num_simulations=100, + seed = seed1) +``` + +`SBC` class instances have methods for running the SBC simulations and plotting the results as a histogram (against binomial sampling within a bin under the uniform target distribution) and as an empirical CDF function in PIT model (this is from the `arviz` plotting utilities). + +```{python} +S.run_simulations() +``` + +```{python} +#| output: true +#| fig-cap: "Histogram plot of SBC results" +S.plot_results(kind = "hist") +``` + +```{python} +#| output: true +#| fig-cap: "ECDF/PIT plot of SBC results" +S.plot_results() +``` + +We can see that the SBC has identified a problem with the model. The Bayesian inference looks to be slightly, but systematically, over-estimating the inter-school variation parameter $\tau$ compared to simulated values. + +Lets use a more informative prior from [this implementation](https://www.tensorflow.org/probability/examples/Eight_Schools). + +$$ +\tau \sim \text{LogNormal}(5, 1). +$$ + +```{python} +def eight_schools_lognormal_prior(J, sigma, y=None): + mu = numpyro.sample('mu', dist.Normal(0, 5)) + tau = numpyro.sample('tau', dist.LogNormal(5, 1)) + with numpyro.plate('J', J): + theta = numpyro.sample('theta', dist.Normal(mu, tau)) + numpyro.sample('obs', dist.Normal(theta, sigma), obs=y) + +nuts_kernel_lognormal_prior = NUTS(eight_schools_lognormal_prior) + +S2 = sbc.SBC(nuts_kernel_lognormal_prior, J, sigma, + observed_vars = dict(y = "obs"), + sample_kwargs=dict(num_warmup=500, num_samples=1000, progress_bar = False), + num_simulations=100, + seed = seed2) +S2.run_simulations() +``` + +```{python} +#| output: true +#| fig-cap: "Histogram plot of SBC results" +S2.plot_results(kind = "hist") +``` + +```{python} +#| output: true +#| fig-cap: "ECDF/PIT plot of SBC results" +S2.plot_results() +``` + +The results of this SBC approach still indicate potential problems with self-consistent identification of every high school specific treatment effect but the self-consistent ability to infer the population parameters $\mu$ and $\tau$ seems to now work. + +## References diff --git a/pyproject.toml b/pyproject.toml index d7e4f24..dfcab5d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,8 @@ matplotlib = "^3.9.2" epiweeks = "^2.3.0" metaflow = "^2.13.9" numpyro = "^0.17.0" +numpy = "^2.2.2" +tqdm = "^4.67.1" [tool.poetry.group.dev.dependencies] @@ -45,6 +47,7 @@ nbclient = "^0.10.0" jupyter = "^1.1.1" pandas = "^2.2.3" metaflow = "^2.13.9" +jupyter-cache = "^1.0.1" [tool.poetry.group.test.dependencies] diff --git a/tests/test_sbc.py b/tests/test_sbc.py new file mode 100644 index 0000000..d5831ed --- /dev/null +++ b/tests/test_sbc.py @@ -0,0 +1,125 @@ +""" +Test the SBC class using a simple model. + +```math +\begin{aligned} +\\mu &\\sim \text{Normal}(0, 1), \\ +z &\\sim \text{Normal}(\\mu, 1). +\\end{aligned} +``` +""" + +import numpyro +import pytest +from jax import random +from numpyro.infer import NUTS + +from forecasttools.sbc import SBC + + +@pytest.fixture +def simple_model(): + def model(y=None): + mu = numpyro.sample("mu", numpyro.distributions.Normal(0, 1)) + numpyro.sample("z", numpyro.distributions.Normal(mu, 1), obs=y) + + return model + + +@pytest.fixture +def mcmc_kernel(simple_model): + return NUTS(simple_model) + + +@pytest.fixture +def observed_vars(): + return {"y": "z"} + + +@pytest.fixture +def sbc_instance(mcmc_kernel, observed_vars): + return SBC(mcmc_kernel, y=None, observed_vars=observed_vars) + + +@pytest.fixture +def sbc_instance_inspection_on(mcmc_kernel, observed_vars): + return SBC( + mcmc_kernel, y=None, observed_vars=observed_vars, inspection_mode=True + ) + + +def test_sbc_initialization(sbc_instance, mcmc_kernel, observed_vars): + """ + Test that the SBC class is initialized correctly. + """ + assert sbc_instance.mcmc_kernel == mcmc_kernel + assert sbc_instance.observed_vars == observed_vars + assert sbc_instance.num_simulations == 10 + assert sbc_instance.sample_kwargs == dict( + num_warmup=500, num_samples=100, progress_bar=False + ) + assert sbc_instance._simulations_complete == 0 + + +def test_get_prior_predictive_samples(sbc_instance): + """ + Test that the prior and prior predictive samples are generated correctly. + """ + prior, prior_pred = sbc_instance._get_prior_predictive_samples() + assert "y" in prior_pred + assert "mu" in prior + + +def test_get_posterior_samples(sbc_instance): + """ + Test that the posterior samples are generated correctly. + """ + prior, prior_pred = sbc_instance._get_prior_predictive_samples() + seed = random.PRNGKey(0) + idata = sbc_instance._get_posterior_samples(seed, prior_pred) + assert "posterior" in idata + + +def test_increment_rank_statistics(sbc_instance): + """ + Test that the rank statistics are incremented correctly. + """ + prior, prior_pred = sbc_instance._get_prior_predictive_samples() + seed = random.PRNGKey(0) + idata = sbc_instance._get_posterior_samples(seed, prior_pred) + sbc_instance.simulations = {name: [] for name in prior} + prior_draw = {k: v[0] for k, v in prior.items()} + sbc_instance._increment_rank_statistics(prior_draw, idata["posterior"]) + + for name in prior_draw: + assert name in sbc_instance.simulations + assert len(sbc_instance.simulations[name]) == 1 + + +def test_run_simulations(sbc_instance): + """ + Test that the simulations for SBC are run correctly. + """ + sbc_instance.run_simulations() + assert sbc_instance._simulations_complete == sbc_instance.num_simulations + assert "mu" in sbc_instance.simulations + + +def test_run_simulations_with_inspection(sbc_instance_inspection_on): + """ + Test that the simulations for SBC are run correctly. + """ + sbc_instance_inspection_on.run_simulations() + assert isinstance(sbc_instance_inspection_on.idatas, list) + assert isinstance(sbc_instance_inspection_on.prior, dict) + assert isinstance(sbc_instance_inspection_on.prior_pred, dict) + + +def test_plot_results(sbc_instance): + """ + Test that the results are plotted. + """ + sbc_instance.run_simulations() + fig, axes = sbc_instance.plot_results() + assert fig is not None + assert axes is not None