Skip to content

Commit

Permalink
Figure presenting BMMs vs other estimators (#157)
Browse files Browse the repository at this point in the history
  • Loading branch information
pawel-czyz authored May 14, 2024
1 parent 7694a7d commit 34b4de0
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 24 deletions.
20 changes: 7 additions & 13 deletions scripts/Mixtures/plot_appearing_and_vanishing_mi.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,36 +20,30 @@ def main() -> None:

X, Y = np.meshgrid(x, y)

fig, axs = plt.subplots(2, 3, dpi=300)
fig, axs = plt.subplots(1, 5, dpi=300, sharex=True, sharey=True, figsize=(5.5, 1.2))

# First row: appearing MI
# Component 1 (bottom left)
ax = axs[0, 0]
ax = axs[0]
mask1 = (0 < X) & (X < 1) & (0 < Y) & (Y < 1)
plot_density(ax, mask1, "$I=0$")

# Component 2 (top right)
ax = axs[0, 1]
ax = axs[1]
mask2 = (1 < X) & (X < 2) & (1 < Y) & (Y < 2)
plot_density(ax, mask2, "$I=0$")

# Mixture
ax = axs[0, 2]
ax = axs[2]
mask3 = mask1 | mask2
plot_density(ax, 0.5 * mask3, "$I=\\log 2$")

# Second row
# Component 1: mixture from first row
ax = axs[1, 0]
plot_density(ax, 0.5 * mask3, "$I=\\log 2$")

# Component 2: symmetric mixture
ax = axs[1, 1]
# A "complementary" mixture
ax = axs[3]
mask4 = (0 < X) & (X < 1) & (1 < Y) & (Y < 2) | (1 < X) & (X < 2) & (0 < Y) & (Y < 1)
plot_density(ax, 0.5 * mask4, "$I=\\log 2$")

# Mixture: independent
ax = axs[1, 2]
ax = axs[4]
plot_density(ax, 0.25 * (mask3 | mask4), "$I=0$")

fig.tight_layout()
Expand Down
7 changes: 4 additions & 3 deletions workflows/projects/Mixtures/distinct_profiles.smk
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,10 @@ rule plot_samples:
output:
"figure_distinct_profiles.pdf"
run:
fig, axs = plt.subplots(1, 4, figsize=(7, 2))
fig, axs = plt.subplots(1, 4, figsize=(7, 1.5), dpi=500)

color1 = "navy"
color2 = "salmon"
color1 = "mediumblue"
color2 = "forestgreen"

# Plot normal distribution
ax = axs[0]
Expand Down Expand Up @@ -163,6 +163,7 @@ rule plot_samples:
ax.hist(pmi_u, bins=bins, density=True, color=color2, alpha=0.5, label="Mixture")
ax.set_title("PMI profiles")
ax.set_xlabel("PMI")
ax.set_xlim(-1, 2)
ax.set_ylabel("")
ax.set_yticks([])
ax.spines[['right', 'top', 'left']].set_visible(False)
Expand Down
150 changes: 150 additions & 0 deletions workflows/projects/Mixtures/figure_bmm_vs_other.smk
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Figure comparing BMMs and other estimators on selected problems.
# Note: to run this workflow, you need to have the results.csv files from:
# - the benchmark (version 2) in `generated/benchmark/v2/results.csv`
# - the BMM minibenchmark in `generated/projects/Mixtures/gmm_benchmark/results.csv`
from dataclasses import dataclass

import matplotlib
import matplotlib.pyplot as plt
matplotlib.use("Agg")

import numpy as np
import pandas as pd
from subplots_from_axsize import subplots_from_axsize



rule all:
input: "generated/projects/Mixtures/figure_bmm_vs_other.pdf"


class YScaler:
def __init__(self, estimator_ids: list[str], eps: float = 0.1):
self._estimator_ids = estimator_ids
assert eps > 0
self._eps = eps

@property
def n(self) -> int:
return len(self._estimator_ids)

@property
def offset(self) -> float:
return self._eps * 0.5 * 1 / self.n

def get_y(self, estimator_id: str, n_points: int) -> np.ndarray:
index = self._estimator_ids.index(estimator_id)
y0 = index / self.n
y1 = (index + 1) / self.n

return np.linspace(y0 + self._eps, y1 - self._eps, n_points)

def get_tick_locations(self) -> list[float]:
return (np.arange(self.n, dtype=float) + 0.5) / self.n


@dataclass
class TaskConfig:
name: str
xlim: tuple[float, float]
xticks: list[float] | tuple[float, ...]

@dataclass
class EstimatorConfig:
id: str
name: str
color: str

TASKS = {# task_id: task_name,
'1v1-AI': TaskConfig(name="AI", xlim=(0.5, 0.85), xticks=[0.6, 0.7, 0.8]),
'mult-sparse-w-inliers-5-5-2-2.0-0.2': TaskConfig(name="Inliers (5-dim, 0.2)", xlim=(0.4, 0.8), xticks=[0.45, 0.55, 0.65, 0.75]),
'5v1-concentric_gaussians-5': TaskConfig(name="Concentric (5-dim, 5)", xlim=(0.35, 0.75), xticks=[0.4, 0.5, 0.6, 0.7]),
'multinormal-sparse-5-5-2-2.0': TaskConfig(name="Normal (5-dim, sparse)", xlim=(0.65, 1.15), xticks=[0.7, 0.8, 0.9, 1.0, 1.1]),
}


# NAMES = {
# # one-dimensional
# '1v1-additive-0.75': "Additive",
# '1v1-AI': "AI",
# '1v1-X-0.9': "X",
# '2v1-galaxy-0.5-3.0': "Galaxy",
# # Concentric
# '3v1-concentric_gaussians-10': "Concentric (3-dim, 10)",
# '3v1-concentric_gaussians-5': "Concentric (3-dim, 5)",
# '5v1-concentric_gaussians-10': "Concentric (5-dim, 10)",
# '5v1-concentric_gaussians-5': "Concentric (5-dim, 5)",
# # Inliers
# 'mult-sparse-w-inliers-5-5-2-2.0-0.2': "Inliers (5-dim, 0.2)",
# 'mult-sparse-w-inliers-5-5-2-2.0-0.5': "Inliers (5-dim, 0.5)",
# # Multivariate normal
# 'multinormal-dense-5-5-0.5': "Normal (5-dim, dense)",
# 'multinormal-sparse-5-5-2-2.0': "Normal (5-dim, sparse)",
# # Student
# 'asinh-student-identity-1-1-1': "Student (1-dim)",
# 'asinh-student-identity-2-2-1': "Student (2-dim)",
# 'asinh-student-identity-3-3-2': "Student (3-dim)",
# 'asinh-student-identity-5-5-2': "Student (5-dim)",
# }

# TASKS = {
# id_v: TaskConfig(name=name, xlim=(0.2, 1), xticks=[]) for id_v, name in NAMES.items()
# }


N_SAMPLES = 5_000
POINT_ESTIMATORS = [
EstimatorConfig(id="KSG-10", name="KSG", color="green"),
EstimatorConfig(id="InfoNCE", name="InfoNCE", color="magenta"),
]

DOT_SIZE = 7


rule generate_figure:
output: "generated/projects/Mixtures/figure_bmm_vs_other.pdf"
input:
v2 = "generated/benchmark/v2/results.csv",
bmm = "generated/projects/Mixtures/gmm_benchmark/results.csv"
run:
data_v2 = pd.read_csv(input.v2)
data_bmm = pd.read_csv(input.bmm)

fig, axs = subplots_from_axsize(1, len(TASKS), (2.3, 0.8), left=0.8, right=0.05, top=0.3, bottom=0.3, dpi=350, wspace=0.05)

y_scaler = YScaler(estimator_ids=["BMM"] + [config.id for config in POINT_ESTIMATORS], eps=0.12)

for ax, (task_id, task_config) in zip(axs.ravel(), TASKS.items()):
ax.set_title(task_config.name)
ax.set_xlim(*task_config.xlim)
ax.set_xticks(task_config.xticks)
ax.set_yticks([])
ax.set_ylim(-0.05, 1.01)
ax.spines[["top", "left", "right"]].set_visible(False)

mi_true = data_v2.groupby("task_id")["mi_true"].mean()[task_id]
ax.axvline(mi_true, linestyle=":", color="black", linewidth=2)

# Plot credible intervals from the BMM
bmm_subtable = data_bmm[(data_bmm["task_id"] == task_id)].copy()
bmm_subtable["errorbar_low"] = bmm_subtable["mi_mean"] - bmm_subtable["mi_q_low"]
bmm_subtable["errorbar_high"] = bmm_subtable["mi_q_high"] - bmm_subtable["mi_mean"]

y = y_scaler.get_y(estimator_id="BMM", n_points=len(bmm_subtable))
ax.errorbar(x=bmm_subtable["mi_mean"].values, y=y, xerr=bmm_subtable[["errorbar_low", "errorbar_high"]].T, capsize=3, ls="none", color="darkblue")
ax.scatter(x=bmm_subtable["mi_mean"].values, y=y, color="darkblue", s=DOT_SIZE)

# Plot the scatterplot representing estimators
for estimator_config in POINT_ESTIMATORS:
estimator_id = estimator_config.id

index = (data_v2["task_id"] == task_id) & (data_v2["estimator_id"] == estimator_id) & (data_v2["n_samples"] == N_SAMPLES)
estimates = data_v2[index]["mi_estimate"].values
y = y_scaler.get_y(estimator_id=estimator_id, n_points=len(estimates))
ax.scatter(estimates, y, color=estimator_config.color, s=DOT_SIZE, alpha=0.4)

ax = axs[0]
ax.set_yticks(y_scaler.get_tick_locations(), ["BMM"] + [config.name for config in POINT_ESTIMATORS])
ax.spines["left"].set_visible(True)

fig.savefig(str(output))
20 changes: 12 additions & 8 deletions workflows/projects/Mixtures/fitting_gmm.smk
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ DISTRIBUTIONS = {
rule all:
# For the main part of the manuscript
input:
expand("plots/{dist_name}-{n_points}-10.pdf", dist_name=["AI", "Galaxy"], n_points=[250])
expand("plots/{dist_name}-{n_points}-10.pdf", dist_name=["AI", "Galaxy"], n_points=[500])


rule plots_all:
Expand Down Expand Up @@ -192,20 +192,22 @@ rule plot_pdf:
approx_sample = "approx_samples/{dist_name}-{n_points}-{n_components}-0.npz",
output: "plots/{dist_name}-{n_points}-{n_components}.pdf"
run:
fig, axs = subplots_from_axsize(1, 4, axsize=(1.5, 1.5), top=0.3, wspace=0.3)
fig, axs = subplots_from_axsize(1, 4, axsize=(1.2, 1.2), top=0.3, wspace=[0.3, 0.05, 0.05], left=0.5, right=0.15)

for ax in axs:
ax.spines[['right', 'top']].set_visible(False)

FONTDICT = {'fontsize': 10}

# Visualise true sample
ax = axs[0]
ax.set_title("Ground-truth sample")
ax.set_title("Ground-truth sample", fontdict=FONTDICT)
true_sample = np.load(input.true_sample)
visualise_points(true_sample["xs"], true_sample["ys"], ax)

# Visualise approximate sample
ax = axs[1]
ax.set_title("Simulated sample")
ax.set_title("Simulated sample", fontdict=FONTDICT)
approx_sample = np.load(input.approx_sample)
visualise_points(approx_sample["xs"], approx_sample["ys"], ax)

Expand All @@ -214,7 +216,7 @@ rule plot_pdf:

# Visualise posterior on mutual information
ax = axs[2]
ax.set_title("Posterior MI")
ax.set_title("Posterior MI", fontdict=FONTDICT)
mi_true = np.mean(pmi_true)
mi_approx = np.mean(pmi_approx, axis=1) # (num_mcmc_samples,)
ax.set_xlabel("MI")
Expand All @@ -225,11 +227,13 @@ rule plot_pdf:

# Visualise posterior on profile
ax = axs[3]
ax.set_title("Posterior PMI profile")
ax.set_title("Posterior PMI profile", fontdict=FONTDICT)
ax.set_xlabel("PMI")

min_val = np.min([pmi_true.min(), pmi_approx.min()])
max_val = np.max([pmi_true.max(), pmi_approx.max()])
quantile_min = 0.02
quantile_max = 1 - quantile_min
min_val = np.min([np.quantile(pmi_true, quantile_min), np.quantile(pmi_approx, quantile_min)])
max_val = np.max([np.quantile(pmi_true, quantile_max), np.quantile(pmi_approx, quantile_max)])

bins = np.linspace(min_val, max_val, 50)
for pmi_vals in pmi_approx:
Expand Down

0 comments on commit 34b4de0

Please sign in to comment.