Skip to content

Commit

Permalink
add inspection mode
Browse files Browse the repository at this point in the history
  • Loading branch information
SamuelBrand1 committed Feb 14, 2025
1 parent 9efc9bb commit 250e49d
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
13 changes: 12 additions & 1 deletion forecasttools/sbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(
num_simulations=10,
sample_kwargs=None,
seed=None,
inspection_mode=False,
**kwargs,
) -> None:
"""
Expand Down Expand Up @@ -71,13 +72,18 @@ def __init__(

self.num_simulations = num_simulations
self.sample_kwargs = sample_kwargs

# Initialize the simulations and random seeds
self.simulations = {}
self._simulations_complete = 0
prior_pred_rng, sampler_rng = random.split(seed)
self._prior_pred_rng = prior_pred_rng
self._sampler_rng = sampler_rng
self.num_samples = None
# Set the inspection mode
# if in inspection mode, store all idata objects from fitting
self.inspection_mode = inspection_mode
if inspection_mode:
self.idatas = []

def _get_prior_predictive_samples(
self,
Expand Down Expand Up @@ -185,6 +191,9 @@ def run_simulations(self) -> None:
initial=self._simulations_complete,
total=self.num_simulations,
)
if self.inspection_mode:
self.prior = prior
self.prior_pred = prior_pred
try:
while self._simulations_complete < self.num_simulations:
idx = self._simulations_complete
Expand All @@ -195,6 +204,8 @@ def run_simulations(self) -> None:
idata = self._get_posterior_samples(
sampler_seeds[idx], prior_predictive_draw
)
if self.inspection_mode:
self.idatas.append(idata)
self._increment_rank_statistics(prior_draw, idata["posterior"])
self._simulations_complete += 1
progress.update()
Expand Down
17 changes: 17 additions & 0 deletions tests/test_sbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ def sbc_instance(mcmc_kernel, observed_vars):
return SBC(mcmc_kernel, y=None, observed_vars=observed_vars)


@pytest.fixture
def sbc_instance_inspection_on(mcmc_kernel, observed_vars):
return SBC(
mcmc_kernel, y=None, observed_vars=observed_vars, inspection_mode=True
)


def test_sbc_initialization(sbc_instance, mcmc_kernel, observed_vars):
"""
Test that the SBC class is initialized correctly.
Expand Down Expand Up @@ -98,6 +105,16 @@ def test_run_simulations(sbc_instance):
assert "mu" in sbc_instance.simulations


def test_run_simulations_with_inspection(sbc_instance_inspection_on):
"""
Test that the simulations for SBC are run correctly.
"""
sbc_instance_inspection_on.run_simulations()
assert isinstance(sbc_instance_inspection_on.idatas, list)
assert isinstance(sbc_instance_inspection_on.prior, dict)
assert isinstance(sbc_instance_inspection_on.prior_pred, dict)


def test_plot_results(sbc_instance):
"""
Test that the results are plotted.
Expand Down

0 comments on commit 250e49d

Please sign in to comment.