Skip to content

Commit

Permalink
pre-commit fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
AFg6K7h4fhy2 committed Feb 11, 2025
1 parent 972c32e commit fb7b16d
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 65 deletions.
110 changes: 78 additions & 32 deletions forecasttools/sbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ def __init__(
*args,
observed_vars: dict[str, str],
num_simulations=10,
sample_kwargs=dict(num_warmup=500, num_samples=100, progress_bar = False),
seed=random.PRNGKey(1234),
sample_kwargs=None,
seed=1234,
**kwargs,
):
"""Set up class for doing SBC.
Expand All @@ -28,23 +28,33 @@ def __init__(
mcmc_kernel : numpyro.infer.mcmc.MCMCKernel
An instance of a numpyo MCMC kernel object.
observed_vars : dict[str, str]
A dictionary mapping observed/response variable name as a kwarg to the
numpyro model to the corresponding variable name sampled using `numpyro.sample`.
A dictionary mapping observed/response variable name as a kwarg to
the numpyro model to the corresponding variable name sampled using
`numpyro.sample`.
args : tuple
Positional arguments passed to `numpyro.sample`.
num_simulations : int
How many simulations to run for SBC.
sample_kwargs : dict[str] -> Any
Arguments passed to `numpyro.sample`. Defaults to `dict(num_warmup=500, num_samples=100, progress_bar = False)`.
Arguments passed to `numpyro.sample`. Defaults to
`dict(num_warmup=500, num_samples=100, progress_bar = False)`.
Which assumes a MCMC sampler e.g. NUTS.
seed : random.PRNGKey
seed : int
Random seed.
kwargs : dict
Keyword arguments passed to `numpyro` models.
"""
if sample_kwargs is None:
sample_kwargs = dict(
num_warmup=500, num_samples=100, progress_bar=False
)
seed = random.PRNGKey(seed)

self.mcmc_kernel = mcmc_kernel
if not hasattr(mcmc_kernel, 'model'):
raise ValueError("The `mcmc_kernel` must have a 'model' attribute.")
if not hasattr(mcmc_kernel, "model"):
raise ValueError(
"The `mcmc_kernel` must have a 'model' attribute."
)

self.model = mcmc_kernel.model
self.args = args
Expand All @@ -53,7 +63,10 @@ def __init__(

for key in self.observed_vars:
if key in self.kwargs and self.kwargs[key] is not None:
raise ValueError(f"The value for '{key}' in kwargs must be None for this to be a prior predictive check.")
raise ValueError(
f"The value for '{key}' in kwargs must be None for this to"
" be a prior predictive check."
)

self.num_simulations = num_simulations
self.sample_kwargs = sample_kwargs
Expand All @@ -67,20 +80,32 @@ def __init__(

def _get_prior_predictive_samples(self):
"""
Generate samples to use for the simulations by prior predictive sampling. Then splits between
observed and unobserved variables based on the `observed_vars` attribute.
Generate samples to use for the simulations by prior predictive
sampling. Then splits between observed and unobserved variables based
on the `observed_vars` attribute.
"""
prior_predictive_fn = numpyro.infer.Predictive(self.mcmc_kernel.model, num_samples=self.num_simulations)
prior_predictions = prior_predictive_fn(self._prior_pred_rng, *self.args, **self.kwargs)
prior_pred = {k: prior_predictions[v] for k, v in self.observed_vars.items()}
prior = {k: v for k, v in prior_predictions.items() if k not in self.observed_vars.values()}
prior_predictive_fn = numpyro.infer.Predictive(
self.mcmc_kernel.model, num_samples=self.num_simulations
)
prior_predictions = prior_predictive_fn(
self._prior_pred_rng, *self.args, **self.kwargs
)
prior_pred = {
k: prior_predictions[v] for k, v in self.observed_vars.items()
}
prior = {
k: v
for k, v in prior_predictions.items()
if k not in self.observed_vars.values()
}
return prior, prior_pred

def _get_posterior_samples(self, seed, prior_predictive_draw):
"""
Generate posterior samples conditioned to a prior predictive sample.
This returns the posterior samples and the number of samples. The number of samples are used
in scaling plotting and checking that each inference draw has the same number of samples.
This returns the posterior samples and the number of samples. The
number of samples are used in scaling plotting and checking that each
inference draw has the same number of samples.
"""
mcmc = MCMC(self.mcmc_kernel, **self.sample_kwargs)
obs_vars = {**self.kwargs, **prior_predictive_draw}
Expand All @@ -91,42 +116,55 @@ def _get_posterior_samples(self, seed, prior_predictive_draw):

def run_simulations(self):
"""
The main method of `SBC` class that runs the simulations for simulation based calibration and
fills the `simulations` attribute with the results.
The main method of `SBC` class that runs the simulations for
simulation based calibration and fills the `simulations` attribute
with the results.
"""
prior, prior_pred = self._get_prior_predictive_samples()
sampler_seeds = random.split(self._sampler_rng, self.num_simulations)
self.simulations = {name: [] for name in prior.keys()}
self.simulations = {name: [] for name in prior}
progress = tqdm(
initial=self._simulations_complete,
total=self.num_simulations,
)
try:
while self._simulations_complete < self.num_simulations:
idx = self._simulations_complete
prior_draw = {k:v[idx] for k, v in prior.items()}
prior_predictive_draw = {k:v[idx] for k, v in prior_pred.items()}
idata, num_samples = self._get_posterior_samples(sampler_seeds[idx], prior_predictive_draw)
prior_draw = {k: v[idx] for k, v in prior.items()}
prior_predictive_draw = {
k: v[idx] for k, v in prior_pred.items()
}
idata, num_samples = self._get_posterior_samples(
sampler_seeds[idx], prior_predictive_draw
)
if self.num_samples is None:
self.num_samples = num_samples
if self.num_samples != num_samples:
raise ValueError("The number of samples from the posterior is not consistent.")
posterior = idata['posterior']
for name in prior.keys():
raise ValueError(
"The number of samples from the posterior is not"
" consistent."
)
posterior = idata["posterior"]
for name in prior:
num_dims = jnp.ndim(prior_draw[name])
if num_dims == 0:
self.simulations[name].append(
(posterior[name].sel(chain=0) < prior_draw[name]).sum().values
(posterior[name].sel(chain=0) < prior_draw[name])
.sum()
.values
)
else:
self.simulations[name].append(
(posterior[name].sel(chain=0) < prior_draw[name]).sum(axis=0).values
(posterior[name].sel(chain=0) < prior_draw[name])
.sum(axis=0)
.values
)
self._simulations_complete += 1
progress.update()
finally:
self.simulations = {
k: v[: self._simulations_complete] for k, v in self.simulations.items()
k: v[: self._simulations_complete]
for k, v in self.simulations.items()
}
progress.close()

Expand All @@ -143,11 +181,13 @@ def plot_results(self, kind="ecdf", var_names=None, color="C0"):
simulations
The SBC.simulations dictionary.
kind : str
What kind of plot to make. Supported values are 'ecdf' (default) and 'hist'
What kind of plot to make. Supported values are 'ecdf' (default)
and 'hist'
var_names : list[str]
Variables to plot (defaults to all)
figsize : tuple
Figure size for the plot. If None, it will be defined automatically.
Figure size for the plot. If None, it will be defined
automatically.
color : str
Color to use for the eCDF or histogram
Expand All @@ -156,4 +196,10 @@ def plot_results(self, kind="ecdf", var_names=None, color="C0"):
fig, axes
matplotlib figure and axes
"""
return plot_results(self.simulations, self.num_samples, kind=kind, var_names=var_names, color=color)
return plot_results(
self.simulations,
self.num_samples,
kind=kind,
var_names=var_names,
color=color,
)
71 changes: 46 additions & 25 deletions forecasttools/sbc_plots.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Plots for the simulation based calibration"""

import itertools

import arviz as az
Expand All @@ -7,7 +8,9 @@
from scipy.special import bdtrik


def plot_results(simulations, ndraws, kind="ecdf", var_names=None, figsize=None, color="C0"):
def plot_results(
simulations, ndraws, kind="ecdf", var_names=None, figsize=None, color="C0"
):
"""Visual diagnostic for SBC.
Currently it support two options: `ecdf` for the empirical CDF plots
Expand All @@ -22,7 +25,8 @@ def plot_results(simulations, ndraws, kind="ecdf", var_names=None, figsize=None,
ndraws : int
Number of draws in each posterior predictive sample
kind : str
What kind of plot to make. Supported values are 'ecdf' (default) and 'hist'
What kind of plot to make. Supported values are 'ecdf' (default)
and 'hist'
var_names : list[str]
Variables to plot (defaults to all)
figsize : tuple
Expand Down Expand Up @@ -53,13 +57,15 @@ def plot_results(simulations, ndraws, kind="ecdf", var_names=None, figsize=None,

if n_plots > 1:
if figsize is None:
figsize=(8, n_plots*1.0)
figsize = (8, n_plots * 1.0)

fig, axes = plt.subplots(nrows=(n_plots + 1) // 2, ncols=2, figsize=figsize, sharex=True)
fig, axes = plt.subplots(
nrows=(n_plots + 1) // 2, ncols=2, figsize=figsize, sharex=True
)
axes = axes.flatten()
else:
if figsize is None:
figsize=(8, 1.5)
figsize = (8, 1.5)

fig, axes = plt.subplots(nrows=1, ncols=1, figsize=figsize)
axes = [axes]
Expand All @@ -69,28 +75,28 @@ def plot_results(simulations, ndraws, kind="ecdf", var_names=None, figsize=None,

idx = 0
for var_name, var_data in sims.items():
plot_idxs = list(itertools.product(*(np.arange(s) for s in var_data.shape[1:])))
if len(plot_idxs) > 1:
has_dims = True
else:
has_dims = False
plot_idxs = list(
itertools.product(*(np.arange(s) for s in var_data.shape[1:]))
)

for indices in plot_idxs:
if has_dims:
dim_label = f'{var_name}[{"][".join(map(str, indices))}]'
if len(plot_idxs) > 1: # has dims
dim_label = f"{var_name}[{']['.join(map(str, indices))}]"
else:
dim_label = var_name
ax = axes[idx]
ary = var_data[(...,) + indices]
if kind == "ecdf":
az.plot_ecdf(ary,
cdf=cdf,
difference = True,
pit = True,
confidence_bands = "auto",
plot_kwargs={"color":color},
fill_kwargs={"color":color},
ax=ax)
az.plot_ecdf(
ary,
cdf=cdf,
difference=True,
pit=True,
confidence_bands="auto",
plot_kwargs={"color": color},
fill_kwargs={"color": color},
ax=ax,
)
else:
hist(ary, color=color, ax=ax)
ax.set_title(dim_label)
Expand All @@ -102,22 +108,37 @@ def plot_results(simulations, ndraws, kind="ecdf", var_names=None, figsize=None,

return fig, axes


def hist(ary, color, ax):
hist, bins = np.histogram(ary, bins="auto")
bin_centers = 0.5 * (bins[:-1] + bins[1:])
max_rank = np.ceil(bins[-1])
len_bins = len(bins)
n_sims = len(ary)

band = np.ceil(bdtrik([0.025, 0.5, 0.975], n_sims, 1/len_bins))
ax.bar(bin_centers, hist, width=bins[1] - bins[0], color=color, edgecolor='black')
band = np.ceil(bdtrik([0.025, 0.5, 0.975], n_sims, 1 / len_bins))
ax.bar(
bin_centers,
hist,
width=bins[1] - bins[0],
color=color,
edgecolor="black",
)
ax.axhline(band[1], color="0.5", ls="--")
ax.fill_between(np.linspace(0, max_rank, len_bins), band[0], band[2], color="0.5", alpha=0.5)
ax.fill_between(
np.linspace(0, max_rank, len_bins),
band[0],
band[2],
color="0.5",
alpha=0.5,
)


class UniformCDF():
class UniformCDF:
def __init__(self, upper_bound):
self.upper_bound = upper_bound

def __call__(self, x):
return np.where(x < 0, 0, np.where(x > self.upper_bound, 1, x / self.upper_bound))
return np.where(
x < 0, 0, np.where(x > self.upper_bound, 1, x / self.upper_bound)
)
Loading

0 comments on commit fb7b16d

Please sign in to comment.