Skip to content

Commit

Permalink
add cp_comm_type param to Mistral config (#12049)
Browse files Browse the repository at this point in the history
* add cp_comm_type param for Mistral config

Signed-off-by: dimapihtar <[email protected]>

* Apply isort and black reformatting

Signed-off-by: dimapihtar <[email protected]>

* fix style

Signed-off-by: dimapihtar <[email protected]>

* Apply isort and black reformatting

Signed-off-by: dimapihtar <[email protected]>

---------

Signed-off-by: dimapihtar <[email protected]>
Signed-off-by: dimapihtar <[email protected]>
Co-authored-by: dimapihtar <[email protected]>
  • Loading branch information
dimapihtar and dimapihtar authored Feb 6, 2025
1 parent e51ec38 commit 45e92b8
Showing 1 changed file with 26 additions and 1 deletion.
27 changes: 26 additions & 1 deletion nemo/collections/llm/gpt/model/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@

@dataclass
class MistralConfig7B(GPTConfig):
"""
Mistral 7B config.
"""

normalization: str = "RMSNorm"
activation_func: Callable = F.silu
position_embedding_type: str = "rope"
Expand All @@ -56,6 +60,7 @@ class MistralConfig7B(GPTConfig):
init_method_std: float = 0.02
layernorm_epsilon: float = 1e-5
window_size: List[int] = field(default_factory=lambda: [4096, 0])
cp_comm_type: str = "a2a"


@dataclass
Expand All @@ -70,6 +75,7 @@ class MistralNeMoConfig12B(MistralConfig7B):
seq_length: int = 4096 # but "max_position_embeddings": 1024000,

window_size: List[int] = None
cp_comm_type: str = None
rotary_percent: float = 1.0
rotary_base: float = 1000000.0

Expand All @@ -88,11 +94,14 @@ class MistralNeMoConfig123B(MistralConfig7B):
seq_length: int = 4096 # but "max_position_embeddings": 131072,

window_size: List[int] = None
cp_comm_type: str = None
rotary_percent: float = 1.0
rotary_base: float = 1000000.0


class MistralModel(GPTModel):
""" """

def __init__(
self,
config: Annotated[Optional[MistralConfig7B], Config[MistralConfig7B]] = None,
Expand All @@ -107,6 +116,8 @@ def __init__(

@io.model_importer(MistralModel, "hf")
class HFMistralImporter(io.ModelConnector["MistralForCausalLM", MistralModel]):
""" """

def init(self) -> MistralModel:
return MistralModel(self.config, tokenizer=self.tokenizer)

Expand All @@ -127,6 +138,7 @@ def apply(self, output_path: Path) -> Path:
return output_path

def convert_state(self, source, target):
""" """
mapping = {
"model.embed_tokens.weight": "embedding.word_embeddings.weight",
"model.layers.*.self_attn.o_proj.weight": "decoder.layers.*.self_attention.linear_proj.weight",
Expand All @@ -141,12 +153,14 @@ def convert_state(self, source, target):

@property
def tokenizer(self) -> "AutoTokenizer":
""" """
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer

return AutoTokenizer(self.save_hf_tokenizer_assets(str(self)))

@property
def config(self) -> MistralConfig7B:
""" """
from transformers import MistralConfig

source = MistralConfig.from_pretrained(str(self))
Expand All @@ -157,9 +171,10 @@ def make_vocab_size_divisible_by(mistral_vocab_size):
base //= 2
return base

window_size = None
window_size, cp_comm_type = (None, None)
if getattr(source, 'sliding_window', None) is not None:
window_size = [source.sliding_window, 0]
cp_comm_type = 'a2a'
output = MistralConfig7B(
seq_length=source.sliding_window,
num_layers=source.num_hidden_layers,
Expand All @@ -175,6 +190,7 @@ def make_vocab_size_divisible_by(mistral_vocab_size):
gated_linear_unit=True,
make_vocab_size_divisible_by=make_vocab_size_divisible_by(source.vocab_size),
window_size=window_size,
cp_comm_type=cp_comm_type,
share_embeddings_and_output_weights=False,
fp16=(dtype_from_hf(source) == torch.float16),
bf16=(dtype_from_hf(source) == torch.bfloat16),
Expand All @@ -186,6 +202,8 @@ def make_vocab_size_divisible_by(mistral_vocab_size):

@io.model_exporter(MistralModel, "hf")
class HFMistralExporter(io.ModelConnector[MistralModel, "MistralForCausalLM"]):
""" """

def init(self, dtype=torch.bfloat16) -> "MistralForCausalLM":
from transformers import AutoModelForCausalLM
from transformers.modeling_utils import no_init_weights
Expand All @@ -209,6 +227,7 @@ def apply(self, output_path: Path) -> Path:
return output_path

def convert_state(self, source, target):
""" """
mapping = {
"decoder.layers.*.self_attention.linear_proj.weight": "model.layers.*.self_attn.o_proj.weight",
"decoder.layers.*.mlp.linear_fc2.weight": "model.layers.*.mlp.down_proj.weight",
Expand All @@ -226,10 +245,12 @@ def convert_state(self, source, target):

@property
def tokenizer(self):
""" """
return io.load_context(str(self)).model.tokenizer.tokenizer

@property
def config(self) -> "MistralConfig":
""" """
source: MistralConfig7B = io.load_context(str(self)).model.config

from transformers import MistralConfig as HfMistralConfig
Expand Down Expand Up @@ -259,6 +280,7 @@ def config(self) -> "MistralConfig":
target_key="decoder.layers.*.self_attention.linear_qkv.weight",
)
def _import_qkv(ctx: io.TransformCTX, q, k, v):
""" """
megatron_config = ctx.target.config

head_num = megatron_config.num_attention_heads
Expand Down Expand Up @@ -301,6 +323,7 @@ def _import_qkv(ctx: io.TransformCTX, q, k, v):
),
)
def _export_qkv(ctx: io.TransformCTX, linear_qkv):
""" """
megatron_config = ctx.source.config

head_num = megatron_config.num_attention_heads
Expand Down Expand Up @@ -333,6 +356,7 @@ def _export_qkv(ctx: io.TransformCTX, linear_qkv):
target_key="decoder.layers.*.mlp.linear_fc1.weight",
)
def _import_linear_fc1(down, gate):
""" """
return torch.cat((down, gate), axis=0)


Expand All @@ -341,6 +365,7 @@ def _import_linear_fc1(down, gate):
target_key=("model.layers.*.mlp.gate_proj.weight", "model.layers.*.mlp.up_proj.weight"),
)
def _export_linear_fc1(linear_fc1):
""" """
gate_proj, up_proj = torch.chunk(linear_fc1, 2, dim=0)

return gate_proj, up_proj
Expand Down

0 comments on commit 45e92b8

Please sign in to comment.