Skip to content

Commit

Permalink
add some validation and guardrails
Browse files Browse the repository at this point in the history
  • Loading branch information
wd60622 committed Feb 7, 2025
1 parent fc24994 commit 9365e8c
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 3 deletions.
5 changes: 5 additions & 0 deletions pymc_marketing/mmm/components/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,11 @@ def variable_mapping(self) -> dict[str, str]:
for parameter in self.default_priors.keys()
}

@property
def combined_dims(self) -> tuple[str, ...]:
"""Get the combined dims for all the parameters."""
return tuple(self._infer_output_core_dims())

def _infer_output_core_dims(self) -> tuple[str, ...]:
parameter_dims = sorted(
[
Expand Down
25 changes: 23 additions & 2 deletions pymc_marketing/mmm/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def create_basis_matrix(df_events: pd.DataFrame, model_dates: np.ndarray):
import pymc as pm
import pytensor.tensor as pt
import xarray as xr
from pydantic import BaseModel, Field, InstanceOf, validate_call
from pydantic import BaseModel, Field, InstanceOf, model_validator, validate_call
from pytensor.tensor.variable import TensorVariable

from pymc_marketing.deserialize import deserialize, register_deserialization
Expand Down Expand Up @@ -179,7 +179,28 @@ class EventEffect(BaseModel):

basis: InstanceOf[Basis]
effect_size: InstanceOf[Prior]
dims: tuple[str, ...]
dims: str | tuple[str, ...]

@model_validator(mode="before")
def _dims_to_tuple(self):
if isinstance(self["dims"], str):
self["dims"] = (self["dims"],)

return self

@model_validator(mode="after")
def _validate_dims(self):
print(self)
if not self.dims:
raise ValueError("The dims must not be empty.")

if not set(self.basis.combined_dims).issubset(set(self.dims)):
raise ValueError("The dims must contain all dimensions of the basis.")

if not set(self.effect_size.dims).issubset(set(self.dims)):
raise ValueError("The dims must contain all dimensions of the effect size.")

return self

def apply(self, X: pt.TensorLike, name: str = "event") -> TensorVariable:
"""Apply the event effect to the data."""
Expand Down
5 changes: 5 additions & 0 deletions pymc_marketing/mmm/multidimensional.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,11 @@ def add_events(
This must be called before building the model.
"""
if not set((prefix, self.dims)).issubset(set(effect.dims)):
raise ValueError(
f"Event effect dims {effect.dims} must contain {prefix} and {self.dims}"
)

event_effect = create_event_mu_effect(df_events, prefix, effect)
self.mu_effects.append(event_effect)

Expand Down
19 changes: 19 additions & 0 deletions tests/mmm/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,3 +437,22 @@ def test_days_from_reference(dates_constructor, reference_constructor):
)

np.testing.assert_allclose(result, np.arange(-4, 6))


@pytest.mark.parametrize(
"sigma_dims, effect_dims",
[
pytest.param("something else", "event", id="basis_not_subset"),
pytest.param("event", "something else", id="effect_not_subset"),
],
)
def test_event_effect_dim_validation(sigma_dims, effect_dims) -> None:
basis = GaussianBasis(
priors={
"sigma": Prior("HalfNormal", dims=sigma_dims),
}
)
effect_size = Prior("Normal", dims=effect_dims)

with pytest.raises(ValueError):
EventEffect(basis=basis, effect_size=effect_size, dims="event")
6 changes: 5 additions & 1 deletion tests/mmm/test_multidimensional.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,11 @@ class MMM:

@pytest.fixture
def create_event_effect() -> Callable[[str], EventEffect]:
def create(prefix: str = "holiday"):
def create(
prefix: str = "holiday",
sigma_dims: str | None = None,
effect_size: Prior | None = None,
):
basis = GaussianBasis()
return EventEffect(
basis=basis,
Expand Down

0 comments on commit 9365e8c

Please sign in to comment.