Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Training on AMD / ROCm #302

Closed
wants to merge 58 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
3d7f3c6
rocm
Mar 22, 2024
0f9f2d3
up
Mar 22, 2024
0dc757d
Merge remote-tracking branch 'origin/main' into rocm
Jul 22, 2024
4c11393
update lumi scripts
Jul 22, 2024
23395b3
remove correct dir
jpata Jul 24, 2024
4d2bfc7
update samples
jpata Jul 24, 2024
8c07fba
add samples
jpata Jul 24, 2024
86a0395
up
jpata Jul 24, 2024
3867972
fix supervised key
jpata Jul 24, 2024
4bd6c5f
resubmit training
jpata Jul 24, 2024
308e154
add missing
jpata Jul 24, 2024
38b882e
Merge remote-tracking branch 'origin/fixes_20240724' into rocm
Jul 24, 2024
c15cf3e
Merge
Jul 24, 2024
16bfb44
add mpgun
jpata Jul 29, 2024
0549b33
plt target and gen separately
jpata Aug 1, 2024
f7ac402
separate submission scripts
jpata Aug 2, 2024
bbde4a5
add stats
jpata Aug 2, 2024
96b1e42
fix softmax bug
jpata Aug 4, 2024
6d425d3
Merge remote-tracking branch 'origin/fixes_20240724' into rocm
Aug 4, 2024
c2da035
Merge remote-tracking branch 'origin/fixes_20240724' into rocm
Aug 4, 2024
78701f8
add CLD sample
jpata Aug 8, 2024
fc26057
add CLD sample
jpata Aug 8, 2024
f9683dd
add finetuning script
jpata Aug 8, 2024
321e20c
add genjob script
jpata Aug 8, 2024
35a3a70
Merge remote-tracking branch 'origin/fixes_20240724' into rocm
Aug 8, 2024
ad3c825
add cld
Aug 9, 2024
b552648
use common backbone in model
jpata Sep 4, 2024
40befa2
revert common backbone
jpata Sep 4, 2024
5d18024
log correction for pt and e
jpata Sep 4, 2024
546f44a
fix inference outputs for 0 elems
jpata Sep 6, 2024
cc926d0
binned regression
jpata Sep 6, 2024
4c170f7
fix
jpata Sep 6, 2024
725ed38
fix
jpata Sep 6, 2024
9bf4fa2
use only energy bins
jpata Sep 6, 2024
9008621
fix bin loss
jpata Sep 7, 2024
23322c4
add decoding tokens
jpata Sep 9, 2024
a3c94a0
prepare cms tfds 2.3
jpata Sep 9, 2024
7ca47c3
remove mamba, bins configurable
jpata Sep 10, 2024
8c96c78
up
jpata Sep 11, 2024
5034203
key_padding_mask
jpata Sep 13, 2024
bced81a
save attention
jpata Sep 13, 2024
bf1364d
fixes for distributed
jpata Sep 13, 2024
948b8e8
up
jpata Sep 14, 2024
c1acc52
up
jpata Sep 14, 2024
7987761
mass term
jpata Sep 14, 2024
bb6336c
add jet eta plot
jpata Sep 14, 2024
26f70cd
change mass loss coef
jpata Sep 14, 2024
8b0432f
runs of sep14
jpata Sep 14, 2024
8c0bdb2
restrict to pos mass
jpata Sep 15, 2024
f1c36a0
disable layernorm
jpata Sep 15, 2024
949dcf1
change jet ptcut
jpata Sep 15, 2024
28406b1
split apart reg losses
jpata Sep 16, 2024
81ac447
Merge remote-tracking branch 'origin/fixes_sep6' into rocm
Sep 16, 2024
fd17174
up
Sep 16, 2024
640d617
up
Sep 16, 2024
4ddb48c
up
Sep 16, 2024
caa3ef6
up
Sep 16, 2024
5719b54
up
Sep 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions mlpf/data_cms/genjob_pu55to75.sh
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,10 @@ ls -lrt
echo "process.RandomNumberGeneratorService.generator.initialSeed = $SEED" >> step2_phase1_new.py
cmsRun step2_phase1_new.py > /dev/null
cmsRun step3_phase1_new.py > /dev/null
#cmsRun $CMSSWDIR/src/Validation/RecoParticleFlow/test/pfanalysis_ntuple.py
mv pfntuple.root pfntuple_${SEED}.root
# python3 ${MLPF_PATH}/mlpf/data_cms/postprocessing2.py --input pfntuple_${SEED}.root --outpath ./
# bzip2 -z pfntuple_${SEED}.pkl
# cp *.pkl.bz2 $OUTDIR/$SAMPLE/raw/
python3 ${MLPF_PATH}/mlpf/data_cms/postprocessing2.py --input pfntuple_${SEED}.root --outpath ./
bzip2 -z pfntuple_${SEED}.pkl
cp *.pkl.bz2 $OUTDIR/$SAMPLE/raw/

#copy ROOT outputs
#cp step2_phase1_new.root $OUTDIR/$SAMPLE/root/step2_${SEED}.root
Expand Down
6 changes: 4 additions & 2 deletions mlpf/data_cms/postprocessing_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,16 @@ def write_script(infiles, outfiles):


samples = [
"/local/joosep/mlpf/cms/20240823_simcluster/pu55to75/TTbar_14TeV_TuneCUETP8M1_cfi",
"/local/joosep/mlpf/cms/20240823_simcluster/nopu/TTbar_14TeV_TuneCUETP8M1_cfi",
# "/local/joosep/mlpf/cms/20240823_simcluster/pu55to75/TTbar_14TeV_TuneCUETP8M1_cfi",
# "/local/joosep/mlpf/cms/20240823_simcluster/pu55to75/QCDForPF_14TeV_TuneCUETP8M1_cfi",
]

ichunk = 1
for sample in samples:
infiles = list(glob.glob(f"{sample}/root/pfntuple*.root"))
for infiles_chunk in chunks(infiles, 10):
outfiles_chunk = [inf.replace(".root", ".pkl.bz2").replace("/root/", "/raw2/") for inf in infiles_chunk]
outfiles_chunk = [inf.replace(".root", ".pkl.bz2").replace("/root/", "/raw/") for inf in infiles_chunk]
os.makedirs(os.path.dirname(outfiles_chunk[0]), exist_ok=True)
scr = write_script(infiles_chunk, outfiles_chunk)
ofname = f"jobscripts/postproc_{ichunk}.sh"
Expand Down
10 changes: 5 additions & 5 deletions mlpf/data_cms/prepare_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
outdir = "/local/joosep/mlpf/cms/20240823_simcluster"

samples = [
# ("TTbar_14TeV_TuneCUETP8M1_cfi", 105000, 110010, "genjob_pu55to75.sh", outdir + "/pu55to75"),
("TTbar_14TeV_TuneCUETP8M1_cfi", 100000, 120010, "genjob_pu55to75.sh", outdir + "/pu55to75"),
# ("ZTT_All_hadronic_14TeV_TuneCUETP8M1_cfi", 200000, 220010, "genjob_pu55to75.sh", outdir + "/pu55to75"),
("QCDForPF_14TeV_TuneCUETP8M1_cfi", 300000, 305000, "genjob_pu55to75.sh", outdir + "/pu55to75"),
("QCDForPF_14TeV_TuneCUETP8M1_cfi", 300000, 320010, "genjob_pu55to75.sh", outdir + "/pu55to75"),
# ("SMS-T1tttt_mGl-1500_mLSP-100_TuneCP5_14TeV_pythia8_cfi", 500000, 520010, "genjob_pu55to75.sh", outdir + "/pu55to75"),
# ("ZpTT_1500_14TeV_TuneCP5_cfi", 600000, 620010, "genjob_pu55to75.sh", outdir + "/pu55to75"),
# ("VBF_TuneCP5_14TeV_pythia8_cfi", 700000, 720010, "genjob_pu55to75.sh", outdir + "/pu55to75"),
# ("VBF_TuneCP5_14TeV_pythia8_cfi", 700000, 705010, "genjob_pu55to75.sh", outdir + "/pu55to75"),

# ("TTbar_14TeV_TuneCUETP8M1_cfi", 702000, 705000, "genjob_nopu.sh", outdir + "/nopu"),
# ("MultiParticlePFGun50_cfi", 800000, 820000, "genjob_nopu.sh", outdir + "/nopu"),
# ("MultiParticlePFGun50_cfi", 800000, 805000, "genjob_nopu.sh", outdir + "/nopu"),
# ("VBF_TuneCP5_14TeV_pythia8_cfi", 900000, 920010, "genjob_nopu.sh", outdir + "/nopu"),
# ("QCDForPF_14TeV_TuneCUETP8M1_cfi", 1000000,1020010, "genjob_nopu.sh", outdir + "/nopu"),

Expand All @@ -36,6 +36,6 @@
os.makedirs(this_outdir + "/" + samp + "/root", exist_ok=True)

for seed in range(seed0, seed1):
p = this_outdir + "/" + samp + "/raw2/pfntuple_{}.pkl.bz2".format(seed)
p = this_outdir + "/" + samp + "/root/pfntuple_{}.root".format(seed)
if not os.path.isfile(p):
print(f"sbatch --mem-per-cpu 8G --partition main --time 20:00:00 --cpus-per-task 1 scripts/tallinn/cmssw-el8.sh mlpf/data_cms/{script} {samp} {seed}")
3 changes: 2 additions & 1 deletion mlpf/heptfds/cms_pf/qcd.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
class CmsPfQcd(tfds.core.GeneratorBasedBuilder):
"""DatasetBuilder for cms_pf_qcd dataset."""

VERSION = tfds.core.Version("2.1.0")
VERSION = tfds.core.Version("2.3.0")
RELEASE_NOTES = {
"1.3.0": "12_2_0_pre2 generation with updated caloparticle/trackingparticle",
"1.3.1": "Remove PS again",
Expand All @@ -33,6 +33,7 @@ class CmsPfQcd(tfds.core.GeneratorBasedBuilder):
"1.7.1": "Increase stats to 400k events",
"2.0.0": "New truth def based primarily on CaloParticles",
"2.1.0": "Additional stats",
"2.3.0": "Split CaloParticles along tracks",
}
MANUAL_DOWNLOAD_INSTRUCTIONS = """
rsync -r --progress lxplus.cern.ch:/eos/user/j/jpata/mlpf/tensorflow_datasets/cms/cms_pf_qcd ~/tensorflow_datasets/
Expand Down
3 changes: 2 additions & 1 deletion mlpf/heptfds/cms_pf/ttbar.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
class CmsPfTtbar(tfds.core.GeneratorBasedBuilder):
"""DatasetBuilder for cms_pf dataset."""

VERSION = tfds.core.Version("2.2.0")
VERSION = tfds.core.Version("2.3.0")
RELEASE_NOTES = {
"1.0.0": "Initial release.",
"1.1.0": "Add muon type, fix electron GSF association",
Expand All @@ -38,6 +38,7 @@ class CmsPfTtbar(tfds.core.GeneratorBasedBuilder):
"2.0.0": "New truth def based primarily on CaloParticles",
"2.1.0": "Additional stats",
"2.2.0": "Split CaloParticles along tracks",
"2.3.0": "Increase stats",
}
MANUAL_DOWNLOAD_INSTRUCTIONS = """
rsync -r --progress lxplus.cern.ch:/eos/user/j/jpata/mlpf/tensorflow_datasets/cms/cms_pf_ttbar ~/tensorflow_datasets/
Expand Down
8 changes: 4 additions & 4 deletions mlpf/jet_utils.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
import numpy as np

import numba
# import numba
import awkward
import vector


@numba.njit
# @numba.njit
def deltaphi(phi1, phi2):
diff = phi1 - phi2
return np.arctan2(np.sin(diff), np.cos(diff))


@numba.njit
# @numba.njit
def deltar(eta1, phi1, eta2, phi2):
deta = eta1 - eta2
dphi = deltaphi(phi1, phi2)
return np.sqrt(deta**2 + dphi**2)


@numba.njit
# @numba.njit
def match_jets(jets1, jets2, deltaR_cut):
iev = len(jets1)
jet_inds_1_ev = []
Expand Down
71 changes: 71 additions & 0 deletions mlpf/plotting/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,16 @@ def plot_jets(yvals, epoch=None, cp_dir=None, comet_experiment=None, title=None,

plt.figure()
b = np.linspace(0, 1000, 100)

pt = awkward.to_numpy(awkward.flatten(yvals["jets_target_pt"]))
plt.hist(
pt,
bins=b,
histtype="step",
lw=2,
label="Target",
)

pt = awkward.to_numpy(awkward.flatten(yvals["jets_cand_pt"]))
plt.hist(
pt,
Expand Down Expand Up @@ -580,6 +590,67 @@ def plot_jets(yvals, epoch=None, cp_dir=None, comet_experiment=None, title=None,
)



plt.figure()
b = np.linspace(-5, 5, 100)
eta = awkward.to_numpy(awkward.flatten(yvals["jets_target_eta"]))
plt.hist(
eta,
bins=b,
histtype="step",
lw=2,
label="Target",
)

eta = awkward.to_numpy(awkward.flatten(yvals["jets_cand_eta"]))
plt.hist(
eta,
bins=b,
histtype="step",
lw=2,
label="PF",
)

eta = awkward.to_numpy(awkward.flatten(yvals["jets_pred_eta"]))
plt.hist(
eta,
bins=b,
histtype="step",
lw=2,
label="MLPF",
)

eta = awkward.to_numpy(awkward.flatten(yvals["jets_gen_eta"]))
plt.hist(
eta,
bins=b,
histtype="step",
lw=2,
label="Truth",
)

plt.xlabel("jet $\eta$")
plt.ylabel("Jets / bin")
plt.yscale("log")
plt.legend(loc="best")
if title:
plt.title(title)
ax = plt.gca()
ylim = ax.get_ylim()
ax.set_ylim(ylim[0], 10 * ylim[1])

if dataset:
EXPERIMENT_LABELS[dataset](ax)
if sample:
sample_label(ax, sample)

save_img(
"jet_eta.png",
epoch,
cp_dir=cp_dir,
comet_experiment=comet_experiment,
)

def plot_jet_ratio(
yvals,
epoch=None,
Expand Down
25 changes: 23 additions & 2 deletions mlpf/pyg/PFDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,23 @@ def __getitem__(self, item):
ret["ygen"][:, 0][(ret["X"][:, 0] == 10) & (ret["ygen"][:, 0] == 7)] = 2
ret["ygen"][:, 0][(ret["X"][:, 0] == 11) & (ret["ygen"][:, 0] == 7)] = 2

# set pt for HO which would otherwise be 0
msk_ho = ret["X"][:, 0] == 10
eta = ret["X"][:, 2][msk_ho]
e = ret["X"][:, 5][msk_ho]
ret["X"][:, 1][msk_ho] = np.sqrt(e**2 - (np.tanh(eta) * e) ** 2)

# transform pt -> log(pt / elem pt), same for energy
ret["ygen"][:, 6] = np.log(ret["ygen"][:, 6] / ret["X"][:, 5])
ret["ygen"][:, 6][np.isnan(ret["ygen"][:, 6])] = 0.0
ret["ygen"][:, 6][np.isinf(ret["ygen"][:, 6])] = 0.0
ret["ygen"][:, 6][ret["ygen"][:, 0] == 0] = 0

ret["ygen"][:, 2] = np.log(ret["ygen"][:, 2] / ret["X"][:, 1])
ret["ygen"][:, 2][np.isnan(ret["ygen"][:, 2])] = 0.0
ret["ygen"][:, 2][np.isinf(ret["ygen"][:, 2])] = 0.0
ret["ygen"][:, 2][ret["ygen"][:, 0] == 0] = 0

return ret

def __len__(self):
Expand Down Expand Up @@ -214,10 +231,14 @@ def get_interleaved_dataloaders(world_size, rank, config, use_cuda, use_ray):
dataset.append(ds)
dataset = torch.utils.data.ConcatDataset(dataset)

shuffle = split == "train"
if world_size > 1:
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle)
else:
sampler = torch.utils.data.SequentialSampler(dataset)
if shuffle:
sampler = torch.utils.data.RandomSampler(dataset)
else:
sampler = torch.utils.data.SequentialSampler(dataset)

# build dataloaders
batch_size = config[f"{split}_dataset"][config["dataset"]][type_]["batch_size"] * config["gpu_batch_multiplier"]
Expand Down
32 changes: 24 additions & 8 deletions mlpf/pyg/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,26 @@ def predict_one_batch(conv_type, model, i, batch, rank, jetdef, jet_ptcut, jet_m
batch = batch.to(rank)
ypred = model(batch.X, batch.mask)

# transform log (pt/elempt) -> pt
pred_cls = torch.argmax(ypred[0], axis=-1)
ypred[2][..., 0] = torch.exp(ypred[2][..., 0]) * batch.X[..., 1]
batch.ygen[..., 2] = torch.exp(batch.ygen[..., 2]) * batch.X[..., 1]

# transform log (E/elemE) -> E
ypred[2][..., 4] = torch.exp(ypred[2][..., 4]) * batch.X[..., 5]
batch.ygen[..., 6] = torch.exp(batch.ygen[..., 6]) * batch.X[..., 5]

ypred[2][..., 0][pred_cls == 0] = 0
ypred[2][..., 4][pred_cls == 0] = 0
batch.ygen[..., 2][batch.ygen[..., 0] == 0] = 0
batch.ygen[..., 6][batch.ygen[..., 0] == 0] = 0

# convert all outputs to float32 in case running in float16 or bfloat16
ypred = tuple([y.to(torch.float32) for y in ypred])

ygen = unpack_target(batch.ygen.to(torch.float32))
ycand = unpack_target(batch.ycand.to(torch.float32))
ygen = unpack_target(batch.ygen.to(torch.float32), model)
ycand = unpack_target(batch.ycand.to(torch.float32), model)
ypred = unpack_predictions(ypred)

genjets_msk = batch.genjets[:, :, 0].cpu() != 0
genjets = awkward.unflatten(batch.genjets.cpu().to(torch.float64)[genjets_msk], torch.sum(genjets_msk, axis=1))
genjets = vector.awk(
Expand Down Expand Up @@ -79,15 +92,18 @@ def predict_one_batch(conv_type, model, i, batch, rank, jetdef, jet_ptcut, jet_m
jets_coll = {}
for typ, ydata in zip(["cand", "target"], [ycand, ygen]):
clsid = awkward.unflatten(ydata["cls_id"], counts)
pt = awkward.unflatten(ydata["pt"], counts)
eta = awkward.unflatten(ydata["eta"], counts)
phi = awkward.unflatten(ydata["phi"], counts)
e = awkward.unflatten(ydata["energy"], counts)
msk = clsid != 0
p4 = awkward.unflatten(ydata["p4"], counts)
vec = vector.awk(
awkward.zip(
{
"pt": p4[msk][:, :, 0],
"eta": p4[msk][:, :, 1],
"phi": p4[msk][:, :, 2],
"e": p4[msk][:, :, 3],
"pt": pt[msk],
"eta": eta[msk],
"phi": phi[msk],
"e": e[msk],
}
)
)
Expand Down
Loading