Skip to content

Commit

Permalink
Merge pull request #7 from sede-open/thining_results
Browse files Browse the repository at this point in the history
Added ability to thin MCMC results.
  • Loading branch information
mattj89 authored Jun 19, 2024
2 parents 2f2ca16 + fe1a2be commit 523a5f4
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 24 deletions.
21 changes: 13 additions & 8 deletions src/openmcmc/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,19 @@ class MCMC:
"""Class for running Markov Chain Monte Carlo on a Model object to do parameter inference.
Args:
state (dict): initial state of sampler any parameters not
specified will be sampler from prior distributions
samplers (list): list of the samplers to be used for each parameter to be estimated
n_burn (int, optional): number of initial burn in these iterations are not stored, default 5000
state (dict): initial state of sampler. Any parameters not specified will be sampled from prior distributions.
samplers (list): list of the samplers to be used for each parameter to be estimated.
n_burn (int, optional): number of initial burn in these iterations are not stored, default 5000.
n_iter (int, optional): number of iterations which are stored in store, default 5000
n_thin (int, optional): number of iterations to thin by, default 1.
Attributes:
state (dict): initial state of sampler any parameters not
specified will be sampler from prior distributions
samplers (list): list of the samplers to be used for each parameter to be estimated.
n_burn (int): number of initial burn in these iterations are not stored.
n_iter (int): number of iterations which are stored in store.
n_thin (int): number of iterations to thin by.
store (dict): dictionary storing MCMC output as np.array for each inference parameter.
"""
Expand All @@ -42,10 +43,11 @@ class MCMC:
model: Model
n_burn: int = 5000
n_iter: int = 5000
n_thin: int = 1
store: dict = field(default_factory=dict, init=False)

def __post_init__(self):
"""Convert any state values to at least 2D np.arrays and sample any missing states from the prior distributions, and set up storage arrays for the sampled values.
"""Convert any state values to at least 2D np.arrays and sample any missing states from the prior distributions and set up storage arrays for the sampled values.
Ensures that all elements of the initial state are in an appropriate format for running
the sampler:
Expand Down Expand Up @@ -83,16 +85,19 @@ def __post_init__(self):
self.store["log_post"] = np.full(shape=(self.n_iter, 1), fill_value=np.nan)

def run_mcmc(self):
"""Runs MCMC routine for model specification loops for n_iter+ n_burn iterations sampling the state for each parameter and updating the parameter state.
"""Runs MCMC routine for the given model specification.
Numbers the iteratins of the sampler from -self.n_burn to self.n_iter, sotring every self.n_thin samples.
Runs a first loop over samplers, and generates a sample for all corresponding variables in the state. Then
stores the value of each of the sampled parameters in the self.store dictionary, as well as the data fitted
values and the log-posterior value.
"""
for i_it in tqdm(range(-self.n_burn, self.n_iter)):
for sampler in self.samplers:
self.state = sampler.sample(self.state)
for _ in range(self.n_thin):
for sampler in self.samplers:
self.state = sampler.sample(self.state)

if i_it < 0:
continue
Expand Down
44 changes: 28 additions & 16 deletions tests/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,26 +63,31 @@ def fix_state(request):


@pytest.fixture(
params=[(0, 4000), (2000, 4000), (0, 6000), (2000, 6000)],
ids=["n_burn=0,n_iter=4000", "n_burn=non-zero,n_iter=4000", "n_burn=0,n_iter=6000", "n_burn=non-zero,n_iter=6000"],
name="nburn_niter",
params=[(0, 4000, 1), (2000, 4000, 5), (0, 6000, 10), (2000, 6000, 1)],
ids=[
"n_burn=0,n_iter=4000, n_thin=1",
"n_burn=non-zero,n_iter=4000,n_thin=5",
"n_burn=0,n_iter=6000, n_thin=10",
"n_burn=non-zero,n_iter=6000,n_thin=1",
],
name="mcmc_settings",
)
def fix_nburn_niter(request):
def fix_mcmc_settings(request):
"""Define the initial state for the MCMC."""
[n_burn, n_iter] = request.param
nburn_niter = {"nburn": n_burn, "niter": n_iter}
[n_burn, n_iter, n_thin] = request.param
fix_mcmc_settings = {"nburn": n_burn, "niter": n_iter, "nthin": n_thin}

return nburn_niter
return fix_mcmc_settings


def test_run_mcmc(state: dict, sampler: list, model: Model, nburn_niter: dict, monkeypatch):
def test_run_mcmc(state: dict, sampler: list, model: Model, mcmc_settings: dict, monkeypatch):
"""Test run_mcmc function Checks size is correct for the output parameters of the function (state and store) based
on the number of iterations (n_iter) and number of burn (n_burn), i.e.,
Args:
state: dictionary
model: Model input
nburn_niter: dictionary of mcmc settings
mcmc_settings: dictionary of mcmc settings
monkeypatch object for avoiding computationally expensive mcmc sampler.
"""
Expand All @@ -105,32 +110,39 @@ def mock_log_p(self, current_state):
monkeypatch.setattr(NormalGamma, "store", mock_store)
monkeypatch.setattr(Model, "log_p", mock_log_p)

M = MCMC(state, sampler, model, n_burn=nburn_niter["nburn"], n_iter=nburn_niter["niter"])
M = MCMC(
state,
sampler,
model,
n_burn=mcmc_settings["nburn"],
n_iter=mcmc_settings["niter"],
n_thin=mcmc_settings["nthin"],
)
M.store["count"] = 0
M.run_mcmc()
assert M.state["count"] == (M.n_iter + M.n_burn) * len(sampler)
assert M.state["count"] == (M.n_iter + M.n_burn) * len(sampler) * M.n_thin
assert M.store["count"] == M.n_iter * len(sampler)


def test_post_init(state: dict, sampler: list, model: Model, nburn_niter: dict):
def test_post_init(state: dict, sampler: list, model: Model, mcmc_settings: dict):
"""This function test __pos__init function to check returned store and state parameters are np.array of the
dimension n * 1
Args:
state: dictionary
nburn_niter: integer
mcmc_settings: integer
model:
"""
M = MCMC(state, sampler, model, n_iter=nburn_niter["niter"])
M = MCMC(state, sampler, model, n_iter=mcmc_settings["niter"])

assert isinstance(M.state["count"], np.ndarray)
assert M.state["count"].ndim == 2

assert isinstance(M.state["beta"], np.ndarray)
assert M.state["beta"].ndim == 2
assert (len(M.store) - 1) * (M.store["beta"]).shape[1] == len(sampler) * nburn_niter["niter"]
assert M.store["log_post"].size == nburn_niter["niter"]
assert (len(M.store) - 1) * (M.store["beta"]).shape[1] == len(sampler) * mcmc_settings["niter"]
assert M.store["log_post"].size == mcmc_settings["niter"]

if len(sampler) > 1:
assert isinstance(M.state["tau"], np.ndarray)
Expand Down

0 comments on commit 523a5f4

Please sign in to comment.