diff --git a/pymc_marketing/clv/models/pareto_nbd.py b/pymc_marketing/clv/models/pareto_nbd.py index 8e2b2a7c..5fec84e3 100644 --- a/pymc_marketing/clv/models/pareto_nbd.py +++ b/pymc_marketing/clv/models/pareto_nbd.py @@ -384,10 +384,13 @@ def _extract_predictive_variables( must_be_unique=["customer_id"], ) + customer_id = data["customer_id"] + model_coords = self.model.coords # type: ignore if self.purchase_covariate_cols: purchase_xarray = xarray.DataArray( data[self.purchase_covariate_cols], dims=["customer_id", "purchase_covariate"], + coords=[customer_id, list(model_coords["purchase_covariate"])], ) alpha_scale = self.fit_result["alpha_scale"] purchase_coefficient = self.fit_result["purchase_coefficient"] @@ -404,6 +407,7 @@ def _extract_predictive_variables( dropout_xarray = xarray.DataArray( data[self.dropout_covariate_cols], dims=["customer_id", "dropout_covariate"], + coords=[customer_id, list(model_coords["dropout_covariate"])], ) beta_scale = self.fit_result["beta_scale"] dropout_coefficient = self.fit_result["dropout_coefficient"] diff --git a/tests/clv/models/test_pareto_nbd.py b/tests/clv/models/test_pareto_nbd.py index 9ad845a8..12d29234 100644 --- a/tests/clv/models/test_pareto_nbd.py +++ b/tests/clv/models/test_pareto_nbd.py @@ -485,12 +485,20 @@ def test_extract_predictive_covariates(self): new_data = self.data.assign( purchase_cov1=1.0, dropout_cov=1.0, + customer_id=self.data["customer_id"] + 1, ) different_vars = model._extract_predictive_variables(data=new_data) - different_alpha = different_vars["alpha"] - different_beta = different_vars["beta"] + different_alpha = different_vars["alpha"] + assert np.all( + different_alpha.customer_id.values == alpha_model.customer_id.values + 1 + ) assert not np.allclose(alpha_model, different_alpha) + + different_beta = different_vars["beta"] + assert np.all( + different_beta.customer_id.values == beta_model.customer_id.values + 1 + ) assert not np.allclose(beta_model, different_beta) def test_logp(self):