Skip to content

Commit

Permalink
Fix, add llama config
Browse files Browse the repository at this point in the history
  • Loading branch information
jlamypoirier committed Dec 17, 2024
1 parent 1dc5def commit b234331
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 39 deletions.
10 changes: 6 additions & 4 deletions fast_llm/engine/checkpoint/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,10 +238,12 @@ def _export_config(cls, config: BaseModelArchitectureConfig) -> dict[str, typing
exported_config = {}
for converter in cls._get_config_converters():
try:
values = [
converter.export_params(cls._get_fast_llm_attribute(config, fast_llm_name))
for fast_llm_name in converter.fast_llm_names
]
values = converter.export_params(
tuple(
cls._get_fast_llm_attribute(config, fast_llm_name)
for fast_llm_name in converter.fast_llm_names
)
)
for export_name, value in zip(converter.export_names, values, strict=True):
if value is not MISSING:
set_nested_dict_value(exported_config, export_name, value)
Expand Down
6 changes: 3 additions & 3 deletions fast_llm/models/gpt/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,9 +299,9 @@ def __post_init__(self):
def export_params(self, fast_llm_values):
rope_type, *parameters = fast_llm_values
if rope_type == RotaryEmbeddingType.default:
return None
return (None,)
elif rope_type == RotaryEmbeddingType.llama3:
return {key: value for key, value in zip(self._HUGGINGFACE_NAMES, ("llama3",) + parameters, strict=True)}
return ({key: value for key, value in zip(self._HUGGINGFACE_NAMES, ("llama3", *parameters), strict=True)},)
else:
raise ValueError(f"Unsupported rotary scaling type: {rope_type}")

Expand All @@ -311,7 +311,7 @@ def import_params(self, export_values):
return (RotaryEmbeddingType.default,) + (DEFAULT,) * 4
elif rope_type == RotaryEmbeddingType.llama3:
# TODO: Is it safe to assume all values are provided?
return {self._HUGGINGFACE_NAMES[0]: "llama3", **{export_value[key] for key in self._HUGGINGFACE_NAMES[1:]}}
return ("llama3", *[export_value[key] for key in self._HUGGINGFACE_NAMES[1:]])
else:
raise ValueError(f"Unsupported rotary scaling type: {rope_type}")

Expand Down
58 changes: 30 additions & 28 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
FORCE_REUSE_RESULTS = int(os.environ.get("FORCE_REUSE_RESULTS", 0)) != 0
REUSE_RESULTS = FORCE_REUSE_RESULTS or int(os.environ.get("REUSE_RESULTS", 0)) != 0
_LOG_LEVEL = int(os.environ.get("LOG_LEVEL", 13))
TEST_MODEL = os.environ.get("MODEL", "llama")

ARTIFACT_PATH = "runs/0/artifacts"

Expand Down Expand Up @@ -111,44 +112,40 @@
]
CONFIG_SC2_COMMON = CONFIG_SC2_FAST_LLM + ["model.distributed.training_dtype=bf16"]

CONFIG_MISTRAL_MEGATRON = CONFIG_SC2_MEGATRON + [
CONFIG_LLAMA_MEGATRON = CONFIG_SC2_MEGATRON + [
"--swiglu",
"--disable-bias-linear",
"--normalization=RMSNorm",
"--ffn-hidden-size=1024",
"--untie-embeddings-and-output-weights",
]
CONFIG_MISTRAL_FAST_LLM = CONFIG_SC2_FAST_LLM + [
CONFIG_LLAMA_FAST_LLM = CONFIG_SC2_FAST_LLM + [
"model.base_model.transformer.gated=True",
"model.base_model.transformer.activation_type=silu",
"model.base_model.transformer.add_linear_biases=False",
"model.base_model.transformer.normalization.type=rms_norm",
"model.base_model.transformer.ffn_hidden_size=1024",
"model.base_model.tie_word_embeddings=False",
]
CONFIG_MISTRAL_COMMON = CONFIG_MISTRAL_FAST_LLM + ["model.distributed.training_dtype=bf16"]
CONFIG_LLAMA_COMMON = CONFIG_LLAMA_FAST_LLM + ["model.distributed.training_dtype=bf16"]

CONFIG_MIXTRAL_MEGATRON = CONFIG_MISTRAL_MEGATRON + [
# Megatron does not support Llama3-style Rotary Embeddings
CONFIG_LLAMA3_MEGATRON = None
CONFIG_LLAMA3_FAST_LLM = CONFIG_LLAMA_FAST_LLM + [
"model.base_model.transformer.rotary.type=llama3",
]
CONFIG_LLAMA3_COMMON = CONFIG_LLAMA3_FAST_LLM + ["model.distributed.training_dtype=bf16"]

CONFIG_MIXTRAL_MEGATRON = CONFIG_LLAMA_MEGATRON + [
"--num-experts=4",
"--moe-router-topk=4",
]
CONFIG_MIXTRAL_FAST_LLM = CONFIG_MISTRAL_FAST_LLM + [
CONFIG_MIXTRAL_FAST_LLM = CONFIG_LLAMA_FAST_LLM + [
"model.base_model.transformer.num_experts=4",
"model.base_model.transformer.num_experts_per_token=4",
]
CONFIG_MIXTRAL_COMMON = CONFIG_MIXTRAL_FAST_LLM + ["model.distributed.training_dtype=bf16"]

CONFIG_LLAMA3_MEGATRON = None # Megatron does not support Llama3-style Rotary Embeddings
CONFIG_LLAMA3_FAST_LLM = CONFIG_SC2_FAST_LLM + [
"model.base_model.transformer.gated=True",
"model.base_model.transformer.activation_type=silu",
"model.base_model.transformer.add_linear_biases=False",
"model.base_model.transformer.normalization.type=rms_norm",
"model.base_model.transformer.rotary.type=llama3",
"model.base_model.tie_word_embeddings=False",
]
CONFIG_LLAMA3_COMMON = CONFIG_LLAMA3_FAST_LLM + ["model.distributed.training_dtype=bf16"]

_CONFIGS = {
"gpt2": ("gpt", CONFIG_GPT2_FAST_LLM, CONFIG_GPT2_MEGATRON, CONFIG_GPT2_COMMON, None),
"sc1": ("gpt", HuggingfaceGPTModelForCausalLM, CONFIG_SC1_FAST_LLM, CONFIG_SC1_MEGATRON, CONFIG_SC1_COMMON, None),
Expand All @@ -159,11 +156,25 @@
CONFIG_SC2_COMMON,
Starcoder2GPTHuggingfaceCheckpointFormat,
),
"llama": (
"gpt",
CONFIG_LLAMA_FAST_LLM,
CONFIG_LLAMA_MEGATRON,
CONFIG_LLAMA_COMMON,
LlamaGPTHuggingfaceCheckpointFormat,
),
"llama3": (
"gpt",
CONFIG_LLAMA3_FAST_LLM,
CONFIG_LLAMA3_MEGATRON,
CONFIG_LLAMA3_COMMON,
LlamaGPTHuggingfaceCheckpointFormat,
),
"mistral": (
"gpt",
CONFIG_MISTRAL_FAST_LLM,
CONFIG_MISTRAL_MEGATRON,
CONFIG_MISTRAL_COMMON,
CONFIG_LLAMA_FAST_LLM,
CONFIG_LLAMA_MEGATRON,
CONFIG_LLAMA_COMMON,
MistralGPTHuggingfaceCheckpointFormat,
),
"mixtral": (
Expand All @@ -173,18 +184,9 @@
CONFIG_MIXTRAL_COMMON,
MixtralGPTHuggingfaceCheckpointFormat,
),
"llama3": (
"gpt",
CONFIG_LLAMA3_FAST_LLM,
CONFIG_LLAMA3_MEGATRON,
CONFIG_LLAMA3_COMMON,
LlamaGPTHuggingfaceCheckpointFormat,
),
}


TEST_MODEL = os.environ.get("MODEL", "mistral")

TEST_MODEL_TYPE, CONFIG_FAST_LLM, CONFIG_GPT2, CONFIG_COMMON, HUGGINGFACE_CHECKPOINT_FORMAT = _CONFIGS[TEST_MODEL]


Expand Down
8 changes: 4 additions & 4 deletions tests/test_match_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from tests.common import (
CONFIG_GPT2_FAST_LLM,
CONFIG_GPT2_MEGATRON,
CONFIG_MISTRAL_FAST_LLM,
CONFIG_MISTRAL_MEGATRON,
CONFIG_LLAMA_FAST_LLM,
CONFIG_LLAMA_MEGATRON,
CONFIG_MIXTRAL_FAST_LLM,
CONFIG_MIXTRAL_MEGATRON,
CONFIG_SC1_FAST_LLM,
Expand Down Expand Up @@ -100,15 +100,15 @@ def test_gpt2_match_meg():
def test_mistral_meg():
# Mistral with Megatron.
# No linear bias, swiglu activation, RMSNorm
run_test_script("test_mistral_meg", CONFIG_MISTRAL_MEGATRON + ["--micro-batch-size=8"], is_megatron=True)
run_test_script("test_mistral_meg", CONFIG_LLAMA_MEGATRON + ["--micro-batch-size=8"], is_megatron=True)


@pytest.mark.depends(on=["test_mistral_meg"])
def test_mistral_match_meg():
# Mistral with Fast-LLM.
run_test_script(
"test_mistral_match_meg",
CONFIG_MISTRAL_FAST_LLM + ["model.base_model.use_megatron_initialization=True"],
CONFIG_LLAMA_FAST_LLM + ["model.base_model.use_megatron_initialization=True"],
compare="test_mistral_meg",
config=CompareConfig(
ignore_tensors=[
Expand Down

0 comments on commit b234331

Please sign in to comment.