From c927f64983e1c7b29e512d0caee0e42ad0f1daf6 Mon Sep 17 00:00:00 2001 From: Tomas Capretto Date: Tue, 9 Apr 2024 16:40:50 -0300 Subject: [PATCH] Flag axis as broadcastable --- bambi/backend/model_components.py | 5 ----- bambi/backend/terms.py | 6 ++++-- pyproject.toml | 2 +- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/bambi/backend/model_components.py b/bambi/backend/model_components.py index cace08267..686959fa5 100644 --- a/bambi/backend/model_components.py +++ b/bambi/backend/model_components.py @@ -154,11 +154,6 @@ def build_group_specific_terms(self, pymc_backend, bmb_model): elif isinstance(bmb_model.family, (MultivariateFamily, Categorical)): self.output += coef * predictor[:, np.newaxis] else: - # FIXME: here we see why it fails - print("coef shape", coef.shape.eval()) - print("coef squeezed shape", coef.squeeze().shape.eval()) - print("predictor shape", predictor.shape) - print((coef * predictor).shape.eval()) self.output += coef * predictor def build_response(self, pymc_backend, bmb_model): diff --git a/bambi/backend/terms.py b/bambi/backend/terms.py index 943857b91..ed2a2e02c 100644 --- a/bambi/backend/terms.py +++ b/bambi/backend/terms.py @@ -110,11 +110,13 @@ def build(self, spec): response_dims = list(spec.response_component.response_term.coords) dims = list(self.coords) + response_dims + coef = self.build_distribution(self.term.prior, label, dims=dims, **kwargs) + # Squeeze ensures we don't have a shape of (n, 1) when we mean (n, ) # This happens with categorical predictors with two levels and intercept. - # FIXME: This is not working anymore! # See https://github.com/pymc-devs/pymc/issues/7246 - coef = self.build_distribution(self.term.prior, label, dims=dims, **kwargs).squeeze() + if len(coef.shape.eval()) == 2 and coef.shape.eval()[-1] == 1: + coef = pt.specify_broadcastable(coef, 1).squeeze() coef = coef[self.term.group_index] return coef, predictor diff --git a/pyproject.toml b/pyproject.toml index f1f00dc88..f8b5ac674 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ dependencies = [ "formulae>=0.5.3", "graphviz", "pandas>=1.0.0", - "pymc>=5.12.0,<5.13.0", + "pymc>=5.13.0", ] [project.optional-dependencies]