Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tflops callback and flag: ESM2 #633

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ci/benchmarks/perf/esm2_pretrain.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,5 @@ script: |-
--accumulate-grad-batches=${acc_grad} \
--pipeline-model-parallel-size=${pp} \
--tensor-model-parallel-size={tp} \
--log-tflops-per-sec-per-gpu \
--disable-checkpointing;
92 changes: 57 additions & 35 deletions sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/train_esm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import argparse
from dataclasses import asdict
from pathlib import Path
from typing import List, Optional, Sequence, get_args

Expand All @@ -24,6 +25,7 @@
from nemo.collections import llm
from nemo.lightning import resume
from nemo.lightning.pytorch import callbacks as nl_callbacks
from nemo.lightning.pytorch.callbacks.flops_callback import FLOPsMeasurementCallback
from nemo.lightning.pytorch.optim import MegatronOptimizerModule

from bionemo.core.utils.dtypes import PrecisionTypes, get_autocast_dtype
Expand Down Expand Up @@ -98,6 +100,7 @@ def main(
overlap_param_gather: bool = True,
average_in_collective: bool = True,
grad_reduce_in_fp32: bool = False,
log_tflops_per_sec_per_gpu: bool = False,
) -> None:
"""Train an ESM2 model on UR data.

Expand Down Expand Up @@ -159,6 +162,7 @@ def main(
overlap_param_gather (bool): overlap parameter gather
average_in_collective (bool): average in collective
grad_reduce_in_fp32 (bool): gradient reduction in fp32
log_tflops_per_sec_per_gpu (bool): Enables FLOP tracking callback to measure teraFLOPs/second performance per GPU device
"""
# Create the result directory if it does not exist.
result_dir.mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -210,41 +214,6 @@ def main(
)
)

callbacks = [
PerplexityLoggingCallback(log_train=False, log_val=True),
RichModelSummary(max_depth=4),
LearningRateMonitor(),
nl_callbacks.PreemptionCallback(),
]
if nsys_profiling:
if nsys_end_step is None:
nsys_end_step = num_steps
callbacks.append(
nl_callbacks.NsysCallback(
start_step=nsys_start_step, end_step=nsys_end_step, ranks=nsys_ranks, gen_shape=True
)
)

trainer = nl.Trainer(
devices=devices,
max_steps=num_steps,
accelerator="gpu",
strategy=strategy,
limit_val_batches=limit_val_batches, # This controls upsampling and downsampling
val_check_interval=val_check_interval,
log_every_n_steps=log_every_n_steps,
num_nodes=num_nodes,
callbacks=callbacks,
plugins=nl.MegatronMixedPrecision(
precision=precision,
params_dtype=get_autocast_dtype(precision),
pipeline_dtype=get_autocast_dtype(precision),
grad_reduce_in_fp32=grad_reduce_in_fp32,
autocast_enabled=False,
),
enable_checkpointing=create_checkpoint_callback,
)

tokenizer = get_tokenizer()

# Initialize the data module.
Expand Down Expand Up @@ -303,6 +272,50 @@ def main(
),
)

callbacks = [
PerplexityLoggingCallback(log_train=False, log_val=True),
RichModelSummary(max_depth=4),
LearningRateMonitor(),
nl_callbacks.PreemptionCallback(),
]
if nsys_profiling:
if nsys_end_step is None:
nsys_end_step = num_steps
callbacks.append(
nl_callbacks.NsysCallback(
start_step=nsys_start_step, end_step=nsys_end_step, ranks=nsys_ranks, gen_shape=True
)
)

if log_tflops_per_sec_per_gpu:
# Add callback that logs the tera-FLOPS per second per GPU during training.
flop_meas_callback = FLOPsMeasurementCallback(
asdict(esm2_config),
data,
"bert",
)
callbacks.append(flop_meas_callback)

trainer = nl.Trainer(
devices=devices,
max_steps=num_steps,
accelerator="gpu",
strategy=strategy,
limit_val_batches=limit_val_batches, # This controls upsampling and downsampling
val_check_interval=val_check_interval,
log_every_n_steps=log_every_n_steps,
num_nodes=num_nodes,
callbacks=callbacks,
plugins=nl.MegatronMixedPrecision(
precision=precision,
params_dtype=get_autocast_dtype(precision),
pipeline_dtype=get_autocast_dtype(precision),
grad_reduce_in_fp32=grad_reduce_in_fp32,
autocast_enabled=False,
),
enable_checkpointing=create_checkpoint_callback,
)

# Configure our custom Checkpointer
if create_checkpoint_callback:
checkpoint_callback = nl_callbacks.ModelCheckpoint(
Expand Down Expand Up @@ -398,6 +411,7 @@ def train_esm2_entrypoint():
overlap_param_gather=not args.no_overlap_param_gather,
average_in_collective=not args.no_average_in_collective,
grad_reduce_in_fp32=args.grad_reduce_in_fp32,
log_tflops_per_sec_per_gpu=args.log_tflops_per_sec_per_gpu,
)


Expand Down Expand Up @@ -739,6 +753,14 @@ def get_parser():
action="store_true",
default=False,
)

parser.add_argument(
"--log-tflops-per-sec-per-gpu",
action="store_true",
default=False,
help="Enables FLOP tracking callback to measure teraFLOPs/second performance per GPU device",
)

return parser


Expand Down