From 034b9a417ca1993f7f0969be59c412be9a302d38 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 29 Mar 2024 17:43:26 +0100 Subject: [PATCH] Make default STEP_METHODS a list that can be modified --- pymc/step_methods/__init__.py | 7 ++++--- tests/sampling/test_mcmc.py | 16 +++++++++++----- tests/step_methods/test_compound.py | 6 +----- 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/pymc/step_methods/__init__.py b/pymc/step_methods/__init__.py index 3413609514..5f44acc728 100644 --- a/pymc/step_methods/__init__.py +++ b/pymc/step_methods/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pymc.step_methods.compound import CompoundStep +from pymc.step_methods.compound import BlockedStep, CompoundStep from pymc.step_methods.hmc import NUTS, HamiltonianMC from pymc.step_methods.metropolis import ( BinaryGibbsMetropolis, @@ -30,7 +30,8 @@ ) from pymc.step_methods.slicer import Slice -STEP_METHODS = ( +# Other step methods can be added by appending to this list +STEP_METHODS: list[type[BlockedStep]] = [ NUTS, HamiltonianMC, Metropolis, @@ -38,4 +39,4 @@ BinaryGibbsMetropolis, Slice, CategoricalGibbsMetropolis, -) +] diff --git a/tests/sampling/test_mcmc.py b/tests/sampling/test_mcmc.py index a18430818d..3f676d0846 100644 --- a/tests/sampling/test_mcmc.py +++ b/tests/sampling/test_mcmc.py @@ -762,12 +762,18 @@ def kill_grad(x): steps = assign_step_methods(model, []) assert isinstance(steps, Slice) - def test_modify_step_methods(self): + @pytest.fixture + def step_methods(self): + """Make sure we reset the STEP_METHODS after the test is done.""" + methods_copy = pm.STEP_METHODS.copy() + yield pm.STEP_METHODS + pm.STEP_METHODS.clear() + for method in methods_copy: + pm.STEP_METHODS.append(method) + + def test_modify_step_methods(self, step_methods): """Test step methods can be changed""" - # remove nuts from step_methods - step_methods = list(pm.STEP_METHODS) step_methods.remove(NUTS) - pm.STEP_METHODS = step_methods with pm.Model() as model: pm.Normal("x", 0, 1) @@ -776,7 +782,7 @@ def test_modify_step_methods(self): assert not isinstance(steps, NUTS) # add back nuts - pm.STEP_METHODS = [*step_methods, NUTS] + step_methods.append(NUTS) with pm.Model() as model: pm.Normal("x", 0, 1) diff --git a/tests/step_methods/test_compound.py b/tests/step_methods/test_compound.py index 6c9f771a7a..4a83ec593c 100644 --- a/tests/step_methods/test_compound.py +++ b/tests/step_methods/test_compound.py @@ -26,7 +26,6 @@ Slice, ) from pymc.step_methods.compound import ( - BlockedStep, StatsBijection, flatten_steps, get_stats_dtypes_shapes_from_steps, @@ -38,10 +37,7 @@ def test_all_stepmethods_emit_tune_stat(): - attrs = [getattr(pm.step_methods, n) for n in dir(pm.step_methods)] - step_types = [ - attr for attr in attrs if isinstance(attr, type) and issubclass(attr, BlockedStep) - ] + step_types = pm.step_methods.STEP_METHODS assert len(step_types) > 5 for cls in step_types: assert "tune" in cls.stats_dtypes_shapes