Skip to content

Commit

Permalink
update roc
Browse files Browse the repository at this point in the history
  • Loading branch information
rkansal47 committed May 22, 2024
1 parent 4946784 commit bc5e762
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 222 deletions.
215 changes: 54 additions & 161 deletions src/HHbbVV/postprocessing/InferenceAnalysis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@
"metadata": {},
"outputs": [],
"source": [
"# plot_dir = f\"{MAIN_DIR}/plots/TaggerAnalysis/23Aug28\"\n",
"plot_dir = MAIN_DIR / \"plots/BDT/24Apr9\"\n",
"plot_dir = MAIN_DIR / \"plots/TaggerAnalysis/24May20\"\n",
"# plot_dir = MAIN_DIR / \"plots/BDT/24Apr9\"\n",
"plot_dir.mkdir(parents=True, exist_ok=True)\n",
"\n",
"samples_dir = f\"{MAIN_DIR}/../data/skimmer/24Mar14UpdateData\"\n",
Expand Down Expand Up @@ -129,12 +129,28 @@
"metadata": {},
"outputs": [],
"source": [
"# (column name, number of subcolumns)\n",
"load_columns = [\n",
" (\"weight\", 1),\n",
" (\"weight_noTrigEffs\", 1),\n",
" (\"ak8FatJetPt\", 2),\n",
" (\"ak8FatJetParticleNetMass\", 2),\n",
" (\"VVFatJetParTMD_THWWvsT\", 1),\n",
" (\"VVFatJetParTMD_probHWW4q\", 1),\n",
" (\"VVFatJetParTMD_probHWW3q\", 1),\n",
" (\"VVFatJetParTMD_probQCD\", 1),\n",
" (\"VVFatJetParTMD_probT\", 1),\n",
" # (\"VVFatJetParticleNet_Th4q\", 2),\n",
"]\n",
"\n",
"# Both Jet's Regressed Mass above 50, electron veto\n",
"events_dict = utils.load_samples(\n",
"events_dict = postprocessing.load_samples(\n",
" samples_dir,\n",
" {sig_key: samples[sig_key] for sig_key in nonres_sig_keys},\n",
" {key: samples[key] for key in nonres_sig_keys + [\"QCD\", \"TT\"]},\n",
" year,\n",
" filters=postprocessing.load_filters,\n",
" columns=utils.format_columns(load_columns),\n",
" variations=False,\n",
")"
]
},
Expand All @@ -144,144 +160,9 @@
"metadata": {},
"outputs": [],
"source": [
"# (column name, number of subcolumns)\n",
"save_columns = [\n",
" (\"weight\", 1),\n",
" (\"ak8FatJetPt\", 2),\n",
" (\"ak8FatJetMsd\", 2),\n",
" (\"ak8FatJetParTMD_THWW4q\", 2),\n",
" (\"ak8FatJetParTMD_probHWW4q\", 2),\n",
" (\"ak8FatJetParTMD_probHWW3q\", 2),\n",
" (\"ak8FatJetParTMD_probQCD\", 2),\n",
" (\"ak8FatJetParTMD_probT\", 2),\n",
" (\"ak8FatJetParticleNet_Th4q\", 2),\n",
"]"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Signal Processing"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for sig_key, events in list(events_dict.items()):\n",
" sig_dict = {}\n",
" masks = events[\"ak8FatJetHVV\"].astype(bool)\n",
"\n",
" for column, num_idx in save_columns:\n",
" if num_idx == 1:\n",
" sig_dict[column] = np.tile(events[column].values, 2)[masks]\n",
" else:\n",
" sig_dict[column] = np.nan_to_num(events[column].values[masks], copy=True, nan=0)\n",
"\n",
" events_dict[sig_key] = sig_dict"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Background Processing"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"full_samples_list = listdir(f\"{samples_dir}/{year}\")\n",
"\n",
"# reformat into (\"column name\", \"idx\") format for reading multiindex columns\n",
"bg_column_labels = []\n",
"for key, num_columns in save_columns:\n",
" for i in range(num_columns):\n",
" bg_column_labels.append(f\"('{key}', '{i}')\")\n",
"\n",
"\n",
"bg_keys = [\"TT\", \"QCD\"]\n",
"# bg_keys = [\"QCD\"]\n",
"\n",
"for bg_key in bg_keys:\n",
" events_dict[bg_key] = {}\n",
" for sample in full_samples_list:\n",
" if bg_key not in sample:\n",
" continue\n",
"\n",
" # doesn't have probT for some reason\n",
" if sample in [\"QCD_HT300to500\", \"QCD_HT200to300\"]:\n",
" continue\n",
"\n",
" if \"HH\" in sample or \"GluGluH\" in sample:\n",
" continue\n",
"\n",
" if not exists(f\"{samples_dir}/{year}/{sample}/parquet\"):\n",
" print(f\"No parquet file for {sample}\")\n",
" continue\n",
"\n",
" print(sample)\n",
"\n",
" with utils.timer():\n",
" events = pd.read_parquet(\n",
" f\"{samples_dir}/{year}/{sample}/parquet\",\n",
" columns=bg_column_labels,\n",
" )\n",
"\n",
" pickles_path = f\"{samples_dir}/{year}/{sample}/pickles\"\n",
" n_events = utils.get_nevents(pickles_path, year, sample)\n",
" events[\"weight\"] /= n_events\n",
"\n",
" for var, num_idx in save_columns:\n",
" if num_idx == 1:\n",
" values = np.tile(events[var].values, 2).reshape(-1)\n",
" else:\n",
" values = np.reshape(events[var].values, -1)\n",
"\n",
" if var in events_dict[bg_key]:\n",
" events_dict[bg_key][var] = np.concatenate(\n",
" (events_dict[bg_key][var], values), axis=0\n",
" )\n",
" else:\n",
" events_dict[bg_key][var] = values"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# print weighted sample yields\n",
"for sample in events_dict:\n",
" tot_weight = np.sum(events_dict[sample][\"weight\"])\n",
" print(f\"Pre-selection {sample} yield: {tot_weight:.2f}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for sample, events in events_dict.items():\n",
" if \"ak8FatJetParTMD_THWWvsT\" not in events:\n",
" events[\"ak8FatJetParTMD_THWWvsT\"] = (\n",
" events[\"ak8FatJetParTMD_probHWW3q\"] + events[\"ak8FatJetParTMD_probHWW4q\"]\n",
" ) / (\n",
" events[\"ak8FatJetParTMD_probHWW3q\"]\n",
" + events[\"ak8FatJetParTMD_probHWW4q\"]\n",
" + events[\"ak8FatJetParTMD_probQCD\"]\n",
" + events[\"ak8FatJetParTMD_probT\"]\n",
" )"
"cutflow = pd.DataFrame(index=list(events_dict.keys()))\n",
"utils.add_to_cutflow(events_dict, \"Preselection\", \"finalWeight\", cutflow)\n",
"cutflow"
]
},
{
Expand Down Expand Up @@ -311,7 +192,7 @@
"\"\"\"\n",
"\n",
"pt_key = \"Pt\"\n",
"msd_key = \"Msd\"\n",
"msd_key = \"ParticleNetMass\"\n",
"var_prefix = \"ak8FatJet\"\n",
"\n",
"cutvars_dict = {\"Pt\": \"pt\", \"Msd\": \"msoftdrop\"}\n",
Expand All @@ -322,7 +203,7 @@
" # {pt_key: [300, 1500], msd_key: [110, 140]},\n",
"]\n",
"\n",
"var_labels = {pt_key: \"pT\", msd_key: \"mSD\"}\n",
"var_labels = {pt_key: r\"$p_T$\", msd_key: r\"$m_{Reg}$\"}\n",
"\n",
"cuts_dict = {}\n",
"cut_labels = {} # labels for plot titles, formatted as \"var1label: [min, max] var2label...\"\n",
Expand Down Expand Up @@ -363,19 +244,19 @@
"outputs": [],
"source": [
"plot_vars = {\n",
" \"th4q\": {\n",
" \"title\": \"ParticleNet Non-MD Th4q\",\n",
" \"score_label\": \"ak8FatJetParticleNet_Th4q\",\n",
" \"colour\": \"orange\",\n",
" },\n",
" \"thvv4q\": {\n",
" \"title\": \"ParT MD THVV\",\n",
" \"score_label\": \"ak8FatJetParTMD_THWW4q\",\n",
" \"colour\": \"green\",\n",
" },\n",
" # \"th4q\": {\n",
" # \"title\": \"ParticleNet Non-MD Th4q\",\n",
" # \"score_label\": \"ak8FatJetParticleNet_Th4q\",\n",
" # \"colour\": \"orange\",\n",
" # },\n",
" # \"thvv4q\": {\n",
" # \"title\": \"ParT MD THVV\",\n",
" # \"score_label\": \"ak8FatJetParTMD_THWW4q\",\n",
" # \"colour\": \"green\",\n",
" # },\n",
" \"thvv4qt\": {\n",
" \"title\": \"ParT MD THVV\",\n",
" \"score_label\": \"ak8FatJetParTMD_THWWvsT\",\n",
" \"title\": r\"GloParT $T_{HVV}$\",\n",
" \"score_label\": \"VVFatJetParTMD_THWWvsT\",\n",
" \"colour\": \"green\",\n",
" },\n",
"}"
Expand Down Expand Up @@ -434,18 +315,19 @@
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics import roc_curve, auc\n",
"from sklearn.metrics import roc_curve, auc, integrate\n",
"\n",
"rocs = {}\n",
"# sig_key = \"HHbbVV\"\n",
"bg_keys = [\"TT\", \"QCD\"]\n",
"bg_skip = 4\n",
"bg_skip = 1\n",
"\n",
"\n",
"for cutstr in cut_labels:\n",
" # print(cutstr)\n",
" rocs[cutstr] = {}\n",
" for sig_key in tqdm(nonres_sig_keys + res_sig_keys):\n",
" # for sig_key in tqdm(nonres_sig_keys + res_sig_keys):\n",
" for sig_key in tqdm(nonres_sig_keys):\n",
" rocs[cutstr][sig_key] = {}\n",
" sig_cut = cuts_dict[sig_key][cutstr]\n",
" bg_cuts = [cuts_dict[bg_key][cutstr] for bg_key in bg_keys]\n",
Expand All @@ -465,7 +347,6 @@
" for bg_key, bg_cut in zip(bg_keys, bg_cuts)\n",
" ],\n",
" )\n",
" # print(weights[np.sum(sig_cut):])\n",
"\n",
" for t, pvars in plot_vars.items():\n",
" score_label = pvars[\"score_label\"]\n",
Expand Down Expand Up @@ -526,10 +407,22 @@
"]\n",
"\n",
"sig_splits = [\n",
" [\"HHbbVV\"] + [f\"X[{mX}]->H(bb)Y[{mY}](WW)\" for (mX, mY) in mps] for mps in sig_split_points\n",
" [\"HHbbVV\"] # + [f\"X[{mX}]->H(bb)Y[{mY}](WW)\" for (mX, mY) in mps] for mps in sig_split_points\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"roc = rocs[cutstr][sig_key][t]\n",
"roc_auc = integrate.trapz(y=roc[\"tpr\"], x=roc[\"fpr\"])\n",
"print(\"AUC:\", roc_auc)\n",
"plotting.rocCurve(roc[\"fpr\"], roc[\"tpr\"], show=True, plot_dir=plot_dir, name=\"THVV\", auc=roc_auc)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
Loading

0 comments on commit bc5e762

Please sign in to comment.