diff --git a/src/openmcmc/mcmc.py b/src/openmcmc/mcmc.py index d77cc57..445ee2b 100644 --- a/src/openmcmc/mcmc.py +++ b/src/openmcmc/mcmc.py @@ -21,11 +21,11 @@ class MCMC: """Class for running Markov Chain Monte Carlo on a Model object to do parameter inference. Args: - state (dict): initial state of sampler any parameters not - specified will be sampler from prior distributions - samplers (list): list of the samplers to be used for each parameter to be estimated - n_burn (int, optional): number of initial burn in these iterations are not stored, default 5000 + state (dict): initial state of sampler. Any parameters not specified will be sampled from prior distributions. + samplers (list): list of the samplers to be used for each parameter to be estimated. + n_burn (int, optional): number of initial burn in these iterations are not stored, default 5000. n_iter (int, optional): number of iterations which are stored in store, default 5000 + n_thin (int, optional): number of iterations to thin by, default 1. Attributes: state (dict): initial state of sampler any parameters not @@ -33,6 +33,7 @@ class MCMC: samplers (list): list of the samplers to be used for each parameter to be estimated. n_burn (int): number of initial burn in these iterations are not stored. n_iter (int): number of iterations which are stored in store. + n_thin (int): number of iterations to thin by. store (dict): dictionary storing MCMC output as np.array for each inference parameter. """ @@ -42,10 +43,11 @@ class MCMC: model: Model n_burn: int = 5000 n_iter: int = 5000 + n_thin: int = 1 store: dict = field(default_factory=dict, init=False) def __post_init__(self): - """Convert any state values to at least 2D np.arrays and sample any missing states from the prior distributions, and set up storage arrays for the sampled values. + """Convert any state values to at least 2D np.arrays and sample any missing states from the prior distributions and set up storage arrays for the sampled values. Ensures that all elements of the initial state are in an appropriate format for running the sampler: @@ -83,7 +85,9 @@ def __post_init__(self): self.store["log_post"] = np.full(shape=(self.n_iter, 1), fill_value=np.nan) def run_mcmc(self): - """Runs MCMC routine for model specification loops for n_iter+ n_burn iterations sampling the state for each parameter and updating the parameter state. + """Runs MCMC routine for the given model specification. + + Numbers the iteratins of the sampler from -self.n_burn to self.n_iter, sotring every self.n_thin samples. Runs a first loop over samplers, and generates a sample for all corresponding variables in the state. Then stores the value of each of the sampled parameters in the self.store dictionary, as well as the data fitted @@ -91,8 +95,9 @@ def run_mcmc(self): """ for i_it in tqdm(range(-self.n_burn, self.n_iter)): - for sampler in self.samplers: - self.state = sampler.sample(self.state) + for _ in range(self.n_thin): + for sampler in self.samplers: + self.state = sampler.sample(self.state) if i_it < 0: continue diff --git a/tests/test_mcmc.py b/tests/test_mcmc.py index 357db55..b980ff0 100644 --- a/tests/test_mcmc.py +++ b/tests/test_mcmc.py @@ -63,26 +63,31 @@ def fix_state(request): @pytest.fixture( - params=[(0, 4000), (2000, 4000), (0, 6000), (2000, 6000)], - ids=["n_burn=0,n_iter=4000", "n_burn=non-zero,n_iter=4000", "n_burn=0,n_iter=6000", "n_burn=non-zero,n_iter=6000"], - name="nburn_niter", + params=[(0, 4000, 1), (2000, 4000, 5), (0, 6000, 10), (2000, 6000, 1)], + ids=[ + "n_burn=0,n_iter=4000, n_thin=1", + "n_burn=non-zero,n_iter=4000,n_thin=5", + "n_burn=0,n_iter=6000, n_thin=10", + "n_burn=non-zero,n_iter=6000,n_thin=1", + ], + name="mcmc_settings", ) -def fix_nburn_niter(request): +def fix_mcmc_settings(request): """Define the initial state for the MCMC.""" - [n_burn, n_iter] = request.param - nburn_niter = {"nburn": n_burn, "niter": n_iter} + [n_burn, n_iter, n_thin] = request.param + fix_mcmc_settings = {"nburn": n_burn, "niter": n_iter, "nthin": n_thin} - return nburn_niter + return fix_mcmc_settings -def test_run_mcmc(state: dict, sampler: list, model: Model, nburn_niter: dict, monkeypatch): +def test_run_mcmc(state: dict, sampler: list, model: Model, mcmc_settings: dict, monkeypatch): """Test run_mcmc function Checks size is correct for the output parameters of the function (state and store) based on the number of iterations (n_iter) and number of burn (n_burn), i.e., Args: state: dictionary model: Model input - nburn_niter: dictionary of mcmc settings + mcmc_settings: dictionary of mcmc settings monkeypatch object for avoiding computationally expensive mcmc sampler. """ @@ -105,32 +110,39 @@ def mock_log_p(self, current_state): monkeypatch.setattr(NormalGamma, "store", mock_store) monkeypatch.setattr(Model, "log_p", mock_log_p) - M = MCMC(state, sampler, model, n_burn=nburn_niter["nburn"], n_iter=nburn_niter["niter"]) + M = MCMC( + state, + sampler, + model, + n_burn=mcmc_settings["nburn"], + n_iter=mcmc_settings["niter"], + n_thin=mcmc_settings["nthin"], + ) M.store["count"] = 0 M.run_mcmc() - assert M.state["count"] == (M.n_iter + M.n_burn) * len(sampler) + assert M.state["count"] == (M.n_iter + M.n_burn) * len(sampler) * M.n_thin assert M.store["count"] == M.n_iter * len(sampler) -def test_post_init(state: dict, sampler: list, model: Model, nburn_niter: dict): +def test_post_init(state: dict, sampler: list, model: Model, mcmc_settings: dict): """This function test __pos__init function to check returned store and state parameters are np.array of the dimension n * 1 Args: state: dictionary - nburn_niter: integer + mcmc_settings: integer model: """ - M = MCMC(state, sampler, model, n_iter=nburn_niter["niter"]) + M = MCMC(state, sampler, model, n_iter=mcmc_settings["niter"]) assert isinstance(M.state["count"], np.ndarray) assert M.state["count"].ndim == 2 assert isinstance(M.state["beta"], np.ndarray) assert M.state["beta"].ndim == 2 - assert (len(M.store) - 1) * (M.store["beta"]).shape[1] == len(sampler) * nburn_niter["niter"] - assert M.store["log_post"].size == nburn_niter["niter"] + assert (len(M.store) - 1) * (M.store["beta"]).shape[1] == len(sampler) * mcmc_settings["niter"] + assert M.store["log_post"].size == mcmc_settings["niter"] if len(sampler) > 1: assert isinstance(M.state["tau"], np.ndarray)