diff --git a/ludwig/models/llm.py b/ludwig/models/llm.py index f12035e776b..0ef2c9e3200 100644 --- a/ludwig/models/llm.py +++ b/ludwig/models/llm.py @@ -71,7 +71,11 @@ def update(self, modules: Dict[str, torch.nn.Module]) -> None: self.obj.update(modules) -def load_pretrained_from_config(config_obj: LLMModelConfig, weights_save_path: Optional[str] = None) -> PreTrainedModel: +def load_pretrained_from_config( + config_obj: LLMModelConfig, + model_config: Optional[AutoConfig] = None, + weights_save_path: Optional[str] = None, +) -> PreTrainedModel: load_kwargs = {} if config_obj.quantization: # Apply quanitzation configuration at model load time @@ -79,6 +83,19 @@ def load_pretrained_from_config(config_obj: LLMModelConfig, weights_save_path: O load_kwargs["quantization_config"] = config_obj.quantization.to_bitsandbytes() load_kwargs["device_map"] = "auto" + if config_obj.model_parameters: + # Add any model specific parameters to the load kwargs + for param_name, param_value in config_obj.model_parameters.to_dict().items(): + # Not all parameters are supported by all models, so we only add the parameter to the load kwargs + # if it is supported by the model. + if param_value is None: + continue + + if hasattr(model_config, param_name): + load_kwargs[param_name] = param_value + else: + logger.warning(f"Parameter {param_name} is not supported by {config_obj.base_model}. Skipping.") + logger.info("Loading large language model...") pretrained_model_name_or_path = weights_save_path or config_obj.base_model model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **load_kwargs) @@ -105,7 +122,7 @@ def __init__( self.model_name = self.config_obj.base_model self.model_config = AutoConfig.from_pretrained(self.config_obj.base_model) - self.model = load_pretrained_from_config(self.config_obj) + self.model = load_pretrained_from_config(self.config_obj, model_config=self.model_config) self.curr_device = next(self.model.parameters()).device logger.info("Done.") @@ -585,7 +602,9 @@ def load(self, save_path): self.model = PeftModel.from_pretrained(self.model, weights_save_path) elif self.config_obj.trainer.type != "none": - self.model = load_pretrained_from_config(self.config_obj, weights_save_path) + self.model = load_pretrained_from_config( + self.config_obj, model_config=self.model_config, weights_save_path=weights_save_path + ) else: logger.info("Skipped loading LLM without weight adjustments.") diff --git a/ludwig/schema/llms/model_parameters.py b/ludwig/schema/llms/model_parameters.py new file mode 100644 index 00000000000..2e8bbc4e958 --- /dev/null +++ b/ludwig/schema/llms/model_parameters.py @@ -0,0 +1,87 @@ +from typing import Optional + +from ludwig.api_annotations import DeveloperAPI +from ludwig.error import ConfigValidationError +from ludwig.schema import utils as schema_utils +from ludwig.schema.utils import ludwig_dataclass + + +@DeveloperAPI +@ludwig_dataclass +class RoPEScalingConfig(schema_utils.BaseMarshmallowConfig): + """Dynamic RoPE-scaling (rotary position embeddings) to extend the context length of LLM like LLaMA, GPT-NeoX, + or Falcon. + + This parameter is a dictionary containing the scaling configuration + for the RoPE embeddings. Currently supports three scaling strategies: linear and dynamic. Their + scaling factor must be an float greater than 1. The expected format is {'type': strategy name, + 'factor': scaling factor} + """ + + def __post_init__(self): + # Both parameters must be set, or none. + if not self.type: + raise ConfigValidationError( + f"`rope_scaling`'s `type` field must be one of ['linear', 'dynamic'], got {self.type}" + ) + + if not self.factor: + raise ConfigValidationError( + f"When using `rope_scaling`, `factor` must be specified and be > 1. Got {self.factor}." + ) + + type: Optional[str] = schema_utils.StringOptions( + options=["linear", "dynamic"], + default=None, + allow_none=True, + description="Currently supports two strategies: linear and dynamic scaling.", + ) + + factor: Optional[float] = schema_utils.FloatRange( + default=None, + allow_none=True, + min=1.0, + min_inclusive=False, + description="The scaling factor for RoPE embeddings.", + ) + + +@DeveloperAPI +class RoPEScalingConfigField(schema_utils.DictMarshmallowField): + def __init__(self): + super().__init__(RoPEScalingConfig, default_missing=True) + + def _jsonschema_type_mapping(self): + return schema_utils.unload_jsonschema_from_marshmallow_class(RoPEScalingConfig, title="rope_scaling") + + +@DeveloperAPI +@ludwig_dataclass +class ModelParametersConfig(schema_utils.BaseMarshmallowConfig): + rope_scaling: RoPEScalingConfig = RoPEScalingConfigField().get_default_field() + + def to_dict(self): + config = {} + if self.rope_scaling: + config["rope_scaling"] = self.rope_scaling.to_dict() + return config + + +@DeveloperAPI +class ModelParametersConfigField(schema_utils.DictMarshmallowField): + def __init__(self): + super().__init__(ModelParametersConfig, default_missing=True) + + def _jsonschema_type_mapping(self): + return { + "oneOf": [ + {"type": "null", "title": "disabled", "description": "Skip configurable model parameters."}, + { + **schema_utils.unload_jsonschema_from_marshmallow_class(ModelParametersConfig), + "title": "enabled", + "description": "Set model parameters options.", + }, + ], + "title": "Model Parameters", + "description": "Configurable model parameters for LLMs.", + } diff --git a/ludwig/schema/model_types/llm.py b/ludwig/schema/model_types/llm.py index e9b735352c5..95ddb29bc69 100644 --- a/ludwig/schema/model_types/llm.py +++ b/ludwig/schema/model_types/llm.py @@ -13,6 +13,7 @@ from ludwig.schema.hyperopt import HyperoptConfig, HyperoptField from ludwig.schema.llms.base_model import BaseModelDataclassField from ludwig.schema.llms.generation import LLMGenerationConfig, LLMGenerationConfigField +from ludwig.schema.llms.model_parameters import ModelParametersConfig, ModelParametersConfigField from ludwig.schema.llms.peft import AdapterDataclassField, BaseAdapterConfig from ludwig.schema.llms.prompt import PromptConfig, PromptConfigField from ludwig.schema.llms.quantization import QuantizationConfig, QuantizationConfigField @@ -50,3 +51,4 @@ class LLMModelConfig(ModelConfig): adapter: Optional[BaseAdapterConfig] = AdapterDataclassField() quantization: Optional[QuantizationConfig] = QuantizationConfigField().get_default_field() + model_parameters: Optional[ModelParametersConfig] = ModelParametersConfigField().get_default_field() diff --git a/tests/integration_tests/test_llm.py b/tests/integration_tests/test_llm.py index 239b42eb981..ce852cd1e07 100644 --- a/tests/integration_tests/test_llm.py +++ b/tests/integration_tests/test_llm.py @@ -428,6 +428,32 @@ def test_lora_wrap_on_init(): assert isinstance(model.model, PeftModel) +def test_llama_rope_scaling(): + config = { + MODEL_TYPE: MODEL_LLM, + BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM", + INPUT_FEATURES: [text_feature(name="input", encoder={"type": "passthrough"})], + OUTPUT_FEATURES: [text_feature(name="output")], + TRAINER: { + TYPE: "finetune", + BATCH_SIZE: 8, + EPOCHS: 2, + }, + "model_parameters": { + "rope_scaling": { + "type": "dynamic", + "factor": 2.0, + } + }, + } + config_obj = ModelConfig.from_dict(config) + model = LLM(config_obj) + + assert model.model.config.rope_scaling + assert model.model.config.rope_scaling["type"] == "dynamic" + assert model.model.config.rope_scaling["factor"] == 2.0 + + def _compare_models(model_1: torch.nn.Module, model_2: torch.nn.Module) -> bool: # Source: https://discuss.pytorch.org/t/check-if-models-have-same-weights/4351/6 for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()): diff --git a/tests/ludwig/schema/test_model_config.py b/tests/ludwig/schema/test_model_config.py index 35cc0468db8..b06b4c7b0e1 100644 --- a/tests/ludwig/schema/test_model_config.py +++ b/tests/ludwig/schema/test_model_config.py @@ -1,6 +1,6 @@ import os from tempfile import TemporaryDirectory -from typing import Optional +from typing import Any, Dict, Optional, Union import pytest import yaml @@ -855,3 +855,42 @@ def test_llm_quantization_config(bits: Optional[int], expected_qconfig: Optional config_obj = ModelConfig.from_dict(config) assert config_obj.quantization == expected_qconfig + + +@pytest.mark.parametrize( + "rope_scaling_config", + [ + ({"type": "linear"}), + ({"factor": 2.0}), + ({"type": "linear", "factor": 1.0}), + ], +) +def test_llm_rope_scaling_failure_modes( + rope_scaling_config: Union[None, Dict[str, Any]], +): + config = { + MODEL_TYPE: MODEL_LLM, + BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM", + INPUT_FEATURES: [{NAME: "text_input", TYPE: "text"}], + OUTPUT_FEATURES: [{NAME: "text_output", TYPE: "text"}], + "model_parameters": { + "rope_scaling": rope_scaling_config, + }, + } + + with pytest.raises(ConfigValidationError): + ModelConfig.from_dict(config) + + +def test_llm_model_parameters_no_rope_scaling(): + config = { + MODEL_TYPE: MODEL_LLM, + BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM", + INPUT_FEATURES: [{NAME: "text_input", TYPE: "text"}], + OUTPUT_FEATURES: [{NAME: "text_output", TYPE: "text"}], + "model_parameters": {}, + } + + config_obj = ModelConfig.from_dict(config) + assert config_obj.model_parameters.rope_scaling is None + assert config_obj.model_parameters.to_dict() == {}