diff --git a/fast_llm/config.py b/fast_llm/config.py index d8ae570c..1934caf2 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -17,6 +17,7 @@ _AUTO_VALIDATE = True MISSING = Tag("") +DEFAULT = Tag("") class NoAutoValidate: @@ -347,6 +348,10 @@ def _validate(self): if not field.init or field._field_type == dataclasses._FIELD_CLASSVAR: # noqa continue value = getattr(self, name) + if value is DEFAULT: + # Replace the value with its default. + # We still need to validate because some fields have invalid defaults. + value = field.default new_value = self._validate_nested(value, field.type, field.name, field.valid, errors, False) setattr(self, name, new_value) for name in getattr(self, "_unknown_fields", {}): @@ -603,7 +608,9 @@ def _add_field_to_args( field_value = field_value.__fast_llm_serialize__() if isinstance(value, enum.Enum): field_value = field_value.value - elif not isinstance(value, int | float | bool | str | None): + # Tag is not actually serializable, but needs to be kept as-is for config processing, + # and should be absent for valid configs. + elif not isinstance(value, int | float | bool | str | Tag | None): field_value = str(field_value) if format_ == _ConfigDictFormat.tuple: field_value = {(): field_value} diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index 733b3833..529697c9 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -9,6 +9,7 @@ import torch from fast_llm import __version__ +from fast_llm.config import MISSING from fast_llm.engine.base_model.config import BaseModelArchitectureConfig from fast_llm.engine.checkpoint.config import ( CheckpointLoadConfig, @@ -24,65 +25,104 @@ logger = logging.getLogger(__name__) -@dataclasses.dataclass -class ParamConverter: - fast_llm_name: tuple[str, ...] | None - export_name: tuple[str, ...] | str | None +@dataclasses.dataclass(kw_only=True) +class ParamConverter(abc.ABC): + fast_llm_names: tuple[tuple[str, ...], ...] = () # Array of fast-llm names, in nested (tuple) format. + export_names: tuple[tuple[str, ...], ...] = () # Array of export names, in nested (tuple) format. - def export_param(self, fast_llm_value): - return fast_llm_value + @abc.abstractmethod + def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: + pass + + @abc.abstractmethod + def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: + pass + + +@dataclasses.dataclass(kw_only=True) +class RenameParamConverter(ParamConverter): - def import_param(self, export_value): - return export_value + def __post_init__(self): + Assert.eq(len(self.fast_llm_names), 1) + Assert.eq(len(self.export_names), 1) + def export_params(self, fast_llm_values): + return fast_llm_values -@dataclasses.dataclass + def import_params(self, export_values): + return export_values + + +# def __repr__(self): +# return f"RenameParamConverter({'.'.join(self.fast_llm_names[0])} <--> {'.'.join(self.export_names[0])})" + + +@dataclasses.dataclass(kw_only=True) class ConstantImportParamConverter(ParamConverter): - fast_llm_value: typing.Any + fast_llm_value: typing.Any = MISSING + + def __post_init__(self): + Assert.eq(len(self.fast_llm_names), 1) + Assert.eq(len(self.export_names), 0) - def export_param(self, fast_llm_value): - Assert.eq(fast_llm_value, self.fast_llm_value) + def export_params(self, fast_llm_values): + Assert.eq(fast_llm_values[0], self.fast_llm_value) + return () - def import_param(self, export_value): - return self.fast_llm_value + def import_params(self, export_values): + return (self.fast_llm_value,) -@dataclasses.dataclass +@dataclasses.dataclass(kw_only=True) class ConstantExportParamConverter(ParamConverter): - export_value: typing.Any + export_value: typing.Any = MISSING - def export_param(self, fast_llm_value): - return self.export_value + def __post_init__(self): + Assert.eq(len(self.fast_llm_names), 0) + Assert.eq(len(self.export_names), 1) - def import_param(self, export_value): - Assert.eq(export_value, self.export_value) + def export_params(self, fast_llm_values): + return (self.export_value,) + + def import_params(self, export_values): + Assert.eq(export_values[0], self.export_value) + return () -@dataclasses.dataclass +@dataclasses.dataclass(kw_only=True) class IgnoreImportParamConverter(ParamConverter): - ignore_export_value: typing.Any + ignore_export_value: typing.Any = MISSING - def export_param(self, fast_llm_value): - pass + def __post_init__(self): + Assert.eq(len(self.fast_llm_names), 0) + Assert.eq(len(self.export_names), 1) - def import_param(self, export_value): - if export_value is not self.ignore_export_value: + def export_params(self, fast_llm_values): + return (MISSING,) + + def import_params(self, export_values): + if export_values[0] not in (self.ignore_export_value, MISSING): logger.warning( - f"The configuration parameter `{self.export_name}={export_value}` is ignored during conversion." + f"The configuration parameter `{self.export_names[0]}={export_values[0]}` is ignored during conversion." f" If you intend to use it in Fast-LLM, make sure to set it explicitly in the model configuration." ) + return () -@dataclasses.dataclass +@dataclasses.dataclass(kw_only=True) class MappedConfigParamConverter(ParamConverter): - fast_llm_value: typing.Callable - export_value: typing.Callable + fast_llm_value: typing.Callable = lambda x: x + export_value: typing.Callable = lambda x: x + + def __post_init__(self): + Assert.eq(len(self.fast_llm_names), 1) + Assert.eq(len(self.export_names), 1) - def export_param(self, fast_llm_value): - return self.export_value(fast_llm_value) + def export_params(self, fast_llm_values): + return (self.export_value(fast_llm_values[0]),) - def import_param(self, export_value): - return self.fast_llm_value(export_value) + def import_params(self, export_values): + return (self.fast_llm_value(export_values[0]),) class WeightConverter: @@ -197,13 +237,18 @@ def _export_config(cls, config: BaseModelArchitectureConfig) -> dict[str, typing # TODO v0.3: not used in this class exported_config = {} for converter in cls._get_config_converters(): - value = converter.export_param( - None - if converter.fast_llm_name is None - else cls._get_fast_llm_attribute(config, converter.fast_llm_name) # Noqa - ) - if converter.export_name is not None: - set_nested_dict_value(exported_config, converter.export_name, value) + try: + values = converter.export_params( + tuple( + cls._get_fast_llm_attribute(config, fast_llm_name) + for fast_llm_name in converter.fast_llm_names + ) + ) + for export_name, value in zip(converter.export_names, values, strict=True): + if value is not MISSING: + set_nested_dict_value(exported_config, export_name, value) + except Exception as e: + raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) return exported_config # Noqa @@ -214,12 +259,25 @@ def _import_config( kwargs = {} for converter in cls._get_config_converters(): try: - value = None if converter.export_name is None else get_nested_dict_value(config, converter.export_name) - except KeyError: - value = None - value = converter.import_param(value) - if converter.fast_llm_name is not None: - kwargs[converter.fast_llm_name] = value + values = () + for export_name in converter.export_names: + try: + value = get_nested_dict_value(config, export_name) + except KeyError: + value = MISSING + values = values + (value,) + values = converter.import_params(values) + for fast_llm_name, value in zip(converter.fast_llm_names, values, strict=True): + if value is MISSING: + # Missing values need to be handled in dedicated converters, + # because implicit / default values may not match. + # TODO: Different behavior from other uses of MISSING. Use different tag? + raise ValueError(f"Missing converted value for fast-llm parameter {fast_llm_name}") + if fast_llm_name in kwargs: + raise ValueError(f"Duplicate converted value for fast-llm parameter {fast_llm_name}") + kwargs[fast_llm_name] = value + except Exception as e: + raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) config_class = cls._model_class.get_base_model_config_class() if architecture_only: @@ -335,7 +393,11 @@ def _get_key(cls, parameter_name: str, shard_name: str) -> str: @classmethod @abc.abstractmethod def _create_config_converters(cls) -> list[ParamConverter]: - return [ConstantExportParamConverter(None, "model_type", cls.get_huggingface_model_type())] + return [ + ConstantExportParamConverter( + export_names=(("model_type",),), export_value=cls.get_huggingface_model_type() + ) + ] @classmethod def _load_config(cls, directory: pathlib.Path | str) -> dict: diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 3e688d1c..ceb31bd4 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -123,15 +123,6 @@ def complex_format(self): return self.enabled and not self.triton def _validate(self): - # These happen during conversion. - if self.scale_factor is None: - self.scale_factor = 8.0 - if self.low_frequency_factor is None: - self.low_frequency_factor = 1.0 - if self.high_frequency_factor is None: - self.high_frequency_factor = 4.0 - if self.original_context_length is None: - self.original_context_length = 8192 super()._validate() if self.triton and not TritonConfig.TRITON_ENABLED: warnings.warn("Triton is disabled, but the triton rotary kernel will be used anyway.") diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 578892ae..13f9f938 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -1,8 +1,10 @@ import abc +import dataclasses import typing import torch +from fast_llm.config import DEFAULT from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( AutoStateDictCheckpointHandler, @@ -13,6 +15,7 @@ IgnoreWeightConverter, MappedConfigParamConverter, ParamConverter, + RenameParamConverter, SplitWeightConverter, WeightConverter, ) @@ -31,6 +34,7 @@ ) from fast_llm.models.gpt.model import GPTModel from fast_llm.tensor import SafeTensorSlice +from fast_llm.utils import Assert if typing.TYPE_CHECKING: pass @@ -112,21 +116,44 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str): @classmethod def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ - ConstantImportParamConverter(("use_position_embeddings",), None, False), - ParamConverter(("transformer", "rotary", "theta"), "rope_theta"), + ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False), + RenameParamConverter( + fast_llm_names=(("transformer", "rotary", "theta"),), export_names=(("rope_theta",),) + ), MappedConfigParamConverter( - ("transformer", "activation_type"), - "hidden_act", - ActivationType.from_hf_name, - lambda activation_type: activation_type.hf_name, + fast_llm_names=(("transformer", "activation_type"),), + export_names=(("hidden_act",),), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + RenameParamConverter( + fast_llm_names=(("transformer", "num_layers"),), + export_names=(("num_hidden_layers",),), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "hidden_size"),), + export_names=(("hidden_size",),), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "num_attention_heads"),), + export_names=(("num_attention_heads",),), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "head_groups"),), + export_names=(("num_key_value_heads",),), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "ffn_hidden_size"),), + export_names=(("intermediate_size",),), + ), + RenameParamConverter( + fast_llm_names=(("vocab_size",),), + export_names=(("vocab_size",),), + ), + RenameParamConverter( + fast_llm_names=(("tie_word_embeddings",),), + export_names=(("tie_word_embeddings",),), ), - ParamConverter(("transformer", "num_layers"), "num_hidden_layers"), - ParamConverter(("transformer", "hidden_size"), "hidden_size"), - ParamConverter(("transformer", "num_attention_heads"), "num_attention_heads"), - ParamConverter(("transformer", "head_groups"), "num_key_value_heads"), - ParamConverter(("transformer", "ffn_hidden_size"), "intermediate_size"), - ParamConverter(("vocab_size",), "vocab_size"), - ParamConverter(("tie_word_embeddings",), "tie_word_embeddings"), ] def _create_weight_converters(self) -> list[WeightConverter]: @@ -214,12 +241,18 @@ class Starcoder2HuggingfaceCheckpointHandler(CommonHuggingfaceCheckpointHandler) @classmethod def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ - ConstantExportParamConverter(None, "architectures", ["Starcoder2ForCausalLM"]), - ConstantImportParamConverter(("transformer", "rotary", "type"), None, RotaryEmbeddingType.default), - ConstantImportParamConverter(("transformer", "normalization", "type"), None, NormalizationType.layer_norm), - ParamConverter(("transformer", "normalization", "epsilon"), "norm_epsilon"), - ConstantImportParamConverter(("transformer", "gated"), None, False), - ConstantImportParamConverter(("transformer", "add_linear_biases"), None, True), + ConstantExportParamConverter(export_names=(("architectures",),), export_value=["Starcoder2ForCausalLM"]), + ConstantImportParamConverter( + fast_llm_names=(("transformer", "rotary", "type"),), fast_llm_value=RotaryEmbeddingType.default + ), + ConstantImportParamConverter( + fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.layer_norm + ), + RenameParamConverter( + fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("norm_epsilon",),) + ), + ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=False), + ConstantImportParamConverter(fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value=True), ] def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str): @@ -238,33 +271,49 @@ class CommonLlamaHuggingfaceCheckpointHandler(CommonHuggingfaceCheckpointHandler @classmethod def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ - ConstantImportParamConverter(("transformer", "normalization", "type"), None, NormalizationType.rms_norm), - ParamConverter(("transformer", "normalization", "epsilon"), "rms_norm_eps"), - ConstantImportParamConverter(("transformer", "gated"), None, True), - ConstantImportParamConverter(("transformer", "add_linear_biases"), None, False), + ConstantImportParamConverter( + fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.rms_norm + ), + RenameParamConverter( + fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("rms_norm_eps",),) + ), + ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True), + ConstantImportParamConverter(fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value=False), ] -def export_rotary_scaling_type(fast_llm_value: RotaryEmbeddingType): - match fast_llm_value: - case RotaryEmbeddingType.default: - return "default" - case RotaryEmbeddingType.llama3: - return "llama3" - case _: - raise ValueError(f"Unsupported rotary scaling type: {fast_llm_value}") - - -def import_rotary_scaling_type(export_value): - if export_value is None: - return RotaryEmbeddingType.default - match export_value: - case "default": - return RotaryEmbeddingType.default - case "llama3": - return RotaryEmbeddingType.llama3 - case _: - raise ValueError(f"Unsupported rotary scaling type: {export_value}") +@dataclasses.dataclass +class RopeScalingParamConverter(ParamConverter): + _HUGGINGFACE_NAMES = ( + "rope_type", + "factor", + "low_freq_factor", + "high_freq_factor", + "original_max_position_embeddings", + ) + + def __post_init__(self): + Assert.eq(len(self.fast_llm_names), 5) + Assert.eq(len(self.export_names), 1) + + def export_params(self, fast_llm_values): + rope_type, *parameters = fast_llm_values + if rope_type == RotaryEmbeddingType.default: + return (None,) + elif rope_type == RotaryEmbeddingType.llama3: + return ({key: value for key, value in zip(self._HUGGINGFACE_NAMES, ("llama3", *parameters), strict=True)},) + else: + raise ValueError(f"Unsupported rotary scaling type: {rope_type}") + + def import_params(self, export_values): + (export_value,) = export_values + if export_value is None or (rope_type := export_value[self._HUGGINGFACE_NAMES[0]]) == "default": + return (RotaryEmbeddingType.default,) + (DEFAULT,) * 4 + elif rope_type == RotaryEmbeddingType.llama3: + # TODO: Is it safe to assume all values are provided? + return ("llama3", *[export_value[key] for key in self._HUGGINGFACE_NAMES[1:]]) + else: + raise ValueError(f"Unsupported rotary scaling type: {rope_type}") class LlamaHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler): @@ -273,31 +322,19 @@ class LlamaHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler) @classmethod def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ - ConstantExportParamConverter(None, "architectures", ["LlamaForCausalLM"]), + ConstantExportParamConverter(export_names=(("architectures",),), export_value=["LlamaForCausalLM"]), # TODO: Llama supports biases - ConstantExportParamConverter(None, "attention_bias", False), - ConstantExportParamConverter(None, "mlp_bias", False), - MappedConfigParamConverter( - ("transformer", "rotary", "type"), - ("rope_scaling", "rope_type"), - import_rotary_scaling_type, - export_rotary_scaling_type, - ), - ParamConverter( - ("transformer", "rotary", "scale_factor"), - ("rope_scaling", "factor"), - ), - ParamConverter( - ("transformer", "rotary", "low_frequency_factor"), - ("rope_scaling", "low_freq_factor"), - ), - ParamConverter( - ("transformer", "rotary", "high_frequency_factor"), - ("rope_scaling", "high_freq_factor"), - ), - ParamConverter( - ("transformer", "rotary", "original_context_length"), - ("rope_scaling", "original_max_position_embeddings"), + ConstantExportParamConverter(export_names=(("attention_bias",),), export_value=False), + ConstantExportParamConverter(export_names=(("mlp_bias",),), export_value=False), + RopeScalingParamConverter( + fast_llm_names=( + ("transformer", "rotary", "type"), + ("transformer", "rotary", "scale_factor"), + ("transformer", "rotary", "low_frequency_factor"), + ("transformer", "rotary", "high_frequency_factor"), + ("transformer", "rotary", "original_context_length"), + ), + export_names=(("rope_scaling",),), ), ] @@ -325,9 +362,11 @@ class MistralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandle @classmethod def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ - ConstantExportParamConverter(None, "architectures", ["MistralForCausalLM"]), - ConstantImportParamConverter(("transformer", "rotary", "type"), None, RotaryEmbeddingType.default), - IgnoreImportParamConverter(None, "sliding_window", None), + ConstantExportParamConverter(export_names=(("architectures",),), export_value=["MistralForCausalLM"]), + ConstantImportParamConverter( + fast_llm_names=(("transformer", "rotary", "type"),), fast_llm_value=RotaryEmbeddingType.default + ), + IgnoreImportParamConverter(export_names=(("sliding_window",),), ignore_export_value=None), ] def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str): @@ -350,12 +389,20 @@ class MixtralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandle @classmethod def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ - ConstantExportParamConverter(None, "architectures", ["MixtralForCausalLM"]), - ConstantImportParamConverter(("transformer", "rotary", "type"), None, RotaryEmbeddingType.default), - ConstantImportParamConverter(("transformer", "expert_routing_type"), None, RoutingType.topk), - ParamConverter(("transformer", "num_experts"), "num_local_experts"), - ParamConverter(("transformer", "num_experts_per_token"), "num_experts_per_tok"), - IgnoreImportParamConverter(None, "sliding_window", None), + ConstantExportParamConverter(export_names=(("architectures",),), export_value=["MixtralForCausalLM"]), + ConstantImportParamConverter( + fast_llm_names=(("transformer", "rotary", "type"),), fast_llm_value=RotaryEmbeddingType.default + ), + ConstantImportParamConverter( + fast_llm_names=(("transformer", "expert_routing_type"),), fast_llm_value=RoutingType.topk + ), + RenameParamConverter( + fast_llm_names=(("transformer", "num_experts"),), export_names=(("num_local_experts",),) + ), + RenameParamConverter( + fast_llm_names=(("transformer", "num_experts_per_token"),), export_names=(("num_experts_per_tok",),) + ), + IgnoreImportParamConverter(export_names=(("sliding_window",),), ignore_export_value=None), ] def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str): diff --git a/tests/common.py b/tests/common.py index b96a2204..139e0838 100644 --- a/tests/common.py +++ b/tests/common.py @@ -30,6 +30,7 @@ FORCE_REUSE_RESULTS = int(os.environ.get("FORCE_REUSE_RESULTS", 0)) != 0 REUSE_RESULTS = FORCE_REUSE_RESULTS or int(os.environ.get("REUSE_RESULTS", 0)) != 0 _LOG_LEVEL = int(os.environ.get("LOG_LEVEL", 13)) +TEST_MODEL = os.environ.get("MODEL", "llama") ARTIFACT_PATH = "runs/0/artifacts" @@ -111,14 +112,14 @@ ] CONFIG_SC2_COMMON = CONFIG_SC2_FAST_LLM + ["model.distributed.training_dtype=bf16"] -CONFIG_MISTRAL_MEGATRON = CONFIG_SC2_MEGATRON + [ +CONFIG_LLAMA_MEGATRON = CONFIG_SC2_MEGATRON + [ "--swiglu", "--disable-bias-linear", "--normalization=RMSNorm", "--ffn-hidden-size=1024", "--untie-embeddings-and-output-weights", ] -CONFIG_MISTRAL_FAST_LLM = CONFIG_SC2_FAST_LLM + [ +CONFIG_LLAMA_FAST_LLM = CONFIG_SC2_FAST_LLM + [ "model.base_model.transformer.gated=True", "model.base_model.transformer.activation_type=silu", "model.base_model.transformer.add_linear_biases=False", @@ -126,29 +127,25 @@ "model.base_model.transformer.ffn_hidden_size=1024", "model.base_model.tie_word_embeddings=False", ] -CONFIG_MISTRAL_COMMON = CONFIG_MISTRAL_FAST_LLM + ["model.distributed.training_dtype=bf16"] +CONFIG_LLAMA_COMMON = CONFIG_LLAMA_FAST_LLM + ["model.distributed.training_dtype=bf16"] -CONFIG_MIXTRAL_MEGATRON = CONFIG_MISTRAL_MEGATRON + [ +# Megatron does not support Llama3-style Rotary Embeddings +CONFIG_LLAMA3_MEGATRON = None +CONFIG_LLAMA3_FAST_LLM = CONFIG_LLAMA_FAST_LLM + [ + "model.base_model.transformer.rotary.type=llama3", +] +CONFIG_LLAMA3_COMMON = CONFIG_LLAMA3_FAST_LLM + ["model.distributed.training_dtype=bf16"] + +CONFIG_MIXTRAL_MEGATRON = CONFIG_LLAMA_MEGATRON + [ "--num-experts=4", "--moe-router-topk=4", ] -CONFIG_MIXTRAL_FAST_LLM = CONFIG_MISTRAL_FAST_LLM + [ +CONFIG_MIXTRAL_FAST_LLM = CONFIG_LLAMA_FAST_LLM + [ "model.base_model.transformer.num_experts=4", "model.base_model.transformer.num_experts_per_token=4", ] CONFIG_MIXTRAL_COMMON = CONFIG_MIXTRAL_FAST_LLM + ["model.distributed.training_dtype=bf16"] -CONFIG_LLAMA3_MEGATRON = None # Megatron does not support Llama3-style Rotary Embeddings -CONFIG_LLAMA3_FAST_LLM = CONFIG_SC2_FAST_LLM + [ - "model.base_model.transformer.gated=True", - "model.base_model.transformer.activation_type=silu", - "model.base_model.transformer.add_linear_biases=False", - "model.base_model.transformer.normalization.type=rms_norm", - "model.base_model.transformer.rotary.type=llama3", - "model.base_model.tie_word_embeddings=False", -] -CONFIG_LLAMA3_COMMON = CONFIG_LLAMA3_FAST_LLM + ["model.distributed.training_dtype=bf16"] - _CONFIGS = { "gpt2": ("gpt", CONFIG_GPT2_FAST_LLM, CONFIG_GPT2_MEGATRON, CONFIG_GPT2_COMMON, None), "sc1": ("gpt", HuggingfaceGPTModelForCausalLM, CONFIG_SC1_FAST_LLM, CONFIG_SC1_MEGATRON, CONFIG_SC1_COMMON, None), @@ -159,11 +156,25 @@ CONFIG_SC2_COMMON, Starcoder2GPTHuggingfaceCheckpointFormat, ), + "llama": ( + "gpt", + CONFIG_LLAMA_FAST_LLM, + CONFIG_LLAMA_MEGATRON, + CONFIG_LLAMA_COMMON, + LlamaGPTHuggingfaceCheckpointFormat, + ), + "llama3": ( + "gpt", + CONFIG_LLAMA3_FAST_LLM, + CONFIG_LLAMA3_MEGATRON, + CONFIG_LLAMA3_COMMON, + LlamaGPTHuggingfaceCheckpointFormat, + ), "mistral": ( "gpt", - CONFIG_MISTRAL_FAST_LLM, - CONFIG_MISTRAL_MEGATRON, - CONFIG_MISTRAL_COMMON, + CONFIG_LLAMA_FAST_LLM, + CONFIG_LLAMA_MEGATRON, + CONFIG_LLAMA_COMMON, MistralGPTHuggingfaceCheckpointFormat, ), "mixtral": ( @@ -173,18 +184,9 @@ CONFIG_MIXTRAL_COMMON, MixtralGPTHuggingfaceCheckpointFormat, ), - "llama3": ( - "gpt", - CONFIG_LLAMA3_FAST_LLM, - CONFIG_LLAMA3_MEGATRON, - CONFIG_LLAMA3_COMMON, - LlamaGPTHuggingfaceCheckpointFormat, - ), } -TEST_MODEL = os.environ.get("MODEL", "mistral") - TEST_MODEL_TYPE, CONFIG_FAST_LLM, CONFIG_GPT2, CONFIG_COMMON, HUGGINGFACE_CHECKPOINT_FORMAT = _CONFIGS[TEST_MODEL] diff --git a/tests/test_match_megatron.py b/tests/test_match_megatron.py index 354b5188..37f63b2d 100644 --- a/tests/test_match_megatron.py +++ b/tests/test_match_megatron.py @@ -3,8 +3,8 @@ from tests.common import ( CONFIG_GPT2_FAST_LLM, CONFIG_GPT2_MEGATRON, - CONFIG_MISTRAL_FAST_LLM, - CONFIG_MISTRAL_MEGATRON, + CONFIG_LLAMA_FAST_LLM, + CONFIG_LLAMA_MEGATRON, CONFIG_MIXTRAL_FAST_LLM, CONFIG_MIXTRAL_MEGATRON, CONFIG_SC1_FAST_LLM, @@ -100,7 +100,7 @@ def test_gpt2_match_meg(): def test_mistral_meg(): # Mistral with Megatron. # No linear bias, swiglu activation, RMSNorm - run_test_script("test_mistral_meg", CONFIG_MISTRAL_MEGATRON + ["--micro-batch-size=8"], is_megatron=True) + run_test_script("test_mistral_meg", CONFIG_LLAMA_MEGATRON + ["--micro-batch-size=8"], is_megatron=True) @pytest.mark.depends(on=["test_mistral_meg"]) @@ -108,7 +108,7 @@ def test_mistral_match_meg(): # Mistral with Fast-LLM. run_test_script( "test_mistral_match_meg", - CONFIG_MISTRAL_FAST_LLM + ["model.base_model.use_megatron_initialization=True"], + CONFIG_LLAMA_FAST_LLM + ["model.base_model.use_megatron_initialization=True"], compare="test_mistral_meg", config=CompareConfig( ignore_tensors=[