From 45e92b84b86fcc0ce6daac268b3292a13ad8936b Mon Sep 17 00:00:00 2001 From: Dmytro Pykhtar <37850217+dimapihtar@users.noreply.github.com> Date: Thu, 6 Feb 2025 08:14:15 +0200 Subject: [PATCH] add cp_comm_type param to Mistral config (#12049) * add cp_comm_type param for Mistral config Signed-off-by: dimapihtar * Apply isort and black reformatting Signed-off-by: dimapihtar * fix style Signed-off-by: dimapihtar * Apply isort and black reformatting Signed-off-by: dimapihtar --------- Signed-off-by: dimapihtar Signed-off-by: dimapihtar Co-authored-by: dimapihtar --- nemo/collections/llm/gpt/model/mistral.py | 27 ++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/nemo/collections/llm/gpt/model/mistral.py b/nemo/collections/llm/gpt/model/mistral.py index 0aa611b4454e..a964f1efeb63 100644 --- a/nemo/collections/llm/gpt/model/mistral.py +++ b/nemo/collections/llm/gpt/model/mistral.py @@ -37,6 +37,10 @@ @dataclass class MistralConfig7B(GPTConfig): + """ + Mistral 7B config. + """ + normalization: str = "RMSNorm" activation_func: Callable = F.silu position_embedding_type: str = "rope" @@ -56,6 +60,7 @@ class MistralConfig7B(GPTConfig): init_method_std: float = 0.02 layernorm_epsilon: float = 1e-5 window_size: List[int] = field(default_factory=lambda: [4096, 0]) + cp_comm_type: str = "a2a" @dataclass @@ -70,6 +75,7 @@ class MistralNeMoConfig12B(MistralConfig7B): seq_length: int = 4096 # but "max_position_embeddings": 1024000, window_size: List[int] = None + cp_comm_type: str = None rotary_percent: float = 1.0 rotary_base: float = 1000000.0 @@ -88,11 +94,14 @@ class MistralNeMoConfig123B(MistralConfig7B): seq_length: int = 4096 # but "max_position_embeddings": 131072, window_size: List[int] = None + cp_comm_type: str = None rotary_percent: float = 1.0 rotary_base: float = 1000000.0 class MistralModel(GPTModel): + """ """ + def __init__( self, config: Annotated[Optional[MistralConfig7B], Config[MistralConfig7B]] = None, @@ -107,6 +116,8 @@ def __init__( @io.model_importer(MistralModel, "hf") class HFMistralImporter(io.ModelConnector["MistralForCausalLM", MistralModel]): + """ """ + def init(self) -> MistralModel: return MistralModel(self.config, tokenizer=self.tokenizer) @@ -127,6 +138,7 @@ def apply(self, output_path: Path) -> Path: return output_path def convert_state(self, source, target): + """ """ mapping = { "model.embed_tokens.weight": "embedding.word_embeddings.weight", "model.layers.*.self_attn.o_proj.weight": "decoder.layers.*.self_attention.linear_proj.weight", @@ -141,12 +153,14 @@ def convert_state(self, source, target): @property def tokenizer(self) -> "AutoTokenizer": + """ """ from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer return AutoTokenizer(self.save_hf_tokenizer_assets(str(self))) @property def config(self) -> MistralConfig7B: + """ """ from transformers import MistralConfig source = MistralConfig.from_pretrained(str(self)) @@ -157,9 +171,10 @@ def make_vocab_size_divisible_by(mistral_vocab_size): base //= 2 return base - window_size = None + window_size, cp_comm_type = (None, None) if getattr(source, 'sliding_window', None) is not None: window_size = [source.sliding_window, 0] + cp_comm_type = 'a2a' output = MistralConfig7B( seq_length=source.sliding_window, num_layers=source.num_hidden_layers, @@ -175,6 +190,7 @@ def make_vocab_size_divisible_by(mistral_vocab_size): gated_linear_unit=True, make_vocab_size_divisible_by=make_vocab_size_divisible_by(source.vocab_size), window_size=window_size, + cp_comm_type=cp_comm_type, share_embeddings_and_output_weights=False, fp16=(dtype_from_hf(source) == torch.float16), bf16=(dtype_from_hf(source) == torch.bfloat16), @@ -186,6 +202,8 @@ def make_vocab_size_divisible_by(mistral_vocab_size): @io.model_exporter(MistralModel, "hf") class HFMistralExporter(io.ModelConnector[MistralModel, "MistralForCausalLM"]): + """ """ + def init(self, dtype=torch.bfloat16) -> "MistralForCausalLM": from transformers import AutoModelForCausalLM from transformers.modeling_utils import no_init_weights @@ -209,6 +227,7 @@ def apply(self, output_path: Path) -> Path: return output_path def convert_state(self, source, target): + """ """ mapping = { "decoder.layers.*.self_attention.linear_proj.weight": "model.layers.*.self_attn.o_proj.weight", "decoder.layers.*.mlp.linear_fc2.weight": "model.layers.*.mlp.down_proj.weight", @@ -226,10 +245,12 @@ def convert_state(self, source, target): @property def tokenizer(self): + """ """ return io.load_context(str(self)).model.tokenizer.tokenizer @property def config(self) -> "MistralConfig": + """ """ source: MistralConfig7B = io.load_context(str(self)).model.config from transformers import MistralConfig as HfMistralConfig @@ -259,6 +280,7 @@ def config(self) -> "MistralConfig": target_key="decoder.layers.*.self_attention.linear_qkv.weight", ) def _import_qkv(ctx: io.TransformCTX, q, k, v): + """ """ megatron_config = ctx.target.config head_num = megatron_config.num_attention_heads @@ -301,6 +323,7 @@ def _import_qkv(ctx: io.TransformCTX, q, k, v): ), ) def _export_qkv(ctx: io.TransformCTX, linear_qkv): + """ """ megatron_config = ctx.source.config head_num = megatron_config.num_attention_heads @@ -333,6 +356,7 @@ def _export_qkv(ctx: io.TransformCTX, linear_qkv): target_key="decoder.layers.*.mlp.linear_fc1.weight", ) def _import_linear_fc1(down, gate): + """ """ return torch.cat((down, gate), axis=0) @@ -341,6 +365,7 @@ def _import_linear_fc1(down, gate): target_key=("model.layers.*.mlp.gate_proj.weight", "model.layers.*.mlp.up_proj.weight"), ) def _export_linear_fc1(linear_fc1): + """ """ gate_proj, up_proj = torch.chunk(linear_fc1, 2, dim=0) return gate_proj, up_proj