Skip to content

Commit

Permalink
test for (mock) fit of more MMMMs
Browse files Browse the repository at this point in the history
  • Loading branch information
wd60622 committed Feb 8, 2025
1 parent 94764dc commit b198043
Showing 1 changed file with 32 additions and 136 deletions.
168 changes: 32 additions & 136 deletions tests/mmm/test_multidimensional.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def fit_mmm(df, mmm, mock_pymc_sample):
return mmm


def test_fit(fit_mmm):
def test_simple_fit(fit_mmm):
assert isinstance(fit_mmm.posterior, xr.Dataset)
assert isinstance(fit_mmm.idata.constant_data, xr.Dataset)

Expand Down Expand Up @@ -164,29 +164,34 @@ def multi_dim_data():


@pytest.mark.parametrize(
"time_varying_intercept, time_varying_media, yearly_seasonality, dims",
"fixture_name, dims",
[
(False, False, None, ()), # no time-varying, no seasonality, no extra dims
(False, False, 4, ()), # no time-varying, has seasonality, no extra dims
(
True,
False,
None,
(),
), # time-varying intercept only, no seasonality, no extra dims
(False, True, 4, ()), # time-varying media only, has seasonality, no extra dims
(True, True, 4, ()), # both time-varying, has seasonality, no extra dims
pytest.param("single_dim_data", (), id="Marginal model"),
pytest.param("multi_dim_data", ("country",), id="County model"),
],
)
def test_build_model_single_dim(
single_dim_data,
@pytest.mark.parametrize(
"time_varying_intercept, time_varying_media, yearly_seasonality",
[
pytest.param(False, False, None, id="no tvps or fourier"),
pytest.param(False, False, 4, id="no tvps with fourier"),
pytest.param(True, False, None, id="tvp intercept only, no fourier"),
pytest.param(False, True, 4, id="tvp media only with fourier"),
pytest.param(True, True, 4, id="tvps and fourier"),
],
)
def test_fit(
request,
fixture_name,
time_varying_intercept,
time_varying_media,
yearly_seasonality,
dims,
mock_pymc_sample,
):
"""Test that building the model works with different configurations (single-dim)."""
X, y = single_dim_data
X, y = request.getfixturevalue(fixture_name)

adstock = GeometricAdstock(l_max=2)
saturation = LogisticSaturation()

Expand All @@ -202,11 +207,18 @@ def test_build_model_single_dim(
time_varying_media=time_varying_media,
)

mmm.build_model(X, y)
seed = sum(map(ord, "Fitting the MMMM"))
random_seed = np.random.default_rng(seed)

idata = mmm.fit(X, y, random_seed=random_seed)

# Assertions
assert hasattr(mmm, "model"), "Model attribute should be set after build_model."
assert isinstance(mmm.model, pm.Model), "mmm.model should be a PyMC Model instance."
for dim in dims:
assert dim in mmm.model.coords, (
f"Extra dimension '{dim}' should be in model coords."
)

# Basic checks to confirm presence of key variables
var_names = mmm.model.named_vars.keys()
Expand All @@ -219,89 +231,6 @@ def test_build_model_single_dim(
if yearly_seasonality is not None:
assert "fourier_contribution" in var_names


@pytest.mark.parametrize(
"time_varying_intercept, time_varying_media, yearly_seasonality, dims",
[
(
False,
False,
None,
("country",),
), # no time-varying, no seasonality, 1 extra dim
(
True,
False,
4,
("country",),
), # time-varying intercept only, has seasonality, 1 extra dim
(
False,
True,
4,
("country",),
), # time-varying media only, has seasonality, 1 extra dim
(
True,
True,
2,
("country",),
), # both time-varying, has seasonality, 1 extra dim
],
)
def test_build_model_multi_dim(
multi_dim_data, time_varying_intercept, time_varying_media, yearly_seasonality, dims
):
"""Test building the model when extra dimensions (like 'country') are present."""
X, y = multi_dim_data
adstock = GeometricAdstock(l_max=2)
saturation = LogisticSaturation()

mmm = MMM(
date_column="date",
target_column="target",
channel_columns=["channel_1", "channel_2"],
dims=dims,
adstock=adstock,
saturation=saturation,
yearly_seasonality=yearly_seasonality,
time_varying_intercept=time_varying_intercept,
time_varying_media=time_varying_media,
)

mmm.build_model(X, y)

assert hasattr(mmm, "model"), "Model attribute should be set after build_model."
assert isinstance(mmm.model, pm.Model), "mmm.model should be a PyMC Model instance."
assert "country" in mmm.model.coords, (
"Extra dimension 'country' should be in model coords."
)


def test_fit_single_dim(single_dim_data, mock_pymc_sample):
"""Test fitting the model on a single-dimension dataset."""
X, y = single_dim_data

adstock = GeometricAdstock(l_max=2)
saturation = LogisticSaturation()

mmm = MMM(
date_column="date",
target_column="target",
channel_columns=["channel_1", "channel_2"],
dims=(),
adstock=adstock,
saturation=saturation,
yearly_seasonality=None, # disable yearly seasonality
time_varying_intercept=False,
time_varying_media=False,
)

# Build and fit
mmm.build_model(X, y)

# To keep tests fast, set small number of draws/tune
idata = mmm.fit(X, y, draws=10, tune=10, chains=1)
assert isinstance(idata, az.InferenceData), (
"fit should return an InferenceData object."
)
Expand All @@ -314,45 +243,12 @@ def test_fit_single_dim(single_dim_data, mock_pymc_sample):
"InferenceData should have a posterior group."
)


def test_fit_multi_dim(multi_dim_data, mock_pymc_sample):
"""Test fitting the model on a multi-dimensional dataset (e.g. with 'country')."""
X, y = multi_dim_data

adstock = GeometricAdstock(l_max=2)
saturation = LogisticSaturation()

mmm = MMM(
date_column="date",
target_column="target",
channel_columns=["channel_1", "channel_2"],
dims=("country",),
adstock=adstock,
saturation=saturation,
yearly_seasonality=2,
time_varying_intercept=True,
time_varying_media=True,
)

# Build and fit
mmm.build_model(X, y)

# Again, keep the sampler small for test speed
idata = mmm.fit(X, y, draws=10, tune=10, chains=1)
assert isinstance(idata, az.InferenceData), (
"fit should return an InferenceData object."
)
assert hasattr(mmm, "idata"), (
"MMM instance should store the inference data as 'idata'."
)

# Check if 'country' is in the posterior dimensions
assert "country" in mmm.idata.posterior.dims, (
"Posterior should have 'country' dimension."
)
for dim in dims:
assert dim in mmm.idata.posterior.dims, (
f"Extra dimension '{dim}' should be in posterior dims."
)


# @pytest.mark.xfail(reason="Need to work through the new data.")
def test_sample_posterior_predictive_new_data(single_dim_data, mock_pymc_sample):
"""
Test that sampling from the posterior predictive with new/unseen data
Expand Down

0 comments on commit b198043

Please sign in to comment.