diff --git a/bambi/backend/pymc.py b/bambi/backend/pymc.py index 7e879917..311a8b94 100644 --- a/bambi/backend/pymc.py +++ b/bambi/backend/pymc.py @@ -3,7 +3,6 @@ import operator import traceback import warnings - from copy import deepcopy from importlib.metadata import version @@ -11,13 +10,19 @@ import pymc as pm import pytensor.tensor as pt import xarray as xr - from pymc.backends.arviz import coords_and_dims_for_inferencedata, find_observations from pymc.util import get_default_varnames from pytensor.tensor.special import softmax from bambi.backend.inference_methods import inference_methods -from bambi.backend.links import cloglog, identity, inverse_squared, logit, probit, arctan_2 +from bambi.backend.links import ( + arctan_2, + cloglog, + identity, + inverse_squared, + logit, + probit, +) from bambi.backend.model_components import ( ConstantComponent, DistributionalComponent, @@ -246,6 +251,17 @@ def _run_mcmc( import bayeux as bx # pylint: disable=import-outside-toplevel import jax # pylint: disable=import-outside-toplevel + # pylint: disable=import-outside-toplevel + from pymc.sampling.parallel import ( + _cpu_count, + ) + + # handle case where cores and chains are not provided + if cores is None: + cores = min(4, _cpu_count()) + if chains is None: + chains = max(2, cores) + # Set the seed for reproducibility if provided if random_seed is not None: if not isinstance(random_seed, int): @@ -255,10 +271,20 @@ def _run_mcmc( jax_seed = jax.random.PRNGKey(np.random.randint(2**31 - 1)) bx_model = bx.Model.from_pymc(self.model) - bx_sampler = operator.attrgetter(sampler_backend)( - bx_model.mcmc # pylint: disable=no-member + # pylint: disable=no-member + bx_sampler = operator.attrgetter(sampler_backend)(bx_model.mcmc) + + # We pass 'draws', 'tune', 'chains', and 'cores' because they can be used by some + # samplers. Since those are keyword arguments of `Model.fit()`, they would not + # be passed in the `kwargs` dict. + idata = bx_sampler( + seed=jax_seed, + draws=draws, + tune=tune, + chains=chains, + cores=cores, + **kwargs, ) - idata = bx_sampler(seed=jax_seed, **kwargs) idata_from = "bayeux" else: raise ValueError( @@ -494,7 +520,10 @@ def create_posterior_bayeux(posterior, pm_model): # https://docs.xarray.dev/en/stable/generated/xarray.Dataset.html data_vars_values = {} for data_var_name, data_var_dims in data_vars_dims.items(): - data_vars_values[data_var_name] = (data_var_dims, posterior[data_var_name].to_numpy()) + data_vars_values[data_var_name] = ( + data_var_dims, + posterior[data_var_name].to_numpy(), + ) # Get coords dims_in_use = set(dim for dims in data_vars_dims.values() for dim in dims) diff --git a/conda-envs/environment-dev.yml b/conda-envs/environment-dev.yml new file mode 100644 index 00000000..b0aa377c --- /dev/null +++ b/conda-envs/environment-dev.yml @@ -0,0 +1,24 @@ +name: bambi-env +channels: + - conda-forge + - defaults +dependencies: + - python>=3.10,<3.13 + - arviz>=0.12.0 + - formulae>=0.5.3 + - graphviz + - pandas>=1.0.0 + - pymc>=5.16.1 + # Dev dependencies + - black=24.3.0 + - ipython>=5.8.0,!=8.7.0 + - pre-commit>=2.19 + - pylint=3.1.0 + - pytest-cov>=2.6.1 + - pytest>=4.4.0 + - seaborn>=0.9.0 + - pip + - watermark + - pip: + - quartodoc==0.6.1 + - bayeux-ml==0.1.15 # Optional JAX dependency diff --git a/docs/notebooks/alternative_samplers.ipynb b/docs/notebooks/alternative_samplers.ipynb index bfe314d8..88728c41 100644 --- a/docs/notebooks/alternative_samplers.ipynb +++ b/docs/notebooks/alternative_samplers.ipynb @@ -21,7 +21,6 @@ "source": [ "import arviz as az\n", "import bambi as bmb\n", - "import bayeux as bx\n", "import numpy as np\n", "import pandas as pd" ] @@ -30,11 +29,11 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## bayeux\n", + "## Bayeux\n", "\n", - "Bambi leverages `bayeux` to access different sampling backends. In short, `bayeux` lets you write a probabilistic model in JAX and immediately have access to state-of-the-art inference methods. \n", + "Bambi leverages [`bayeux`](https://jax-ml.github.io/bayeux/) to access different sampling backends. In short, `bayeux` lets you write a probabilistic model in JAX and immediately have access to state-of-the-art inference methods.\n", "\n", - "Since the underlying Bambi model is a PyMC model, this PyMC model can be \"given\" to `bayeux`. Then, we can choose from a variety of MCMC methods to perform inference. \n", + "Since the underlying Bambi model is a PyMC model, this PyMC model can be \"given\" to `bayeux`. Then, we can choose from a variety of MCMC methods to perform inference.\n", "\n", "To demonstrate the available backends, we will fist simulate data and build a model." ] @@ -50,11 +49,11 @@ "noise_std = 1.0\n", "random_seed = 42\n", "\n", - "np.random.seed(random_seed)\n", + "rng = np.random.default_rng(random_seed)\n", "\n", - "coefficients = np.random.randn(num_features)\n", - "X = np.random.randn(num_samples, num_features)\n", - "error = np.random.normal(scale=noise_std, size=num_samples)\n", + "coefficients = rng.normal(size=num_features)\n", + "X = rng.normal(size=(num_samples, num_features))\n", + "error = rng.normal(scale=noise_std, size=num_samples)\n", "y = X @ coefficients + error\n", "\n", "data = pd.DataFrame({\"y\": y, \"x\": X.flatten()})" @@ -100,7 +99,8 @@ " 'flowmc_realnvp_hmc',\n", " 'flowmc_realnvp_mala',\n", " 'numpyro_hmc',\n", - " 'numpyro_nuts']}}" + " 'numpyro_nuts',\n", + " 'nutpie']}}" ] }, "execution_count": 4, @@ -169,7 +169,8 @@ " 'flowmc_realnvp_hmc',\n", " 'flowmc_realnvp_mala',\n", " 'numpyro_hmc',\n", - " 'numpyro_nuts']}" + " 'numpyro_nuts',\n", + " 'nutpie']}" ] }, "execution_count": 6, @@ -214,6 +215,13 @@ "execution_count": 7, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:2024-12-21 13:43:24,702:jax._src.xla_bridge:969: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n" + ] + }, { "data": { "text/html": [ @@ -225,8 +233,8 @@ "
\n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
  • created_at :
    2024-12-21T16:43:32.791248+00:00
    arviz_version :
    0.19.0
    modeling_interface :
    bambi
    modeling_interface_version :
    0.14.1.dev17+g25798ce7

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
  • created_at :
    2024-12-21T16:43:32.789194+00:00
    arviz_version :
    0.19.0
    modeling_interface :
    bambi
    modeling_interface_version :
    0.14.1.dev17+g25798ce7

  • \n", " \n", " \n", " \n", @@ -1902,14 +1928,15 @@ { "data": { "text/plain": [ - "{ blackjax.base.AdaptationAlgorithm>: {'logdensity_fn': .wrap_log_density..wrapped(args)>,\n", + "{, integrator=.euclidean_integrator at 0x7f164c15c680>, **extra_parameters) -> blackjax.base.AdaptationAlgorithm>: {'logdensity_fn': .wrap_log_density..wrapped(args)>,\n", " 'is_mass_matrix_diagonal': True,\n", " 'initial_step_size': 1.0,\n", " 'target_acceptance_rate': 0.8,\n", " 'progress_bar': False,\n", - " 'algorithm': GenerateSamplingAPI(differentiable=, init=, build_kernel=)},\n", + " 'adaptation_info_fn': ,\n", + " 'algorithm': GenerateSamplingAPI(differentiable=, init=, build_kernel=)},\n", " 'adapt.run': {'num_steps': 500},\n", - " .euclidean_integrator at 0x77fd1f323c40>) -> blackjax.base.SamplingAlgorithm>: {'max_num_doublings': 10,\n", + " .euclidean_integrator at 0x7f164c15c680>) -> blackjax.base.SamplingAlgorithm>: {'max_num_doublings': 10,\n", " 'divergence_threshold': 1000,\n", " 'integrator': .euclidean_integrator(logdensity_fn: Callable, kinetic_energy_fn: blackjax.mcmc.metrics.KineticEnergy) -> Callable[[blackjax.mcmc.integrators.IntegratorState, float], blackjax.mcmc.integrators.IntegratorState]>,\n", " 'logdensity_fn': .wrap_log_density..wrapped(args)>,\n", @@ -1953,8 +1980,8 @@ "
      \n", " \n", "
    • \n", - " \n", - " \n", + " \n", + " \n", "
      \n", "
      \n", "
        \n", @@ -1989,6 +2016,7 @@ "}\n", "\n", "html[theme=dark],\n", + "html[data-theme=dark],\n", "body[data-theme=dark],\n", "body.vscode-dark {\n", " --xr-font-color0: rgba(255, 255, 255, 1);\n", @@ -2039,7 +2067,7 @@ ".xr-sections {\n", " padding-left: 0 !important;\n", " display: grid;\n", - " grid-template-columns: 150px auto auto 1fr 20px 20px;\n", + " grid-template-columns: 150px auto auto 1fr 0 20px 0 20px;\n", "}\n", "\n", ".xr-section-item {\n", @@ -2047,7 +2075,8 @@ "}\n", "\n", ".xr-section-item input {\n", - " display: none;\n", + " display: inline-block;\n", + " opacity: 0;\n", "}\n", "\n", ".xr-section-item input + label {\n", @@ -2059,6 +2088,10 @@ " color: var(--xr-font-color2);\n", "}\n", "\n", + ".xr-section-item input:focus + label {\n", + " border: 2px solid var(--xr-font-color0);\n", + "}\n", + "\n", ".xr-section-item input:enabled + label:hover {\n", " color: var(--xr-font-color0);\n", "}\n", @@ -2324,147 +2357,147 @@ "
        <xarray.Dataset> Size: 26kB\n",
                "Dimensions:    (chain: 4, draw: 250)\n",
                "Coordinates:\n",
        -       "  * chain      (chain) int64 32B 0 1 2 3\n",
                "  * draw       (draw) int64 2kB 0 1 2 3 4 5 6 7 ... 243 244 245 246 247 248 249\n",
        +       "  * chain      (chain) int64 32B 0 1 2 3\n",
                "Data variables:\n",
        -       "    Intercept  (chain, draw) float64 8kB 0.1186 0.1811 0.1516 ... 0.104 -0.01889\n",
        -       "    sigma      (chain, draw) float64 8kB 0.9543 0.976 0.9225 ... 0.8462 0.9206\n",
        -       "    x          (chain, draw) float64 8kB 0.1962 0.2625 0.2581 ... 0.3441 0.3412\n",
        +       "    Intercept  (chain, draw) float64 8kB -0.1701 0.1002 ... 0.09008 -0.07872\n",
        +       "    sigma      (chain, draw) float64 8kB 1.024 0.9962 0.9826 ... 0.9153 1.042\n",
        +       "    x          (chain, draw) float64 8kB 0.468 0.5335 0.4088 ... 0.5823 0.2556\n",
                "Attributes:\n",
        -       "    created_at:                  2024-06-02T15:41:41.635714+00:00\n",
        -       "    arviz_version:               0.18.0\n",
        +       "    created_at:                  2024-12-21T16:43:38.392870+00:00\n",
        +       "    arviz_version:               0.19.0\n",
                "    modeling_interface:          bambi\n",
        -       "    modeling_interface_version:  0.13.1.dev44+g55aac858.d20240602
      • chain
        PandasIndex
        PandasIndex(Index([0, 1, 2, 3], dtype='int64', name='chain'))
    • created_at :
      2024-12-21T16:43:38.392870+00:00
      arviz_version :
      0.19.0
      modeling_interface :
      bambi
      modeling_interface_version :
      0.14.1.dev17+g25798ce7

    \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -2499,6 +2532,7 @@ "}\n", "\n", "html[theme=dark],\n", + "html[data-theme=dark],\n", "body[data-theme=dark],\n", "body.vscode-dark {\n", " --xr-font-color0: rgba(255, 255, 255, 1);\n", @@ -2549,7 +2583,7 @@ ".xr-sections {\n", " padding-left: 0 !important;\n", " display: grid;\n", - " grid-template-columns: 150px auto auto 1fr 20px 20px;\n", + " grid-template-columns: 150px auto auto 1fr 0 20px 0 20px;\n", "}\n", "\n", ".xr-section-item {\n", @@ -2557,7 +2591,8 @@ "}\n", "\n", ".xr-section-item input {\n", - " display: none;\n", + " display: inline-block;\n", + " opacity: 0;\n", "}\n", "\n", ".xr-section-item input + label {\n", @@ -2569,6 +2604,10 @@ " color: var(--xr-font-color2);\n", "}\n", "\n", + ".xr-section-item input:focus + label {\n", + " border: 2px solid var(--xr-font-color0);\n", + "}\n", + "\n", ".xr-section-item input:enabled + label:hover {\n", " color: var(--xr-font-color0);\n", "}\n", @@ -2837,58 +2876,58 @@ " * chain (chain) int64 32B 0 1 2 3\n", " * draw (draw) int64 2kB 0 1 2 3 4 5 6 ... 244 245 246 247 248 249\n", "Data variables:\n", - " acceptance_rate (chain, draw) float64 8kB 0.9572 0.9862 1.0 ... 0.921 1.0\n", + " acceptance_rate (chain, draw) float64 8kB 0.972 0.993 0.9295 ... 0.94 1.0\n", " diverging (chain, draw) bool 1kB False False False ... False False\n", - " energy (chain, draw) float64 8kB 144.3 142.1 141.9 ... 140.8 140.6\n", - " lp (chain, draw) float64 8kB -141.3 -141.4 ... -140.7 -139.4\n", - " n_steps (chain, draw) int64 8kB 3 7 3 3 7 3 3 3 ... 3 3 7 3 7 7 1 7\n", - " step_size (chain, draw) float64 8kB 0.8903 0.8903 ... 0.7551 0.7551\n", - " tree_depth (chain, draw) int64 8kB 2 3 2 2 3 2 2 2 ... 2 2 3 2 3 3 1 3\n", + " energy (chain, draw) float64 8kB 145.9 145.9 145.6 ... 146.0 145.4\n", + " lp (chain, draw) float64 8kB -145.5 -144.5 ... -145.3 -145.2\n", + " n_steps (chain, draw) int64 8kB 7 3 3 3 3 3 7 7 ... 7 3 3 3 3 3 3 7\n", + " step_size (chain, draw) float64 8kB 0.8512 0.8512 ... 0.8232 0.8232\n", + " tree_depth (chain, draw) int64 8kB 3 2 2 2 2 2 3 3 ... 3 2 2 2 2 2 2 3\n", "Attributes:\n", - " created_at: 2024-06-02T15:41:41.639906+00:00\n", - " arviz_version: 0.18.0\n", + " created_at: 2024-12-21T16:43:38.394782+00:00\n", + " arviz_version: 0.19.0\n", " modeling_interface: bambi\n", - " modeling_interface_version: 0.13.1.dev44+g55aac858.d20240602
  • created_at :
    2024-12-21T16:43:38.394782+00:00
    arviz_version :
    0.19.0
    modeling_interface :
    bambi
    modeling_interface_version :
    0.14.1.dev17+g25798ce7

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -3173,6 +3212,7 @@ "}\n", "\n", "html[theme=dark],\n", + "html[data-theme=dark],\n", "body[data-theme=dark],\n", "body.vscode-dark {\n", " --xr-font-color0: rgba(255, 255, 255, 1);\n", @@ -3223,7 +3263,7 @@ ".xr-sections {\n", " padding-left: 0 !important;\n", " display: grid;\n", - " grid-template-columns: 150px auto auto 1fr 20px 20px;\n", + " grid-template-columns: 150px auto auto 1fr 0 20px 0 20px;\n", "}\n", "\n", ".xr-section-item {\n", @@ -3231,7 +3271,8 @@ "}\n", "\n", ".xr-section-item input {\n", - " display: none;\n", + " display: inline-block;\n", + " opacity: 0;\n", "}\n", "\n", ".xr-section-item input + label {\n", @@ -3243,6 +3284,10 @@ " color: var(--xr-font-color2);\n", "}\n", "\n", + ".xr-section-item input:focus + label {\n", + " border: 2px solid var(--xr-font-color0);\n", + "}\n", + "\n", ".xr-section-item input:enabled + label:hover {\n", " color: var(--xr-font-color0);\n", "}\n", @@ -3510,42 +3555,42 @@ "Coordinates:\n", " * __obs__ (__obs__) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99\n", "Data variables:\n", - " y (__obs__) float64 800B -0.4893 -0.021 -0.04577 ... -1.259 -0.3452\n", + " y (__obs__) float64 800B 0.9823 -0.1276 1.024 ... -0.4394 0.2223\n", "Attributes:\n", - " created_at: 2024-06-02T15:41:41.635714+00:00\n", - " arviz_version: 0.18.0\n", + " created_at: 2024-12-21T16:43:38.392870+00:00\n", + " arviz_version: 0.19.0\n", " modeling_interface: bambi\n", - " modeling_interface_version: 0.13.1.dev44+g55aac858.d20240602
  • created_at :
    2024-12-21T16:43:38.392870+00:00
    arviz_version :
    0.19.0
    modeling_interface :
    bambi
    modeling_interface_version :
    0.14.1.dev17+g25798ce7

  • \n", " \n", " \n", " \n", @@ -3911,7 +3956,7 @@ " \"adapt.run\": {\"num_steps\": 500},\n", " \"num_chains\": 4,\n", " \"num_draws\": 250,\n", - " \"num_adapt_draws\": 250\n", + " \"num_adapt_draws\": 250,\n", "}\n", "\n", "blackjax_nuts_idata = model.fit(inference_method=\"blackjax_nuts\", **kwargs)\n", @@ -3941,8 +3986,8 @@ "
      \n", " \n", "
    • \n", - " \n", - " \n", + " \n", + " \n", "
      \n", "
      \n", "
        \n", @@ -3977,6 +4022,7 @@ "}\n", "\n", "html[theme=dark],\n", + "html[data-theme=dark],\n", "body[data-theme=dark],\n", "body.vscode-dark {\n", " --xr-font-color0: rgba(255, 255, 255, 1);\n", @@ -4027,7 +4073,7 @@ ".xr-sections {\n", " padding-left: 0 !important;\n", " display: grid;\n", - " grid-template-columns: 150px auto auto 1fr 20px 20px;\n", + " grid-template-columns: 150px auto auto 1fr 0 20px 0 20px;\n", "}\n", "\n", ".xr-section-item {\n", @@ -4035,7 +4081,8 @@ "}\n", "\n", ".xr-section-item input {\n", - " display: none;\n", + " display: inline-block;\n", + " opacity: 0;\n", "}\n", "\n", ".xr-section-item input + label {\n", @@ -4047,6 +4094,10 @@ " color: var(--xr-font-color2);\n", "}\n", "\n", + ".xr-section-item input:focus + label {\n", + " border: 2px solid var(--xr-font-color0);\n", + "}\n", + "\n", ".xr-section-item input:enabled + label:hover {\n", " color: var(--xr-font-color0);\n", "}\n", @@ -4312,63 +4363,63 @@ "
        <xarray.Dataset> Size: 200kB\n",
                "Dimensions:    (chain: 8, draw: 1000)\n",
                "Coordinates:\n",
        -       "  * chain      (chain) int64 64B 0 1 2 3 4 5 6 7\n",
                "  * draw       (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999\n",
        +       "  * chain      (chain) int64 64B 0 1 2 3 4 5 6 7\n",
                "Data variables:\n",
        -       "    Intercept  (chain, draw) float64 64kB 0.2415 -0.0268 ... -0.07376 -0.05367\n",
        -       "    sigma      (chain, draw) float64 64kB 0.9948 0.9385 0.9726 ... 0.8749 1.129\n",
        -       "    x          (chain, draw) float64 64kB 0.3051 0.3062 0.1433 ... 0.2551 0.5439\n",
        +       "    Intercept  (chain, draw) float64 64kB -0.06265 -0.06601 ... 0.08766 0.08766\n",
        +       "    sigma      (chain, draw) float64 64kB 0.9457 0.9487 0.9521 ... 0.9434 0.9434\n",
        +       "    x          (chain, draw) float64 64kB 0.3832 0.3474 0.276 ... 0.395 0.395\n",
                "Attributes:\n",
        -       "    created_at:                  2024-06-02T15:41:52.350361+00:00\n",
        -       "    arviz_version:               0.18.0\n",
        +       "    created_at:                  2024-12-21T16:43:45.717159+00:00\n",
        +       "    arviz_version:               0.19.0\n",
                "    modeling_interface:          bambi\n",
        -       "    modeling_interface_version:  0.13.1.dev44+g55aac858.d20240602
      • chain
        PandasIndex
        PandasIndex(Index([0, 1, 2, 3, 4, 5, 6, 7], dtype='int64', name='chain'))
    • created_at :
      2024-12-21T16:43:45.717159+00:00
      arviz_version :
      0.19.0
      modeling_interface :
      bambi
      modeling_interface_version :
      0.14.1.dev17+g25798ce7

    \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -4403,6 +4454,7 @@ "}\n", "\n", "html[theme=dark],\n", + "html[data-theme=dark],\n", "body[data-theme=dark],\n", "body.vscode-dark {\n", " --xr-font-color0: rgba(255, 255, 255, 1);\n", @@ -4453,7 +4505,7 @@ ".xr-sections {\n", " padding-left: 0 !important;\n", " display: grid;\n", - " grid-template-columns: 150px auto auto 1fr 20px 20px;\n", + " grid-template-columns: 150px auto auto 1fr 0 20px 0 20px;\n", "}\n", "\n", ".xr-section-item {\n", @@ -4461,7 +4513,8 @@ "}\n", "\n", ".xr-section-item input {\n", - " display: none;\n", + " display: inline-block;\n", + " opacity: 0;\n", "}\n", "\n", ".xr-section-item input + label {\n", @@ -4473,6 +4526,10 @@ " color: var(--xr-font-color2);\n", "}\n", "\n", + ".xr-section-item input:focus + label {\n", + " border: 2px solid var(--xr-font-color0);\n", + "}\n", + "\n", ".xr-section-item input:enabled + label:hover {\n", " color: var(--xr-font-color0);\n", "}\n", @@ -4741,49 +4798,49 @@ " * chain (chain) int64 64B 0 1 2 3 4 5 6 7\n", " * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999\n", "Data variables:\n", - " accept_ratio (chain, draw) float64 64kB 0.9302 1.0 ... 0.9067 0.7528\n", + " accept_ratio (chain, draw) float64 64kB 0.9721 0.9725 ... 0.9694 0.8617\n", " diverging (chain, draw) bool 8kB False False False ... False False\n", - " is_accepted (chain, draw) bool 8kB True True True ... True True True\n", - " n_steps (chain, draw) int32 32kB 7 3 3 3 3 7 7 3 ... 7 7 7 3 3 7 7\n", - " step_size (chain, draw) float64 64kB 0.545 0.545 0.545 ... nan nan\n", - " target_log_prob (chain, draw) float64 64kB -142.4 -139.5 ... -140.7 -144.1\n", + " is_accepted (chain, draw) bool 8kB True True True ... True True False\n", + " n_steps (chain, draw) int32 32kB 7 3 7 3 7 7 7 7 ... 7 3 3 3 3 3 7\n", + " step_size (chain, draw) float64 64kB 0.563 0.563 0.563 ... nan nan\n", + " target_log_prob (chain, draw) float64 64kB -144.0 -144.2 ... -144.2 -144.2\n", " tune (chain, draw) float64 64kB 0.0 0.0 0.0 0.0 ... nan nan nan\n", "Attributes:\n", - " created_at: 2024-06-02T15:41:52.353603+00:00\n", - " arviz_version: 0.18.0\n", + " created_at: 2024-12-21T16:43:45.718997+00:00\n", + " arviz_version: 0.19.0\n", " modeling_interface: bambi\n", - " modeling_interface_version: 0.13.1.dev44+g55aac858.d20240602
  • created_at :
    2024-12-21T16:43:45.718997+00:00
    arviz_version :
    0.19.0
    modeling_interface :
    bambi
    modeling_interface_version :
    0.14.1.dev17+g25798ce7

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -4857,6 +4914,7 @@ "}\n", "\n", "html[theme=dark],\n", + "html[data-theme=dark],\n", "body[data-theme=dark],\n", "body.vscode-dark {\n", " --xr-font-color0: rgba(255, 255, 255, 1);\n", @@ -4907,7 +4965,7 @@ ".xr-sections {\n", " padding-left: 0 !important;\n", " display: grid;\n", - " grid-template-columns: 150px auto auto 1fr 20px 20px;\n", + " grid-template-columns: 150px auto auto 1fr 0 20px 0 20px;\n", "}\n", "\n", ".xr-section-item {\n", @@ -4915,7 +4973,8 @@ "}\n", "\n", ".xr-section-item input {\n", - " display: none;\n", + " display: inline-block;\n", + " opacity: 0;\n", "}\n", "\n", ".xr-section-item input + label {\n", @@ -4927,6 +4986,10 @@ " color: var(--xr-font-color2);\n", "}\n", "\n", + ".xr-section-item input:focus + label {\n", + " border: 2px solid var(--xr-font-color0);\n", + "}\n", + "\n", ".xr-section-item input:enabled + label:hover {\n", " color: var(--xr-font-color0);\n", "}\n", @@ -5194,42 +5257,42 @@ "Coordinates:\n", " * __obs__ (__obs__) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99\n", "Data variables:\n", - " y (__obs__) float64 800B -0.4893 -0.021 -0.04577 ... -1.259 -0.3452\n", + " y (__obs__) float64 800B 0.9823 -0.1276 1.024 ... -0.4394 0.2223\n", "Attributes:\n", - " created_at: 2024-06-02T15:41:52.350361+00:00\n", - " arviz_version: 0.18.0\n", + " created_at: 2024-12-21T16:43:45.717159+00:00\n", + " arviz_version: 0.19.0\n", " modeling_interface: bambi\n", - " modeling_interface_version: 0.13.1.dev44+g55aac858.d20240602
  • created_at :
    2024-12-21T16:43:45.717159+00:00
    arviz_version :
    0.19.0
    modeling_interface :
    bambi
    modeling_interface_version :
    0.14.1.dev17+g25798ce7

  • \n", " \n", " \n", " \n", @@ -5611,7 +5674,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "sample: 100%|██████████| 1500/1500 [00:06<00:00, 242.04it/s]\n" + "sample: 100%|██████████| 1500/1500 [00:03<00:00, 386.97it/s]\n" ] }, { @@ -5625,8 +5688,8 @@ "
      \n", " \n", "
    • \n", - " \n", - " \n", + " \n", + " \n", "
      \n", "
      \n", "
        \n", @@ -5661,6 +5724,7 @@ "}\n", "\n", "html[theme=dark],\n", + "html[data-theme=dark],\n", "body[data-theme=dark],\n", "body.vscode-dark {\n", " --xr-font-color0: rgba(255, 255, 255, 1);\n", @@ -5711,7 +5775,7 @@ ".xr-sections {\n", " padding-left: 0 !important;\n", " display: grid;\n", - " grid-template-columns: 150px auto auto 1fr 20px 20px;\n", + " grid-template-columns: 150px auto auto 1fr 0 20px 0 20px;\n", "}\n", "\n", ".xr-section-item {\n", @@ -5719,7 +5783,8 @@ "}\n", "\n", ".xr-section-item input {\n", - " display: none;\n", + " display: inline-block;\n", + " opacity: 0;\n", "}\n", "\n", ".xr-section-item input + label {\n", @@ -5731,6 +5796,10 @@ " color: var(--xr-font-color2);\n", "}\n", "\n", + ".xr-section-item input:focus + label {\n", + " border: 2px solid var(--xr-font-color0);\n", + "}\n", + "\n", ".xr-section-item input:enabled + label:hover {\n", " color: var(--xr-font-color0);\n", "}\n", @@ -5996,65 +6065,65 @@ "
        <xarray.Dataset> Size: 200kB\n",
                "Dimensions:    (chain: 8, draw: 1000)\n",
                "Coordinates:\n",
        -       "  * chain      (chain) int64 64B 0 1 2 3 4 5 6 7\n",
                "  * draw       (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999\n",
        +       "  * chain      (chain) int64 64B 0 1 2 3 4 5 6 7\n",
                "Data variables:\n",
        -       "    Intercept  (chain, draw) float64 64kB -0.01687 0.06615 ... 0.1263 0.03044\n",
        -       "    sigma      (chain, draw) float64 64kB 0.965 0.8374 1.078 ... 1.002 0.8794\n",
        -       "    x          (chain, draw) float64 64kB 0.2405 0.4685 0.2349 ... 0.3402 0.3522\n",
        +       "    Intercept  (chain, draw) float64 64kB 0.04368 -0.1021 ... -0.00282 0.1476\n",
        +       "    sigma      (chain, draw) float64 64kB 0.9309 0.9906 0.9233 ... 0.9424 0.9128\n",
        +       "    x          (chain, draw) float64 64kB 0.6003 0.3584 0.5494 ... 0.3202 0.2671\n",
                "Attributes:\n",
        -       "    created_at:                  2024-06-02T15:42:01.224796+00:00\n",
        -       "    arviz_version:               0.18.0\n",
        +       "    created_at:                  2024-12-21T16:43:50.477087+00:00\n",
        +       "    arviz_version:               0.19.0\n",
                "    inference_library:           numpyro\n",
        -       "    inference_library_version:   0.15.0\n",
        +       "    inference_library_version:   0.15.3\n",
                "    modeling_interface:          bambi\n",
        -       "    modeling_interface_version:  0.13.1.dev44+g55aac858.d20240602
      • chain
        PandasIndex
        PandasIndex(Index([0, 1, 2, 3, 4, 5, 6, 7], dtype='int64', name='chain'))
    • created_at :
      2024-12-21T16:43:50.477087+00:00
      arviz_version :
      0.19.0
      inference_library :
      numpyro
      inference_library_version :
      0.15.3
      modeling_interface :
      bambi
      modeling_interface_version :
      0.14.1.dev17+g25798ce7

    \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -6089,6 +6158,7 @@ "}\n", "\n", "html[theme=dark],\n", + "html[data-theme=dark],\n", "body[data-theme=dark],\n", "body.vscode-dark {\n", " --xr-font-color0: rgba(255, 255, 255, 1);\n", @@ -6139,7 +6209,7 @@ ".xr-sections {\n", " padding-left: 0 !important;\n", " display: grid;\n", - " grid-template-columns: 150px auto auto 1fr 20px 20px;\n", + " grid-template-columns: 150px auto auto 1fr 0 20px 0 20px;\n", "}\n", "\n", ".xr-section-item {\n", @@ -6147,7 +6217,8 @@ "}\n", "\n", ".xr-section-item input {\n", - " display: none;\n", + " display: inline-block;\n", + " opacity: 0;\n", "}\n", "\n", ".xr-section-item input + label {\n", @@ -6159,6 +6230,10 @@ " color: var(--xr-font-color2);\n", "}\n", "\n", + ".xr-section-item input:focus + label {\n", + " border: 2px solid var(--xr-font-color0);\n", + "}\n", + "\n", ".xr-section-item input:enabled + label:hover {\n", " color: var(--xr-font-color0);\n", "}\n", @@ -6427,96 +6502,96 @@ " * chain (chain) int64 64B 0 1 2 3 4 5 6 7\n", " * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999\n", "Data variables:\n", - " acceptance_rate (chain, draw) float64 64kB 0.9735 0.7476 ... 0.9667 0.8725\n", + " acceptance_rate (chain, draw) float64 64kB 0.9297 0.9775 ... 0.9538 0.7392\n", " diverging (chain, draw) bool 8kB False False False ... False False\n", - " energy (chain, draw) float64 64kB 140.5 145.1 ... 140.7 141.8\n", - " lp (chain, draw) float64 64kB 140.1 141.2 ... 140.4 139.6\n", - " n_steps (chain, draw) int64 64kB 7 7 7 3 3 1 3 3 ... 11 3 3 3 3 3 3\n", - " step_size (chain, draw) float64 64kB 0.7685 0.7685 ... 0.8865 0.8865\n", - " tree_depth (chain, draw) int64 64kB 3 3 3 2 2 1 2 2 ... 4 2 2 2 2 2 2\n", + " energy (chain, draw) float64 64kB 145.1 146.0 ... 147.0 147.1\n", + " lp (chain, draw) float64 64kB 145.0 144.4 ... 144.1 146.4\n", + " n_steps (chain, draw) int64 64kB 7 7 7 7 3 7 7 7 ... 3 3 3 7 7 3 3\n", + " step_size (chain, draw) float64 64kB 0.7792 0.7792 ... 0.703 0.703\n", + " tree_depth (chain, draw) int64 64kB 3 3 3 3 2 3 3 3 ... 2 2 2 3 3 2 2\n", "Attributes:\n", - " created_at: 2024-06-02T15:42:01.260288+00:00\n", - " arviz_version: 0.18.0\n", + " created_at: 2024-12-21T16:43:50.504626+00:00\n", + " arviz_version: 0.19.0\n", " inference_library: numpyro\n", - " inference_library_version: 0.15.0\n", + " inference_library_version: 0.15.3\n", " modeling_interface: bambi\n", - " modeling_interface_version: 0.13.1.dev44+g55aac858.d20240602
  • created_at :
    2024-12-21T16:43:50.504626+00:00
    arviz_version :
    0.19.0
    inference_library :
    numpyro
    inference_library_version :
    0.15.3
    modeling_interface :
    bambi
    modeling_interface_version :
    0.14.1.dev17+g25798ce7

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -6551,6 +6626,7 @@ "}\n", "\n", "html[theme=dark],\n", + "html[data-theme=dark],\n", "body[data-theme=dark],\n", "body.vscode-dark {\n", " --xr-font-color0: rgba(255, 255, 255, 1);\n", @@ -6601,7 +6677,7 @@ ".xr-sections {\n", " padding-left: 0 !important;\n", " display: grid;\n", - " grid-template-columns: 150px auto auto 1fr 20px 20px;\n", + " grid-template-columns: 150px auto auto 1fr 0 20px 0 20px;\n", "}\n", "\n", ".xr-section-item {\n", @@ -6609,7 +6685,8 @@ "}\n", "\n", ".xr-section-item input {\n", - " display: none;\n", + " display: inline-block;\n", + " opacity: 0;\n", "}\n", "\n", ".xr-section-item input + label {\n", @@ -6621,6 +6698,10 @@ " color: var(--xr-font-color2);\n", "}\n", "\n", + ".xr-section-item input:focus + label {\n", + " border: 2px solid var(--xr-font-color0);\n", + "}\n", + "\n", ".xr-section-item input:enabled + label:hover {\n", " color: var(--xr-font-color0);\n", "}\n", @@ -6888,44 +6969,44 @@ "Coordinates:\n", " * __obs__ (__obs__) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99\n", "Data variables:\n", - " y (__obs__) float64 800B -0.4893 -0.021 -0.04577 ... -1.259 -0.3452\n", + " y (__obs__) float64 800B 0.9823 -0.1276 1.024 ... -0.4394 0.2223\n", "Attributes:\n", - " created_at: 2024-06-02T15:42:01.224796+00:00\n", - " arviz_version: 0.18.0\n", + " created_at: 2024-12-21T16:43:50.477087+00:00\n", + " arviz_version: 0.19.0\n", " inference_library: numpyro\n", - " inference_library_version: 0.15.0\n", + " inference_library_version: 0.15.3\n", " modeling_interface: bambi\n", - " modeling_interface_version: 0.13.1.dev44+g55aac858.d20240602
  • created_at :
    2024-12-21T16:43:50.477087+00:00
    arviz_version :
    0.19.0
    inference_library :
    numpyro
    inference_library_version :
    0.15.3
    modeling_interface :
    bambi
    modeling_interface_version :
    0.14.1.dev17+g25798ce7

  • \n", " \n", " \n", " \n", @@ -7314,8 +7395,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "Global Tuning: 100%|██████████| 5/5 [00:44<00:00, 8.89s/it]\n", - "Global Sampling: 100%|██████████| 5/5 [00:00<00:00, 25.89it/s]\n" + "Global Tuning: 100%|██████████| 5/5 [00:20<00:00, 4.05s/it]\n", + "Global Sampling: 100%|██████████| 5/5 [00:00<00:00, 26.22it/s]\n" ] }, { @@ -7329,8 +7410,8 @@ "
      \n", " \n", "
    • \n", - " \n", - " \n", + " \n", + " \n", "
      \n", "
      \n", "
        \n", @@ -7365,6 +7446,7 @@ "}\n", "\n", "html[theme=dark],\n", + "html[data-theme=dark],\n", "body[data-theme=dark],\n", "body.vscode-dark {\n", " --xr-font-color0: rgba(255, 255, 255, 1);\n", @@ -7415,7 +7497,7 @@ ".xr-sections {\n", " padding-left: 0 !important;\n", " display: grid;\n", - " grid-template-columns: 150px auto auto 1fr 20px 20px;\n", + " grid-template-columns: 150px auto auto 1fr 0 20px 0 20px;\n", "}\n", "\n", ".xr-section-item {\n", @@ -7423,7 +7505,8 @@ "}\n", "\n", ".xr-section-item input {\n", - " display: none;\n", + " display: inline-block;\n", + " opacity: 0;\n", "}\n", "\n", ".xr-section-item input + label {\n", @@ -7435,6 +7518,10 @@ " color: var(--xr-font-color2);\n", "}\n", "\n", + ".xr-section-item input:focus + label {\n", + " border: 2px solid var(--xr-font-color0);\n", + "}\n", + "\n", ".xr-section-item input:enabled + label:hover {\n", " color: var(--xr-font-color0);\n", "}\n", @@ -7700,64 +7787,64 @@ "
        <xarray.Dataset> Size: 244kB\n",
                "Dimensions:    (chain: 20, draw: 500)\n",
                "Coordinates:\n",
        -       "  * chain      (chain) int64 160B 0 1 2 3 4 5 6 7 8 ... 12 13 14 15 16 17 18 19\n",
                "  * draw       (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n",
        +       "  * chain      (chain) int64 160B 0 1 2 3 4 5 6 7 8 ... 12 13 14 15 16 17 18 19\n",
                "Data variables:\n",
        -       "    Intercept  (chain, draw) float64 80kB 0.07083 0.06709 ... 0.06182 -0.04028\n",
        -       "    sigma      (chain, draw) float64 80kB 0.9755 0.9504 0.9298 ... 0.8554 0.9118\n",
        -       "    x          (chain, draw) float64 80kB 0.382 0.3589 0.2673 ... 0.4581 0.3594\n",
        +       "    Intercept  (chain, draw) float64 80kB 0.2975 0.2975 ... 0.08134 0.03252\n",
        +       "    sigma      (chain, draw) float64 80kB 0.97 0.97 1.024 ... 0.9849 0.9851\n",
        +       "    x          (chain, draw) float64 80kB 0.5371 0.5371 0.5067 ... 0.4151 0.4007\n",
                "Attributes:\n",
        -       "    created_at:                  2024-06-02T15:42:49.303545+00:00\n",
        -       "    arviz_version:               0.18.0\n",
        +       "    created_at:                  2024-12-21T16:44:12.534363+00:00\n",
        +       "    arviz_version:               0.19.0\n",
                "    modeling_interface:          bambi\n",
        -       "    modeling_interface_version:  0.13.1.dev44+g55aac858.d20240602
      • chain
        PandasIndex
        PandasIndex(Index([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], dtype='int64', name='chain'))
    • created_at :
      2024-12-21T16:44:12.534363+00:00
      arviz_version :
      0.19.0
      modeling_interface :
      bambi
      modeling_interface_version :
      0.14.1.dev17+g25798ce7

    \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -7792,6 +7879,7 @@ "}\n", "\n", "html[theme=dark],\n", + "html[data-theme=dark],\n", "body[data-theme=dark],\n", "body.vscode-dark {\n", " --xr-font-color0: rgba(255, 255, 255, 1);\n", @@ -7842,7 +7930,7 @@ ".xr-sections {\n", " padding-left: 0 !important;\n", " display: grid;\n", - " grid-template-columns: 150px auto auto 1fr 20px 20px;\n", + " grid-template-columns: 150px auto auto 1fr 0 20px 0 20px;\n", "}\n", "\n", ".xr-section-item {\n", @@ -7850,7 +7938,8 @@ "}\n", "\n", ".xr-section-item input {\n", - " display: none;\n", + " display: inline-block;\n", + " opacity: 0;\n", "}\n", "\n", ".xr-section-item input + label {\n", @@ -7862,6 +7951,10 @@ " color: var(--xr-font-color2);\n", "}\n", "\n", + ".xr-section-item input:focus + label {\n", + " border: 2px solid var(--xr-font-color0);\n", + "}\n", + "\n", ".xr-section-item input:enabled + label:hover {\n", " color: var(--xr-font-color0);\n", "}\n", @@ -8129,42 +8222,42 @@ "Coordinates:\n", " * __obs__ (__obs__) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99\n", "Data variables:\n", - " y (__obs__) float64 800B -0.4893 -0.021 -0.04577 ... -1.259 -0.3452\n", + " y (__obs__) float64 800B 0.9823 -0.1276 1.024 ... -0.4394 0.2223\n", "Attributes:\n", - " created_at: 2024-06-02T15:42:49.303545+00:00\n", - " arviz_version: 0.18.0\n", + " created_at: 2024-12-21T16:44:12.534363+00:00\n", + " arviz_version: 0.19.0\n", " modeling_interface: bambi\n", - " modeling_interface_version: 0.13.1.dev44+g55aac858.d20240602
  • created_at :
    2024-12-21T16:44:12.534363+00:00
    arviz_version :
    0.19.0
    modeling_interface :
    bambi
    modeling_interface_version :
    0.14.1.dev17+g25798ce7

  • \n", " \n", " \n", " \n", @@ -8533,9 +8626,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Sampler comparisons\n", - "\n", - "With ArviZ, we can compare the inference result summaries of the samplers. _Note:_ We can't use `az.compare` as not each inference data object returns the pointwise log-probabilities. Thus, an error would be raised." + "### nutpie" ] }, { @@ -8545,87 +8636,30 @@ "outputs": [ { "data": { - "text/html": [ - "
    \n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
    meansdhdi_3%hdi_97%mcse_meanmcse_sdess_bulkess_tailr_hat
    Intercept0.0260.098-0.1520.2060.0030.003796.0648.01.01
    sigma0.9450.0700.8171.0740.0020.002970.0759.01.00
    x0.3550.1030.1570.5320.0030.0021067.0692.01.00
    \n", - "
    " - ], "text/plain": [ - " mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk \\\n", - "Intercept 0.026 0.098 -0.152 0.206 0.003 0.003 796.0 \n", - "sigma 0.945 0.070 0.817 1.074 0.002 0.002 970.0 \n", - "x 0.355 0.103 0.157 0.532 0.003 0.002 1067.0 \n", - "\n", - " ess_tail r_hat \n", - "Intercept 648.0 1.01 \n", - "sigma 759.0 1.00 \n", - "x 692.0 1.00 " + "{: {'ndim': 1,\n", + " 'make_logp_fn': .make_logp_fn()>,\n", + " 'make_expand_fn': .make_expand_fn(*args, **kwargs)>,\n", + " 'expanded_shapes': [(1,)],\n", + " 'expanded_names': ['x'],\n", + " 'expanded_dtypes': [numpy.float64]},\n", + " arviz.data.inference_data.InferenceData>: {'draws': 1000,\n", + " 'tune': 300,\n", + " 'chains': 8,\n", + " 'cores': 8,\n", + " 'seed': None,\n", + " 'save_warmup': True,\n", + " 'progress_bar': True,\n", + " 'low_rank_modified_mass_matrix': False,\n", + " 'init_mean': None,\n", + " 'return_raw_trace': False,\n", + " 'blocking': True,\n", + " 'progress_template': None,\n", + " 'progress_style': None,\n", + " 'progress_rate': 100},\n", + " 'extra_parameters': {'flatten': .flatten(pytree)>,\n", + " 'unflatten': ,\n", + " 'return_pytree': False}}" ] }, "execution_count": 13, @@ -8634,7 +8668,7 @@ } ], "source": [ - "az.summary(blackjax_nuts_idata)" + "bmb.inference_methods.get_kwargs(\"nutpie\")" ] }, { @@ -8645,119 +8679,2832 @@ { "data": { "text/html": [ - "
    \n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
    meansdhdi_3%hdi_97%mcse_meanmcse_sdess_bulkess_tailr_hat
    Intercept0.0240.096-0.1570.2040.0010.0017048.05524.01.0
    sigma0.9480.0680.8291.0830.0010.0017933.05659.01.0
    x0.3610.1030.1680.5500.0010.0016986.05702.01.0
    \n", - "
    " + " progress::-webkit-progress-bar {\n", + " background-color: #eee;\n", + " border-radius: 5px;\n", + " }\n", + " progress::-webkit-progress-value {\n", + " background-color: #5cb85c;\n", + " border-radius: 5px;\n", + " }\n", + " progress::-moz-progress-bar {\n", + " background-color: #5cb85c;\n", + " border-radius: 5px;\n", + " }\n", + " .nutpie .progress-cell {\n", + " width: 100%;\n", + " }\n", + "\n", + " .nutpie p strong { font-size: 16px; font-weight: bold; }\n", + "\n", + " @media (prefers-color-scheme: dark) {\n", + " .nutpie {\n", + " //color: #ddd;\n", + " //background-color: #1e1e1e;\n", + " box-shadow: 0 4px 6px rgba(0,0,0,0.2);\n", + " }\n", + " .nutpie table, .nutpie th, .nutpie td {\n", + " border-color: #555;\n", + " color: #ccc;\n", + " }\n", + " .nutpie th {\n", + " background-color: #2a2a2a;\n", + " }\n", + " .nutpie progress::-webkit-progress-bar {\n", + " background-color: #444;\n", + " }\n", + " .nutpie progress::-webkit-progress-value {\n", + " background-color: #3178c6;\n", + " }\n", + " .nutpie progress::-moz-progress-bar {\n", + " background-color: #3178c6;\n", + " }\n", + " }\n", + "\n" ], "text/plain": [ - " mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk \\\n", - "Intercept 0.024 0.096 -0.157 0.204 0.001 0.001 7048.0 \n", - "sigma 0.948 0.068 0.829 1.083 0.001 0.001 7933.0 \n", - "x 0.361 0.103 0.168 0.550 0.001 0.001 6986.0 \n", - "\n", - " ess_tail r_hat \n", - "Intercept 5524.0 1.0 \n", - "sigma 5659.0 1.0 \n", - "x 5702.0 1.0 " + "" ] }, - "execution_count": 14, "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "az.summary(tfp_nuts_idata)" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [ + "output_type": "display_data" + }, { "data": { "text/html": [ - "
    \n", - "\n", + "
    \n", + "
    \n", + "
    arviz.InferenceData
    \n", + "
    \n", + "
      \n", + " \n", + "
    • \n", + " \n", + " \n", + "
      \n", + "
      \n", + "
        \n", + "
        \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
        <xarray.Dataset> Size: 40kB\n",
        +       "Dimensions:    (chain: 3, draw: 500)\n",
        +       "Coordinates:\n",
        +       "  * draw       (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n",
        +       "  * chain      (chain) int64 24B 0 1 2\n",
        +       "Data variables:\n",
        +       "    Intercept  (chain, draw) float64 12kB 0.08496 -0.02695 ... 0.005357 0.1237\n",
        +       "    sigma      (chain, draw) float64 12kB 1.116 0.89 0.8934 ... 0.9256 0.926\n",
        +       "    x          (chain, draw) float64 12kB 0.3081 0.4959 0.3477 ... 0.4546 0.638\n",
        +       "Attributes:\n",
        +       "    created_at:                  2024-12-21T16:44:15.471804+00:00\n",
        +       "    arviz_version:               0.19.0\n",
        +       "    modeling_interface:          bambi\n",
        +       "    modeling_interface_version:  0.14.1.dev17+g25798ce7

        \n", + "
      \n", + "
      \n", + "
    • \n", + " \n", + "
    • \n", + " \n", + " \n", + "
      \n", + "
      \n", + "
        \n", + "
        \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
        <xarray.Dataset> Size: 127kB\n",
        +       "Dimensions:               (chain: 3, draw: 500)\n",
        +       "Coordinates:\n",
        +       "  * chain                 (chain) int64 24B 0 1 2\n",
        +       "  * draw                  (draw) int64 4kB 0 1 2 3 4 5 ... 495 496 497 498 499\n",
        +       "Data variables:\n",
        +       "    depth                 (chain, draw) uint64 12kB 2 2 2 2 2 2 ... 2 2 2 2 2 2\n",
        +       "    diverging             (chain, draw) bool 2kB False False ... False False\n",
        +       "    energy                (chain, draw) float64 12kB 146.6 147.2 ... 144.6 146.6\n",
        +       "    energy_error          (chain, draw) float64 12kB 0.5871 -0.6172 ... 0.704\n",
        +       "    index_in_trajectory   (chain, draw) int64 12kB 2 3 1 -2 -1 ... -2 -1 3 1 -1\n",
        +       "    logp                  (chain, draw) float64 12kB -146.1 -144.8 ... -146.2\n",
        +       "    maxdepth_reached      (chain, draw) bool 2kB False False ... False False\n",
        +       "    mean_tree_accept      (chain, draw) float64 12kB 0.9476 0.5462 ... 1.0 1.0\n",
        +       "    mean_tree_accept_sym  (chain, draw) float64 12kB 0.8644 0.7061 ... 0.8824\n",
        +       "    n_steps               (chain, draw) uint64 12kB 3 3 3 3 3 3 ... 3 3 3 3 3 3\n",
        +       "    step_size             (chain, draw) float64 12kB 1.039 1.039 ... 0.9917\n",
        +       "    step_size_bar         (chain, draw) float64 12kB 1.039 1.039 ... 0.9917\n",
        +       "Attributes:\n",
        +       "    created_at:                  2024-12-21T16:44:15.348609+00:00\n",
        +       "    arviz_version:               0.19.0\n",
        +       "    modeling_interface:          bambi\n",
        +       "    modeling_interface_version:  0.14.1.dev17+g25798ce7

        \n", + "
      \n", + "
      \n", + "
    • \n", + " \n", + "
    • \n", + " \n", + " \n", + "
      \n", + "
      \n", + "
        \n", + "
        \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
        <xarray.Dataset> Size: 2kB\n",
        +       "Dimensions:  (__obs__: 100)\n",
        +       "Coordinates:\n",
        +       "  * __obs__  (__obs__) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99\n",
        +       "Data variables:\n",
        +       "    y        (__obs__) float64 800B 0.9823 -0.1276 1.024 ... -0.4394 0.2223\n",
        +       "Attributes:\n",
        +       "    created_at:                  2024-12-21T16:44:15.471804+00:00\n",
        +       "    arviz_version:               0.19.0\n",
        +       "    modeling_interface:          bambi\n",
        +       "    modeling_interface_version:  0.14.1.dev17+g25798ce7

        \n", + "
      \n", + "
      \n", + "
    • \n", + " \n", + "
    • \n", + " \n", + " \n", + "
      \n", + "
      \n", + "
        \n", + "
        \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
        <xarray.Dataset> Size: 32kB\n",
        +       "Dimensions:    (chain: 3, draw: 400)\n",
        +       "Coordinates:\n",
        +       "  * chain      (chain) int64 24B 0 1 2\n",
        +       "  * draw       (draw) int64 3kB 0 1 2 3 4 5 6 7 ... 393 394 395 396 397 398 399\n",
        +       "Data variables:\n",
        +       "    Intercept  (chain, draw) float64 10kB 0.4285 0.4285 ... 0.05143 0.1415\n",
        +       "    sigma      (chain, draw) float64 10kB 1.157 1.157 0.9778 ... 0.7789 0.8057\n",
        +       "    x          (chain, draw) float64 10kB -0.1518 -0.1518 ... 0.5574 0.378\n",
        +       "Attributes:\n",
        +       "    created_at:                  2024-12-21T16:44:15.473126+00:00\n",
        +       "    arviz_version:               0.19.0\n",
        +       "    modeling_interface:          bambi\n",
        +       "    modeling_interface_version:  0.14.1.dev17+g25798ce7

        \n", + "
      \n", + "
      \n", + "
    • \n", + " \n", + "
    • \n", + " \n", + " \n", + "
      \n", + "
      \n", + "
        \n", + "
        \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
        <xarray.Dataset> Size: 102kB\n",
        +       "Dimensions:               (chain: 3, draw: 400)\n",
        +       "Coordinates:\n",
        +       "  * chain                 (chain) int64 24B 0 1 2\n",
        +       "  * draw                  (draw) int64 3kB 0 1 2 3 4 5 ... 395 396 397 398 399\n",
        +       "Data variables:\n",
        +       "    depth                 (chain, draw) uint64 10kB 2 0 2 1 1 3 ... 2 2 2 2 3 2\n",
        +       "    diverging             (chain, draw) bool 1kB False True ... False False\n",
        +       "    energy                (chain, draw) float64 10kB 191.2 163.4 ... 151.0 153.1\n",
        +       "    energy_error          (chain, draw) float64 10kB -0.388 0.0 ... -0.1098\n",
        +       "    index_in_trajectory   (chain, draw) int64 10kB -3 0 -1 0 0 3 ... -1 -2 2 4 1\n",
        +       "    logp                  (chain, draw) float64 10kB -161.4 -161.4 ... -149.8\n",
        +       "    maxdepth_reached      (chain, draw) bool 1kB False False ... False False\n",
        +       "    mean_tree_accept      (chain, draw) float64 10kB 0.0 0.9011 ... 0.8973\n",
        +       "    mean_tree_accept_sym  (chain, draw) float64 10kB 0.0 0.8825 ... 0.7341\n",
        +       "    n_steps               (chain, draw) uint64 10kB 0 3 1 3 3 2 ... 3 3 3 3 3 7\n",
        +       "    step_size             (chain, draw) float64 10kB 0.4 4.807 ... 0.8206 0.7726\n",
        +       "    step_size_bar         (chain, draw) float64 10kB 0.4 4.807 ... 0.9982 0.9953\n",
        +       "Attributes:\n",
        +       "    created_at:                  2024-12-21T16:44:15.351287+00:00\n",
        +       "    arviz_version:               0.19.0\n",
        +       "    modeling_interface:          bambi\n",
        +       "    modeling_interface_version:  0.14.1.dev17+g25798ce7

        \n", + "
      \n", + "
      \n", + "
    • \n", + " \n", + "
    \n", + "
    \n", + " " + ], + "text/plain": [ + "Inference data with groups:\n", + "\t> posterior\n", + "\t> sample_stats\n", + "\t> observed_data\n", + "\n", + "Warmup iterations saved (warmup_*)." + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "nutpie_idata = model.fit(inference_method=\"nutpie\", tune=400, draws=500, chains=3)\n", + "nutpie_idata" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Sampler comparisons\n", + "\n", + "With ArviZ, we can compare the inference result summaries of the samplers. _Note:_ We can't use `az.compare` as not each inference data object returns the pointwise log-probabilities. Thus, an error would be raised." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
    \n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
    meansdhdi_3%hdi_97%mcse_meanmcse_sdess_bulkess_tailr_hat
    Intercept-0.0000.097-0.1800.1830.0030.003938.0752.01.0
    sigma0.9870.0730.8591.1260.0020.002913.0739.01.0
    x0.4230.1250.1510.6290.0040.0031044.0820.01.0
    \n", + "
    " + ], + "text/plain": [ + " mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk \\\n", + "Intercept -0.000 0.097 -0.180 0.183 0.003 0.003 938.0 \n", + "sigma 0.987 0.073 0.859 1.126 0.002 0.002 913.0 \n", + "x 0.423 0.125 0.151 0.629 0.004 0.003 1044.0 \n", + "\n", + " ess_tail r_hat \n", + "Intercept 752.0 1.0 \n", + "sigma 739.0 1.0 \n", + "x 820.0 1.0 " + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "az.summary(blackjax_nuts_idata)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
    \n", + "\n", "\n", " \n", " \n", @@ -8776,38 +11523,137 @@ " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
    Intercept0.0250.095-0.1620.1960.0020.099-0.1830.1900.0010.0017396.05859.06775.05598.01.0
    sigma0.9460.0680.8191.0750.9870.0710.8481.1140.0010.0017131.05580.08338.05715.01.0
    x0.3610.1060.1710.5690.4240.1270.1860.6610.0020.0016244.05267.01.0
    \n", + "
    " + ], + "text/plain": [ + " mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk \\\n", + "Intercept 0.002 0.099 -0.183 0.190 0.001 0.001 6775.0 \n", + "sigma 0.987 0.071 0.848 1.114 0.001 0.001 8338.0 \n", + "x 0.424 0.127 0.186 0.661 0.002 0.001 6244.0 \n", + "\n", + " ess_tail r_hat \n", + "Intercept 5598.0 1.0 \n", + "sigma 5715.0 1.0 \n", + "x 5267.0 1.0 " + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "az.summary(tfp_nuts_idata)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
    \n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", @@ -8816,17 +11662,17 @@ ], "text/plain": [ " mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk \\\n", - "Intercept 0.025 0.095 -0.162 0.196 0.001 0.001 7396.0 \n", - "sigma 0.946 0.068 0.819 1.075 0.001 0.001 7131.0 \n", - "x 0.361 0.106 0.171 0.569 0.001 0.001 7673.0 \n", + "Intercept 0.005 0.098 -0.180 0.188 0.001 0.001 9065.0 \n", + "sigma 0.988 0.074 0.856 1.127 0.001 0.001 7217.0 \n", + "x 0.423 0.130 0.179 0.661 0.002 0.001 7449.0 \n", "\n", " ess_tail r_hat \n", - "Intercept 5859.0 1.0 \n", - "sigma 5580.0 1.0 \n", - "x 5905.0 1.0 " + "Intercept 6523.0 1.0 \n", + "sigma 5477.0 1.0 \n", + "x 6203.0 1.0 " ] }, - "execution_count": 15, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } @@ -8837,7 +11683,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 18, "metadata": {}, "outputs": [ { @@ -8875,39 +11721,39 @@ " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", "
    meansdhdi_3%hdi_97%mcse_meanmcse_sdess_bulkess_tailr_hat
    Intercept0.0050.098-0.1800.1880.0010.0019065.06523.01.0
    sigma0.9880.0740.8561.1270.0010.0017217.05477.01.0
    x0.4230.1300.1790.6610.0020.0017673.05905.07449.06203.01.0
    Intercept0.0240.096-0.1490.2070.0030.0040.101-0.1840.1930.002876.0615.01.020.0012352.03365.01.01
    sigma0.9470.0670.8221.0660.9870.0700.8611.1230.0010.0015554.05920.01.004252.04034.01.01
    x0.3610.1040.1610.5500.4250.1290.1710.6560.0010.0015081.04653.01.007504.03764.01.01
    \n", @@ -8915,17 +11761,17 @@ ], "text/plain": [ " mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk \\\n", - "Intercept 0.024 0.096 -0.149 0.207 0.003 0.002 876.0 \n", - "sigma 0.947 0.067 0.822 1.066 0.001 0.001 5554.0 \n", - "x 0.361 0.104 0.161 0.550 0.001 0.001 5081.0 \n", + "Intercept 0.004 0.101 -0.184 0.193 0.002 0.001 2352.0 \n", + "sigma 0.987 0.070 0.861 1.123 0.001 0.001 4252.0 \n", + "x 0.425 0.129 0.171 0.656 0.001 0.001 7504.0 \n", "\n", " ess_tail r_hat \n", - "Intercept 615.0 1.02 \n", - "sigma 5920.0 1.00 \n", - "x 4653.0 1.00 " + "Intercept 3365.0 1.01 \n", + "sigma 4034.0 1.01 \n", + "x 3764.0 1.01 " ] }, - "execution_count": 16, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } @@ -8934,6 +11780,105 @@ "az.summary(flowmc_idata)" ] }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
    \n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
    meansdhdi_3%hdi_97%mcse_meanmcse_sdess_bulkess_tailr_hat
    Intercept0.0020.098-0.1790.1810.0020.0032288.01040.01.0
    sigma0.9890.0720.8571.1180.0020.0012199.01155.01.0
    x0.4230.1280.1760.6570.0030.0021956.01287.01.0
    \n", + "
    " + ], + "text/plain": [ + " mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk \\\n", + "Intercept 0.002 0.098 -0.179 0.181 0.002 0.003 2288.0 \n", + "sigma 0.989 0.072 0.857 1.118 0.002 0.001 2199.0 \n", + "x 0.423 0.128 0.176 0.657 0.003 0.002 1956.0 \n", + "\n", + " ess_tail r_hat \n", + "Intercept 1040.0 1.0 \n", + "sigma 1155.0 1.0 \n", + "x 1287.0 1.0 " + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "az.summary(nutpie_idata)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -8945,26 +11890,25 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Last updated: Sun Jun 02 2024\n", + "Last updated: Sat Dec 21 2024\n", "\n", "Python implementation: CPython\n", "Python version : 3.11.9\n", - "IPython version : 8.24.0\n", + "IPython version : 8.27.0\n", "\n", - "arviz : 0.18.0\n", - "pandas: 2.2.2\n", - "bayeux: 0.1.12\n", - "bambi : 0.13.1.dev44+g55aac858.d20240602\n", + "bambi : 0.14.1.dev17+g25798ce7\n", + "arviz : 0.19.0\n", + "pandas: 2.2.3\n", "numpy : 1.26.4\n", "\n", - "Watermark: 2.4.3\n", + "Watermark: 2.5.0\n", "\n" ] } @@ -8977,7 +11921,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "bambi-dev", "language": "python", "name": "python3" }, diff --git a/pyproject.toml b/pyproject.toml index 7d03b483..1c7afdb9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,12 +38,8 @@ dev = [ "seaborn>=0.9.0", ] -# TODO: Unpin this before making a release jax = [ - "bayeux-ml==0.1.14", - "blackjax==1.2.3", - "jax<=0.4.33", - "jaxlib<=0.4.33", + "bayeux-ml==0.1.15", ] [project.urls]