Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Numpyro quasiposterior for the quasimultinomial model #32

Merged
merged 5 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
263 changes: 263 additions & 0 deletions examples/frequentist_notebook_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,3 +288,266 @@ def format_date(x, pos):


figure_spec.map(plot_city, range(len(cities)))
# -

# ## Quasiposterior modelling
#
# Above we fitted the model using the maximum quasilikelihood approach, and then constructed approximate confidence intervals basing on the assumed covariance matrix structure and adjusting it by the estimated overdispersion factor.
# There exists also another method of quantifying uncertainty, which is based on generalized Bayesian paradigm, where the likelihood is replaced by the quasilikelihood.
#
# These methods of quantifying uncertainty do not have to be necessarily compatible and may reveal that the quasiposterior on growth advantage estimates is e.g., not symmetric.
#
# In fact, we attempt to use separate overdispersion for each city. Let's compare both approaches.

# +
import arviz as az
from numpyro.infer import MCMC, NUTS
from functools import partial


def sample_from_model(share_overdispersion: bool):
if share_overdispersion:
_overdispersion = overdispersion_tuple.overall
else:
_overdispersion = overdispersion_tuple.cities

model = qm.construct_model(
ys=ys_effective,
ts=ts_lst_scaled,
overdispersion=_overdispersion,
sigma_offset=100.0,
)

mcmc = MCMC(NUTS(model), num_chains=4, num_samples=2000, num_warmup=2000)
mcmc.run(jax.random.PRNGKey(42))
return mcmc


mcmc_shared = sample_from_model(share_overdispersion=True)
mcmc_indivi = sample_from_model(share_overdispersion=False)
# -

# Before we proceed with the analysis of the quasiposteriors, let's see if we can trust the obtained samples.
#
# **Shared overdispersion**

idata = az.from_numpyro(mcmc_shared)
az.summary(idata, filter_vars="regex", var_names="^r.*")

az.plot_trace(idata, filter_vars="regex", var_names="^r.*")
plt.tight_layout()
plt.show()

# **Individual overdispersion parameters**

idata = az.from_numpyro(mcmc_indivi)
az.summary(idata, filter_vars="regex", var_names="^r.*")

az.plot_trace(idata, filter_vars="regex", var_names="^r.*")
plt.tight_layout()
plt.show()

# If we do not see sampling problems, we can try to understand the quasiposterior distributions.
#
# Let's compare both quasiposteriors additionally with the confidence intervals.

# +
from subplots_from_axsize import subplots_from_axsize


def plot_posterior(ax, i, mcmc):
max_quasilikelihood = qm.get_relative_growths(
theta_star, n_variants=n_variants_effective
)
lower = qm.get_relative_growths(
confints_estimates[0], n_variants=n_variants_effective
)
upper = qm.get_relative_growths(
confints_estimates[1], n_variants=n_variants_effective
)

# Plot maximum quasilikelihood and confidence interval bands
ax.axvline(max_quasilikelihood[i], c="k")
ax.axvspan(lower[i], upper[i], alpha=0.3, facecolor="k", edgecolor=None)

# Plot quasiposterior samples using a histogram
samples = mcmc.get_samples()["relative_growths"][:, i]
ax.hist(samples, bins=40, color="maroon")

# Plot the credible interval calculated using quantiles
credibility = 0.95
_a = (1 - credibility) / 2.0
ax.axvline(jnp.quantile(samples, q=_a), c="maroon", linestyle=":")
ax.axvline(jnp.quantile(samples, q=1.0 - _a), c="maroon", linestyle=":")

# Apply some styling
ax.spines[["left", "right", "top"]].set_visible(False)
ax.set_yticks([])


fig, axs = subplots_from_axsize(
ncols=n_variants_effective - 1,
axsize=(2, 0.8),
nrows=2,
sharex="col",
hspace=0.25,
dpi=400,
)

for i in range(n_variants_effective - 1):
plot_posterior(axs[0, i], i, mcmc_shared)
plot_posterior(axs[1, i], i, mcmc_indivi)

axs[0, 0].set_ylabel("Shared")
axs[1, 0].set_ylabel("Individual")

for i, variant in enumerate(variants_effective[1:]):
axs[0, i].set_title(f"Advantage of {variant}")


# -

# We see two things:
#
# - Quasiposterior employing shared overdispersion gives similar results to the ones obtained with confidence intervals.
# - When we use individual overdispersion factors (one per city), we see a discrepancy.
#
# Let's compare the predictive plots between two quasiposteriors and the confidence bands obtained earlier.


# +
def plot_predictions(
ax,
i: int,
*,
fitted_line,
fitted_lower,
fitted_upper,
predicted_line,
predicted_lower,
predicted_upper,
) -> None:
def remove_0th(arr):
"""We don't plot the artificial 0th variant 'other'."""
return arr[:, 1:]

# Plot fits in observed and unobserved time intervals.
plot_ts.plot_fit(ax, ts_lst[i], remove_0th(fitted_line[i]), colors=colors)
plot_ts.plot_fit(
ax, ts_pred_lst[i], remove_0th(predicted_line[i]), colors=colors, linestyle="--"
)
plot_ts.plot_confidence_bands(
ax,
ts_lst[i],
(remove_0th(fitted_lower[i]), remove_0th(fitted_upper[i])),
colors=colors,
)
plot_ts.plot_confidence_bands(
ax,
ts_pred_lst[i],
(remove_0th(predicted_lower[i]), remove_0th(predicted_upper[i])),
colors=colors,
)

# Plot the data points
plot_ts.plot_data(ax, ts_lst[i], remove_0th(ys_effective[i]), colors=colors)

# Plot the complements
plot_ts.plot_complement(ax, ts_lst[i], remove_0th(fitted_line[i]), alpha=0.3)
plot_ts.plot_complement(
ax, ts_pred_lst[i], remove_0th(predicted_line[i]), linestyle="--", alpha=0.3
)

# format axes and title
def format_date(x, pos):
return plot_ts.num_to_date(x, date_min=start_date)

date_formatter = ticker.FuncFormatter(format_date)
ax.xaxis.set_major_formatter(date_formatter)
tick_positions = [0, 0.5, 1]
tick_labels = ["0%", "50%", "100%"]
ax.set_yticks(tick_positions)
ax.set_yticklabels(tick_labels)


fig, axs = subplots_from_axsize(
ncols=3,
axsize=(2, 0.8),
nrows=len(cities),
sharex=True,
sharey=True,
hspace=0.4,
dpi=400,
)

for i, city in enumerate(cities):
axs[i, 0].set_ylabel(city)

for ax, name in zip(
axs[0, :], ["Confidence", "Credible (shared)", "Credible (individual)"]
):
ax.set_title(name)


# Plot the quasilikelihood fits
for i, ax in enumerate(axs[:, 0]):
plot_predictions(
ax,
i,
fitted_line=ys_fitted,
fitted_lower=[y.lower for y in ys_fitted_confint],
fitted_upper=[y.upper for y in ys_fitted_confint],
predicted_line=ys_pred,
predicted_lower=[y.lower for y in ys_pred_confint],
predicted_upper=[y.upper for y in ys_pred_confint],
)


# Plot the quasiposterior with shared MCMC


def obtain_predictions(mcmc, _a=0.05):
def get_fit(sample):
theta = qm.construct_theta(
relative_growths=sample["relative_growths"],
relative_midpoints=sample["relative_offsets"],
)

y_fit = qm.fitted_values(
ts_lst_scaled, theta, cities=cities, n_variants=n_variants_effective
)
y_pre = qm.fitted_values(
ts_pred_lst_scaled, theta, cities=cities, n_variants=n_variants_effective
)
return y_fit, y_pre

def get_line(ys):
return jnp.mean(ys, axis=0)

def get_lower(ys):
return jnp.quantile(ys, q=_a / 2, axis=0)

def get_upper(ys):
return jnp.quantile(ys, q=1 - _a / 2, axis=0)

# Apply some thinning for computational speedup
samples = jax.tree.map(lambda x: x[::10, ...], mcmc.get_samples())

fits, preds = jax.vmap(get_fit)(samples)
return dict(
fitted_line=jax.tree.map(get_line, fits),
fitted_lower=jax.tree.map(get_lower, fits),
fitted_upper=jax.tree.map(get_upper, fits),
predicted_line=jax.tree.map(get_line, preds),
predicted_lower=jax.tree.map(get_lower, preds),
predicted_upper=jax.tree.map(get_upper, preds),
)


for i, ax in enumerate(axs[:, 1]):
plot_predictions(ax, i, **obtain_predictions(mcmc_shared))

# Plot individual overdispersions
for i, ax in enumerate(axs[:, 2]):
plot_predictions(ax, i, **obtain_predictions(mcmc_indivi))
# -
8 changes: 7 additions & 1 deletion src/covvfit/_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@


def _is_scalar(value) -> bool:
return not hasattr(value, "__len__")
try:
length = len(value)
if length != 0:
return False
return True
except TypeError:
return True


def create_padded_array(
Expand Down
Loading