Skip to content

Commit

Permalink
renaming the fixture function for iteration. burn and thinning
Browse files Browse the repository at this point in the history
  • Loading branch information
TannazH committed Apr 30, 2024
1 parent fca8594 commit 9e8b9ad
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions tests/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,24 +68,24 @@ def fix_state(request):
"n_burn=non-zero,n_iter=4000,n_thin=5",
"n_burn=0,n_iter=6000, n_thin=10",
"n_burn=non-zero,n_iter=6000,n_thin=1"],
name="nburn_niter",
name="mcmc_settings",
)
def fix_nburn_niter_nthin(request):
def fix_mcmc_settings(request):
"""Define the initial state for the MCMC."""
[n_burn, n_iter, n_thin] = request.param
fix_nburn_niter_nthin = {"nburn": n_burn, "niter": n_iter, "nthin": n_thin}
fix_mcmc_settings = {"nburn": n_burn, "niter": n_iter, "nthin": n_thin}

return fix_nburn_niter_nthin
return fix_mcmc_settings


def test_run_mcmc(state: dict, sampler: list, model: Model, nburn_niter: dict, monkeypatch):
def test_run_mcmc(state: dict, sampler: list, model: Model, mcmc_settings: dict, monkeypatch):
"""Test run_mcmc function Checks size is correct for the output parameters of the function (state and store) based
on the number of iterations (n_iter) and number of burn (n_burn), i.e.,
Args:
state: dictionary
model: Model input
nburn_niter: dictionary of mcmc settings
mcmc_settings: dictionary of mcmc settings
monkeypatch object for avoiding computationally expensive mcmc sampler.
"""
Expand All @@ -109,34 +109,34 @@ def mock_log_p(self, current_state):
monkeypatch.setattr(Model, "log_p", mock_log_p)

M = MCMC(state, sampler, model,
n_burn=nburn_niter["nburn"],
n_iter=nburn_niter["niter"],
n_thin=nburn_niter["nthin"])
n_burn=mcmc_settings["nburn"],
n_iter=mcmc_settings["niter"],
n_thin=mcmc_settings["nthin"])
M.store["count"] = 0
M.run_mcmc()
assert M.state["count"] == (M.n_iter + M.n_burn) * len(sampler) * M.n_thin
assert M.store["count"] == M.n_iter * len(sampler)


def test_post_init(state: dict, sampler: list, model: Model, nburn_niter: dict):
def test_post_init(state: dict, sampler: list, model: Model, mcmc_settings: dict):
"""This function test __pos__init function to check returned store and state parameters are np.array of the
dimension n * 1
Args:
state: dictionary
nburn_niter: integer
mcmc_settings: integer
model:
"""
M = MCMC(state, sampler, model, n_iter=nburn_niter["niter"])
M = MCMC(state, sampler, model, n_iter=mcmc_settings["niter"])

assert isinstance(M.state["count"], np.ndarray)
assert M.state["count"].ndim == 2

assert isinstance(M.state["beta"], np.ndarray)
assert M.state["beta"].ndim == 2
assert (len(M.store) - 1) * (M.store["beta"]).shape[1] == len(sampler) * nburn_niter["niter"]
assert M.store["log_post"].size == nburn_niter["niter"]
assert (len(M.store) - 1) * (M.store["beta"]).shape[1] == len(sampler) * mcmc_settings["niter"]
assert M.store["log_post"].size == mcmc_settings["niter"]

if len(sampler) > 1:
assert isinstance(M.state["tau"], np.ndarray)
Expand Down

0 comments on commit 9e8b9ad

Please sign in to comment.