Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass kwargs to nutpie + create env.yml file #855

Merged
merged 8 commits into from
Dec 21, 2024
71 changes: 55 additions & 16 deletions bambi/backend/pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,26 @@
import operator
import traceback
import warnings

from copy import deepcopy
from importlib.metadata import version

import numpy as np
import pymc as pm
import pytensor.tensor as pt
import xarray as xr

from pymc.backends.arviz import coords_and_dims_for_inferencedata, find_observations
from pymc.util import get_default_varnames
from pytensor.tensor.special import softmax

from bambi.backend.inference_methods import inference_methods
from bambi.backend.links import cloglog, identity, inverse_squared, logit, probit, arctan_2
from bambi.backend.links import (
arctan_2,
cloglog,
identity,
inverse_squared,
logit,
probit,
)
from bambi.backend.model_components import (
ConstantComponent,
DistributionalComponent,
Expand Down Expand Up @@ -127,7 +132,9 @@ def run(
)

# NOTE: Methods return different types of objects (idata, approximation, and dictionary)
if inference_method in (self.pymc_methods["mcmc"] + self.bayeux_methods["mcmc"]):
if inference_method in (
self.pymc_methods["mcmc"] + self.bayeux_methods["mcmc"]
):
result = self._run_mcmc(
draws,
tune,
Expand All @@ -147,7 +154,9 @@ def run(
elif inference_method == "laplace":
result = self._run_laplace(draws, omit_offsets, include_response_params)
else:
raise NotImplementedError(f"'{inference_method}' method has not been implemented")
raise NotImplementedError(
f"'{inference_method}' method has not been implemented"
)

self.fit = True
return result
Expand Down Expand Up @@ -258,15 +267,24 @@ def _run_mcmc(
bx_sampler = operator.attrgetter(sampler_backend)(
bx_model.mcmc # pylint: disable=no-member
)
idata = bx_sampler(seed=jax_seed, **kwargs)
idata = bx_sampler(
AlexAndorra marked this conversation as resolved.
Show resolved Hide resolved
seed=jax_seed,
draws=draws,
tune=tune,
chains=chains,
cores=cores,
**kwargs,
)
idata_from = "bayeux"
else:
raise ValueError(
f"sampler_backend value {sampler_backend} is not valid. Please choose one of"
f" {self.pymc_methods['mcmc'] + self.bayeux_methods['mcmc']}"
)

idata = self._clean_results(idata, omit_offsets, include_response_params, idata_from)
idata = self._clean_results(
idata, omit_offsets, include_response_params, idata_from
)
return idata

def _clean_results(self, idata, omit_offsets, include_response_params, idata_from):
Expand Down Expand Up @@ -300,7 +318,9 @@ def _clean_results(self, idata, omit_offsets, include_response_params, idata_fro
getattr(idata, group).attrs["modeling_interface_version"] = __version__

if omit_offsets:
offset_vars = [var for var in idata.posterior.data_vars if var.endswith("_offset")]
offset_vars = [
var for var in idata.posterior.data_vars if var.endswith("_offset")
]
idata.posterior = idata.posterior.drop_vars(offset_vars)

dims_original = list(self.model.coords)
Expand Down Expand Up @@ -346,7 +366,9 @@ def _clean_results(self, idata, omit_offsets, include_response_params, idata_fro
dims += tuple(response_coords)

posterior = idata.posterior.stack(samples=dims)
coefs = np.vstack([np.atleast_2d(posterior[name].values) for name in common_terms])
coefs = np.vstack(
[np.atleast_2d(posterior[name].values) for name in common_terms]
)
name = get_aliased_name(bambi_component.intercept_term)
center_factor = np.dot(X.mean(0), coefs).reshape(shape)
idata.posterior[name] = idata.posterior[name] - center_factor
Expand Down Expand Up @@ -409,16 +431,24 @@ def _run_laplace(self, draws, omit_offsets, include_response_params):
samples = np.random.multivariate_normal(modes, cov, size=draws)

idata = _posterior_samples_to_idata(samples, self.model)
idata = self._clean_results(idata, omit_offsets, include_response_params, idata_from="pymc")
idata = self._clean_results(
idata, omit_offsets, include_response_params, idata_from="pymc"
)
return idata

@property
def constant_components(self):
return {k: v for k, v in self.components.items() if isinstance(v, ConstantComponent)}
return {
k: v for k, v in self.components.items() if isinstance(v, ConstantComponent)
}

@property
def distributional_components(self):
return {k: v for k, v in self.components.items() if isinstance(v, DistributionalComponent)}
return {
k: v
for k, v in self.components.items()
if isinstance(v, DistributionalComponent)
}


def _posterior_samples_to_idata(samples, model):
Expand Down Expand Up @@ -486,21 +516,30 @@ def create_posterior_bayeux(posterior, pm_model):
data_vars_dims = {}
for data_var_name in data_vars_names:
if data_var_name in vars_to_dims:
data_vars_dims[data_var_name] = ["chain", "draw"] + list(vars_to_dims[data_var_name])
data_vars_dims[data_var_name] = ["chain", "draw"] + list(
vars_to_dims[data_var_name]
)
else:
data_vars_dims[data_var_name] = ["chain", "draw"]

# Create dictionary with data var dims and values (as required by xr.Dataset)
# https://docs.xarray.dev/en/stable/generated/xarray.Dataset.html
data_vars_values = {}
for data_var_name, data_var_dims in data_vars_dims.items():
data_vars_values[data_var_name] = (data_var_dims, posterior[data_var_name].to_numpy())
data_vars_values[data_var_name] = (
data_var_dims,
posterior[data_var_name].to_numpy(),
)

# Get coords
dims_in_use = set(dim for dims in data_vars_dims.values() for dim in dims)
coords_in_use = {coord_name: np.array(coords[coord_name]) for coord_name in dims_in_use}
coords_in_use = {
coord_name: np.array(coords[coord_name]) for coord_name in dims_in_use
}

return xr.Dataset(data_vars=data_vars_values, coords=coords_in_use, attrs=posterior.attrs)
return xr.Dataset(
data_vars=data_vars_values, coords=coords_in_use, attrs=posterior.attrs
)


def create_observed_data_bayeux(pm_model):
Expand Down
23 changes: 23 additions & 0 deletions env-dev.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
name: bambi-env
channels:
- conda-forge
- defaults
dependencies:
- python>=3.10,<3.13
- arviz>=0.12.0
- formulae>=0.5.3
- graphviz
- pandas>=1.0.0
- pymc>=5.16.1
# Dev dependencies
- black=24.3.0
- ipython>=5.8.0,!=8.7.0
- pre-commit>=2.19
- pylint=3.1.0
- pytest-cov>=2.6.1
- pytest>=4.4.0
- seaborn>=0.9.0
- pip
- pip:
- quartodoc==0.6.1
- bayeux-ml>=0.1.13 # Optional JAX dependency