Skip to content

Commit

Permalink
Add support for loading NeMo 2.0 checkpoints
Browse files Browse the repository at this point in the history
Signed-off-by: Hemil Desai <[email protected]>
  • Loading branch information
hemildesai committed Nov 20, 2024
1 parent 5075f94 commit 06e7170
Showing 1 changed file with 139 additions and 34 deletions.
173 changes: 139 additions & 34 deletions nemo_aligner/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,38 +24,39 @@
from copy import deepcopy
from dataclasses import replace
from functools import partial, wraps
from typing import Iterator, List
from typing import Any, Iterator, List, Optional
from unittest.mock import patch

import torch
from megatron.core.dist_checkpointing.mapping import ShardedObject, ShardedTensorFactory
from megatron.core.num_microbatches_calculator import reconfigure_num_microbatches_calculator
from omegaconf import DictConfig, OmegaConf
from torch.masked import as_masked_tensor

from nemo.collections.nlp.modules.common.megatron.utils import get_ltor_masks_and_position_ids
from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector
from nemo.core.classes.mixins.adapter_mixins import AdapterModuleMixin
from nemo.utils import AppState, logging
from nemo.utils.exp_manager import NeMoModelCheckpoint
from omegaconf import DictConfig, OmegaConf
from torch.masked import as_masked_tensor

from nemo_aligner.models.nlp.gpt.gpt_reward_model import GPTRewardModel


class CustomSaveRestoreConnector(NLPSaveRestoreConnector):
"""A save connector that will ask the Reward model to not try to load
the rm head if load_base_model_only is True
the rm head if load_base_model_only is True
"""

def __init__(self, *args, load_base_model_only=False, **kwargs):
def __init__(self, *args, load_base_model_only=False, replace_sharded_tensor_key: Optional[str] = None, **kwargs):
super().__init__(*args, **kwargs)
self.__load_base_model_only = load_base_model_only
self.__replace_sharded_tensor_key = replace_sharded_tensor_key

def restore_from(self, *args, **kwargs):
if not self.__load_base_model_only:
return super().restore_from(*args, **kwargs)
return super().restore_from(*args, replace_sharded_tensor_key=self.__replace_sharded_tensor_key, **kwargs)

with patch.object(GPTRewardModel, "return_rm_head_in_state_dict", False):
output = super().restore_from(*args, **kwargs)
output = super().restore_from(*args, replace_sharded_tensor_key=self.__replace_sharded_tensor_key, **kwargs)

return output

Expand Down Expand Up @@ -85,18 +86,30 @@ def load_from_nemo(
load_base_model_only=False,
return_updated_cfg=False,
):
"""load a model using nemo checkpoint
"""
connector = CustomSaveRestoreConnector(load_base_model_only=load_base_model_only)
"""load a model using nemo checkpoint"""
assert os.path.exists(restore_path), f"tried to load from {restore_path=} but it does not exist"

is_2_0_ckpt = load_2_0_checkpoint_model_config(restore_path) is not None
if is_2_0_ckpt:
replace_sharded_tensor_key = "module"
else:
replace_sharded_tensor_key = None

connector = CustomSaveRestoreConnector(load_base_model_only=load_base_model_only, replace_sharded_tensor_key=replace_sharded_tensor_key)

if is_2_0_ckpt:
connector.model_weights_ckpt = "weights"

# if we gave it a directory, then load as if it was extracted already
if os.path.isdir(restore_path):
connector.model_extracted_dir = restore_path

if modify_config_fn is not None:
origin_cfg = cls.restore_from(
restore_path=restore_path, trainer=trainer, return_config=True, save_restore_connector=connector,
restore_path=restore_path,
trainer=trainer,
return_config=True,
save_restore_connector=connector,
)
model_cfg = modify_config_fn(origin_cfg, model_cfg, add_cfg_to_tree=False)

Expand All @@ -111,8 +124,7 @@ def load_from_nemo(


def load_checkpoint_model_config(restore_path):
"""load only the model config from a checkpoint
"""
"""load only the model config from a checkpoint"""
config_name_in_ckpt = NLPSaveRestoreConnector()._model_config_yaml
assert os.path.exists(restore_path), f"tried to load from {restore_path=} but it does not exist"

Expand All @@ -131,11 +143,102 @@ def load_checkpoint_model_config(restore_path):
return cfg


def load_2_0_checkpoint_model_config(restore_path: str):
from nemo.lightning import io
from nemo.lightning.ckpt_utils import ckpt_to_context_subdir
from nemo.lightning.io.pl import ckpt_to_weights_subdir

if (
os.path.isdir(ckpt_to_context_subdir(restore_path))
and os.path.isdir(ckpt_to_weights_subdir(restore_path, is_saving=False))
and os.path.isfile(os.path.join(ckpt_to_context_subdir(restore_path), "io.json"))
):
config = io.load_context(restore_path, subpath="model.config")
tokenizer_cfg = OmegaConf.load(os.path.join(ckpt_to_context_subdir(restore_path), "model.yaml")).tokenizer

def get_tokenizer_args(tokenizer_cfg):
if "AutoTokenizer" in tokenizer_cfg._target_:
tokenizer_type = "huggingface"
tokenizer_name = tokenizer_cfg.pretrained_model_name
if os.path.isfile(os.path.join(ckpt_to_context_subdir(restore_path), tokenizer_name)) or os.path.isdir(
os.path.join(ckpt_to_context_subdir(restore_path), tokenizer_name)
):
tokenizer_name = os.path.join(ckpt_to_context_subdir(restore_path), "nemo_tokenizer")
elif not os.path.isfile(tokenizer_name):
raise FileNotFoundError(f"Tokenizer file {tokenizer_name} not found")

return {
"library": tokenizer_type,
"type": tokenizer_name,
"use_fast": True,
}
elif "SentencePieceTokenizer" in tokenizer_cfg._target_:
tokenizer_type = "sentencepiece"
tokenizer_name = tokenizer_cfg.model_path
if os.path.isfile(os.path.join(ckpt_to_context_subdir(restore_path), tokenizer_name)) or os.path.isdir(
os.path.join(ckpt_to_context_subdir(restore_path), tokenizer_name)
):
tokenizer_name = os.path.join(ckpt_to_context_subdir(restore_path), tokenizer_name)
elif not os.path.isfile(tokenizer_name):
raise FileNotFoundError(f"Tokenizer file {tokenizer_name} not found")

return {"library": tokenizer_type, "type": None, "model": tokenizer_name}
else:
raise ValueError(f"Unknown tokenizer type: {tokenizer_cfg}")

tokenizer_args = get_tokenizer_args(tokenizer_cfg)

config_dict = {}
for k, v in config.__dict__.items():
if isinstance(v, (float, int, str, bool)):
config_dict[k] = v
elif k == "activation_func":
config_dict["activation"] = v.__name__

if config_dict["activation"] == "silu":
config_dict["activation"] = "fast-swiglu"

config_dict["encoder_seq_length"] = config_dict["seq_length"]

config_dict["mcore_gpt"] = True
config_dict["max_position_embeddings"] = config_dict.get("seq_length")
config_dict["tokenizer"] = tokenizer_args

try:
strategy: dict[str, Any] = io.load_context(restore_path, subpath="trainer.strategy").__dict__
config_dict["gradient_as_bucket_view"] = strategy.get("gradient_as_bucket_view", True)
# TODO: Add any other parameters required from strategy here
except Exception:
# Default to True based on default values in https://github.com/NVIDIA/NeMo/tree/main/nemo/collections/llm/recipes
config_dict["gradient_as_bucket_view"] = True

try:
precision_plugin: dict[str, Any] = io.load_context(restore_path, subpath="trainer.plugins").__dict__
config_dict["fp16"] = precision_plugin.get("fp16", False)
config_dict["bf16"] = precision_plugin.get("bf16", True)
# TODO: Add any other parameters required from precision plugin here
except Exception:
# Default to True based on default values in https://github.com/NVIDIA/NeMo/tree/main/nemo/collections/llm/recipes
config_dict["fp16"] = False
config_dict["bf16"] = True

if not os.path.isfile(os.path.join(restore_path, "model_config.yaml")):
OmegaConf.save(config=OmegaConf.create(config_dict), f=os.path.join(restore_path, "model_config.yaml"))

return config_dict

return None


def load_and_override_model_config(restore_path, model_cfg_to_overwrite, remove_meta_info=True):
"""load the config in the model checkpoint and then overwrite it
with whatever is provided
with whatever is provided
"""
checkpoint_cfg = load_checkpoint_model_config(restore_path)
checkpoint_cfg_2_0 = load_2_0_checkpoint_model_config(restore_path)
if checkpoint_cfg_2_0 is not None:
checkpoint_cfg = checkpoint_cfg_2_0
else:
checkpoint_cfg = load_checkpoint_model_config(restore_path)

if remove_meta_info:
checkpoint_cfg.pop("target", None)
Expand Down Expand Up @@ -264,8 +367,7 @@ def select_log_probs(full_log_probs, indices):


def dist_adam_load_state_bucket_into_device(state_bucket, device):
"""put the state bucket onto a device
"""
"""put the state bucket onto a device"""
attrs_to_offload = ["params_shard", "param_remainders_shard", "exp_avg_shard", "exp_avg_sq_shard"]

for attr in attrs_to_offload:
Expand All @@ -276,8 +378,7 @@ def dist_adam_load_state_bucket_into_device(state_bucket, device):

@contextmanager
def offload_distributed_adam(state_dict, force_clear_memory=False):
"""context manager to offload distributed adam states
"""
"""context manager to offload distributed adam states"""
# off load onto cpu
for state_bucket in state_dict["state"]["buckets"]:
dist_adam_load_state_bucket_into_device(state_bucket, device="cpu")
Expand All @@ -302,7 +403,15 @@ def offload_distributed_adam(state_dict, force_clear_memory=False):

def batch_pad_to_fixed_len(batch, max_batch_len, pad_token):
batch_pad = torch.stack(
[torch.cat([seq, torch.full((max_batch_len - len(seq),), pad_token, dtype=seq.dtype),]) for seq in batch]
[
torch.cat(
[
seq,
torch.full((max_batch_len - len(seq),), pad_token, dtype=seq.dtype),
]
)
for seq in batch
]
)

return batch_pad
Expand All @@ -317,8 +426,7 @@ def collate_with_batch_max_sequence_length(
eod_mask_loss,
generate_masks_and_position_ids,
):
"""collate function that batches by max sequence length
"""
"""collate function that batches by max sequence length"""
texts = [item["text"] for item in data_batch]
loss_multipliers = torch.as_tensor([item["loss_multiplier"] for item in data_batch]).view(len(data_batch), 1)
lengths = torch.as_tensor([item["length"] for item in data_batch])
Expand Down Expand Up @@ -367,8 +475,7 @@ def clear_memory():


def retrieve_model_state_dict_in_cpu(model, megatron_amp_O2=True):
"""get a copy of the model states in CPU
"""
"""get a copy of the model states in CPU"""
cpu_dict = {}

for name, item in model.state_dict().items():
Expand All @@ -387,7 +494,7 @@ def retrieve_model_state_dict_in_cpu(model, megatron_amp_O2=True):
@torch.no_grad()
def copy_model_states_to_cpu(model, cpu_dict=None, megatron_amp_O2=True, sync=True, alias_non_tensor=False):
"""This function mutates the cpu_dict object to throw the model states into preallocated tensors(if they exist)
for non tensors it will do a deepcopy, unless alias_non_tensor is True
for non tensors it will do a deepcopy, unless alias_non_tensor is True
"""
if cpu_dict is None:
cpu_dict = {}
Expand Down Expand Up @@ -416,7 +523,7 @@ def copy_model_states_to_cpu(model, cpu_dict=None, megatron_amp_O2=True, sync=Tr
@torch.no_grad()
def swap_dict(resident_model, cpu_weights, offload_onto_cpu=True, megatron_amp_O2=True):
"""swap the state dict with a specified state dict, and offload the current state dict onto CPU
if needed
if needed
"""
offloaded_weights = {}

Expand All @@ -429,8 +536,7 @@ def swap_dict(resident_model, cpu_weights, offload_onto_cpu=True, megatron_amp_O

@contextmanager
def cpu_weight_swap(resident_model, cpu_weights, megatron_amp_O2=True):
"""swap the weights into GPU, and then swap it out once return
"""
"""swap the weights into GPU, and then swap it out once return"""
cpu_dict = swap_dict(resident_model, cpu_weights, megatron_amp_O2=megatron_amp_O2)
try:
yield
Expand All @@ -441,8 +547,7 @@ def cpu_weight_swap(resident_model, cpu_weights, megatron_amp_O2=True):

@contextmanager
def adapter_control(model):
"""Temporarily disable adapters and re-enable them after the operation
"""
"""Temporarily disable adapters and re-enable them after the operation"""
try:
# Disable adapters before yielding control
for _, module in model.named_modules():
Expand All @@ -458,7 +563,7 @@ def adapter_control(model):

def convert_to_amp_o2_format(state_dict):
"""when amp_o2 is enabled, the model gets wrapped in a Float16Module which changes
the keys and how it loads need to add module onto it
the keys and how it loads need to add module onto it
"""
new_state_dict = {}

Expand All @@ -473,11 +578,11 @@ def convert_to_amp_o2_format(state_dict):
def get_iterator_k_split_list(batch: List[str], num_microbatches: int) -> Iterator:
"""
Generate an iterator to split a list into microbatches of equal size.
Args:
batch (List[str]): The list to be split into microbatches.
num_microbatches (int): The number of microbatches to split the list into.
Returns:
Iterator: An iterator that yields the microbatches.
"""
Expand Down

0 comments on commit 06e7170

Please sign in to comment.