Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass kwargs to nutpie + create env.yml file #855

Merged
merged 8 commits into from
Dec 21, 2024
43 changes: 36 additions & 7 deletions bambi/backend/pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,26 @@
import operator
import traceback
import warnings

from copy import deepcopy
from importlib.metadata import version

import numpy as np
import pymc as pm
import pytensor.tensor as pt
import xarray as xr

from pymc.backends.arviz import coords_and_dims_for_inferencedata, find_observations
from pymc.util import get_default_varnames
from pytensor.tensor.special import softmax

from bambi.backend.inference_methods import inference_methods
from bambi.backend.links import cloglog, identity, inverse_squared, logit, probit, arctan_2
from bambi.backend.links import (
arctan_2,
cloglog,
identity,
inverse_squared,
logit,
probit,
)
from bambi.backend.model_components import (
ConstantComponent,
DistributionalComponent,
Expand Down Expand Up @@ -246,6 +251,17 @@ def _run_mcmc(
import bayeux as bx # pylint: disable=import-outside-toplevel
import jax # pylint: disable=import-outside-toplevel

# pylint: disable=import-outside-toplevel
from pymc.sampling.parallel import (
_cpu_count,
)

# handle case where cores and chains are not provided
if cores is None:
cores = min(4, _cpu_count())
if chains is None:
chains = max(2, cores)

# Set the seed for reproducibility if provided
if random_seed is not None:
if not isinstance(random_seed, int):
Expand All @@ -255,10 +271,20 @@ def _run_mcmc(
jax_seed = jax.random.PRNGKey(np.random.randint(2**31 - 1))

bx_model = bx.Model.from_pymc(self.model)
bx_sampler = operator.attrgetter(sampler_backend)(
bx_model.mcmc # pylint: disable=no-member
# pylint: disable=no-member
bx_sampler = operator.attrgetter(sampler_backend)(bx_model.mcmc)

# We pass 'draws', 'tune', 'chains', and 'cores' because they can be used by some
# samplers. Since those are keyword arguments of `Model.fit()`, they would not
# be passed in the `kwargs` dict.
idata = bx_sampler(
AlexAndorra marked this conversation as resolved.
Show resolved Hide resolved
seed=jax_seed,
draws=draws,
tune=tune,
chains=chains,
cores=cores,
**kwargs,
)
idata = bx_sampler(seed=jax_seed, **kwargs)
idata_from = "bayeux"
else:
raise ValueError(
Expand Down Expand Up @@ -494,7 +520,10 @@ def create_posterior_bayeux(posterior, pm_model):
# https://docs.xarray.dev/en/stable/generated/xarray.Dataset.html
data_vars_values = {}
for data_var_name, data_var_dims in data_vars_dims.items():
data_vars_values[data_var_name] = (data_var_dims, posterior[data_var_name].to_numpy())
data_vars_values[data_var_name] = (
data_var_dims,
posterior[data_var_name].to_numpy(),
)

# Get coords
dims_in_use = set(dim for dims in data_vars_dims.values() for dim in dims)
Expand Down
27 changes: 27 additions & 0 deletions conda-envs/environment-dev.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: bambi-env
channels:
- conda-forge
- defaults
dependencies:
- python>=3.10,<3.13
- arviz>=0.12.0
- formulae>=0.5.3
- graphviz
- pandas>=1.0.0
- pymc>=5.16.1
# Dev dependencies
- black=24.3.0
- ipython>=5.8.0,!=8.7.0
- pre-commit>=2.19
- pylint=3.1.0
- pytest-cov>=2.6.1
- pytest>=4.4.0
- seaborn>=0.9.0
- pip
- watermark
- pip:
- quartodoc==0.6.1
- bayeux-ml==0.1.14 # Optional JAX dependency
AlexAndorra marked this conversation as resolved.
Show resolved Hide resolved
- blackjax==1.2.3
- jax==0.4.33
- jaxlib==0.4.33
Loading
Loading