Skip to content

Commit

Permalink
Add mistral model type support (#384)
Browse files Browse the repository at this point in the history
  • Loading branch information
arm-diaz authored Aug 26, 2024
1 parent fe01050 commit 7791f5d
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions 3.test_cases/10.FSDP/model_utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,22 @@ def get_model_config(args):
num_experts_per_tok=2,
num_local_experts=8,
)
elif "mistral" in args.model_type:
from transformers import MistralConfig
model_config = MistralConfig(
vocab_size=args.vocab_size,
hidden_size=args.hidden_width,
intermediate_size=args.intermediate_size,
num_hidden_layers=args.num_layers,
num_attention_heads=args.num_heads,
num_key_value_heads=args.num_key_value_heads,
hidden_act="silu",
max_position_embeddings=args.max_context_width,
initializer_range=args.initializer_range,
rms_norm_eps=1e-5,
use_cache=False,
tie_word_embeddings=False
)
else:
raise NotImplementedError(f"Model {args.model_type} not implemented")
return model_config
Expand Down Expand Up @@ -231,6 +247,11 @@ def get_transformer_layer(model_type="gpt2"):

transformer_layer = MixtralDecoderLayer

elif model_type == "mistral":
from transformers.models.mistral.modeling_mistral import MistralDecoderLayer

transformer_layer = MistralDecoderLayer

else:
raise NotImplementedError(f"Model type {model_type} not implemented")

Expand Down

0 comments on commit 7791f5d

Please sign in to comment.