Skip to content

Commit

Permalink
fix style
Browse files Browse the repository at this point in the history
Signed-off-by: dimapihtar <[email protected]>
  • Loading branch information
dimapihtar committed Feb 4, 2025
1 parent 1ff755d commit 0cbcafb
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions nemo/collections/llm/gpt/model/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@

@dataclass
class MistralConfig7B(GPTConfig):
"""
Mistral 7B config.
"""

normalization: str = "RMSNorm"
activation_func: Callable = F.silu
position_embedding_type: str = "rope"
Expand Down Expand Up @@ -96,6 +100,7 @@ class MistralNeMoConfig123B(MistralConfig7B):


class MistralModel(GPTModel):
""" """
def __init__(
self,
config: Annotated[Optional[MistralConfig7B], Config[MistralConfig7B]] = None,
Expand All @@ -110,6 +115,7 @@ def __init__(

@io.model_importer(MistralModel, "hf")
class HFMistralImporter(io.ModelConnector["MistralForCausalLM", MistralModel]):
""" """
def init(self) -> MistralModel:
return MistralModel(self.config, tokenizer=self.tokenizer)

Expand All @@ -130,6 +136,7 @@ def apply(self, output_path: Path) -> Path:
return output_path

def convert_state(self, source, target):
""" """
mapping = {
"model.embed_tokens.weight": "embedding.word_embeddings.weight",
"model.layers.*.self_attn.o_proj.weight": "decoder.layers.*.self_attention.linear_proj.weight",
Expand All @@ -144,12 +151,14 @@ def convert_state(self, source, target):

@property
def tokenizer(self) -> "AutoTokenizer":
""" """
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer

return AutoTokenizer(self.save_hf_tokenizer_assets(str(self)))

@property
def config(self) -> MistralConfig7B:
""" """
from transformers import MistralConfig

source = MistralConfig.from_pretrained(str(self))
Expand Down Expand Up @@ -190,6 +199,7 @@ def make_vocab_size_divisible_by(mistral_vocab_size):

@io.model_exporter(MistralModel, "hf")
class HFMistralExporter(io.ModelConnector[MistralModel, "MistralForCausalLM"]):
""" """
def init(self, dtype=torch.bfloat16) -> "MistralForCausalLM":
from transformers import AutoModelForCausalLM
from transformers.modeling_utils import no_init_weights
Expand All @@ -213,6 +223,7 @@ def apply(self, output_path: Path) -> Path:
return output_path

def convert_state(self, source, target):
""" """
mapping = {
"decoder.layers.*.self_attention.linear_proj.weight": "model.layers.*.self_attn.o_proj.weight",
"decoder.layers.*.mlp.linear_fc2.weight": "model.layers.*.mlp.down_proj.weight",
Expand All @@ -230,10 +241,12 @@ def convert_state(self, source, target):

@property
def tokenizer(self):
""" """
return io.load_context(str(self)).model.tokenizer.tokenizer

@property
def config(self) -> "MistralConfig":
""" """
source: MistralConfig7B = io.load_context(str(self)).model.config

from transformers import MistralConfig as HfMistralConfig
Expand Down Expand Up @@ -263,6 +276,7 @@ def config(self) -> "MistralConfig":
target_key="decoder.layers.*.self_attention.linear_qkv.weight",
)
def _import_qkv(ctx: io.TransformCTX, q, k, v):
""" """
megatron_config = ctx.target.config

head_num = megatron_config.num_attention_heads
Expand Down Expand Up @@ -305,6 +319,7 @@ def _import_qkv(ctx: io.TransformCTX, q, k, v):
),
)
def _export_qkv(ctx: io.TransformCTX, linear_qkv):
""" """
megatron_config = ctx.source.config

head_num = megatron_config.num_attention_heads
Expand Down Expand Up @@ -337,6 +352,7 @@ def _export_qkv(ctx: io.TransformCTX, linear_qkv):
target_key="decoder.layers.*.mlp.linear_fc1.weight",
)
def _import_linear_fc1(down, gate):
""" """
return torch.cat((down, gate), axis=0)


Expand All @@ -345,6 +361,7 @@ def _import_linear_fc1(down, gate):
target_key=("model.layers.*.mlp.gate_proj.weight", "model.layers.*.mlp.up_proj.weight"),
)
def _export_linear_fc1(linear_fc1):
""" """
gate_proj, up_proj = torch.chunk(linear_fc1, 2, dim=0)

return gate_proj, up_proj
Expand Down

0 comments on commit 0cbcafb

Please sign in to comment.