Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Log checkpoint saves at start and finish #12018

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions examples/nlp/language_modeling/megatron_gpt_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,42 @@ def main(cfg) -> None:
logging.info(f"Continual training: loading weights from {cfg.model.restore_from_path}")
from nemo.collections.nlp.models.language_modeling.megatron_gpt_sft_model import MegatronGPTSFTModel

logging.info(
f'Checkpoint load from path {cfg.model.restore_from_path} starts - logging'
)
print(
f'Checkpoint load from path {cfg.model.restore_from_path} starts - print'
)
model_cfg = MegatronGPTSFTModel.merge_cfg_with(cfg.model.restore_from_path, cfg)
model = MegatronGPTModel.restore_from(
restore_path=cfg.model.restore_from_path,
override_config_path=model_cfg,
trainer=trainer,
save_restore_connector=NLPSaveRestoreConnector(),
)
logging.info(
f'Checkpoint load from path {cfg.model.restore_from_path} ends - logging'
)
print(
f'Checkpoint load from path {cfg.model.restore_from_path} ends - print'
)
elif cfg.model.get("restore_from_ckpt") is not None:
# Option 2: Restore both model weights and optimizer states from a PTL checkpoint
logging.info(f"Continual training: loading weights and optimizer states from {cfg.model.restore_from_ckpt}")
logging.info(
f'Checkpoint load from ckpt {cfg.model.restore_from_ckpt} starts - logging'
)
print(
f'Checkpoint load from ckpt {cfg.model.restore_from_ckpt} starts - print'
)
trainer.ckpt_path = Path(cfg.model.restore_from_ckpt)
model = MegatronGPTModel(cfg.model, trainer)
logging.info(
f'Checkpoint load from ckpt {cfg.model.restore_from_ckpt} ends - logging'
)
print(
f'Checkpoint load from ckpt {cfg.model.restore_from_ckpt} ends - print'
)

# Start new pretraining or resume from a checkpoint if it exists
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1787,7 +1787,19 @@ def setup(self, stage=None):

resume_checkpoint_path = self.trainer.ckpt_path
if resume_checkpoint_path and not self.continue_training:
logging.info(
f'Extract consumed samples from ckpt {resume_checkpoint_path} starts - logging'
)
print(
f'Extract consumed samples from ckpt {resume_checkpoint_path} starts - print'
)
init_consumed_samples = self._extract_consumed_samples_from_ckpt(resume_checkpoint_path)
logging.info(
f'Extract consumed samples from ckpt {resume_checkpoint_path} ends - logging'
)
print(
f'Extract consumed samples from ckpt {resume_checkpoint_path} ends - print'
)
else:
init_consumed_samples = 0
self.init_consumed_samples = init_consumed_samples
Expand Down
12 changes: 12 additions & 0 deletions nemo/core/connectors/save_restore_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,12 @@ def restore_from(
"""
# Get path where the command is executed - the artifacts will be "retrieved" there
# (original .nemo behavior)
logging.info(
f'Connector restores from {restore_path} starts - logging'
)
print(
f'Connector restores from {restore_path} starts - print'
)
loaded_params = self.load_config_and_state_dict(
calling_cls,
restore_path,
Expand All @@ -273,6 +279,12 @@ def restore_from(
state_dict = self.modify_state_dict(conf, state_dict)
self.load_instance_with_state_dict(instance, state_dict, strict)
logging.info(f'Model {instance.__class__.__name__} was successfully restored from {restore_path}.')
logging.info(
f'Connector restores from {restore_path} ends - logging'
)
print(
f'Connector restores from {restore_path} ends - print'
)
return instance

def extract_state_dict_from(self, restore_path: str, save_dir: str, split_by_module: bool = False):
Expand Down
14 changes: 14 additions & 0 deletions nemo/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,12 @@ def _save_checkpoint(self, trainer: 'lightning.pytorch.Trainer', filepath: str)
# sync save calls the finalization function immediately after save.
finalize_fn = self._get_finalize_save_checkpoint_callback(trainer, filepath, trainer.global_step)
if self.async_save:
logging.info(
f'Checkpoint async save for step {trainer.gloubal_step} starts. - logging'
)
print(
f'Checkpoint async save for step {trainer.gloubal_step} starts. - print'
)
checkpoint_io = trainer.strategy.checkpoint_io
from nemo.utils.callbacks.dist_ckpt_io import AsyncFinalizableCheckpointIO

Expand All @@ -493,6 +499,7 @@ def _save_checkpoint(self, trainer: 'lightning.pytorch.Trainer', filepath: str)
self.deferred_ckpts_to_remove.append([])
else:
storage_options = None

trainer.save_checkpoint(filepath, save_weights_only, storage_options=storage_options)

if self.always_save_context and is_global_rank_zero():
Expand All @@ -501,8 +508,15 @@ def _save_checkpoint(self, trainer: 'lightning.pytorch.Trainer', filepath: str)
if self.async_save:
self._last_checkpoint_saved = filepath
logging.info(f'Scheduled async checkpoint save for {filepath}')
logging.info(
f'Checkpoint async save for step {trainer.gloubal_step} ends. - logging'
)
print(
f'Checkpoint async save for step {trainer.gloubal_step} ends. - print'
)
else:
finalize_fn()
logging.info(f'Checkpoint save for step {trainer.gloubal_step} ends')

def _get_finalize_save_checkpoint_callback(
self, trainer: 'lightning.pytorch.Trainer', filepath: str, global_step: int
Expand Down
Loading