Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CLV plot_expected_purchases_ppc #1222

Merged
merged 23 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4,292 changes: 0 additions & 4,292 deletions docs/source/notebooks/clv/bg_nbd.ipynb

This file was deleted.

1,054 changes: 0 additions & 1,054 deletions docs/source/notebooks/clv/dev/beta_geo_dev.ipynb

This file was deleted.

757 changes: 0 additions & 757 deletions docs/source/notebooks/clv/dev/utilities.ipynb

This file was deleted.

6,267 changes: 6,267 additions & 0 deletions docs/source/notebooks/clv/dev/utilities_plotting.ipynb

Large diffs are not rendered by default.

643 changes: 334 additions & 309 deletions docs/source/notebooks/clv/pareto_nbd.ipynb

Large diffs are not rendered by default.

Binary file removed docs/source/notebooks/clv/pnbd.nc
Binary file not shown.
2 changes: 2 additions & 0 deletions pymc_marketing/clv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from pymc_marketing.clv.plotting import (
plot_customer_exposure,
plot_expected_purchases,
plot_expected_purchases_ppc,
plot_frequency_recency_matrix,
plot_probability_alive_matrix,
)
Expand All @@ -46,6 +47,7 @@
"plot_frequency_recency_matrix",
"plot_expected_purchases",
"plot_probability_alive_matrix",
"plot_expected_purchases_ppc",
"rfm_segments",
"rfm_summary",
"rfm_train_test_split",
Expand Down
103 changes: 102 additions & 1 deletion pymc_marketing/clv/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,18 @@
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
from matplotlib.lines import Line2D

from pymc_marketing.clv import BetaGeoModel, ParetoNBDModel
from pymc_marketing.clv.utils import _expected_cumulative_transactions

__all__ = [
"plot_customer_exposure",
"plot_expected_purchases",
"plot_frequency_recency_matrix",
"plot_probability_alive_matrix",
"plot_expected_purchases",
"plot_expected_purchases_ppc",
]


Expand Down Expand Up @@ -474,6 +476,105 @@
return ax


def plot_expected_purchases_ppc(
model,
ppc: str = "posterior",
max_purchases: int = 10,
samples: int = 1000,
random_seed: int = 45,
ax: plt.Axes | None = None,
**kwargs,
) -> plt.Axes:
"""Plot a prior or posterior predictive check for the customer purchase frequency distribution.

At this time only ParetoNBDModel and BetaGeoBetaBinomModel are supported.

Adapted from legacy ``lifetimes`` library:
https://github.com/CamDavidsonPilon/lifetimes/blob/master/lifetimes/plotting.py#L25

Parameters
----------
model : CLV model
Prior predictive checks can be performed before or after a model is fit.
Posterior predictive checks require a fitted model.
ppc : string, optional
Type of predictive check to perform. Options are 'prior' or 'posterior'; defaults to 'posterior'.
max_purchases : int, optional
Cutoff for bars of purchase counts to plot. Default is 10.
samples : int, optional
Number of samples to draw for prior predictive checks. This is not used for posterior predictive checks.
random_seed : int, optional
Random seed to fix sampling results
ax : matplotlib.AxesSubplot, optional
A matplotlib axes instance. Creates new axes instance by default.
**kwargs
Additional arguments to pass into the pandas.DataFrame.plot command.

Returns
-------
axes: matplotlib.AxesSubplot
"""
# TODO: BetaGeoModel requires its own dist class in distributions.py for this function.
if isinstance(model, BetaGeoModel):
raise AttributeError("BetaGeoModel is unsupported for this function.")

Check warning on line 519 in pymc_marketing/clv/plotting.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/clv/plotting.py#L518-L519

Added lines #L518 - L519 were not covered by tests

if ax is None:
ax = plt.subplot(111)

Check warning on line 522 in pymc_marketing/clv/plotting.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/clv/plotting.py#L521-L522

Added lines #L521 - L522 were not covered by tests

match ppc:
case "prior":

Check warning on line 525 in pymc_marketing/clv/plotting.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/clv/plotting.py#L524-L525

Added lines #L524 - L525 were not covered by tests
# build model if it has not been fit yet
model.build_model()

Check warning on line 527 in pymc_marketing/clv/plotting.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/clv/plotting.py#L527

Added line #L527 was not covered by tests

prior_idata = pm.sample_prior_predictive(

Check warning on line 529 in pymc_marketing/clv/plotting.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/clv/plotting.py#L529

Added line #L529 was not covered by tests
samples=samples,
model=model.model,
random_seed=random_seed,
)

# obs_var must be retrieved from prior_idata if model has not been fit
obs_freq = prior_idata.observed_data["recency_frequency"].sel(

Check warning on line 536 in pymc_marketing/clv/plotting.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/clv/plotting.py#L536

Added line #L536 was not covered by tests
obs_var="frequency"
)
ppc_freq = prior_idata.prior_predictive["recency_frequency"].sel(

Check warning on line 539 in pymc_marketing/clv/plotting.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/clv/plotting.py#L539

Added line #L539 was not covered by tests
obs_var="frequency"
)
title = "Prior Predictive Check for Customer Frequency"
case "posterior":
obs_freq = model.idata.observed_data["recency_frequency"].sel(

Check warning on line 544 in pymc_marketing/clv/plotting.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/clv/plotting.py#L542-L544

Added lines #L542 - L544 were not covered by tests
obs_var="frequency"
)
# Keep samples at 1 here because (chain * draw * customer) samples are already being drawn
ppc_freq = model.distribution_new_customer_recency_frequency(

Check warning on line 548 in pymc_marketing/clv/plotting.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/clv/plotting.py#L548

Added line #L548 was not covered by tests
random_seed=random_seed,
n_samples=1,
).sel(obs_var="frequency")
title = "Posterior Predictive Check for Customer Frequency"
case _:
raise NameError("Specify 'prior' or 'posterior' for 'ppc' parameter.")

Check warning on line 554 in pymc_marketing/clv/plotting.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/clv/plotting.py#L552-L554

Added lines #L552 - L554 were not covered by tests

# convert estimated and observed xarrays into dataframes for plotting
estimated = ppc_freq.to_dataframe().value_counts(normalize=True).sort_index()
observed = obs_freq.to_dataframe().value_counts(normalize=True).sort_index()

Check warning on line 558 in pymc_marketing/clv/plotting.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/clv/plotting.py#L557-L558

Added lines #L557 - L558 were not covered by tests

# PPC histogram plot
ax = pd.DataFrame(

Check warning on line 561 in pymc_marketing/clv/plotting.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/clv/plotting.py#L561

Added line #L561 was not covered by tests
{
"Estimated": estimated.reset_index()["proportion"].head(max_purchases),
"Observed": observed.reset_index()["proportion"].head(max_purchases),
},
).plot(
kind="bar",
ax=ax,
title=title,
xlabel="Repeat Purchases",
ylabel="% of Customer Population",
rot=0.0,
**kwargs,
)
return ax

Check warning on line 575 in pymc_marketing/clv/plotting.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/clv/plotting.py#L575

Added line #L575 was not covered by tests


def _force_aspect(ax: plt.Axes, aspect=1):
im = ax.get_images()
extent = im[0].get_extent()
Expand Down
35 changes: 34 additions & 1 deletion tests/clv/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
import xarray as xr
from pytensor.tensor import TensorVariable

from pymc_marketing.clv.plotting import (
from pymc_marketing.clv import (
plot_customer_exposure,
plot_expected_purchases,
plot_expected_purchases_ppc,
plot_frequency_recency_matrix,
plot_probability_alive_matrix,
)
Expand All @@ -29,6 +30,7 @@
class MockModel:
def __init__(self, data: pd.DataFrame):
self.data = data
self._model_type = None

def _mock_posterior(self, data: pd.DataFrame) -> xr.DataArray:
n_customers = len(data)
Expand Down Expand Up @@ -178,3 +180,34 @@ def test_plot_expected_purchases(

# clear any existing pyplot figures
plt.clf()


def test_plot_expected_purchases_ppc_exceptions(fitted_bg, fitted_pnbd):
with pytest.raises(
AttributeError, match="BetaGeoModel is unsupported for this function."
):
plot_expected_purchases_ppc(fitted_bg)

with pytest.raises(
NameError, match="Specify 'prior' or 'posterior' for 'ppc' parameter."
):
plot_expected_purchases_ppc(fitted_pnbd, ppc="ppc")


@pytest.mark.parametrize(
"ppc, max_purchases, samples, subplot",
[("prior", 10, 100, None), ("posterior", 20, 50, plt.subplot())],
)
def test_plot_expected_purchases_ppc(fitted_pnbd, ppc, max_purchases, samples, subplot):
ax = plot_expected_purchases_ppc(
model=fitted_pnbd,
ppc=ppc,
max_purchases=max_purchases,
samples=samples,
ax=subplot,
)

assert isinstance(ax, plt.Axes)

# clear any existing pyplot figures
plt.clf()
53 changes: 1 addition & 52 deletions tests/clv/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import xarray
from pandas.testing import assert_frame_equal

from pymc_marketing.clv import BetaGeoModel, GammaGammaModel, ParetoNBDModel
from pymc_marketing.clv import GammaGammaModel, ParetoNBDModel
from pymc_marketing.clv.utils import (
_expected_cumulative_transactions,
_find_first_transactions,
Expand Down Expand Up @@ -59,57 +59,6 @@ def test_to_xarray():
np.testing.assert_array_equal(new_y.coords["test_dim"], customer_id)


@pytest.fixture(scope="module")
def fitted_bg(test_summary_data) -> BetaGeoModel:
rng = np.random.default_rng(13)

model_config = {
# Narrow Gaussian centered at MLE params from lifetimes BetaGeoFitter
"a_prior": Prior("DiracDelta", c=1.85034151),
"alpha_prior": Prior("DiracDelta", c=1.86428187),
"b_prior": Prior("DiracDelta", c=3.18105431),
"r_prior": Prior("DiracDelta", c=0.16385072),
}
model = BetaGeoModel(
data=test_summary_data,
model_config=model_config,
)
model.build_model()
fake_fit = pm.sample_prior_predictive(
samples=50, model=model.model, random_seed=rng
).prior
set_model_fit(model, fake_fit)

return model


@pytest.fixture(scope="module")
def fitted_pnbd(test_summary_data) -> ParetoNBDModel:
rng = np.random.default_rng(45)

model_config = {
# Narrow Gaussian centered at MLE params from lifetimes ParetoNBDFitter
"r_prior": Prior("DiracDelta", c=0.560),
"alpha_prior": Prior("DiracDelta", c=10.591),
"s_prior": Prior("DiracDelta", c=0.550),
"beta_prior": Prior("DiracDelta", c=9.756),
}
pnbd_model = ParetoNBDModel(
data=test_summary_data,
model_config=model_config,
)
pnbd_model.build_model()

# Mock an idata object for tests requiring a fitted model
# TODO: This is quite slow. Check similar fixtures in the model tests to speed this up.
fake_fit = pm.sample_prior_predictive(
samples=50, model=pnbd_model.model, random_seed=rng
).prior
set_model_fit(pnbd_model, fake_fit)

return pnbd_model


@pytest.fixture(scope="module")
def fitted_gg(test_summary_data) -> GammaGammaModel:
rng = np.random.default_rng(40)
Expand Down
62 changes: 61 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from arviz import InferenceData
from xarray import DataArray, Dataset

from pymc_marketing.clv.models import CLVModel
from pymc_marketing.clv.models import BetaGeoModel, CLVModel, ParetoNBDModel
from pymc_marketing.prior import Prior


def pytest_addoption(parser):
Expand Down Expand Up @@ -152,3 +153,62 @@ def mock_fit_MAP(self, *args, **kwargs):
idata = mock_sample(*args, **kwargs, chains=chains, draws=draws, model=self.model)

return idata.sel(chain=[0], draw=[0])


# TODO: This fixture is used in the plotting and utils test modules.
# Consider creating a MockModel class to replace this and other fitted model fixtures.
@pytest.fixture(scope="module")
def fitted_bg(test_summary_data) -> BetaGeoModel:
rng = np.random.default_rng(13)

model_config = {
# Narrow Gaussian centered at MLE params from lifetimes BetaGeoFitter
"a_prior": Prior("DiracDelta", c=1.85034151),
"alpha_prior": Prior("DiracDelta", c=1.86428187),
"b_prior": Prior("DiracDelta", c=3.18105431),
"r_prior": Prior("DiracDelta", c=0.16385072),
}
model = BetaGeoModel(
data=test_summary_data,
model_config=model_config,
)
model.build_model()
fake_fit = pm.sample_prior_predictive(
samples=50, model=model.model, random_seed=rng
)
# posterior group required to pass L80 assert check
fake_fit.add_groups(posterior=fake_fit.prior)
set_model_fit(model, fake_fit)

return model


# TODO: This fixture is used in the plotting and utils test modules.
# Consider creating a MockModel class to replace this and other fitted model fixtures.
@pytest.fixture(scope="module")
def fitted_pnbd(test_summary_data) -> ParetoNBDModel:
rng = np.random.default_rng(45)

model_config = {
# Narrow Gaussian centered at MLE params from lifetimes ParetoNBDFitter
"r_prior": Prior("DiracDelta", c=0.560),
"alpha_prior": Prior("DiracDelta", c=10.591),
"s_prior": Prior("DiracDelta", c=0.550),
"beta_prior": Prior("DiracDelta", c=9.756),
}
pnbd_model = ParetoNBDModel(
data=test_summary_data,
model_config=model_config,
)
pnbd_model.build_model()

# Mock an idata object for tests requiring a fitted model
# TODO: This is quite slow. Check similar fixtures in the model tests to speed this up.
fake_fit = pm.sample_prior_predictive(
samples=50, model=pnbd_model.model, random_seed=rng
)
# posterior group required to pass L80 assert check
fake_fit.add_groups(posterior=fake_fit.prior)
set_model_fit(pnbd_model, fake_fit)

return pnbd_model
Loading