Skip to content

Commit

Permalink
Flag axis as broadcastable
Browse files Browse the repository at this point in the history
  • Loading branch information
tomicapretto committed Apr 9, 2024
1 parent f704c5e commit c927f64
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 8 deletions.
5 changes: 0 additions & 5 deletions bambi/backend/model_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,6 @@ def build_group_specific_terms(self, pymc_backend, bmb_model):
elif isinstance(bmb_model.family, (MultivariateFamily, Categorical)):
self.output += coef * predictor[:, np.newaxis]
else:
# FIXME: here we see why it fails
print("coef shape", coef.shape.eval())
print("coef squeezed shape", coef.squeeze().shape.eval())
print("predictor shape", predictor.shape)
print((coef * predictor).shape.eval())
self.output += coef * predictor

def build_response(self, pymc_backend, bmb_model):
Expand Down
6 changes: 4 additions & 2 deletions bambi/backend/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,13 @@ def build(self, spec):
response_dims = list(spec.response_component.response_term.coords)

dims = list(self.coords) + response_dims
coef = self.build_distribution(self.term.prior, label, dims=dims, **kwargs)

# Squeeze ensures we don't have a shape of (n, 1) when we mean (n, )
# This happens with categorical predictors with two levels and intercept.
# FIXME: This is not working anymore!
# See https://github.com/pymc-devs/pymc/issues/7246
coef = self.build_distribution(self.term.prior, label, dims=dims, **kwargs).squeeze()
if len(coef.shape.eval()) == 2 and coef.shape.eval()[-1] == 1:
coef = pt.specify_broadcastable(coef, 1).squeeze()
coef = coef[self.term.group_index]

return coef, predictor
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ dependencies = [
"formulae>=0.5.3",
"graphviz",
"pandas>=1.0.0",
"pymc>=5.12.0,<5.13.0",
"pymc>=5.13.0",
]

[project.optional-dependencies]
Expand Down

0 comments on commit c927f64

Please sign in to comment.