diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index fea599c5..529697c9 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -238,10 +238,12 @@ def _export_config(cls, config: BaseModelArchitectureConfig) -> dict[str, typing exported_config = {} for converter in cls._get_config_converters(): try: - values = [ - converter.export_params(cls._get_fast_llm_attribute(config, fast_llm_name)) - for fast_llm_name in converter.fast_llm_names - ] + values = converter.export_params( + tuple( + cls._get_fast_llm_attribute(config, fast_llm_name) + for fast_llm_name in converter.fast_llm_names + ) + ) for export_name, value in zip(converter.export_names, values, strict=True): if value is not MISSING: set_nested_dict_value(exported_config, export_name, value) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 55f595a6..13f9f938 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -299,9 +299,9 @@ def __post_init__(self): def export_params(self, fast_llm_values): rope_type, *parameters = fast_llm_values if rope_type == RotaryEmbeddingType.default: - return None + return (None,) elif rope_type == RotaryEmbeddingType.llama3: - return {key: value for key, value in zip(self._HUGGINGFACE_NAMES, ("llama3",) + parameters, strict=True)} + return ({key: value for key, value in zip(self._HUGGINGFACE_NAMES, ("llama3", *parameters), strict=True)},) else: raise ValueError(f"Unsupported rotary scaling type: {rope_type}") @@ -311,7 +311,7 @@ def import_params(self, export_values): return (RotaryEmbeddingType.default,) + (DEFAULT,) * 4 elif rope_type == RotaryEmbeddingType.llama3: # TODO: Is it safe to assume all values are provided? - return {self._HUGGINGFACE_NAMES[0]: "llama3", **{export_value[key] for key in self._HUGGINGFACE_NAMES[1:]}} + return ("llama3", *[export_value[key] for key in self._HUGGINGFACE_NAMES[1:]]) else: raise ValueError(f"Unsupported rotary scaling type: {rope_type}") diff --git a/tests/common.py b/tests/common.py index b96a2204..139e0838 100644 --- a/tests/common.py +++ b/tests/common.py @@ -30,6 +30,7 @@ FORCE_REUSE_RESULTS = int(os.environ.get("FORCE_REUSE_RESULTS", 0)) != 0 REUSE_RESULTS = FORCE_REUSE_RESULTS or int(os.environ.get("REUSE_RESULTS", 0)) != 0 _LOG_LEVEL = int(os.environ.get("LOG_LEVEL", 13)) +TEST_MODEL = os.environ.get("MODEL", "llama") ARTIFACT_PATH = "runs/0/artifacts" @@ -111,14 +112,14 @@ ] CONFIG_SC2_COMMON = CONFIG_SC2_FAST_LLM + ["model.distributed.training_dtype=bf16"] -CONFIG_MISTRAL_MEGATRON = CONFIG_SC2_MEGATRON + [ +CONFIG_LLAMA_MEGATRON = CONFIG_SC2_MEGATRON + [ "--swiglu", "--disable-bias-linear", "--normalization=RMSNorm", "--ffn-hidden-size=1024", "--untie-embeddings-and-output-weights", ] -CONFIG_MISTRAL_FAST_LLM = CONFIG_SC2_FAST_LLM + [ +CONFIG_LLAMA_FAST_LLM = CONFIG_SC2_FAST_LLM + [ "model.base_model.transformer.gated=True", "model.base_model.transformer.activation_type=silu", "model.base_model.transformer.add_linear_biases=False", @@ -126,29 +127,25 @@ "model.base_model.transformer.ffn_hidden_size=1024", "model.base_model.tie_word_embeddings=False", ] -CONFIG_MISTRAL_COMMON = CONFIG_MISTRAL_FAST_LLM + ["model.distributed.training_dtype=bf16"] +CONFIG_LLAMA_COMMON = CONFIG_LLAMA_FAST_LLM + ["model.distributed.training_dtype=bf16"] -CONFIG_MIXTRAL_MEGATRON = CONFIG_MISTRAL_MEGATRON + [ +# Megatron does not support Llama3-style Rotary Embeddings +CONFIG_LLAMA3_MEGATRON = None +CONFIG_LLAMA3_FAST_LLM = CONFIG_LLAMA_FAST_LLM + [ + "model.base_model.transformer.rotary.type=llama3", +] +CONFIG_LLAMA3_COMMON = CONFIG_LLAMA3_FAST_LLM + ["model.distributed.training_dtype=bf16"] + +CONFIG_MIXTRAL_MEGATRON = CONFIG_LLAMA_MEGATRON + [ "--num-experts=4", "--moe-router-topk=4", ] -CONFIG_MIXTRAL_FAST_LLM = CONFIG_MISTRAL_FAST_LLM + [ +CONFIG_MIXTRAL_FAST_LLM = CONFIG_LLAMA_FAST_LLM + [ "model.base_model.transformer.num_experts=4", "model.base_model.transformer.num_experts_per_token=4", ] CONFIG_MIXTRAL_COMMON = CONFIG_MIXTRAL_FAST_LLM + ["model.distributed.training_dtype=bf16"] -CONFIG_LLAMA3_MEGATRON = None # Megatron does not support Llama3-style Rotary Embeddings -CONFIG_LLAMA3_FAST_LLM = CONFIG_SC2_FAST_LLM + [ - "model.base_model.transformer.gated=True", - "model.base_model.transformer.activation_type=silu", - "model.base_model.transformer.add_linear_biases=False", - "model.base_model.transformer.normalization.type=rms_norm", - "model.base_model.transformer.rotary.type=llama3", - "model.base_model.tie_word_embeddings=False", -] -CONFIG_LLAMA3_COMMON = CONFIG_LLAMA3_FAST_LLM + ["model.distributed.training_dtype=bf16"] - _CONFIGS = { "gpt2": ("gpt", CONFIG_GPT2_FAST_LLM, CONFIG_GPT2_MEGATRON, CONFIG_GPT2_COMMON, None), "sc1": ("gpt", HuggingfaceGPTModelForCausalLM, CONFIG_SC1_FAST_LLM, CONFIG_SC1_MEGATRON, CONFIG_SC1_COMMON, None), @@ -159,11 +156,25 @@ CONFIG_SC2_COMMON, Starcoder2GPTHuggingfaceCheckpointFormat, ), + "llama": ( + "gpt", + CONFIG_LLAMA_FAST_LLM, + CONFIG_LLAMA_MEGATRON, + CONFIG_LLAMA_COMMON, + LlamaGPTHuggingfaceCheckpointFormat, + ), + "llama3": ( + "gpt", + CONFIG_LLAMA3_FAST_LLM, + CONFIG_LLAMA3_MEGATRON, + CONFIG_LLAMA3_COMMON, + LlamaGPTHuggingfaceCheckpointFormat, + ), "mistral": ( "gpt", - CONFIG_MISTRAL_FAST_LLM, - CONFIG_MISTRAL_MEGATRON, - CONFIG_MISTRAL_COMMON, + CONFIG_LLAMA_FAST_LLM, + CONFIG_LLAMA_MEGATRON, + CONFIG_LLAMA_COMMON, MistralGPTHuggingfaceCheckpointFormat, ), "mixtral": ( @@ -173,18 +184,9 @@ CONFIG_MIXTRAL_COMMON, MixtralGPTHuggingfaceCheckpointFormat, ), - "llama3": ( - "gpt", - CONFIG_LLAMA3_FAST_LLM, - CONFIG_LLAMA3_MEGATRON, - CONFIG_LLAMA3_COMMON, - LlamaGPTHuggingfaceCheckpointFormat, - ), } -TEST_MODEL = os.environ.get("MODEL", "mistral") - TEST_MODEL_TYPE, CONFIG_FAST_LLM, CONFIG_GPT2, CONFIG_COMMON, HUGGINGFACE_CHECKPOINT_FORMAT = _CONFIGS[TEST_MODEL] diff --git a/tests/test_match_megatron.py b/tests/test_match_megatron.py index 354b5188..37f63b2d 100644 --- a/tests/test_match_megatron.py +++ b/tests/test_match_megatron.py @@ -3,8 +3,8 @@ from tests.common import ( CONFIG_GPT2_FAST_LLM, CONFIG_GPT2_MEGATRON, - CONFIG_MISTRAL_FAST_LLM, - CONFIG_MISTRAL_MEGATRON, + CONFIG_LLAMA_FAST_LLM, + CONFIG_LLAMA_MEGATRON, CONFIG_MIXTRAL_FAST_LLM, CONFIG_MIXTRAL_MEGATRON, CONFIG_SC1_FAST_LLM, @@ -100,7 +100,7 @@ def test_gpt2_match_meg(): def test_mistral_meg(): # Mistral with Megatron. # No linear bias, swiglu activation, RMSNorm - run_test_script("test_mistral_meg", CONFIG_MISTRAL_MEGATRON + ["--micro-batch-size=8"], is_megatron=True) + run_test_script("test_mistral_meg", CONFIG_LLAMA_MEGATRON + ["--micro-batch-size=8"], is_megatron=True) @pytest.mark.depends(on=["test_mistral_meg"]) @@ -108,7 +108,7 @@ def test_mistral_match_meg(): # Mistral with Fast-LLM. run_test_script( "test_mistral_match_meg", - CONFIG_MISTRAL_FAST_LLM + ["model.base_model.use_megatron_initialization=True"], + CONFIG_LLAMA_FAST_LLM + ["model.base_model.use_megatron_initialization=True"], compare="test_mistral_meg", config=CompareConfig( ignore_tensors=[