From 44966d3b196da5871b73d7d5e85287ff926bfda8 Mon Sep 17 00:00:00 2001 From: GStechschulte Date: Sun, 4 Feb 2024 15:13:09 +0100 Subject: [PATCH 01/34] use bayeux to access a wide range of samplers --- bambi/backend/pymc.py | 134 +++++++++++++++++++++++------------------- 1 file changed, 74 insertions(+), 60 deletions(-) diff --git a/bambi/backend/pymc.py b/bambi/backend/pymc.py index 354f08ce3..f1913e28e 100644 --- a/bambi/backend/pymc.py +++ b/bambi/backend/pymc.py @@ -95,7 +95,7 @@ def run( """Run PyMC sampler.""" inference_method = inference_method.lower() # NOTE: Methods return different types of objects (idata, approximation, and dictionary) - if inference_method in ["mcmc", "nuts_numpyro", "nuts_blackjax"]: + if inference_method in ["mcmc", "nuts_numpyro", "nuts_blackjax", "bayeux_blackjax_hmc"]: result = self._run_mcmc( draws, tune, @@ -169,74 +169,88 @@ def _run_mcmc( sampler_backend="mcmc", **kwargs, ): - with self.model: - if sampler_backend == "mcmc": - try: - idata = pm.sample( - draws=draws, - tune=tune, - discard_tuned_samples=discard_tuned_samples, - init=init, - n_init=n_init, - chains=chains, - cores=cores, - random_seed=random_seed, - **kwargs, - ) - except (RuntimeError, ValueError): - if ( - "ValueError: Mass matrix contains" in traceback.format_exc() - and init == "auto" - ): - _log.info( - "\nThe default initialization using init='auto' has failed, trying to " - "recover by switching to init='adapt_diag'", - ) + # bayeux does not want the PyMC model to be within a context manager + if sampler_backend == "bayeux_blackjax_hmc": + import bayeux as bx + import jax + + # TODO: Think how to better map bayeux samplers to `inference_method` arg. + + bx_model = bx.Model.from_pymc(self.model) + idata = bx_model.mcmc.blackjax_hmc(seed=jax.random.key(0)) + else: + with self.model: + if sampler_backend == "mcmc": + try: idata = pm.sample( draws=draws, tune=tune, discard_tuned_samples=discard_tuned_samples, - init="adapt_diag", + init=init, n_init=n_init, chains=chains, cores=cores, random_seed=random_seed, **kwargs, ) - else: - raise - elif sampler_backend == "nuts_numpyro": - import pymc.sampling_jax # pylint: disable=import-outside-toplevel - - if not chains: - # sample_numpyro_nuts does not handle chains = None like pm.sample does - chains = 4 - idata = pymc.sampling_jax.sample_numpyro_nuts( - draws=draws, - tune=tune, - chains=chains, - random_seed=random_seed, - **kwargs, - ) - elif sampler_backend == "nuts_blackjax": - import pymc.sampling_jax # pylint: disable=import-outside-toplevel - - # sample_blackjax_nuts does not handle chains = None like pm.sample does - if not chains: - chains = 4 - idata = pymc.sampling_jax.sample_blackjax_nuts( - draws=draws, - tune=tune, - chains=chains, - random_seed=random_seed, - **kwargs, - ) - else: - raise ValueError( - f"sampler_backend value {sampler_backend} is not valid. Please choose one of" - f"'mcmc', 'nuts_numpyro' or 'nuts_blackjax'" - ) - idata = self._clean_results(idata, omit_offsets, include_mean) + except (RuntimeError, ValueError): + if ( + "ValueError: Mass matrix contains" in traceback.format_exc() + and init == "auto" + ): + _log.info( + "\nThe default initialization using init='auto' has failed, trying to " + "recover by switching to init='adapt_diag'", + ) + idata = pm.sample( + draws=draws, + tune=tune, + discard_tuned_samples=discard_tuned_samples, + init="adapt_diag", + n_init=n_init, + chains=chains, + cores=cores, + random_seed=random_seed, + **kwargs, + ) + else: + raise + elif sampler_backend == "nuts_numpyro": + import pymc.sampling_jax # pylint: disable=import-outside-toplevel + + if not chains: + # sample_numpyro_nuts does not handle chains = None like pm.sample does + chains = 4 + idata = pymc.sampling_jax.sample_numpyro_nuts( + draws=draws, + tune=tune, + chains=chains, + random_seed=random_seed, + **kwargs, + ) + elif sampler_backend == "nuts_blackjax": + import pymc.sampling_jax # pylint: disable=import-outside-toplevel + + # sample_blackjax_nuts does not handle chains = None like pm.sample does + if not chains: + chains = 4 + idata = pymc.sampling_jax.sample_blackjax_nuts( + draws=draws, + tune=tune, + chains=chains, + random_seed=random_seed, + **kwargs, + ) + else: + raise ValueError( + f"sampler_backend value {sampler_backend} is not valid. Please choose one of" + f"'mcmc', 'nuts_numpyro' or 'nuts_blackjax'" + ) + + # TODO: xarray does not like the InferenceData object returned by bayeux + if "bayeux" not in sampler_backend: + idata = self._clean_results(idata, omit_offsets, include_mean) + return idata def _clean_results(self, idata, omit_offsets, include_mean): @@ -306,7 +320,7 @@ def _clean_results(self, idata, omit_offsets, include_mean): self.spec.predict(idata) return idata - + def _run_vi(self, **kwargs): with self.model: self.vi_approx = pm.fit(**kwargs) From 061a1b035415fe34976525ba71485cba715aa077 Mon Sep 17 00:00:00 2001 From: GStechschulte Date: Sun, 4 Feb 2024 15:30:20 +0100 Subject: [PATCH 02/34] use bayeux to access a wide range of samplers --- bambi/backend/pymc.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/bambi/backend/pymc.py b/bambi/backend/pymc.py index f1913e28e..b61c69512 100644 --- a/bambi/backend/pymc.py +++ b/bambi/backend/pymc.py @@ -247,10 +247,6 @@ def _run_mcmc( f"'mcmc', 'nuts_numpyro' or 'nuts_blackjax'" ) - # TODO: xarray does not like the InferenceData object returned by bayeux - if "bayeux" not in sampler_backend: - idata = self._clean_results(idata, omit_offsets, include_mean) - return idata def _clean_results(self, idata, omit_offsets, include_mean): From 8afe534cb9fb98b273e04194dc7b55475951c1c6 Mon Sep 17 00:00:00 2001 From: Gabriel Stechschulte <63432018+GStechschulte@users.noreply.github.com> Date: Sun, 4 Feb 2024 16:21:44 +0100 Subject: [PATCH 03/34] add notebook links to family table (#774) --- docs/notebooks/getting_started.ipynb | 58 ++++++++++++++-------------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/docs/notebooks/getting_started.ipynb b/docs/notebooks/getting_started.ipynb index ce5954227..adf071611 100644 --- a/docs/notebooks/getting_started.ipynb +++ b/docs/notebooks/getting_started.ipynb @@ -970,35 +970,35 @@ "\n", "
\n", "\n", - "|Family name |Response distribution | Default link |\n", - "|:----------------------------- |:------------------------------- |:--------------- |\n", - "asymmetriclaplace | AsymmetricLaplace | identity |\n", - "bernoulli | Bernoulli | logit |\n", - "beta | Beta | logit |\n", - "beta_binomial | BetaBinomial | logit |\n", - "binomial | Binomial | logit | \n", - "categorical | Categorical | softmax | \n", - "cumulative | Cumulative | logit | \n", - "dirichlet_multinomial | DirichletMultinomial | logit |\n", - "exponential | Exponential | log | \n", - "gamma | Gamma | inverse |\n", - "gaussian | Normal | identity |\n", - "hurdle_gamma | HurdleGamma | log |\n", - "hurdle_lognormal | HurdleLogNormal | identity |\n", - "hurdle_negativebinomial | HurdleNegativeBinomial | log |\n", - "hurdle_poisson | HurdlePoisson | log |\n", - "multinomial | Multinomial | softmax |\n", - "negativebinomial | NegativeBinomial | log |\n", - "laplace | Laplace | identity |\n", - "poisson | Poisson | log |\n", - "sratio | StoppingRatio | logit |\n", - "t | StudentT | identity |\n", - "vonmises | VonMises | tan(x / 2) |\n", - "wald | InverseGaussian | inverse squared |\n", - "weibull | Weibull | log |\n", - "zero_inflated_binomial | ZeroInflatedBinomial | logit |\n", - "zero_inflated_negativebinomial | ZeroInflatedNegativeBinomial | log |\n", - "zero_inflated_poisson | ZeroInflatedPoisson | log |\n", + "|Family name |Response distribution | Default link | Example notebook |\n", + "|:----------------------------- |:------------------------------- |:--------------- |:-----------------|\n", + "asymmetriclaplace | AsymmetricLaplace | identity | [Quantile Regression](https://bambinos.github.io/bambi/notebooks/quantile_regression.html#quantile-regression) |\n", + "bernoulli | Bernoulli | logit | [Logistic Regression](https://bambinos.github.io/bambi/notebooks/logistic_regression.html) |\n", + "beta | Beta | logit | [Beta Regression](https://bambinos.github.io/bambi/notebooks/beta_regression.html) |\n", + "beta_binomial | BetaBinomial | logit | _To be added_ |\n", + "binomial | Binomial | logit | [Hierarchical Logistic Regression](https://bambinos.github.io/bambi/notebooks/hierarchical_binomial_bambi.html) | \n", + "categorical | Categorical | softmax | [Categorical Regression](https://bambinos.github.io/bambi/notebooks/categorical_regression.html) | \n", + "cumulative | Cumulative | logit | [Ordinal Models](https://bambinos.github.io/bambi/notebooks/ordinal_regression.html#cumulative-model) | \n", + "dirichlet_multinomial | DirichletMultinomial | logit | _To be added_ |\n", + "exponential | Exponential | log | [Survival Models](https://bambinos.github.io/bambi/notebooks/survival_model.html#survival-models) | \n", + "gamma | Gamma | inverse | [Gamma Regression](https://bambinos.github.io/bambi/notebooks/wald_gamma_glm.html) |\n", + "gaussian | Normal | identity | [Multiple Linear Regression](https://bambinos.github.io/bambi/notebooks/ESCS_multiple_regression.html) |\n", + "hurdle_gamma | HurdleGamma | log | _To be added_ |\n", + "hurdle_lognormal | HurdleLogNormal | identity | _To be added_ |\n", + "hurdle_negativebinomial | HurdleNegativeBinomial | log | _To be added_ |\n", + "hurdle_poisson | HurdlePoisson | log | [Hurdle Poisson Regression](https://bambinos.github.io/bambi/notebooks/zero_inflated_regression.html#hurdle-poisson) |\n", + "multinomial | Multinomial | softmax | _To be added_ |\n", + "negativebinomial | NegativeBinomial | log | [Negative Binomial Regression](https://bambinos.github.io/bambi/notebooks/negative_binomial.html) |\n", + "laplace | Laplace | identity | _To be added_ |\n", + "poisson | Poisson | log | [Gaussian Processes with a Poisson likelihood](https://bambinos.github.io/bambi/notebooks/hsgp_2d.html#a-more-complex-example-poisson-likelihood-with-group-specific-effects) |\n", + "sratio | StoppingRatio | logit | [Ordinal Models](https://bambinos.github.io/bambi/notebooks/ordinal_regression.html#sequential-model) |\n", + "t | StudentT | identity | [Robust Linear Regression](https://bambinos.github.io/bambi/notebooks/t_regression.html) |\n", + "vonmises | VonMises | tan(x / 2) | [Circular Regression](https://bambinos.github.io/bambi/notebooks/circular_regression.html#circular-regression) |\n", + "wald | InverseGaussian | inverse squared | [Wald Regression](https://bambinos.github.io/bambi/notebooks/wald_gamma_glm.html) |\n", + "weibull | Weibull | log | _To be added_ |\n", + "zero_inflated_binomial | ZeroInflatedBinomial | logit | _To be added_ |\n", + "zero_inflated_negativebinomial | ZeroInflatedNegativeBinomial | log | _To be added_ |\n", + "zero_inflated_poisson | ZeroInflatedPoisson | log | [Zero Inflated Poisson Regression](https://bambinos.github.io/bambi/notebooks/zero_inflated_regression.html#zero-inflated-poisson)|\n", "\n", "\n", "
\n", From 9f1d9d179071abbb4cc6255242132829aae80faf Mon Sep 17 00:00:00 2001 From: GStechschulte Date: Mon, 5 Feb 2024 22:02:52 +0100 Subject: [PATCH 04/34] access methods programatically --- bambi/backend/pymc.py | 125 +++++++++++++++++++----------------------- 1 file changed, 57 insertions(+), 68 deletions(-) diff --git a/bambi/backend/pymc.py b/bambi/backend/pymc.py index b61c69512..bdb018b03 100644 --- a/bambi/backend/pymc.py +++ b/bambi/backend/pymc.py @@ -23,6 +23,23 @@ __version__ = version("bambi") +PYMC_SAMPLERS = ["mcmc"] +BAYEUX_SAMPLERS = [ + "blackjax_hmc", + "blackjax_chees_hmc", + "blackjax_meads_hmc", + "blackjax_nuts", + "blackjax_hmc_pathfinder", + "blackjax_nuts_pathfinder", + "flowmc_rqspline_hmc", + "flowmc_rqspline_mala", + "flowmc_realnvp_hmc", + "flowmc_realnvp_mala", + "numpyro_hmc", + "numpyro_nuts", +] + + class PyMCModel: """PyMC model-fitting backend.""" @@ -95,7 +112,7 @@ def run( """Run PyMC sampler.""" inference_method = inference_method.lower() # NOTE: Methods return different types of objects (idata, approximation, and dictionary) - if inference_method in ["mcmc", "nuts_numpyro", "nuts_blackjax", "bayeux_blackjax_hmc"]: + if inference_method in (PYMC_SAMPLERS + BAYEUX_SAMPLERS): result = self._run_mcmc( draws, tune, @@ -169,84 +186,56 @@ def _run_mcmc( sampler_backend="mcmc", **kwargs, ): - # bayeux does not want the PyMC model to be within a context manager - if sampler_backend == "bayeux_blackjax_hmc": - import bayeux as bx - import jax - - # TODO: Think how to better map bayeux samplers to `inference_method` arg. - - bx_model = bx.Model.from_pymc(self.model) - idata = bx_model.mcmc.blackjax_hmc(seed=jax.random.key(0)) - else: + if sampler_backend in PYMC_SAMPLERS: with self.model: - if sampler_backend == "mcmc": - try: + try: + idata = pm.sample( + draws=draws, + tune=tune, + discard_tuned_samples=discard_tuned_samples, + init=init, + n_init=n_init, + chains=chains, + cores=cores, + random_seed=random_seed, + **kwargs, + ) + except (RuntimeError, ValueError): + if ( + "ValueError: Mass matrix contains" in traceback.format_exc() + and init == "auto" + ): + _log.info( + "\nThe default initialization using init='auto' has failed, trying to " + "recover by switching to init='adapt_diag'", + ) idata = pm.sample( draws=draws, tune=tune, discard_tuned_samples=discard_tuned_samples, - init=init, + init="adapt_diag", n_init=n_init, chains=chains, cores=cores, random_seed=random_seed, **kwargs, ) - except (RuntimeError, ValueError): - if ( - "ValueError: Mass matrix contains" in traceback.format_exc() - and init == "auto" - ): - _log.info( - "\nThe default initialization using init='auto' has failed, trying to " - "recover by switching to init='adapt_diag'", - ) - idata = pm.sample( - draws=draws, - tune=tune, - discard_tuned_samples=discard_tuned_samples, - init="adapt_diag", - n_init=n_init, - chains=chains, - cores=cores, - random_seed=random_seed, - **kwargs, - ) - else: - raise - elif sampler_backend == "nuts_numpyro": - import pymc.sampling_jax # pylint: disable=import-outside-toplevel - - if not chains: - # sample_numpyro_nuts does not handle chains = None like pm.sample does - chains = 4 - idata = pymc.sampling_jax.sample_numpyro_nuts( - draws=draws, - tune=tune, - chains=chains, - random_seed=random_seed, - **kwargs, - ) - elif sampler_backend == "nuts_blackjax": - import pymc.sampling_jax # pylint: disable=import-outside-toplevel + else: + raise + elif sampler_backend in BAYEUX_SAMPLERS: + import bayeux as bx + import jax - # sample_blackjax_nuts does not handle chains = None like pm.sample does - if not chains: - chains = 4 - idata = pymc.sampling_jax.sample_blackjax_nuts( - draws=draws, - tune=tune, - chains=chains, - random_seed=random_seed, - **kwargs, - ) - else: - raise ValueError( - f"sampler_backend value {sampler_backend} is not valid. Please choose one of" - f"'mcmc', 'nuts_numpyro' or 'nuts_blackjax'" - ) - + bx_model = bx.Model.from_pymc(self.model) + bx_sampler = getattr(bx_model.mcmc, sampler_backend) + idata = bx_sampler(seed=jax.random.key(0), **kwargs) + else: + raise ValueError( + f"sampler_backend value {sampler_backend} is not valid. Please choose one of" + f"{PYMC_SAMPLERS + BAYEUX_SAMPLERS}" + ) + + idata = self._clean_results(idata, omit_offsets, include_mean) return idata def _clean_results(self, idata, omit_offsets, include_mean): @@ -316,7 +305,7 @@ def _clean_results(self, idata, omit_offsets, include_mean): self.spec.predict(idata) return idata - + def _run_vi(self, **kwargs): with self.model: self.vi_approx = pm.fit(**kwargs) From 9b42fc299d5e91266f67d597cb9159a54be2f40f Mon Sep 17 00:00:00 2001 From: GStechschulte Date: Sat, 10 Feb 2024 09:49:52 +0100 Subject: [PATCH 05/34] clean bayeux idata to be consistent with pymc model coords --- bambi/backend/pymc.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/bambi/backend/pymc.py b/bambi/backend/pymc.py index bdb018b03..1b79eb79d 100644 --- a/bambi/backend/pymc.py +++ b/bambi/backend/pymc.py @@ -1,5 +1,6 @@ import functools import logging +import re import traceback @@ -235,7 +236,7 @@ def _run_mcmc( f"{PYMC_SAMPLERS + BAYEUX_SAMPLERS}" ) - idata = self._clean_results(idata, omit_offsets, include_mean) + # idata = self._clean_results(idata, omit_offsets, include_mean) return idata def _clean_results(self, idata, omit_offsets, include_mean): @@ -257,6 +258,17 @@ def _clean_results(self, idata, omit_offsets, include_mean): dims_original = list(self.model.coords) + # Identify bayeux idata and use regex to remove the trailing numeric suffix from the dims + # TODO: Will "_0" always be the minimum dim suffix, i.e. "_1", "_2", ...? + if [dim for dim in idata.posterior.dims if dim.endswith("_0")]: + bayeux_orig_dims = [ + dim for dim in idata.posterior.dims if not dim.startswith(("chain", "draw")) + ] + bayeux_cleaned_dims = [re.sub(r"_\d", "", element) for element in bayeux_orig_dims] + + for orig, renamed in zip(bayeux_orig_dims, bayeux_cleaned_dims): + idata.posterior = idata.posterior.rename_dims({orig: renamed}) + # Discard dims that are in the model but unused in the posterior dims_original = [dim for dim in dims_original if dim in idata.posterior.dims] From 91ce2a011e178471cc0c0a0a8c6e9cecfc381915 Mon Sep 17 00:00:00 2001 From: GStechschulte Date: Sat, 10 Feb 2024 09:50:40 +0100 Subject: [PATCH 06/34] rename alternative sampler args in tests --- tests/test_alternative_samplers.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_alternative_samplers.py b/tests/test_alternative_samplers.py index a16134762..816755648 100644 --- a/tests/test_alternative_samplers.py +++ b/tests/test_alternative_samplers.py @@ -56,8 +56,8 @@ def test_vi(): "args", [ ("mcmc", {}), - ("nuts_numpyro", {"chain_method": "vectorized"}), - ("nuts_blackjax", {"chain_method": "vectorized"}), + ("numpyro_nuts", {"chain_method": "vectorized"}), + ("blackjax_nuts", {"chain_method": "vectorized"}), ], ) def test_logistic_regression_categoric_alternative_samplers(data_n100, args): @@ -69,8 +69,8 @@ def test_logistic_regression_categoric_alternative_samplers(data_n100, args): "args", [ ("mcmc", {}), - ("nuts_numpyro", {"chain_method": "vectorized"}), - ("nuts_blackjax", {"chain_method": "vectorized"}), + ("numpyro_nuts", {"chain_method": "vectorized"}), + ("blackjax_nuts", {"chain_method": "vectorized"}), ], ) def test_regression_alternative_samplers(data_n100, args): From 89a2aeeac463fa75c2e3ad58c7ae2c41a8651391 Mon Sep 17 00:00:00 2001 From: GStechschulte Date: Sat, 10 Feb 2024 10:17:17 +0100 Subject: [PATCH 07/34] change docstring to reflect bayeux sampler names --- bambi/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bambi/models.py b/bambi/models.py index 5d5eafe09..5dac88881 100644 --- a/bambi/models.py +++ b/bambi/models.py @@ -266,7 +266,7 @@ def fit( using the ``fit`` function. Finally, ``"laplace"``, in which case a Laplace approximation is used and is not recommended other than for pedagogical use. - To use the PyMC numpyro and blackjax samplers, use ``nuts_numpyro`` or ``nuts_blackjax`` + To use the PyMC numpyro and blackjax samplers, use ``numpyro_nuts`` or ``blackjax_nuts`` respectively. Both methods will only work if you can use NUTS sampling, so your model must be differentiable. init : str @@ -306,7 +306,7 @@ def fit( Returns ------- An ArviZ ``InferenceData`` instance if inference_method is ``"mcmc"`` (default), - "nuts_numpyro", "nuts_blackjax" or "laplace". + "numpyro_nuts", "blackjax_nuts" or "laplace". An ``Approximation`` object if ``"vi"``. """ method = kwargs.pop("method", None) From d6058ad1b3a90e69cc2beadc8949b5a50b97dbce Mon Sep 17 00:00:00 2001 From: GStechschulte Date: Sat, 10 Feb 2024 10:17:59 +0100 Subject: [PATCH 08/34] bayeux dependencies are numpyro/jax/jaxlib/blackjax --- pyproject.toml | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 058d94b04..460bc4914 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,12 +35,7 @@ dev = [ "quartodoc==0.6.1", "seaborn>=0.9.0", ] -jax = [ - "blackjax>=1.0.0", - "jax>=0.3.1", - "jaxlib>=0.3.1", - "numpyro>=0.9.0", -] +jax = ["bayeux>=0.1.6",] [project.urls] homepage = "https://bambinos.github.io/bambi" From 722c8b5fffbf445a84181abf876878b05cce3e20 Mon Sep 17 00:00:00 2001 From: GStechschulte Date: Mon, 19 Feb 2024 20:20:23 +0100 Subject: [PATCH 09/34] rename idata coords and dims to PyMC model --- bambi/backend/pymc.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/bambi/backend/pymc.py b/bambi/backend/pymc.py index 1b79eb79d..e6b4fbc13 100644 --- a/bambi/backend/pymc.py +++ b/bambi/backend/pymc.py @@ -221,6 +221,7 @@ def _run_mcmc( random_seed=random_seed, **kwargs, ) + idata_from = "pymc" else: raise elif sampler_backend in BAYEUX_SAMPLERS: @@ -230,16 +231,17 @@ def _run_mcmc( bx_model = bx.Model.from_pymc(self.model) bx_sampler = getattr(bx_model.mcmc, sampler_backend) idata = bx_sampler(seed=jax.random.key(0), **kwargs) + idata_from = "bayeux" else: raise ValueError( f"sampler_backend value {sampler_backend} is not valid. Please choose one of" f"{PYMC_SAMPLERS + BAYEUX_SAMPLERS}" ) - # idata = self._clean_results(idata, omit_offsets, include_mean) + idata = self._clean_results(idata, omit_offsets, include_mean, idata_from) return idata - def _clean_results(self, idata, omit_offsets, include_mean): + def _clean_results(self, idata, omit_offsets, include_mean, idata_from): for group in idata.groups(): getattr(idata, group).attrs["modeling_interface"] = "bambi" @@ -258,16 +260,12 @@ def _clean_results(self, idata, omit_offsets, include_mean): dims_original = list(self.model.coords) - # Identify bayeux idata and use regex to remove the trailing numeric suffix from the dims - # TODO: Will "_0" always be the minimum dim suffix, i.e. "_1", "_2", ...? - if [dim for dim in idata.posterior.dims if dim.endswith("_0")]: - bayeux_orig_dims = [ - dim for dim in idata.posterior.dims if not dim.startswith(("chain", "draw")) - ] - bayeux_cleaned_dims = [re.sub(r"_\d", "", element) for element in bayeux_orig_dims] - - for orig, renamed in zip(bayeux_orig_dims, bayeux_cleaned_dims): - idata.posterior = idata.posterior.rename_dims({orig: renamed}) + # Identify bayeux idata and rename dims and coordinates to match PyMC model + if idata_from == "bayeux": + pymc_model_dims = [dim for dim in dims_original if "_obs" not in dim] + bayeux_dims = [dim for dim in idata.posterior.dims if not dim.startswith(("chain", "draw"))] + cleaned_dims = dict(zip(bayeux_dims, pymc_model_dims)) + idata = idata.rename(cleaned_dims) # Discard dims that are in the model but unused in the posterior dims_original = [dim for dim in dims_original if dim in idata.posterior.dims] From ccc28776168c5dd904e56300efdcb00591d3cfbf Mon Sep 17 00:00:00 2001 From: GStechschulte Date: Mon, 19 Feb 2024 20:37:32 +0100 Subject: [PATCH 10/34] add JAX based sampler dependencies --- pyproject.toml | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 460bc4914..1203e3fdc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,14 @@ dev = [ "quartodoc==0.6.1", "seaborn>=0.9.0", ] -jax = ["bayeux>=0.1.6",] +jax = [ + "bayeux>=0.1.6", + "blackjax>=1.0.0", + "jax>=0.3.1", + "jaxlib>=0.3.1", + "numpyro>=0.9.0", + "flowMC>=0.2.4", +] [project.urls] homepage = "https://bambinos.github.io/bambi" From 74b4e8b4a2bc412ee847bfaea5d2680734f98ed1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Capretto?= Date: Wed, 21 Feb 2024 15:16:56 -0300 Subject: [PATCH 11/34] Update code of conduct (#783) * Update code of conduct * update changelog --- CHANGELOG.md | 2 ++ CODE_OF_CONDUCT.md | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7f7f7bb22..79538d5db 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,8 @@ ### Documentation +* Our Code of Conduct now includes how to send a report (#783) + ### Deprecation ## 0.13.0 diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index fb45a77ff..930baed28 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -1,5 +1,9 @@ # Bambi Community Code of Conduct +Bambi adopts the NumFOCUS Code of Conduct directly. In other words, we expect our community to treat others with kindness and understanding. + +# The short version + Be kind to others. Do not insult or put down others. Behave professionally. Remember that harassment and sexist, racist, or exclusionary jokes are not appropriate. @@ -15,3 +19,31 @@ or religion. We do not tolerate harassment of community members in any form. Thank you for helping make this a welcoming, friendly community for all. + +# How to Submit a Report + +If you feel that there has been a Code of Conduct violation an anonymous +reporting form is available. + +**If you feel your safety is in jeopardy or the situation is an +emergency, we urge you to contact local law enforcement before making +a report. (In the U.S., dial 911.)** + +We are committed to promptly addressing any reported issues. +If you have experienced or witnessed behavior that violates this +Code of Conduct, please complete the form below to +make a report. + +**REPORTING FORM:** https://numfocus.typeform.com/to/ynjGdT + +Reports are sent to the NumFOCUS Code of Conduct Enforcement Team +(see below). + +You can view the Privacy Policy and Terms of Service for TypeForm here. +The NumFOCUS Privacy Policy is here: +https://www.numfocus.org/privacy-policy + +# Full Code of Conduct + +The full text of the NumFOCUS/Bambi Code of Conduct can be found on +NumFOCUS's website https://numfocus.org/code-of-conduct \ No newline at end of file From 47bb161e77bcb38e72d18be82d40730ef532e265 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Capretto?= Date: Thu, 29 Feb 2024 12:08:13 -0300 Subject: [PATCH 12/34] [WIP] Fix HSGP predictions (#780) * Delete all HSGP slices at the same time * Make interpret consider kwargs in function calls * Update code of conduct (#783) * Update code of conduct * update changelog * Update formulae to >=0.5.3 * start a test for the hsgp and 'by' * update changelog --- CHANGELOG.md | 2 ++ bambi/interpret/utils.py | 11 ++++++++++- bambi/model_components.py | 9 ++++++++- pyproject.toml | 2 +- tests/test_hsgp.py | 32 ++++++++++++++++++++++++++++++++ 5 files changed, 53 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 79538d5db..566ead47a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ ### Maintenance and fixes +* Fix bug in predictions with models using HSGP (#780) + ### Documentation * Our Code of Conduct now includes how to send a report (#783) diff --git a/bambi/interpret/utils.py b/bambi/interpret/utils.py index a56e23560..47bc02864 100644 --- a/bambi/interpret/utils.py +++ b/bambi/interpret/utils.py @@ -236,11 +236,20 @@ def get_model_covariates(model: Model) -> np.ndarray: for term in terms.values(): if hasattr(term, "components"): for component in term.components: - # if the component is a function call, use the argument names + # if the component is a function call, look for relevant argument names if isinstance(component, Call): + # Add variable names passed as unnamed arguments covariates.append( [arg.name for arg in component.call.args if isinstance(arg, LazyVariable)] ) + # Add variable names passed as named arguments + covariates.append( + [ + kwarg_value.name + for kwarg_value in component.call.kwargs.values() + if isinstance(kwarg_value, LazyVariable) + ] + ) else: covariates.append([component.name]) elif hasattr(term, "factor"): diff --git a/bambi/model_components.py b/bambi/model_components.py index f4691e5e2..44c781127 100644 --- a/bambi/model_components.py +++ b/bambi/model_components.py @@ -239,11 +239,12 @@ def predict_common( X = np.delete(X, term_slice, axis=1) # Add HSGP components contribution to the linear predictor + hsgp_slices = [] for term_name, term in self.hsgp_terms.items(): # Extract data for the HSGP component from the design matrix term_slice = self.design.common.slices[term_name] x_slice = X[:, term_slice] - X = np.delete(X, term_slice, axis=1) + hsgp_slices.append(term_slice) term_aliased_name = get_aliased_name(term) hsgp_to_stack_dims = (f"{term_aliased_name}_weights_dim",) @@ -288,6 +289,12 @@ def predict_common( # Add contribution to the linear predictor linear_predictor += hsgp_contribution + # Remove columns of X that are associated with HSGP contributions + # All the slices _must be_ deleted at the same time. Otherwise the slice objects don't + # reflect the right columns of X at the time they're used + if hsgp_slices: + X = np.delete(X, np.r_[tuple(hsgp_slices)], axis=1) + if self.common_terms or self.intercept_term: # Create DataArray X_terms = [get_aliased_name(term) for term in self.common_terms.values()] diff --git a/pyproject.toml b/pyproject.toml index 1203e3fdc..c7620d891 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ maintainers = [ dependencies = [ "arviz>=0.12.0", - "formulae>=0.5.0", + "formulae>=0.5.3", "graphviz", "pandas>=1.0.0", "pymc>=5.5.0", diff --git a/tests/test_hsgp.py b/tests/test_hsgp.py index 30bf5ce1c..770c70cc5 100644 --- a/tests/test_hsgp.py +++ b/tests/test_hsgp.py @@ -300,3 +300,35 @@ def test_minimal_1d_predicts(data_1d_single_group): new_idata = model.predict(idata, data=new_data, kind="pps", inplace=False) assert new_idata.posterior_predictive["y"].dims == ("chain", "draw", "y_obs") assert new_idata.posterior_predictive["y"].to_numpy().shape == (2, 500, 10) + + +def test_multiple_hsgp_and_by(data_1d_multiple_groups): + rng = np.random.default_rng(1234) + df = data_1d_multiple_groups.copy() + df["fac2"] = rng.choice(["a", "b", "c"], size=df.shape[0]) + + formula = "y ~ 1 + x0 + hsgp(x1, by=fac, m=10, c=2) + hsgp(x1, by=fac2, m=10, c=2)" + model = bmb.Model( + formula=formula, + data=df, + categorical=["fac"], + ) + idata = model.fit(tune=400, draws=200, target_accept=0.9) + + bmb.interpret.plot_predictions( + model, + idata, + conditional="x1", + subplot_kwargs={"main": "x1", "group": "fac2", "panel": "fac2"}, + ); + + bmb.interpret.plot_predictions( + model, + idata, + conditional={ + "x1": np.linspace(0, 1, num=100), + "fac2": ["a", "b", "c"] + }, + legend=False, + subplot_kwargs={"main": "x1", "group": "fac2", "panel": "fac2"}, + ); \ No newline at end of file From 9f6fc2aa3ae6f9cf9661e07df4093c7c93e6e66f Mon Sep 17 00:00:00 2001 From: GStechschulte Date: Fri, 1 Mar 2024 11:18:00 +0100 Subject: [PATCH 13/34] bayeux 0.1.9 updates --- bambi/backend/pymc.py | 129 ++++++++++++++++++++++++++++++------------ 1 file changed, 92 insertions(+), 37 deletions(-) diff --git a/bambi/backend/pymc.py b/bambi/backend/pymc.py index e6b4fbc13..feb70141a 100644 --- a/bambi/backend/pymc.py +++ b/bambi/backend/pymc.py @@ -1,9 +1,9 @@ import functools +import importlib import logging -import re +import operator import traceback - from copy import deepcopy from importlib.metadata import version @@ -13,7 +13,6 @@ import pytensor.tensor as pt from pytensor.tensor.special import softmax - from bambi.backend.links import cloglog, identity, inverse_squared, logit, probit, arctan_2 from bambi.backend.model_components import ConstantComponent, DistributionalComponent from bambi.utils import get_aliased_name @@ -24,23 +23,6 @@ __version__ = version("bambi") -PYMC_SAMPLERS = ["mcmc"] -BAYEUX_SAMPLERS = [ - "blackjax_hmc", - "blackjax_chees_hmc", - "blackjax_meads_hmc", - "blackjax_nuts", - "blackjax_hmc_pathfinder", - "blackjax_nuts_pathfinder", - "flowmc_rqspline_hmc", - "flowmc_rqspline_mala", - "flowmc_realnvp_hmc", - "flowmc_realnvp_mala", - "numpyro_hmc", - "numpyro_nuts", -] - - class PyMCModel: """PyMC model-fitting backend.""" @@ -64,6 +46,8 @@ def __init__(self): self.model = None self.spec = None self.components = {} + self.bayeux_methods = _get_bayeux_methods() + self.pymc_methods = {"mcmc": ["mcmc"], "vi": ["vi"]} def build(self, spec): """Compile the PyMC model from an abstract model specification. @@ -113,7 +97,7 @@ def run( """Run PyMC sampler.""" inference_method = inference_method.lower() # NOTE: Methods return different types of objects (idata, approximation, and dictionary) - if inference_method in (PYMC_SAMPLERS + BAYEUX_SAMPLERS): + if inference_method in (self.pymc_methods["mcmc"] + self.bayeux_methods["mcmc"]): result = self._run_mcmc( draws, tune, @@ -128,10 +112,12 @@ def run( inference_method, **kwargs, ) - elif inference_method == "vi": - result = self._run_vi(**kwargs) + elif inference_method in (self.pymc_methods["vi"] + self.bayeux_methods["vi"]): + result = self._run_vi(inference_method, random_seed, **kwargs) elif inference_method == "laplace": result = self._run_laplace(draws, omit_offsets, include_mean) + elif inference_method in self.bayeux_methods["optimize"]: + result = self._optimize(inference_method, random_seed, **kwargs) else: raise NotImplementedError(f"'{inference_method}' method has not been implemented") @@ -187,7 +173,7 @@ def _run_mcmc( sampler_backend="mcmc", **kwargs, ): - if sampler_backend in PYMC_SAMPLERS: + if sampler_backend in self.pymc_methods["mcmc"]: with self.model: try: idata = pm.sample( @@ -221,21 +207,25 @@ def _run_mcmc( random_seed=random_seed, **kwargs, ) - idata_from = "pymc" else: raise - elif sampler_backend in BAYEUX_SAMPLERS: + idata_from = "pymc" + elif sampler_backend in self.bayeux_methods["mcmc"]: import bayeux as bx import jax + + # Seed is required for bayeux + if random_seed is None: + random_seed = 0 bx_model = bx.Model.from_pymc(self.model) - bx_sampler = getattr(bx_model.mcmc, sampler_backend) - idata = bx_sampler(seed=jax.random.key(0), **kwargs) + bx_sampler = operator.attrgetter(sampler_backend)(bx_model.mcmc) + idata = bx_sampler(seed=jax.random.key(random_seed), **kwargs) idata_from = "bayeux" else: raise ValueError( f"sampler_backend value {sampler_backend} is not valid. Please choose one of" - f"{PYMC_SAMPLERS + BAYEUX_SAMPLERS}" + f" {self.pymc_methods['mcmc'] + self.bayeux_methods['mcmc']}" ) idata = self._clean_results(idata, omit_offsets, include_mean, idata_from) @@ -263,7 +253,9 @@ def _clean_results(self, idata, omit_offsets, include_mean, idata_from): # Identify bayeux idata and rename dims and coordinates to match PyMC model if idata_from == "bayeux": pymc_model_dims = [dim for dim in dims_original if "_obs" not in dim] - bayeux_dims = [dim for dim in idata.posterior.dims if not dim.startswith(("chain", "draw"))] + bayeux_dims = [ + dim for dim in idata.posterior.dims if not dim.startswith(("chain", "draw")) + ] cleaned_dims = dict(zip(bayeux_dims, pymc_model_dims)) idata = idata.rename(cleaned_dims) @@ -281,7 +273,6 @@ def _clean_results(self, idata, omit_offsets, include_mean, idata_from): idata.posterior = idata.posterior.transpose(*dims_new) # Compute the actual intercept in all distributional components that have an intercept - for pymc_component in self.distributional_components.values(): bambi_component = pymc_component.component if ( @@ -316,17 +307,35 @@ def _clean_results(self, idata, omit_offsets, include_mean, idata_from): return idata - def _run_vi(self, **kwargs): - with self.model: - self.vi_approx = pm.fit(**kwargs) - return self.vi_approx + def _run_vi(self, inference_method, random_seed, **kwargs): + if inference_method in self.pymc_methods["vi"]: + with self.model: + self.vi_approx = pm.fit(**kwargs) + return self.vi_approx + elif inference_method in self.bayeux_methods["vi"]: + import bayeux as bx + import jax + + # Seed is required for bayeux + if random_seed is None: + random_seed = 0 + + bx_model = bx.Model.from_pymc(self.model) + bx_vi = operator.attrgetter(inference_method)(bx_model.vi) + idata = bx_vi(seed=jax.random.key(random_seed), **kwargs) + return idata + else: + raise ValueError( + f"inference_method value {inference_method} is not valid. Please choose one of" + f" {self.pymc_methods['vi'] + self.bayeux_methods['vi']}" + ) def _run_laplace(self, draws, omit_offsets, include_mean): """Fit a model using a Laplace approximation. Mainly for pedagogical use, provides reasonable results for approximately Gaussian posteriors. The approximation can be very poor for some models - like hierarchical ones. Use ``mcmc``, ``nuts_numpyro``, ``nuts_blackjax`` + like hierarchical ones. Use ``mcmc``, ``numpyro_nuts``, ``blackjax_nuts`` or ``vi`` for better approximations. Parameters @@ -361,9 +370,28 @@ def _run_laplace(self, draws, omit_offsets, include_mean): samples = np.random.multivariate_normal(modes, cov, size=draws) idata = _posterior_samples_to_idata(samples, self.model) - idata = self._clean_results(idata, omit_offsets, include_mean) + idata = self._clean_results(idata, omit_offsets, include_mean, idata_from="pymc") return idata + def _optimize(self, inference_method, random_seed, **kwargs): + if inference_method in self.bayeux_methods["optimize"]: + import bayeux as bx + import jax + + # Seed is required for bayeux + if random_seed is None: + random_seed = 0 + + bx_model = bx.Model.from_pymc(self.model) + bx_optimize = operator.attrgetter(inference_method)(bx_model.optimize) + opt_results = bx_optimize(seed=jax.random.key(random_seed), **kwargs) + return opt_results + else: + raise ValueError( + f"inference_method value {inference_method} is not valid. Please choose one of" + f" {self.bayeux_methods['optimize']}" + ) + @property def response_component(self): return self.components[self.spec.response_name] @@ -415,3 +443,30 @@ def _posterior_samples_to_idata(samples, model): idata = pm.to_inference_data(pm.backends.base.MultiTrace([strace]), model=model) return idata + + +def _get_bayeux_methods(): + """Gets a dictionary of usable bayeux methods if the bayeux package is installed + within the user's environment. + + Returns + ------- + dict + A dict where the keys are the module names and the values are the methods + available in that module. + """ + bx_methods = {} + if importlib.util.find_spec("bayeux") is None: + return bx_methods + + import bayeux as bx + + bx_methods = {} + for module in bx._src.bayeux._MODULES: + mname = module.__name__.split(".")[-1] + bx_modules = [] + for k in module.__all__: + bx_modules.append(getattr(module, k).name) + bx_methods[mname] = bx_modules + + return bx_methods From 10bb5089c4192618005ae28d020461d28763fac0 Mon Sep 17 00:00:00 2001 From: GStechschulte Date: Fri, 1 Mar 2024 11:18:14 +0100 Subject: [PATCH 14/34] bump bayeux version --- pyproject.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c7620d891..21d977e54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,12 +36,11 @@ dev = [ "seaborn>=0.9.0", ] jax = [ - "bayeux>=0.1.6", + "bayeux>=0.1.9", "blackjax>=1.0.0", "jax>=0.3.1", "jaxlib>=0.3.1", "numpyro>=0.9.0", - "flowMC>=0.2.4", ] [project.urls] From f7bf97f65bf53fb05560ea7d5fa0eb26b289d264 Mon Sep 17 00:00:00 2001 From: GStechschulte Date: Fri, 1 Mar 2024 16:16:04 +0100 Subject: [PATCH 15/34] remove TFP methods, optimizers, and resolve pylint errors --- bambi/backend/pymc.py | 78 ++++++++++++++----------------------------- 1 file changed, 25 insertions(+), 53 deletions(-) diff --git a/bambi/backend/pymc.py b/bambi/backend/pymc.py index feb70141a..8bb6439ac 100644 --- a/bambi/backend/pymc.py +++ b/bambi/backend/pymc.py @@ -47,7 +47,7 @@ def __init__(self): self.spec = None self.components = {} self.bayeux_methods = _get_bayeux_methods() - self.pymc_methods = {"mcmc": ["mcmc"], "vi": ["vi"]} + self.pymc_methods = {"mcmc": ["mcmc"]} def build(self, spec): """Compile the PyMC model from an abstract model specification. @@ -112,12 +112,10 @@ def run( inference_method, **kwargs, ) - elif inference_method in (self.pymc_methods["vi"] + self.bayeux_methods["vi"]): - result = self._run_vi(inference_method, random_seed, **kwargs) + elif inference_method == "vi": + result = self._run_vi(**kwargs) elif inference_method == "laplace": result = self._run_laplace(draws, omit_offsets, include_mean) - elif inference_method in self.bayeux_methods["optimize"]: - result = self._optimize(inference_method, random_seed, **kwargs) else: raise NotImplementedError(f"'{inference_method}' method has not been implemented") @@ -211,15 +209,17 @@ def _run_mcmc( raise idata_from = "pymc" elif sampler_backend in self.bayeux_methods["mcmc"]: - import bayeux as bx - import jax - + import bayeux as bx # pylint: disable=import-outside-toplevel + import jax # pylint: disable=import-outside-toplevel + # Seed is required for bayeux if random_seed is None: random_seed = 0 bx_model = bx.Model.from_pymc(self.model) - bx_sampler = operator.attrgetter(sampler_backend)(bx_model.mcmc) + bx_sampler = operator.attrgetter(sampler_backend)( + bx_model.mcmc + ) # pylint: disable=no-member idata = bx_sampler(seed=jax.random.key(random_seed), **kwargs) idata_from = "bayeux" else: @@ -307,28 +307,10 @@ def _clean_results(self, idata, omit_offsets, include_mean, idata_from): return idata - def _run_vi(self, inference_method, random_seed, **kwargs): - if inference_method in self.pymc_methods["vi"]: - with self.model: - self.vi_approx = pm.fit(**kwargs) - return self.vi_approx - elif inference_method in self.bayeux_methods["vi"]: - import bayeux as bx - import jax - - # Seed is required for bayeux - if random_seed is None: - random_seed = 0 - - bx_model = bx.Model.from_pymc(self.model) - bx_vi = operator.attrgetter(inference_method)(bx_model.vi) - idata = bx_vi(seed=jax.random.key(random_seed), **kwargs) - return idata - else: - raise ValueError( - f"inference_method value {inference_method} is not valid. Please choose one of" - f" {self.pymc_methods['vi'] + self.bayeux_methods['vi']}" - ) + def _run_vi(self, **kwargs): + with self.model: + self.vi_approx = pm.fit(**kwargs) + return self.vi_approx def _run_laplace(self, draws, omit_offsets, include_mean): """Fit a model using a Laplace approximation. @@ -373,25 +355,6 @@ def _run_laplace(self, draws, omit_offsets, include_mean): idata = self._clean_results(idata, omit_offsets, include_mean, idata_from="pymc") return idata - def _optimize(self, inference_method, random_seed, **kwargs): - if inference_method in self.bayeux_methods["optimize"]: - import bayeux as bx - import jax - - # Seed is required for bayeux - if random_seed is None: - random_seed = 0 - - bx_model = bx.Model.from_pymc(self.model) - bx_optimize = operator.attrgetter(inference_method)(bx_model.optimize) - opt_results = bx_optimize(seed=jax.random.key(random_seed), **kwargs) - return opt_results - else: - raise ValueError( - f"inference_method value {inference_method} is not valid. Please choose one of" - f" {self.bayeux_methods['optimize']}" - ) - @property def response_component(self): return self.components[self.spec.response_name] @@ -459,14 +422,23 @@ def _get_bayeux_methods(): if importlib.util.find_spec("bayeux") is None: return bx_methods - import bayeux as bx + import bayeux as bx # pylint: disable=import-outside-toplevel bx_methods = {} - for module in bx._src.bayeux._MODULES: - mname = module.__name__.split(".")[-1] + for module in bx._src.bayeux._MODULES: # pylint: disable=protected-access + mname = module.__name__.rsplit(".", maxsplit=1)[-1] bx_modules = [] for k in module.__all__: bx_modules.append(getattr(module, k).name) bx_methods[mname] = bx_modules + # TFP based methods do not work with Bambi models yet + tfp_mcmc = ["tfp_hmc", "tfp_nuts", "tfp_snaper_hmc"] + for method in tfp_mcmc: + bx_methods["mcmc"].remove(method) + + tfp_vi = ["tfp_factored_surrogate_posterior"] + for method in tfp_vi: + bx_methods["vi"].remove(method) + return bx_methods From 1147d96e8d2a12a26a57723f803ddd1ecde29de1 Mon Sep 17 00:00:00 2001 From: GStechschulte Date: Fri, 1 Mar 2024 16:16:30 +0100 Subject: [PATCH 16/34] alternative backends docs --- docs/_quarto.yml | 3 + docs/notebooks/alternative_samplers.ipynb | 5496 +++++++++++++++++++++ docs/notebooks/gallery.yml | 8 +- 3 files changed, 5506 insertions(+), 1 deletion(-) create mode 100644 docs/notebooks/alternative_samplers.ipynb diff --git a/docs/_quarto.yml b/docs/_quarto.yml index 0d141baf0..87ba502de 100644 --- a/docs/_quarto.yml +++ b/docs/_quarto.yml @@ -89,6 +89,9 @@ website: - notebooks/plot_comparisons.ipynb - notebooks/plot_slopes.ipynb - notebooks/interpret_advanced_usage.ipynb + - section: Alternative sampling backends + contents: + - notebooks/alternative_samplers.ipynb quartodoc: style: pkgdown diff --git a/docs/notebooks/alternative_samplers.ipynb b/docs/notebooks/alternative_samplers.ipynb new file mode 100644 index 000000000..3011fe966 --- /dev/null +++ b/docs/notebooks/alternative_samplers.ipynb @@ -0,0 +1,5496 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Alternative sampling backends\n", + "\n", + "In Bambi, the sampler used is automatically selected given the type of variables used in the model. For inference, Bambi supports both MCMC and variational inference. By default, Bambi uses PyMC's implementation of the adaptive Hamiltonian Monte Carlo (HMC) algorithm for sampling. Also known as the No-U-Turn Sampler (NUTS). This sampler is a good choice for many models. However, it is not the only sampling method, nor is PyMC the only library implementing NUTS. \n", + "\n", + "To this extent, Bambi supports multiple backends for MCMC sampling such as NumPyro and Blackjax. This notebook will cover how to use such alternatives in Bambi.\n", + "\n", + "_Note_: Bambi utilizes [bayeux](https://github.com/jax-ml/bayeux) to access a variety of sampling backends. Thus, you will need to install the optional dependencies in the Bambi [pyproject.toml](https://github.com/bambinos/bambi/blob/main/pyproject.toml) file to use these backends." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "import arviz as az\n", + "import bambi as bmb\n", + "import bayeux as bx\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import pandas as pd" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## bayeux\n", + "\n", + "Bambi leverages `bayeux` to access different sampling backends. In short, `bayeux` lets you write a probabilistic model in JAX and immediately have access to state-of-the-art inference methods. \n", + "\n", + "Since the underlying Bambi model is a PyMC model, this PyMC model can be \"given\" to `bayeux`. Then, we can choose from a variety of MCMC methods to perform inference. Below, the list of alternative MCMC methods to use in Bambi is shown." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['blackjax_hmc',\n", + " 'blackjax_chees_hmc',\n", + " 'blackjax_meads_hmc',\n", + " 'blackjax_nuts',\n", + " 'blackjax_hmc_pathfinder',\n", + " 'blackjax_nuts_pathfinder',\n", + " 'flowmc_rqspline_hmc',\n", + " 'flowmc_rqspline_mala',\n", + " 'flowmc_realnvp_hmc',\n", + " 'flowmc_realnvp_mala',\n", + " 'numpyro_hmc',\n", + " 'numpyro_nuts']" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Tensorflow probability based methods are currently not supported\n", + "mcmc_methods = [getattr(bx.mcmc, k).name for k in bx.mcmc.__all__ if \"tfp\" not in getattr(bx.mcmc, k).name ]\n", + "mcmc_methods" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`bayeux` lets us have access to Blackjax, FlowMC, and NumPyro backends. In the section below, we will show how to use these backends in Bambi." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Specifying an `inference_method`\n", + "\n", + "First, we simulate some data to use in the examples." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "num_samples = 100\n", + "num_features = 1\n", + "noise_std = 1.0\n", + "random_seed = 42\n", + "\n", + "np.random.seed(random_seed)\n", + "\n", + "coefficients = np.random.randn(num_features)\n", + "X = np.random.randn(num_samples, num_features)\n", + "error = np.random.normal(scale=noise_std, size=num_samples)\n", + "y = X @ coefficients + error\n", + "\n", + "data = pd.DataFrame({\"y\": y, \"x\": X.flatten()})" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "model = bmb.Model(\"y ~ x\", data)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To use a different backend, we pass the name of the `bayeux` MCMC inference method to the `inference_method` parameter of the `fit` method." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Blackjax" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "
\n", + "
\n", + "
arviz.InferenceData
\n", + "
\n", + "
    \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:    (chain: 8, draw: 500)\n",
      +       "Coordinates:\n",
      +       "  * chain      (chain) int64 0 1 2 3 4 5 6 7\n",
      +       "  * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499\n",
      +       "Data variables:\n",
      +       "    y_sigma    (chain, draw) float64 0.9098 0.9531 0.9082 ... 0.878 0.9928\n",
      +       "    Intercept  (chain, draw) float64 -0.02762 0.07108 ... 0.01903 0.0508\n",
      +       "    x          (chain, draw) float64 0.3369 0.395 0.3184 ... 0.2896 0.4378\n",
      +       "Attributes:\n",
      +       "    created_at:                  2024-03-01T14:56:58.167509\n",
      +       "    arviz_version:               0.17.0\n",
      +       "    modeling_interface:          bambi\n",
      +       "    modeling_interface_version:  0.13.1.dev16+g9a1387a7.d20240204

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:          (chain: 8, draw: 500)\n",
      +       "Coordinates:\n",
      +       "  * chain            (chain) int64 0 1 2 3 4 5 6 7\n",
      +       "  * draw             (draw) int64 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499\n",
      +       "Data variables:\n",
      +       "    lp               (chain, draw) float64 -139.5 -139.5 ... -139.9 -139.9\n",
      +       "    step_size        (chain, draw) float64 0.7258 0.7258 ... 0.7077 0.7077\n",
      +       "    diverging        (chain, draw) bool False False False ... False False False\n",
      +       "    energy           (chain, draw) float64 141.2 139.9 139.7 ... 140.1 140.3\n",
      +       "    tree_depth       (chain, draw) int64 2 3 3 3 3 2 3 3 2 ... 2 3 3 3 2 2 3 3 3\n",
      +       "    n_steps          (chain, draw) int64 3 7 7 7 7 3 7 7 3 ... 3 7 7 7 3 3 7 7 7\n",
      +       "    acceptance_rate  (chain, draw) float64 0.9614 0.9746 ... 0.9652 0.9926\n",
      +       "Attributes:\n",
      +       "    created_at:                  2024-03-01T14:56:58.169388\n",
      +       "    arviz_version:               0.17.0\n",
      +       "    modeling_interface:          bambi\n",
      +       "    modeling_interface_version:  0.13.1.dev16+g9a1387a7.d20240204

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
\n", + "
\n", + " " + ], + "text/plain": [ + "Inference data with groups:\n", + "\t> posterior\n", + "\t> sample_stats" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "blackjax_nuts_idata = model.fit(inference_method=\"blackjax_nuts\")\n", + "blackjax_nuts_idata" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Different backends have different naming conventions for the parameters specific to that MCMC method. Thus, to specify backend-specific parameters, pass your own `kwargs` to the `fit` method.\n", + "\n", + "Each algorithm has a `.get_kwargs()` method that tells you how it will be called, and what functions are being called." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{ blackjax.base.AdaptationAlgorithm>: {'logdensity_fn': .wrap_log_density..wrapped(args)>,\n", + " 'is_mass_matrix_diagonal': True,\n", + " 'initial_step_size': 1.0,\n", + " 'target_acceptance_rate': 0.8,\n", + " 'progress_bar': False,\n", + " 'algorithm': blackjax.mcmc.nuts.nuts},\n", + " 'adapt.run': {'num_steps': 500},\n", + " blackjax.mcmc.nuts.nuts: {'max_num_doublings': 10,\n", + " 'divergence_threshold': 1000,\n", + " 'integrator': .euclidean_integrator(logdensity_fn: Callable, kinetic_energy_fn: blackjax.mcmc.metrics.KineticEnergy) -> Callable[[blackjax.mcmc.integrators.IntegratorState, float], blackjax.mcmc.integrators.IntegratorState]>,\n", + " 'logdensity_fn': .wrap_log_density..wrapped(args)>,\n", + " 'step_size': 0.5},\n", + " 'extra_parameters': {'chain_method': 'vectorized',\n", + " 'num_chains': 8,\n", + " 'num_draws': 500,\n", + " 'num_adapt_draws': 500,\n", + " 'return_pytree': False}}" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "bx.Model.from_pymc(model.backend.model).mcmc.blackjax_nuts.get_kwargs()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, we can identify the kwargs we would like to change and pass to the `fit` method." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "
\n", + "
\n", + "
arviz.InferenceData
\n", + "
\n", + "
    \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:    (chain: 4, draw: 250)\n",
      +       "Coordinates:\n",
      +       "  * chain      (chain) int64 0 1 2 3\n",
      +       "  * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 242 243 244 245 246 247 248 249\n",
      +       "Data variables:\n",
      +       "    y_sigma    (chain, draw) float64 1.078 1.05 0.8647 ... 0.856 0.9391 0.9165\n",
      +       "    Intercept  (chain, draw) float64 -0.1116 -0.1474 ... -0.04961 0.0266\n",
      +       "    x          (chain, draw) float64 0.4042 0.3106 0.4226 ... 0.2611 0.3592\n",
      +       "Attributes:\n",
      +       "    created_at:                  2024-03-01T14:57:03.782531\n",
      +       "    arviz_version:               0.17.0\n",
      +       "    modeling_interface:          bambi\n",
      +       "    modeling_interface_version:  0.13.1.dev16+g9a1387a7.d20240204

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:          (chain: 4, draw: 250)\n",
      +       "Coordinates:\n",
      +       "  * chain            (chain) int64 0 1 2 3\n",
      +       "  * draw             (draw) int64 0 1 2 3 4 5 6 ... 243 244 245 246 247 248 249\n",
      +       "Data variables:\n",
      +       "    lp               (chain, draw) float64 -142.2 -141.9 ... -139.9 -139.3\n",
      +       "    step_size        (chain, draw) float64 0.9072 0.9072 ... 0.7606 0.7606\n",
      +       "    diverging        (chain, draw) bool False False False ... False False False\n",
      +       "    energy           (chain, draw) float64 144.6 143.1 142.2 ... 141.1 140.1\n",
      +       "    tree_depth       (chain, draw) int64 3 3 2 3 2 1 2 2 2 ... 3 3 2 2 3 3 2 3 2\n",
      +       "    n_steps          (chain, draw) int64 7 7 3 7 3 1 3 3 3 ... 7 7 3 3 7 7 3 7 3\n",
      +       "    acceptance_rate  (chain, draw) float64 1.0 0.9854 0.9968 ... 0.9882 0.9931\n",
      +       "Attributes:\n",
      +       "    created_at:                  2024-03-01T14:57:03.784254\n",
      +       "    arviz_version:               0.17.0\n",
      +       "    modeling_interface:          bambi\n",
      +       "    modeling_interface_version:  0.13.1.dev16+g9a1387a7.d20240204

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
\n", + "
\n", + " " + ], + "text/plain": [ + "Inference data with groups:\n", + "\t> posterior\n", + "\t> sample_stats" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "kwargs = {\n", + " \"adapt.run\": {\"num_steps\": 500},\n", + " \"num_chains\": 4,\n", + " \"num_draws\": 250,\n", + " \"num_adapt_draws\": 250\n", + "}\n", + "\n", + "blackjax_nuts_idata = model.fit(inference_method=\"blackjax_nuts\", **kwargs)\n", + "blackjax_nuts_idata" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### NumPyro" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "sample: 100%|██████████| 1500/1500 [00:02<00:00, 551.76it/s]\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + "
\n", + "
arviz.InferenceData
\n", + "
\n", + "
    \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:    (chain: 8, draw: 1000)\n",
      +       "Coordinates:\n",
      +       "  * chain      (chain) int64 0 1 2 3 4 5 6 7\n",
      +       "  * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999\n",
      +       "Data variables:\n",
      +       "    Intercept  (chain, draw) float64 -0.02485 0.1376 -0.00766 ... 0.01202 0.0375\n",
      +       "    x          (chain, draw) float64 0.4336 0.4907 0.4996 ... 0.4032 0.3964\n",
      +       "    y_sigma    (chain, draw) float64 0.9225 1.015 0.9409 ... 0.8574 0.9083 0.822\n",
      +       "Attributes:\n",
      +       "    created_at:                  2024-03-01T14:57:07.292211\n",
      +       "    arviz_version:               0.17.0\n",
      +       "    inference_library:           numpyro\n",
      +       "    inference_library_version:   0.13.2\n",
      +       "    modeling_interface:          bambi\n",
      +       "    modeling_interface_version:  0.13.1.dev16+g9a1387a7.d20240204

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:          (chain: 8, draw: 1000)\n",
      +       "Coordinates:\n",
      +       "  * chain            (chain) int64 0 1 2 3 4 5 6 7\n",
      +       "  * draw             (draw) int64 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999\n",
      +       "Data variables:\n",
      +       "    acceptance_rate  (chain, draw) float64 0.9973 0.6392 0.987 ... 0.9744 0.8087\n",
      +       "    step_size        (chain, draw) float64 0.7525 0.7525 ... 0.8295 0.8295\n",
      +       "    diverging        (chain, draw) bool False False False ... False False False\n",
      +       "    energy           (chain, draw) float64 140.6 143.8 141.9 ... 140.8 141.4\n",
      +       "    n_steps          (chain, draw) int64 3 3 3 3 1 1 3 7 7 ... 7 7 3 7 7 7 15 3\n",
      +       "    tree_depth       (chain, draw) int64 2 2 2 2 1 1 2 3 3 ... 2 3 3 2 3 3 3 4 2\n",
      +       "    lp               (chain, draw) float64 139.7 141.1 140.3 ... 139.4 141.0\n",
      +       "Attributes:\n",
      +       "    created_at:                  2024-03-01T14:57:07.316723\n",
      +       "    arviz_version:               0.17.0\n",
      +       "    inference_library:           numpyro\n",
      +       "    inference_library_version:   0.13.2\n",
      +       "    modeling_interface:          bambi\n",
      +       "    modeling_interface_version:  0.13.1.dev16+g9a1387a7.d20240204

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
\n", + "
\n", + " " + ], + "text/plain": [ + "Inference data with groups:\n", + "\t> posterior\n", + "\t> sample_stats" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "numpyro_nuts_idata = model.fit(inference_method=\"numpyro_nuts\")\n", + "numpyro_nuts_idata" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### flowMC" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "No autotune found, use input sampler_params\n", + "Training normalizing flow\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Tuning global sampler: 100%|██████████| 5/5 [00:51<00:00, 10.23s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Starting Production run\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Production run: 100%|██████████| 5/5 [00:00<00:00, 9.38it/s]\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + "
\n", + "
arviz.InferenceData
\n", + "
\n", + "
    \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:    (chain: 20, draw: 500)\n",
      +       "Coordinates:\n",
      +       "  * chain      (chain) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19\n",
      +       "  * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499\n",
      +       "Data variables:\n",
      +       "    y_sigma    (chain, draw) float64 0.8082 1.024 1.024 ... 0.971 0.971 0.971\n",
      +       "    Intercept  (chain, draw) float64 0.09035 0.06867 0.06867 ... -0.1322 -0.1322\n",
      +       "    x          (chain, draw) float64 0.4452 0.503 0.503 ... 0.3238 0.3238 0.3238\n",
      +       "Attributes:\n",
      +       "    created_at:                  2024-03-01T14:57:59.802971\n",
      +       "    arviz_version:               0.17.0\n",
      +       "    modeling_interface:          bambi\n",
      +       "    modeling_interface_version:  0.13.1.dev16+g9a1387a7.d20240204

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
\n", + "
\n", + " " + ], + "text/plain": [ + "Inference data with groups:\n", + "\t> posterior" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "flowmc_idata = model.fit(inference_method=\"flowmc_realnvp_hmc\")\n", + "flowmc_idata" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Sampler comparisons\n", + "\n", + "With ArviZ, we can compare the inference result summaries of the samplers. _Note:_ We can't use `az.compare` as not each inference data object returns the pointwise log-probabilities. Thus, an error would be raised." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
meansdhdi_3%hdi_97%mcse_meanmcse_sdess_bulkess_tailr_hat
y_sigma0.9450.0700.8191.0800.0020.0021044.0667.01.0
Intercept0.0180.089-0.1560.1850.0030.002844.0733.01.0
x0.3580.1050.1630.5540.0040.003829.0767.01.0
\n", + "
" + ], + "text/plain": [ + " mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk \\\n", + "y_sigma 0.945 0.070 0.819 1.080 0.002 0.002 1044.0 \n", + "Intercept 0.018 0.089 -0.156 0.185 0.003 0.002 844.0 \n", + "x 0.358 0.105 0.163 0.554 0.004 0.003 829.0 \n", + "\n", + " ess_tail r_hat \n", + "y_sigma 667.0 1.0 \n", + "Intercept 733.0 1.0 \n", + "x 767.0 1.0 " + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "az.summary(blackjax_nuts_idata)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
meansdhdi_3%hdi_97%mcse_meanmcse_sdess_bulkess_tailr_hat
Intercept0.0220.097-0.1490.2170.0010.0017412.05758.01.0
x0.3590.1050.1590.5550.0010.0017406.05967.01.0
y_sigma0.9470.0690.8221.0790.0010.0017371.05405.01.0
\n", + "
" + ], + "text/plain": [ + " mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk \\\n", + "Intercept 0.022 0.097 -0.149 0.217 0.001 0.001 7412.0 \n", + "x 0.359 0.105 0.159 0.555 0.001 0.001 7406.0 \n", + "y_sigma 0.947 0.069 0.822 1.079 0.001 0.001 7371.0 \n", + "\n", + " ess_tail r_hat \n", + "Intercept 5758.0 1.0 \n", + "x 5967.0 1.0 \n", + "y_sigma 5405.0 1.0 " + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "az.summary(numpyro_nuts_idata)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
meansdhdi_3%hdi_97%mcse_meanmcse_sdess_bulkess_tailr_hat
y_sigma0.9460.0670.8251.0760.0010.0016260.05213.01.00
Intercept0.0130.093-0.1650.1900.0030.002924.01302.01.02
x0.3590.1030.1660.5560.0010.0015132.05790.01.00
\n", + "
" + ], + "text/plain": [ + " mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk \\\n", + "y_sigma 0.946 0.067 0.825 1.076 0.001 0.001 6260.0 \n", + "Intercept 0.013 0.093 -0.165 0.190 0.003 0.002 924.0 \n", + "x 0.359 0.103 0.166 0.556 0.001 0.001 5132.0 \n", + "\n", + " ess_tail r_hat \n", + "y_sigma 5213.0 1.00 \n", + "Intercept 1302.0 1.02 \n", + "x 5790.0 1.00 " + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "az.summary(flowmc_idata)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "Thanks to `bayeux`, we can use three different sampling backends and 10+ alternative MCMC methods in Bambi. Using these methods is as simple as passing the inference name to the `inference_method` of the `fit` method." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The watermark extension is already loaded. To reload it, use:\n", + " %reload_ext watermark\n", + "Last updated: Fri Mar 01 2024\n", + "\n", + "Python implementation: CPython\n", + "Python version : 3.11.7\n", + "IPython version : 8.21.0\n", + "\n", + "arviz : 0.17.0\n", + "bambi : 0.13.1.dev16+g9a1387a7.d20240204\n", + "bayeux : 0.1.9\n", + "numpy : 1.26.3\n", + "matplotlib: 3.8.2\n", + "pandas : 2.2.0\n", + "\n", + "Watermark: 2.4.3\n", + "\n" + ] + } + ], + "source": [ + "%load_ext watermark\n", + "%watermark -n -u -v -iv -w" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "bayeux_bambi", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/notebooks/gallery.yml b/docs/notebooks/gallery.yml index e4bb954a8..cbce285e2 100644 --- a/docs/notebooks/gallery.yml +++ b/docs/notebooks/gallery.yml @@ -133,4 +133,10 @@ - title: Advanced interpret usage subtitle: Create data grids and compute complex quantities of interest href: interpret_advanced_usage.ipynb - thumbnail: thumbnails/advanced_interpret.png \ No newline at end of file + thumbnail: thumbnails/advanced_interpret.png +- category: Alternative sampling backends + description: "" + tiles: + - title: Using other samplers + subtitle: JAX based samplers + href: alternative_samplers.ipynb \ No newline at end of file From cdcf1041e6c655a93ac6cac733361098342866be Mon Sep 17 00:00:00 2001 From: GStechschulte Date: Fri, 1 Mar 2024 16:17:01 +0100 Subject: [PATCH 17/34] tests for JAX based samplers except TFP --- tests/test_alternative_samplers.py | 30 ++++++++++-------------------- 1 file changed, 10 insertions(+), 20 deletions(-) diff --git a/tests/test_alternative_samplers.py b/tests/test_alternative_samplers.py index 816755648..a8cfc6805 100644 --- a/tests/test_alternative_samplers.py +++ b/tests/test_alternative_samplers.py @@ -1,10 +1,14 @@ import bambi as bmb +import bayeux as bx import numpy as np import pandas as pd import pytest +# Tensorflow probability based samplers do not work with Bambi models yet. +MCMC_METHODS = [getattr(bx.mcmc, k).name for k in bx.mcmc.__all__ if "tfp" not in getattr(bx.mcmc, k).name ] + @pytest.fixture(scope="module") def data_n100(): size = 100 @@ -52,27 +56,13 @@ def test_vi(): ) -@pytest.mark.parametrize( - "args", - [ - ("mcmc", {}), - ("numpyro_nuts", {"chain_method": "vectorized"}), - ("blackjax_nuts", {"chain_method": "vectorized"}), - ], -) -def test_logistic_regression_categoric_alternative_samplers(data_n100, args): +@pytest.mark.parametrize("sampler", MCMC_METHODS) +def test_logistic_regression_categoric_alternative_samplers(data_n100, sampler): model = bmb.Model("b1 ~ n1", data_n100, family="bernoulli") - model.fit(tune=50, draws=50, inference_method=args[0], **args[1]) + model.fit(inference_method=sampler) -@pytest.mark.parametrize( - "args", - [ - ("mcmc", {}), - ("numpyro_nuts", {"chain_method": "vectorized"}), - ("blackjax_nuts", {"chain_method": "vectorized"}), - ], -) -def test_regression_alternative_samplers(data_n100, args): +@pytest.mark.parametrize("sampler", MCMC_METHODS) +def test_regression_alternative_samplers(data_n100, sampler): model = bmb.Model("n1 ~ n2", data_n100) - model.fit(tune=50, draws=50, inference_method=args[0], **args[1]) + model.fit(inference_method=sampler) From bf1e478e6c1ca024cf8c7507bca846f1e4eeda7f Mon Sep 17 00:00:00 2001 From: GStechschulte Date: Fri, 1 Mar 2024 17:00:43 +0100 Subject: [PATCH 18/34] add TFP backend example --- docs/notebooks/alternative_samplers.ipynb | 1777 ++++++++++++++++++--- 1 file changed, 1569 insertions(+), 208 deletions(-) diff --git a/docs/notebooks/alternative_samplers.ipynb b/docs/notebooks/alternative_samplers.ipynb index 3011fe966..164d40bee 100644 --- a/docs/notebooks/alternative_samplers.ipynb +++ b/docs/notebooks/alternative_samplers.ipynb @@ -15,7 +15,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -40,13 +40,16 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "['blackjax_hmc',\n", + "['tfp_hmc',\n", + " 'tfp_nuts',\n", + " 'tfp_snaper_hmc',\n", + " 'blackjax_hmc',\n", " 'blackjax_chees_hmc',\n", " 'blackjax_meads_hmc',\n", " 'blackjax_nuts',\n", @@ -60,14 +63,14 @@ " 'numpyro_nuts']" ] }, - "execution_count": 5, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Tensorflow probability based methods are currently not supported\n", - "mcmc_methods = [getattr(bx.mcmc, k).name for k in bx.mcmc.__all__ if \"tfp\" not in getattr(bx.mcmc, k).name ]\n", + "mcmc_methods = [getattr(bx.mcmc, k).name for k in bx.mcmc.__all__]\n", "mcmc_methods" ] }, @@ -89,7 +92,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -110,7 +113,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -3012,21 +3015,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### NumPyro" + "### Tensorflow probability" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 5, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "sample: 100%|██████████| 1500/1500 [00:02<00:00, 551.76it/s]\n" - ] - }, { "data": { "text/html": [ @@ -3038,8 +3034,8 @@ "
    \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -3412,62 +3408,60 @@ " * chain (chain) int64 0 1 2 3 4 5 6 7\n", " * draw (draw) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999\n", "Data variables:\n", - " Intercept (chain, draw) float64 -0.02485 0.1376 -0.00766 ... 0.01202 0.0375\n", - " x (chain, draw) float64 0.4336 0.4907 0.4996 ... 0.4032 0.3964\n", - " y_sigma (chain, draw) float64 0.9225 1.015 0.9409 ... 0.8574 0.9083 0.822\n", + " y_sigma (chain, draw) float64 0.9946 0.8708 0.8651 ... 0.9908 0.9958\n", + " Intercept (chain, draw) float64 -0.09685 -0.01575 0.0419 ... 0.1091 0.1152\n", + " x (chain, draw) float64 0.4584 0.399 0.4485 ... 0.5167 0.4703\n", "Attributes:\n", - " created_at: 2024-03-01T14:57:07.292211\n", + " created_at: 2024-03-01T15:57:23.746257\n", " arviz_version: 0.17.0\n", - " inference_library: numpyro\n", - " inference_library_version: 0.13.2\n", " modeling_interface: bambi\n", - " modeling_interface_version: 0.13.1.dev16+g9a1387a7.d20240204
  • created_at :
    2024-03-01T15:57:23.746257
    arviz_version :
    0.17.0
    modeling_interface :
    bambi
    modeling_interface_version :
    0.13.1.dev16+g9a1387a7.d20240204

\n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -3840,89 +3834,81 @@ " * chain (chain) int64 0 1 2 3 4 5 6 7\n", " * draw (draw) int64 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999\n", "Data variables:\n", - " acceptance_rate (chain, draw) float64 0.9973 0.6392 0.987 ... 0.9744 0.8087\n", - " step_size (chain, draw) float64 0.7525 0.7525 ... 0.8295 0.8295\n", + " accept_ratio (chain, draw) float64 1.0 0.9889 0.9769 ... 1.0 0.9975 1.0\n", " diverging (chain, draw) bool False False False ... False False False\n", - " energy (chain, draw) float64 140.6 143.8 141.9 ... 140.8 141.4\n", - " n_steps (chain, draw) int64 3 3 3 3 1 1 3 7 7 ... 7 7 3 7 7 7 15 3\n", - " tree_depth (chain, draw) int64 2 2 2 2 1 1 2 3 3 ... 2 3 3 2 3 3 3 4 2\n", - " lp (chain, draw) float64 139.7 141.1 140.3 ... 139.4 141.0\n", + " is_accepted (chain, draw) bool True True True True ... True True True\n", + " n_steps (chain, draw) int32 7 3 7 7 3 7 3 7 7 ... 3 3 7 3 7 1 1 3 7\n", + " step_size (chain, draw) float64 0.5332 0.5332 0.5332 ... nan nan nan\n", + " target_log_prob (chain, draw) float64 -141.0 -139.9 ... -140.9 -140.5\n", + " tune (chain, draw) float64 0.0 0.0 0.0 0.0 ... nan nan nan nan\n", "Attributes:\n", - " created_at: 2024-03-01T14:57:07.316723\n", + " created_at: 2024-03-01T15:57:23.747950\n", " arviz_version: 0.17.0\n", - " inference_library: numpyro\n", - " inference_library_version: 0.13.2\n", " modeling_interface: bambi\n", - " modeling_interface_version: 0.13.1.dev16+g9a1387a7.d20240204
  • created_at :
    2024-03-01T15:57:23.747950
    arviz_version :
    0.17.0
    modeling_interface :
    bambi
    modeling_interface_version :
    0.13.1.dev16+g9a1387a7.d20240204

  • \n", " \n", " \n", " \n", @@ -4276,55 +4262,33 @@ "\t> sample_stats" ] }, - "execution_count": 11, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "numpyro_nuts_idata = model.fit(inference_method=\"numpyro_nuts\")\n", - "numpyro_nuts_idata" + "tfp_nuts_idata = model.fit(inference_method=\"tfp_nuts\")\n", + "tfp_nuts_idata" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### flowMC" + "### NumPyro" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 11, "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "No autotune found, use input sampler_params\n", - "Training normalizing flow\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Tuning global sampler: 100%|██████████| 5/5 [00:51<00:00, 10.23s/it]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Starting Production run\n" - ] - }, { "name": "stderr", "output_type": "stream", "text": [ - "Production run: 100%|██████████| 5/5 [00:00<00:00, 9.38it/s]\n" + "sample: 100%|██████████| 1500/1500 [00:02<00:00, 551.76it/s]\n" ] }, { @@ -4338,8 +4302,8 @@ "
      \n", " \n", "
    • \n", - " \n", - " \n", + " \n", + " \n", "
      \n", "
      \n", "
        \n", @@ -4707,64 +4671,1364 @@ " fill: currentColor;\n", "}\n", "
        <xarray.Dataset>\n",
        -       "Dimensions:    (chain: 20, draw: 500)\n",
        +       "Dimensions:    (chain: 8, draw: 1000)\n",
                "Coordinates:\n",
        -       "  * chain      (chain) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19\n",
        -       "  * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499\n",
        +       "  * chain      (chain) int64 0 1 2 3 4 5 6 7\n",
        +       "  * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999\n",
                "Data variables:\n",
        -       "    y_sigma    (chain, draw) float64 0.8082 1.024 1.024 ... 0.971 0.971 0.971\n",
        -       "    Intercept  (chain, draw) float64 0.09035 0.06867 0.06867 ... -0.1322 -0.1322\n",
        -       "    x          (chain, draw) float64 0.4452 0.503 0.503 ... 0.3238 0.3238 0.3238\n",
        +       "    Intercept  (chain, draw) float64 -0.02485 0.1376 -0.00766 ... 0.01202 0.0375\n",
        +       "    x          (chain, draw) float64 0.4336 0.4907 0.4996 ... 0.4032 0.3964\n",
        +       "    y_sigma    (chain, draw) float64 0.9225 1.015 0.9409 ... 0.8574 0.9083 0.822\n",
                "Attributes:\n",
        -       "    created_at:                  2024-03-01T14:57:59.802971\n",
        +       "    created_at:                  2024-03-01T14:57:07.292211\n",
                "    arviz_version:               0.17.0\n",
        +       "    inference_library:           numpyro\n",
        +       "    inference_library_version:   0.13.2\n",
                "    modeling_interface:          bambi\n",
        -       "    modeling_interface_version:  0.13.1.dev16+g9a1387a7.d20240204
    • created_at :
      2024-03-01T14:57:07.292211
      arviz_version :
      0.17.0
      inference_library :
      numpyro
      inference_library_version :
      0.13.2
      modeling_interface :
      bambi
      modeling_interface_version :
      0.13.1.dev16+g9a1387a7.d20240204

    \n", " \n", " \n", " \n", " \n", - " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:          (chain: 8, draw: 1000)\n",
      +       "Coordinates:\n",
      +       "  * chain            (chain) int64 0 1 2 3 4 5 6 7\n",
      +       "  * draw             (draw) int64 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999\n",
      +       "Data variables:\n",
      +       "    acceptance_rate  (chain, draw) float64 0.9973 0.6392 0.987 ... 0.9744 0.8087\n",
      +       "    step_size        (chain, draw) float64 0.7525 0.7525 ... 0.8295 0.8295\n",
      +       "    diverging        (chain, draw) bool False False False ... False False False\n",
      +       "    energy           (chain, draw) float64 140.6 143.8 141.9 ... 140.8 141.4\n",
      +       "    n_steps          (chain, draw) int64 3 3 3 3 1 1 3 7 7 ... 7 7 3 7 7 7 15 3\n",
      +       "    tree_depth       (chain, draw) int64 2 2 2 2 1 1 2 3 3 ... 2 3 3 2 3 3 3 4 2\n",
      +       "    lp               (chain, draw) float64 139.7 141.1 140.3 ... 139.4 141.0\n",
      +       "Attributes:\n",
      +       "    created_at:                  2024-03-01T14:57:07.316723\n",
      +       "    arviz_version:               0.17.0\n",
      +       "    inference_library:           numpyro\n",
      +       "    inference_library_version:   0.13.2\n",
      +       "    modeling_interface:          bambi\n",
      +       "    modeling_interface_version:  0.13.1.dev16+g9a1387a7.d20240204

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + " \n", + " \n", + " " + ], + "text/plain": [ + "Inference data with groups:\n", + "\t> posterior\n", + "\t> sample_stats" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "numpyro_nuts_idata = model.fit(inference_method=\"numpyro_nuts\")\n", + "numpyro_nuts_idata" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### flowMC" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "No autotune found, use input sampler_params\n", + "Training normalizing flow\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Tuning global sampler: 100%|██████████| 5/5 [00:51<00:00, 10.23s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Starting Production run\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Production run: 100%|██████████| 5/5 [00:00<00:00, 9.38it/s]\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "
    \n", + "
    \n", + "
    arviz.InferenceData
    \n", + "
    \n", + "
      \n", + " \n", + "
    • \n", + " \n", + " \n", + "
      \n", + "
      \n", + "
        \n", + "
        \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
        <xarray.Dataset>\n",
        +       "Dimensions:    (chain: 20, draw: 500)\n",
        +       "Coordinates:\n",
        +       "  * chain      (chain) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19\n",
        +       "  * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499\n",
        +       "Data variables:\n",
        +       "    y_sigma    (chain, draw) float64 0.8082 1.024 1.024 ... 0.971 0.971 0.971\n",
        +       "    Intercept  (chain, draw) float64 0.09035 0.06867 0.06867 ... -0.1322 -0.1322\n",
        +       "    x          (chain, draw) float64 0.4452 0.503 0.503 ... 0.3238 0.3238 0.3238\n",
        +       "Attributes:\n",
        +       "    created_at:                  2024-03-01T14:57:59.802971\n",
        +       "    arviz_version:               0.17.0\n",
        +       "    modeling_interface:          bambi\n",
        +       "    modeling_interface_version:  0.13.1.dev16+g9a1387a7.d20240204

        \n", + "
      \n", + "
      \n", + "
    • \n", + " \n", + "
    \n", "
    \n", " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
    meansdhdi_3%hdi_97%mcse_meanmcse_sdess_bulkess_tailr_hat
    y_sigma0.9480.0670.8241.0730.0010.0018107.05585.01.0
    Intercept0.0250.095-0.1520.2000.0010.0016772.05624.01.0
    x0.3610.1040.1570.5510.0010.0016682.05414.01.0
    \n", + "" + ], + "text/plain": [ + " mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk \\\n", + "y_sigma 0.948 0.067 0.824 1.073 0.001 0.001 8107.0 \n", + "Intercept 0.025 0.095 -0.152 0.200 0.001 0.001 6772.0 \n", + "x 0.361 0.104 0.157 0.551 0.001 0.001 6682.0 \n", + "\n", + " ess_tail r_hat \n", + "y_sigma 5585.0 1.0 \n", + "Intercept 5624.0 1.0 \n", + "x 5414.0 1.0 " + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "az.summary(tfp_nuts_idata)" + ] + }, { "cell_type": "code", "execution_count": 17, @@ -5439,15 +6802,13 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "The watermark extension is already loaded. To reload it, use:\n", - " %reload_ext watermark\n", "Last updated: Fri Mar 01 2024\n", "\n", "Python implementation: CPython\n", @@ -5456,10 +6817,10 @@ "\n", "arviz : 0.17.0\n", "bambi : 0.13.1.dev16+g9a1387a7.d20240204\n", - "bayeux : 0.1.9\n", "numpy : 1.26.3\n", - "matplotlib: 3.8.2\n", "pandas : 2.2.0\n", + "bayeux : 0.1.9\n", + "matplotlib: 3.8.2\n", "\n", "Watermark: 2.4.3\n", "\n" From 27a41e61df4c4d0acdd00ad5972d68483c3f112b Mon Sep 17 00:00:00 2001 From: GStechschulte Date: Fri, 1 Mar 2024 17:01:15 +0100 Subject: [PATCH 19/34] add TFP MCMC methods --- bambi/backend/pymc.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/bambi/backend/pymc.py b/bambi/backend/pymc.py index 8bb6439ac..991f7a354 100644 --- a/bambi/backend/pymc.py +++ b/bambi/backend/pymc.py @@ -220,7 +220,7 @@ def _run_mcmc( bx_sampler = operator.attrgetter(sampler_backend)( bx_model.mcmc ) # pylint: disable=no-member - idata = bx_sampler(seed=jax.random.key(random_seed), **kwargs) + idata = bx_sampler(seed=jax.random.PRNGKey(random_seed), **kwargs) idata_from = "bayeux" else: raise ValueError( @@ -432,13 +432,4 @@ def _get_bayeux_methods(): bx_modules.append(getattr(module, k).name) bx_methods[mname] = bx_modules - # TFP based methods do not work with Bambi models yet - tfp_mcmc = ["tfp_hmc", "tfp_nuts", "tfp_snaper_hmc"] - for method in tfp_mcmc: - bx_methods["mcmc"].remove(method) - - tfp_vi = ["tfp_factored_surrogate_posterior"] - for method in tfp_vi: - bx_methods["vi"].remove(method) - return bx_methods From 98f7da8860b7c9d330fa508c98e6ff15275f9689 Mon Sep 17 00:00:00 2001 From: GStechschulte Date: Sun, 3 Mar 2024 08:46:48 +0100 Subject: [PATCH 20/34] don't use flowmc, chees, meads for categorical model --- tests/test_alternative_samplers.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/test_alternative_samplers.py b/tests/test_alternative_samplers.py index a8cfc6805..6222f3df3 100644 --- a/tests/test_alternative_samplers.py +++ b/tests/test_alternative_samplers.py @@ -6,8 +6,9 @@ import pytest -# Tensorflow probability based samplers do not work with Bambi models yet. -MCMC_METHODS = [getattr(bx.mcmc, k).name for k in bx.mcmc.__all__ if "tfp" not in getattr(bx.mcmc, k).name ] +MCMC_METHODS = [getattr(bx.mcmc, k).name for k in bx.mcmc.__all__] +MCMC_METHODS_FILTERED = [i for i in MCMC_METHODS if not any(x in i for x in ("flowmc", "chees", "meads"))] + @pytest.fixture(scope="module") def data_n100(): @@ -55,8 +56,8 @@ def test_vi(): (mode_n.item(), std_n.item()), (mode_a.item(), std_a.item()), decimal=2 ) - -@pytest.mark.parametrize("sampler", MCMC_METHODS) +# +@pytest.mark.parametrize("sampler", MCMC_METHODS_FILTERED) def test_logistic_regression_categoric_alternative_samplers(data_n100, sampler): model = bmb.Model("b1 ~ n1", data_n100, family="bernoulli") model.fit(inference_method=sampler) From 4ae1092d0c1a1e61f2e309030573763fd6a3e8ef Mon Sep 17 00:00:00 2001 From: GStechschulte Date: Sun, 3 Mar 2024 08:47:28 +0100 Subject: [PATCH 21/34] call model.backend.inference_methods to show list of samplers --- docs/notebooks/alternative_samplers.ipynb | 194 +++++++++++++++++----- 1 file changed, 150 insertions(+), 44 deletions(-) diff --git a/docs/notebooks/alternative_samplers.ipynb b/docs/notebooks/alternative_samplers.ipynb index 164d40bee..24d610d96 100644 --- a/docs/notebooks/alternative_samplers.ipynb +++ b/docs/notebooks/alternative_samplers.ipynb @@ -15,14 +15,13 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "import arviz as az\n", "import bambi as bmb\n", "import bayeux as bx\n", - "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd" ] @@ -35,13 +34,154 @@ "\n", "Bambi leverages `bayeux` to access different sampling backends. In short, `bayeux` lets you write a probabilistic model in JAX and immediately have access to state-of-the-art inference methods. \n", "\n", - "Since the underlying Bambi model is a PyMC model, this PyMC model can be \"given\" to `bayeux`. Then, we can choose from a variety of MCMC methods to perform inference. Below, the list of alternative MCMC methods to use in Bambi is shown." + "Since the underlying Bambi model is a PyMC model, this PyMC model can be \"given\" to `bayeux`. Then, we can choose from a variety of MCMC methods to perform inference. \n", + "\n", + "To demonstrate the available backends, we will fist simulate data and build a model." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, + "outputs": [], + "source": [ + "num_samples = 100\n", + "num_features = 1\n", + "noise_std = 1.0\n", + "random_seed = 42\n", + "\n", + "np.random.seed(random_seed)\n", + "\n", + "coefficients = np.random.randn(num_features)\n", + "X = np.random.randn(num_samples, num_features)\n", + "error = np.random.normal(scale=noise_std, size=num_samples)\n", + "y = X @ coefficients + error\n", + "\n", + "data = pd.DataFrame({\"y\": y, \"x\": X.flatten()})" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "model = bmb.Model(\"y ~ x\", data)\n", + "model.build()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can call `model.backend.inference_methods` that returns a nested dictionary of the backends and list of inference methods." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'pymc': {'mcmc': ['mcmc'], 'vi': ['vi']},\n", + " 'bayeux': {'mcmc': ['tfp_hmc',\n", + " 'tfp_nuts',\n", + " 'tfp_snaper_hmc',\n", + " 'blackjax_hmc',\n", + " 'blackjax_chees_hmc',\n", + " 'blackjax_meads_hmc',\n", + " 'blackjax_nuts',\n", + " 'blackjax_hmc_pathfinder',\n", + " 'blackjax_nuts_pathfinder',\n", + " 'flowmc_rqspline_hmc',\n", + " 'flowmc_rqspline_mala',\n", + " 'flowmc_realnvp_hmc',\n", + " 'flowmc_realnvp_mala',\n", + " 'numpyro_hmc',\n", + " 'numpyro_nuts'],\n", + " 'optimize': ['jaxopt_bfgs',\n", + " 'jaxopt_gradient_descent',\n", + " 'jaxopt_lbfgs',\n", + " 'jaxopt_nonlinear_cg',\n", + " 'optimistix_bfgs',\n", + " 'optimistix_chord',\n", + " 'optimistix_dogleg',\n", + " 'optimistix_gauss_newton',\n", + " 'optimistix_indirect_levenberg_marquardt',\n", + " 'optimistix_levenberg_marquardt',\n", + " 'optimistix_nelder_mead',\n", + " 'optimistix_newton',\n", + " 'optimistix_nonlinear_cg',\n", + " 'optax_adabelief',\n", + " 'optax_adafactor',\n", + " 'optax_adagrad',\n", + " 'optax_adam',\n", + " 'optax_adamw',\n", + " 'optax_adamax',\n", + " 'optax_amsgrad',\n", + " 'optax_fromage',\n", + " 'optax_lamb',\n", + " 'optax_lion',\n", + " 'optax_noisy_sgd',\n", + " 'optax_novograd',\n", + " 'optax_radam',\n", + " 'optax_rmsprop',\n", + " 'optax_sgd',\n", + " 'optax_sm3',\n", + " 'optax_yogi'],\n", + " 'vi': ['tfp_factored_surrogate_posterior']}}" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "methods = model.backend.inference_methods\n", + "methods" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "With the PyMC backend, we have access to their implementation of the NUTS sampler and mean-field variational inference." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'mcmc': ['mcmc'], 'vi': ['vi']}" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "methods[\"pymc\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`bayeux` lets us have access to Tensorflow probability, Blackjax, FlowMC, and NumPyro backends." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, "outputs": [ { "data": { @@ -63,68 +203,34 @@ " 'numpyro_nuts']" ] }, - "execution_count": 2, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "# Tensorflow probability based methods are currently not supported\n", - "mcmc_methods = [getattr(bx.mcmc, k).name for k in bx.mcmc.__all__]\n", - "mcmc_methods" + "methods[\"bayeux\"][\"mcmc\"]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "`bayeux` lets us have access to Blackjax, FlowMC, and NumPyro backends. In the section below, we will show how to use these backends in Bambi." + "The values of the MCMC and VI keys in the dictionary are the names of the argument you would pass to `inference_method` in `model.fit`. This is shown in the section below." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Specifying an `inference_method`\n", - "\n", - "First, we simulate some data to use in the examples." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "num_samples = 100\n", - "num_features = 1\n", - "noise_std = 1.0\n", - "random_seed = 42\n", - "\n", - "np.random.seed(random_seed)\n", - "\n", - "coefficients = np.random.randn(num_features)\n", - "X = np.random.randn(num_samples, num_features)\n", - "error = np.random.normal(scale=noise_std, size=num_samples)\n", - "y = X @ coefficients + error\n", - "\n", - "data = pd.DataFrame({\"y\": y, \"x\": X.flatten()})" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "model = bmb.Model(\"y ~ x\", data)" + "## Specifying an `inference_method`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "To use a different backend, we pass the name of the `bayeux` MCMC inference method to the `inference_method` parameter of the `fit` method." + "By default, Bambi uses the PyMC NUTS implementation. To use a different backend, pass the name of the `bayeux` MCMC method to the `inference_method` parameter of the `fit` method." ] }, { @@ -1405,7 +1511,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -1430,7 +1536,7 @@ " 'return_pytree': False}}" ] }, - "execution_count": 9, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } From 81936a2418ba61b6fdfd40cfc81120603c519d67 Mon Sep 17 00:00:00 2001 From: GStechschulte Date: Sun, 3 Mar 2024 08:47:47 +0100 Subject: [PATCH 22/34] docstring changes --- bambi/models.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/bambi/models.py b/bambi/models.py index 5dac88881..9b5421d51 100644 --- a/bambi/models.py +++ b/bambi/models.py @@ -266,9 +266,9 @@ def fit( using the ``fit`` function. Finally, ``"laplace"``, in which case a Laplace approximation is used and is not recommended other than for pedagogical use. - To use the PyMC numpyro and blackjax samplers, use ``numpyro_nuts`` or ``blackjax_nuts`` - respectively. Both methods will only work if you can use NUTS sampling, so your model - must be differentiable. + To get a list of JAX based inference methods, call + ``model.backend.inference_methods['bayeux']``. This will return a dictionary of the + available methods such as ``blackjax_nuts``, ``numpyro_nuts``, among others. init : str Initialization method. Defaults to ``"auto"``. The available methods are: * auto: Use ``"jitter+adapt_diag"`` and if this method fails it uses ``"adapt_diag"``. @@ -306,7 +306,8 @@ def fit( Returns ------- An ArviZ ``InferenceData`` instance if inference_method is ``"mcmc"`` (default), - "numpyro_nuts", "blackjax_nuts" or "laplace". + "laplace", or one of the MCMC methods in + ``model.backend.inference_methods['bayeux']['mcmc]``. An ``Approximation`` object if ``"vi"``. """ method = kwargs.pop("method", None) From f6d8894715c5f624fb8a6f91a8d752058b9e5fa2 Mon Sep 17 00:00:00 2001 From: GStechschulte Date: Sun, 3 Mar 2024 08:48:19 +0100 Subject: [PATCH 23/34] inference_methods attribute and change JAX random seed --- bambi/backend/pymc.py | 39 ++++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/bambi/backend/pymc.py b/bambi/backend/pymc.py index 991f7a354..2cd2441ec 100644 --- a/bambi/backend/pymc.py +++ b/bambi/backend/pymc.py @@ -47,7 +47,7 @@ def __init__(self): self.spec = None self.components = {} self.bayeux_methods = _get_bayeux_methods() - self.pymc_methods = {"mcmc": ["mcmc"]} + self.pymc_methods = {"mcmc": ["mcmc"], "vi": ["vi"]} def build(self, spec): """Compile the PyMC model from an abstract model specification. @@ -112,7 +112,7 @@ def run( inference_method, **kwargs, ) - elif inference_method == "vi": + elif inference_method in self.pymc_methods["vi"]: result = self._run_vi(**kwargs) elif inference_method == "laplace": result = self._run_laplace(draws, omit_offsets, include_mean) @@ -212,15 +212,19 @@ def _run_mcmc( import bayeux as bx # pylint: disable=import-outside-toplevel import jax # pylint: disable=import-outside-toplevel - # Seed is required for bayeux - if random_seed is None: - random_seed = 0 + # Set the seed for reproducibility if provided + if random_seed is not None: + if not isinstance(random_seed, int): + random_seed = random_seed[0] + np.random.seed(random_seed) + + jax_seed = jax.random.PRNGKey(np.random.randint(2**32 - 1)) bx_model = bx.Model.from_pymc(self.model) bx_sampler = operator.attrgetter(sampler_backend)( - bx_model.mcmc - ) # pylint: disable=no-member - idata = bx_sampler(seed=jax.random.PRNGKey(random_seed), **kwargs) + bx_model.mcmc # pylint: disable=no-member + ) + idata = bx_sampler(seed=jax_seed, **kwargs) idata_from = "bayeux" else: raise ValueError( @@ -317,8 +321,8 @@ def _run_laplace(self, draws, omit_offsets, include_mean): Mainly for pedagogical use, provides reasonable results for approximately Gaussian posteriors. The approximation can be very poor for some models - like hierarchical ones. Use ``mcmc``, ``numpyro_nuts``, ``blackjax_nuts`` - or ``vi`` for better approximations. + like hierarchical ones. Use ``mcmc``, ``vi``, or JAX based MCMC methods + for better approximations. Parameters ---------- @@ -367,6 +371,10 @@ def constant_components(self): def distributional_components(self): return {k: v for k, v in self.components.items() if isinstance(v, DistributionalComponent)} + @property + def inference_methods(self): + return {"pymc": self.pymc_methods, "bayeux": self.bayeux_methods} + def _posterior_samples_to_idata(samples, model): """Create InferenceData from samples. @@ -424,12 +432,5 @@ def _get_bayeux_methods(): import bayeux as bx # pylint: disable=import-outside-toplevel - bx_methods = {} - for module in bx._src.bayeux._MODULES: # pylint: disable=protected-access - mname = module.__name__.rsplit(".", maxsplit=1)[-1] - bx_modules = [] - for k in module.__all__: - bx_modules.append(getattr(module, k).name) - bx_methods[mname] = bx_modules - - return bx_methods + # Dummy log density to get access to all methods + return bx.Model(lambda x: -(x**2), 0.0).methods From 02d1df6737e53ba5085c796755590a3059677722 Mon Sep 17 00:00:00 2001 From: GStechschulte Date: Mon, 4 Mar 2024 18:52:36 +0100 Subject: [PATCH 24/34] Add FutureWarning to inference_method parameter --- bambi/backend/pymc.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/bambi/backend/pymc.py b/bambi/backend/pymc.py index 2cd2441ec..8a8922ea1 100644 --- a/bambi/backend/pymc.py +++ b/bambi/backend/pymc.py @@ -3,6 +3,7 @@ import logging import operator import traceback +import warnings from copy import deepcopy from importlib.metadata import version @@ -96,6 +97,14 @@ def run( ): """Run PyMC sampler.""" inference_method = inference_method.lower() + + if inference_method == "nuts_numpyro": + inference_method = "numpyro_nuts" + warnings.warn("'nuts_numpyro' has been replaced by 'numpyro_nuts' and will be removed in a future release", category=FutureWarning) + elif inference_method == "nuts_blackjax": + inference_method = "blackjax_nuts" + warnings.warn("'nuts_blackjax' has been replaced by 'blackjax_nuts' and will be removed in a future release", category=FutureWarning) + # NOTE: Methods return different types of objects (idata, approximation, and dictionary) if inference_method in (self.pymc_methods["mcmc"] + self.bayeux_methods["mcmc"]): result = self._run_mcmc( @@ -426,9 +435,8 @@ def _get_bayeux_methods(): A dict where the keys are the module names and the values are the methods available in that module. """ - bx_methods = {} if importlib.util.find_spec("bayeux") is None: - return bx_methods + return {"mcmc": []} import bayeux as bx # pylint: disable=import-outside-toplevel From dd278d42a783e5cc7f6c9bba0302423d3f0d82ec Mon Sep 17 00:00:00 2001 From: GStechschulte Date: Mon, 4 Mar 2024 19:01:29 +0100 Subject: [PATCH 25/34] black formatting and resolve pylint errors --- bambi/backend/pymc.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/bambi/backend/pymc.py b/bambi/backend/pymc.py index 8a8922ea1..82b646ebe 100644 --- a/bambi/backend/pymc.py +++ b/bambi/backend/pymc.py @@ -97,13 +97,21 @@ def run( ): """Run PyMC sampler.""" inference_method = inference_method.lower() - + if inference_method == "nuts_numpyro": inference_method = "numpyro_nuts" - warnings.warn("'nuts_numpyro' has been replaced by 'numpyro_nuts' and will be removed in a future release", category=FutureWarning) + warnings.warn( + "'nuts_numpyro' has been replaced by 'numpyro_nuts' and will be " + "removed in a future release", + category=FutureWarning, + ) elif inference_method == "nuts_blackjax": inference_method = "blackjax_nuts" - warnings.warn("'nuts_blackjax' has been replaced by 'blackjax_nuts' and will be removed in a future release", category=FutureWarning) + warnings.warn( + "'nuts_blackjax' has been replaced by 'blackjax_nuts' and will " + "be removed in a future release", + category=FutureWarning, + ) # NOTE: Methods return different types of objects (idata, approximation, and dictionary) if inference_method in (self.pymc_methods["mcmc"] + self.bayeux_methods["mcmc"]): From b0e94a4a904fc3e438ee3d9ea79afdf2f1a7a79d Mon Sep 17 00:00:00 2001 From: GStechschulte Date: Mon, 4 Mar 2024 19:15:54 +0100 Subject: [PATCH 26/34] fix package name --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 21d977e54..5825e9b40 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ dev = [ "seaborn>=0.9.0", ] jax = [ - "bayeux>=0.1.9", + "bayeux-ml>=0.1.9", "blackjax>=1.0.0", "jax>=0.3.1", "jaxlib>=0.3.1", From 65fd9459668530960cf59997a8f51c6a90de2096 Mon Sep 17 00:00:00 2001 From: GStechschulte Date: Tue, 19 Mar 2024 18:26:47 +0100 Subject: [PATCH 27/34] drop 3.9 and add 3.12 to testing matrix --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 403d5be77..574bcd7fd 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -11,7 +11,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.9", "3.10", "3.11"] + python-version: ["3.10", "3.11", "3.12"] name: Set up Python ${{ matrix.python-version }} steps: From 4712f1a1589e3045fd0cc90b9e475db0b33889fc Mon Sep 17 00:00:00 2001 From: GStechschulte Date: Tue, 19 Mar 2024 18:28:11 +0100 Subject: [PATCH 28/34] change Python versions in requires-python and target-version --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5825e9b40..5cd9c1140 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ requires = ["setuptools>=61.0", "setuptools_scm>=8"] [project] name = "bambi" description = "BAyesian Model Building Interface in Python" -requires-python = ">=3.8" +requires-python = ">=3.10" readme = "README.md" license = {file = "LICENSE"} dynamic = ["version"] @@ -63,4 +63,4 @@ packages = [ [tool.black] line-length = 100 -target-version = ["py39", "py310"] \ No newline at end of file +target-version = ["py310", "py311"] \ No newline at end of file From d508214435dd9f10ecc1f3636186659bbf03dee3 Mon Sep 17 00:00:00 2001 From: GStechschulte Date: Tue, 19 Mar 2024 18:51:18 +0100 Subject: [PATCH 29/34] remove python 3.11 black target-version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 5cd9c1140..c15a72e03 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,4 +63,4 @@ packages = [ [tool.black] line-length = 100 -target-version = ["py310", "py311"] \ No newline at end of file +target-version = ["py310"] \ No newline at end of file From 1d05684dbcd0d7d1bd9260170c4c20fb182e90fb Mon Sep 17 00:00:00 2001 From: GStechschulte Date: Tue, 19 Mar 2024 19:35:33 +0100 Subject: [PATCH 30/34] pin requires-python to <3.13 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index c15a72e03..d23173bf8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ requires = ["setuptools>=61.0", "setuptools_scm>=8"] [project] name = "bambi" description = "BAyesian Model Building Interface in Python" -requires-python = ">=3.10" +requires-python = ">=3.10,<3.13" readme = "README.md" license = {file = "LICENSE"} dynamic = ["version"] From f06715e83deea977cddb3ae0b9140aa6cc5a9fab Mon Sep 17 00:00:00 2001 From: GStechschulte Date: Tue, 19 Mar 2024 21:07:56 +0100 Subject: [PATCH 31/34] pip upgrade setuptools --- .github/workflows/test.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 574bcd7fd..2894c4b2e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -38,6 +38,7 @@ jobs: run: | conda install -c conda-forge python-graphviz conda install pip + pip install --upgrade setuptools # TODO: make conditional on Python version pip install . pip install .[dev] pip install .[jax] From ef575d370e9548ec2dd0a7a96009d99ed5e8124e Mon Sep 17 00:00:00 2001 From: Tomas Capretto Date: Thu, 28 Mar 2024 18:34:39 -0300 Subject: [PATCH 32/34] Bump PyMC to 5.12 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index d23173bf8..6de7a8eae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ dependencies = [ "formulae>=0.5.3", "graphviz", "pandas>=1.0.0", - "pymc>=5.5.0", + "pymc>=5.12.0", ] [project.optional-dependencies] From 9bf90a64674f83153c437882708dcc25a5cabeef Mon Sep 17 00:00:00 2001 From: Tomas Capretto Date: Thu, 28 Mar 2024 18:54:13 -0300 Subject: [PATCH 33/34] Upgrade black and pylint --- bambi/data/__init__.py | 1 + bambi/data/datasets.py | 1 + bambi/defaults/__init__.py | 1 + bambi/families/__init__.py | 1 + bambi/interpret/utils.py | 1 + bambi/priors/__init__.py | 1 + bambi/terms/base.py | 18 ++++++------------ bambi/transformations.py | 1 + pyproject.toml | 4 ++-- 9 files changed, 15 insertions(+), 14 deletions(-) diff --git a/bambi/data/__init__.py b/bambi/data/__init__.py index 38adba619..1f6fb200c 100644 --- a/bambi/data/__init__.py +++ b/bambi/data/__init__.py @@ -1,4 +1,5 @@ """Code for loading datasets.""" + from .datasets import clear_data_home, load_data __all__ = ["clear_data_home", "load_data"] diff --git a/bambi/data/datasets.py b/bambi/data/datasets.py index dc063cb18..0be57e9ba 100644 --- a/bambi/data/datasets.py +++ b/bambi/data/datasets.py @@ -1,4 +1,5 @@ """Base IO code for datasets. Heavily influenced by Arviz's (and scikit-learn's) implementation.""" + import hashlib import itertools import os diff --git a/bambi/defaults/__init__.py b/bambi/defaults/__init__.py index 3dedec422..53ef82498 100644 --- a/bambi/defaults/__init__.py +++ b/bambi/defaults/__init__.py @@ -1,4 +1,5 @@ """Settings for default priors, families, etc. in Bambi.""" + from bambi.defaults.utils import get_default_prior from bambi.defaults.families import get_builtin_family diff --git a/bambi/families/__init__.py b/bambi/families/__init__.py index df1a84a5a..645d27bee 100644 --- a/bambi/families/__init__.py +++ b/bambi/families/__init__.py @@ -1,4 +1,5 @@ """Classes to construct model families.""" + from bambi.families.family import Family from bambi.families.likelihood import Likelihood from bambi.families.link import Link diff --git a/bambi/interpret/utils.py b/bambi/interpret/utils.py index 47bc02864..cbb7bde19 100644 --- a/bambi/interpret/utils.py +++ b/bambi/interpret/utils.py @@ -102,6 +102,7 @@ def set_default_variable_values(self) -> np.ndarray: If categoric dtype the returned value is the unique levels of `variable'. """ + values = None # Otherwise pylint complains terms = get_model_terms(self.model) # get default values for each variable in the model for term in terms.values(): diff --git a/bambi/priors/__init__.py b/bambi/priors/__init__.py index c90e68945..6884486a6 100644 --- a/bambi/priors/__init__.py +++ b/bambi/priors/__init__.py @@ -1,4 +1,5 @@ """Classes to represent prior distributions and methods to set automatic priors""" + from .prior import Prior from .scaler import PriorScaler diff --git a/bambi/terms/base.py b/bambi/terms/base.py index 81fb77a2a..c11f55bc6 100644 --- a/bambi/terms/base.py +++ b/bambi/terms/base.py @@ -13,33 +13,27 @@ class BaseTerm(ABC): @property @abstractmethod - def term(self): - ... + def term(self): ... @property @abstractmethod - def data(self): - ... + def data(self): ... @property @abstractmethod - def name(self): - ... + def name(self): ... @property @abstractmethod - def shape(self): - ... + def shape(self): ... @property @abstractmethod - def levels(self): - ... + def levels(self): ... @property @abstractmethod - def categorical(self): - ... + def categorical(self): ... @property def alias(self): diff --git a/bambi/transformations.py b/bambi/transformations.py index 9442b0a15..eb226ed43 100644 --- a/bambi/transformations.py +++ b/bambi/transformations.py @@ -175,6 +175,7 @@ def weighted(x, weights): weighted.__metadata__ = {"kind": "weighted"} + # pylint: disable = invalid-name @register_stateful_transform class HSGP: # pylint: disable = too-many-instance-attributes diff --git a/pyproject.toml b/pyproject.toml index 6de7a8eae..262482de1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,10 +26,10 @@ dependencies = [ [project.optional-dependencies] dev = [ - "black==22.3.0", + "black==24.3.0", "ipython>=5.8.0,!=8.7.0", "pre-commit>=2.19", - "pylint==2.17.5", + "pylint==3.1.0", "pytest-cov>=2.6.1", "pytest>=4.4.0", "quartodoc==0.6.1", From 9f9d7698e443da7ba53303c02842f8ddb244d1e6 Mon Sep 17 00:00:00 2001 From: Gabriel Stechschulte <63432018+GStechschulte@users.noreply.github.com> Date: Fri, 29 Mar 2024 06:17:50 +0100 Subject: [PATCH 34/34] remove upgrading of setup tools --- .github/workflows/test.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 2894c4b2e..574bcd7fd 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -38,7 +38,6 @@ jobs: run: | conda install -c conda-forge python-graphviz conda install pip - pip install --upgrade setuptools # TODO: make conditional on Python version pip install . pip install .[dev] pip install .[jax]