Skip to content

Commit

Permalink
type hints for _get_posterior_samples
Browse files Browse the repository at this point in the history
  • Loading branch information
SamuelBrand1 committed Feb 12, 2025
1 parent 6da1696 commit d266670
Showing 1 changed file with 21 additions and 1 deletion.
22 changes: 21 additions & 1 deletion forecasttools/sbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ def _get_prior_predictive_samples(
Generate samples to use for the simulations by prior predictive
sampling. Then splits between observed and unobserved variables based
on the `observed_vars` attribute.
Returns
-------
tuple[dict[str, any], dict[str, any]]
The prior and prior predictive samples.
"""
prior_predictive_fn = numpyro.infer.Predictive(
self.mcmc_kernel.model, num_samples=self.num_simulations
Expand All @@ -103,12 +108,27 @@ def _get_prior_predictive_samples(
}
return prior, prior_pred

def _get_posterior_samples(self, seed, prior_predictive_draw):
def _get_posterior_samples(
self, seed: random.PRNGKey, prior_predictive_draw: dict[str, any]
) -> tuple[az.InferenceData, int]:
"""
Generate posterior samples conditioned to a prior predictive sample.
This returns the posterior samples and the number of samples. The
number of samples are used in scaling plotting and checking that each
inference draw has the same number of samples.
Parameters
----------
seed : random.PRNGKey
Random seed for MCMC sampling.
prior_predictive_draw : dict[str, any]
Prior predictive samples.
Returns
-------
tuple[az.InferenceData, int]
Posterior samples as an arviz InferenceData object, with the count
of posterior samples.
"""
mcmc = MCMC(self.mcmc_kernel, **self.sample_kwargs)
obs_vars = {**self.kwargs, **prior_predictive_draw}
Expand Down

0 comments on commit d266670

Please sign in to comment.