Skip to content

Commit

Permalink
Fix unpacking for 3d padded batch, update plot style (#306)
Browse files Browse the repository at this point in the history
* fix unpacking

* fix plot labels
  • Loading branch information
jpata authored Apr 8, 2024
1 parent 8420a5d commit 4caf602
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 19 deletions.
27 changes: 9 additions & 18 deletions mlpf/plotting/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def get_fake(df, pid):


def experiment_label(
ax, experiment="CMS", tag1="Simulation Preliminary", tag2="Run 3 (14 TeV)", x0=0.01, x1=0.15, x2=0.98, y=1.01
ax, experiment="CMS", tag1="Simulation Preliminary", tag2="Run 3 (14 TeV)", x0=0.01, x1=0.17, x2=0.98, y=1.01
):
plt.figtext(
x0,
Expand Down Expand Up @@ -238,15 +238,15 @@ def experiment_label(


def cms_label(ax):
return experiment_label(ax, experiment="CMS", tag1="Simulation Preliminary", tag2="Run 3 (14 TeV)")
return experiment_label(ax, experiment="CMS", tag1="Simulation (Private Work)", tag2="Run 3 (14 TeV)", x1=0.13)


def clic_label(ax):
return experiment_label(ax, experiment="Key4HEP-CLICdp", tag1="Simulation Preliminary", tag2="ee (380 GeV)")
return experiment_label(ax, experiment="Key4HEP-CLICdp", tag1="Simulation", tag2="ee (380 GeV)", x1=0.35)


def delphes_label(ax):
return experiment_label(ax, experiment="Delphes-CMS", tag1="Simulation Preliminary", tag2="pp (14 TeV)")
return experiment_label(ax, experiment="Delphes-CMS", tag1="Simulation", tag2="pp (14 TeV)", x1=0.30)


EXPERIMENT_LABELS = {
Expand Down Expand Up @@ -450,7 +450,6 @@ def plot_jets(yvals, epoch=None, cp_dir=None, comet_experiment=None, title=None,
bins=b,
histtype="step",
lw=2,
# label="PF $(M={:.2f}, IQR={:.2f}, N={})$".format(p[0], p[1], n),
label="PF",
)

Expand All @@ -460,7 +459,6 @@ def plot_jets(yvals, epoch=None, cp_dir=None, comet_experiment=None, title=None,
bins=b,
histtype="step",
lw=2,
# label="MLPF $(M={:.2f}, IQR={:.2f}, N={})$".format(p[0], p[1], n),
label="MLPF",
)

Expand All @@ -470,7 +468,6 @@ def plot_jets(yvals, epoch=None, cp_dir=None, comet_experiment=None, title=None,
bins=b,
histtype="step",
lw=2,
# label="Gen $(M={:.2f}, IQR={:.2f}, N={})$".format(p[0], p[1], n),
label="Truth",
)

Expand Down Expand Up @@ -507,7 +504,6 @@ def plot_jets(yvals, epoch=None, cp_dir=None, comet_experiment=None, title=None,
bins=b,
histtype="step",
lw=2,
# label="PF $(M={:.2f}, IQR={:.2f}, N={})$".format(p[0], p[1], n),
label="PF",
)

Expand All @@ -517,7 +513,6 @@ def plot_jets(yvals, epoch=None, cp_dir=None, comet_experiment=None, title=None,
bins=b,
histtype="step",
lw=2,
# label="MLPF $(M={:.2f}, IQR={:.2f}, N={})$".format(p[0], p[1], n),
label="MLPF",
)

Expand All @@ -527,7 +522,6 @@ def plot_jets(yvals, epoch=None, cp_dir=None, comet_experiment=None, title=None,
bins=b,
histtype="step",
lw=2,
# label="Gen $(M={:.2f}, IQR={:.2f}, N={})$".format(p[0], p[1], n),
label="Truth",
)

Expand Down Expand Up @@ -586,7 +580,7 @@ def plot_jet_ratio(
bins=bins,
histtype="step",
lw=2,
label="MLPF $(M={:.2f}\pm{:.2f})$".format(p[0], p[1]),
label="MLPF $({:.2f}\pm{:.2f})$".format(p[0], p[1]),
)
plt.xlabel(labels["reco_gen_jet_ratio"])
plt.ylabel("Matched jets / bin")
Expand Down Expand Up @@ -639,23 +633,20 @@ def plot_met(met_ratio, epoch=None, cp_dir=None, comet_experiment=None, title=No
bins=b,
histtype="step",
lw=2,
# label="PF $(M={:.2f}, IQR={:.2f})$".format(p[0], p[1]),
label="PF",
)
plt.hist(
met_ratio["pred_met"],
bins=b,
histtype="step",
lw=2,
# label="MLPF $(M={:.2f}, IQR={:.2f})$".format(p[0], p[1]),
label="MLPF",
)
plt.hist(
met_ratio["gen_met"],
bins=b,
histtype="step",
lw=2,
# label="Truth $(M={:.2f}, IQR={:.2f})$".format(p[0], p[1]),
label="Truth",
)
plt.xlabel(labels["met"])
Expand Down Expand Up @@ -735,15 +726,15 @@ def plot_met_ratio(
bins=bins,
histtype="step",
lw=2,
label="PF $(M={:.2f}, IQR={:.2f})$".format(p[0], p[1]),
label="PF $({:.2f}, IQR={:.2f})$".format(p[0], p[1]),
)
p = med_iqr(met_ratio["ratio_pred"])
plt.hist(
met_ratio["ratio_pred"],
bins=bins,
histtype="step",
lw=2,
label="MLPF $(M={:.2f}, IQR={:.2f})$".format(p[0], p[1]),
label="MLPF $({:.2f}, IQR={:.2f})$".format(p[0], p[1]),
)
plt.xlabel(labels["reco_gen_met_ratio"])
plt.ylabel("Events / bin")
Expand Down Expand Up @@ -783,15 +774,15 @@ def plot_3dmomentum_ratio(
bins=bins,
histtype="step",
lw=2,
label="PF $(M={:.2f}, IQR={:.2f})$".format(p[0], p[1]),
label="PF $({:.2f}, IQR={:.2f})$".format(p[0], p[1]),
)
p = med_iqr(mom_ratio["ratio_pred"])
plt.hist(
mom_ratio["ratio_pred"],
bins=bins,
histtype="step",
lw=2,
label="MLPF $(M={:.2f}, IQR={:.2f})$".format(p[0], p[1]),
label="MLPF $({:.2f}, IQR={:.2f})$".format(p[0], p[1]),
)
plt.xlabel(labels["reco_gen_mom_ratio"])
plt.ylabel("Events / bin")
Expand Down
2 changes: 1 addition & 1 deletion mlpf/pyg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def unpack_target(y):
# note ~ momentum = ["pt", "eta", "sin_phi", "cos_phi", "energy"]
ret["momentum"] = y[..., 2:7].to(dtype=torch.float32)
ret["p4"] = torch.cat(
[ret["pt"].unsqueeze(1), ret["eta"].unsqueeze(1), ret["phi"].unsqueeze(1), ret["energy"].unsqueeze(1)], axis=1
[ret["pt"].unsqueeze(-1), ret["eta"].unsqueeze(-1), ret["phi"].unsqueeze(-1), ret["energy"].unsqueeze(-1)], axis=-1
)

ret["genjet_idx"] = y[..., -1].long()
Expand Down

0 comments on commit 4caf602

Please sign in to comment.