diff --git a/tests/mmm/test_multidimensional.py b/tests/mmm/test_multidimensional.py index 4fd6e589..a5500635 100644 --- a/tests/mmm/test_multidimensional.py +++ b/tests/mmm/test_multidimensional.py @@ -76,7 +76,7 @@ def fit_mmm(df, mmm, mock_pymc_sample): return mmm -def test_fit(fit_mmm): +def test_simple_fit(fit_mmm): assert isinstance(fit_mmm.posterior, xr.Dataset) assert isinstance(fit_mmm.idata.constant_data, xr.Dataset) @@ -164,29 +164,34 @@ def multi_dim_data(): @pytest.mark.parametrize( - "time_varying_intercept, time_varying_media, yearly_seasonality, dims", + "fixture_name, dims", [ - (False, False, None, ()), # no time-varying, no seasonality, no extra dims - (False, False, 4, ()), # no time-varying, has seasonality, no extra dims - ( - True, - False, - None, - (), - ), # time-varying intercept only, no seasonality, no extra dims - (False, True, 4, ()), # time-varying media only, has seasonality, no extra dims - (True, True, 4, ()), # both time-varying, has seasonality, no extra dims + pytest.param("single_dim_data", (), id="Marginal model"), + pytest.param("multi_dim_data", ("country",), id="County model"), ], ) -def test_build_model_single_dim( - single_dim_data, +@pytest.mark.parametrize( + "time_varying_intercept, time_varying_media, yearly_seasonality", + [ + pytest.param(False, False, None, id="no tvps or fourier"), + pytest.param(False, False, 4, id="no tvps with fourier"), + pytest.param(True, False, None, id="tvp intercept only, no fourier"), + pytest.param(False, True, 4, id="tvp media only with fourier"), + pytest.param(True, True, 4, id="tvps and fourier"), + ], +) +def test_fit( + request, + fixture_name, time_varying_intercept, time_varying_media, yearly_seasonality, dims, + mock_pymc_sample, ): """Test that building the model works with different configurations (single-dim).""" - X, y = single_dim_data + X, y = request.getfixturevalue(fixture_name) + adstock = GeometricAdstock(l_max=2) saturation = LogisticSaturation() @@ -202,11 +207,18 @@ def test_build_model_single_dim( time_varying_media=time_varying_media, ) - mmm.build_model(X, y) + seed = sum(map(ord, "Fitting the MMMM")) + random_seed = np.random.default_rng(seed) + + idata = mmm.fit(X, y, random_seed=random_seed) # Assertions assert hasattr(mmm, "model"), "Model attribute should be set after build_model." assert isinstance(mmm.model, pm.Model), "mmm.model should be a PyMC Model instance." + for dim in dims: + assert dim in mmm.model.coords, ( + f"Extra dimension '{dim}' should be in model coords." + ) # Basic checks to confirm presence of key variables var_names = mmm.model.named_vars.keys() @@ -219,89 +231,6 @@ def test_build_model_single_dim( if yearly_seasonality is not None: assert "fourier_contribution" in var_names - -@pytest.mark.parametrize( - "time_varying_intercept, time_varying_media, yearly_seasonality, dims", - [ - ( - False, - False, - None, - ("country",), - ), # no time-varying, no seasonality, 1 extra dim - ( - True, - False, - 4, - ("country",), - ), # time-varying intercept only, has seasonality, 1 extra dim - ( - False, - True, - 4, - ("country",), - ), # time-varying media only, has seasonality, 1 extra dim - ( - True, - True, - 2, - ("country",), - ), # both time-varying, has seasonality, 1 extra dim - ], -) -def test_build_model_multi_dim( - multi_dim_data, time_varying_intercept, time_varying_media, yearly_seasonality, dims -): - """Test building the model when extra dimensions (like 'country') are present.""" - X, y = multi_dim_data - adstock = GeometricAdstock(l_max=2) - saturation = LogisticSaturation() - - mmm = MMM( - date_column="date", - target_column="target", - channel_columns=["channel_1", "channel_2"], - dims=dims, - adstock=adstock, - saturation=saturation, - yearly_seasonality=yearly_seasonality, - time_varying_intercept=time_varying_intercept, - time_varying_media=time_varying_media, - ) - - mmm.build_model(X, y) - - assert hasattr(mmm, "model"), "Model attribute should be set after build_model." - assert isinstance(mmm.model, pm.Model), "mmm.model should be a PyMC Model instance." - assert "country" in mmm.model.coords, ( - "Extra dimension 'country' should be in model coords." - ) - - -def test_fit_single_dim(single_dim_data, mock_pymc_sample): - """Test fitting the model on a single-dimension dataset.""" - X, y = single_dim_data - - adstock = GeometricAdstock(l_max=2) - saturation = LogisticSaturation() - - mmm = MMM( - date_column="date", - target_column="target", - channel_columns=["channel_1", "channel_2"], - dims=(), - adstock=adstock, - saturation=saturation, - yearly_seasonality=None, # disable yearly seasonality - time_varying_intercept=False, - time_varying_media=False, - ) - - # Build and fit - mmm.build_model(X, y) - - # To keep tests fast, set small number of draws/tune - idata = mmm.fit(X, y, draws=10, tune=10, chains=1) assert isinstance(idata, az.InferenceData), ( "fit should return an InferenceData object." ) @@ -314,45 +243,12 @@ def test_fit_single_dim(single_dim_data, mock_pymc_sample): "InferenceData should have a posterior group." ) - -def test_fit_multi_dim(multi_dim_data, mock_pymc_sample): - """Test fitting the model on a multi-dimensional dataset (e.g. with 'country').""" - X, y = multi_dim_data - - adstock = GeometricAdstock(l_max=2) - saturation = LogisticSaturation() - - mmm = MMM( - date_column="date", - target_column="target", - channel_columns=["channel_1", "channel_2"], - dims=("country",), - adstock=adstock, - saturation=saturation, - yearly_seasonality=2, - time_varying_intercept=True, - time_varying_media=True, - ) - - # Build and fit - mmm.build_model(X, y) - - # Again, keep the sampler small for test speed - idata = mmm.fit(X, y, draws=10, tune=10, chains=1) - assert isinstance(idata, az.InferenceData), ( - "fit should return an InferenceData object." - ) - assert hasattr(mmm, "idata"), ( - "MMM instance should store the inference data as 'idata'." - ) - - # Check if 'country' is in the posterior dimensions - assert "country" in mmm.idata.posterior.dims, ( - "Posterior should have 'country' dimension." - ) + for dim in dims: + assert dim in mmm.idata.posterior.dims, ( + f"Extra dimension '{dim}' should be in posterior dims." + ) -# @pytest.mark.xfail(reason="Need to work through the new data.") def test_sample_posterior_predictive_new_data(single_dim_data, mock_pymc_sample): """ Test that sampling from the posterior predictive with new/unseen data