Skip to content

Commit

Permalink
Add RoPE scaling to increase context length up to 8K for training or …
Browse files Browse the repository at this point in the history
…inference. (#3477)

Co-authored-by: Travis Addair <[email protected]>
  • Loading branch information
arnavgarg1 and tgaddair authored Jul 30, 2023
1 parent f1ad0df commit f8708a3
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 4 deletions.
25 changes: 22 additions & 3 deletions ludwig/models/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,31 @@ def update(self, modules: Dict[str, torch.nn.Module]) -> None:
self.obj.update(modules)


def load_pretrained_from_config(config_obj: LLMModelConfig, weights_save_path: Optional[str] = None) -> PreTrainedModel:
def load_pretrained_from_config(
config_obj: LLMModelConfig,
model_config: Optional[AutoConfig] = None,
weights_save_path: Optional[str] = None,
) -> PreTrainedModel:
load_kwargs = {}
if config_obj.quantization:
# Apply quanitzation configuration at model load time
load_kwargs["torch_dtype"] = getattr(torch, config_obj.quantization.bnb_4bit_compute_dtype)
load_kwargs["quantization_config"] = config_obj.quantization.to_bitsandbytes()
load_kwargs["device_map"] = "auto"

if config_obj.model_parameters:
# Add any model specific parameters to the load kwargs
for param_name, param_value in config_obj.model_parameters.to_dict().items():
# Not all parameters are supported by all models, so we only add the parameter to the load kwargs
# if it is supported by the model.
if param_value is None:
continue

if hasattr(model_config, param_name):
load_kwargs[param_name] = param_value
else:
logger.warning(f"Parameter {param_name} is not supported by {config_obj.base_model}. Skipping.")

logger.info("Loading large language model...")
pretrained_model_name_or_path = weights_save_path or config_obj.base_model
model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **load_kwargs)
Expand All @@ -105,7 +122,7 @@ def __init__(
self.model_name = self.config_obj.base_model
self.model_config = AutoConfig.from_pretrained(self.config_obj.base_model)

self.model = load_pretrained_from_config(self.config_obj)
self.model = load_pretrained_from_config(self.config_obj, model_config=self.model_config)
self.curr_device = next(self.model.parameters()).device
logger.info("Done.")

Expand Down Expand Up @@ -585,7 +602,9 @@ def load(self, save_path):

self.model = PeftModel.from_pretrained(self.model, weights_save_path)
elif self.config_obj.trainer.type != "none":
self.model = load_pretrained_from_config(self.config_obj, weights_save_path)
self.model = load_pretrained_from_config(
self.config_obj, model_config=self.model_config, weights_save_path=weights_save_path
)
else:
logger.info("Skipped loading LLM without weight adjustments.")

Expand Down
87 changes: 87 additions & 0 deletions ludwig/schema/llms/model_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from typing import Optional

from ludwig.api_annotations import DeveloperAPI
from ludwig.error import ConfigValidationError
from ludwig.schema import utils as schema_utils
from ludwig.schema.utils import ludwig_dataclass


@DeveloperAPI
@ludwig_dataclass
class RoPEScalingConfig(schema_utils.BaseMarshmallowConfig):
"""Dynamic RoPE-scaling (rotary position embeddings) to extend the context length of LLM like LLaMA, GPT-NeoX,
or Falcon.
This parameter is a dictionary containing the scaling configuration
for the RoPE embeddings. Currently supports three scaling strategies: linear and dynamic. Their
scaling factor must be an float greater than 1. The expected format is {'type': strategy name,
'factor': scaling factor}
"""

def __post_init__(self):
# Both parameters must be set, or none.
if not self.type:
raise ConfigValidationError(
f"`rope_scaling`'s `type` field must be one of ['linear', 'dynamic'], got {self.type}"
)

if not self.factor:
raise ConfigValidationError(
f"When using `rope_scaling`, `factor` must be specified and be > 1. Got {self.factor}."
)

type: Optional[str] = schema_utils.StringOptions(
options=["linear", "dynamic"],
default=None,
allow_none=True,
description="Currently supports two strategies: linear and dynamic scaling.",
)

factor: Optional[float] = schema_utils.FloatRange(
default=None,
allow_none=True,
min=1.0,
min_inclusive=False,
description="The scaling factor for RoPE embeddings.",
)


@DeveloperAPI
class RoPEScalingConfigField(schema_utils.DictMarshmallowField):
def __init__(self):
super().__init__(RoPEScalingConfig, default_missing=True)

def _jsonschema_type_mapping(self):
return schema_utils.unload_jsonschema_from_marshmallow_class(RoPEScalingConfig, title="rope_scaling")


@DeveloperAPI
@ludwig_dataclass
class ModelParametersConfig(schema_utils.BaseMarshmallowConfig):
rope_scaling: RoPEScalingConfig = RoPEScalingConfigField().get_default_field()

def to_dict(self):
config = {}
if self.rope_scaling:
config["rope_scaling"] = self.rope_scaling.to_dict()
return config


@DeveloperAPI
class ModelParametersConfigField(schema_utils.DictMarshmallowField):
def __init__(self):
super().__init__(ModelParametersConfig, default_missing=True)

def _jsonschema_type_mapping(self):
return {
"oneOf": [
{"type": "null", "title": "disabled", "description": "Skip configurable model parameters."},
{
**schema_utils.unload_jsonschema_from_marshmallow_class(ModelParametersConfig),
"title": "enabled",
"description": "Set model parameters options.",
},
],
"title": "Model Parameters",
"description": "Configurable model parameters for LLMs.",
}
2 changes: 2 additions & 0 deletions ludwig/schema/model_types/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ludwig.schema.hyperopt import HyperoptConfig, HyperoptField
from ludwig.schema.llms.base_model import BaseModelDataclassField
from ludwig.schema.llms.generation import LLMGenerationConfig, LLMGenerationConfigField
from ludwig.schema.llms.model_parameters import ModelParametersConfig, ModelParametersConfigField
from ludwig.schema.llms.peft import AdapterDataclassField, BaseAdapterConfig
from ludwig.schema.llms.prompt import PromptConfig, PromptConfigField
from ludwig.schema.llms.quantization import QuantizationConfig, QuantizationConfigField
Expand Down Expand Up @@ -50,3 +51,4 @@ class LLMModelConfig(ModelConfig):

adapter: Optional[BaseAdapterConfig] = AdapterDataclassField()
quantization: Optional[QuantizationConfig] = QuantizationConfigField().get_default_field()
model_parameters: Optional[ModelParametersConfig] = ModelParametersConfigField().get_default_field()
26 changes: 26 additions & 0 deletions tests/integration_tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,32 @@ def test_lora_wrap_on_init():
assert isinstance(model.model, PeftModel)


def test_llama_rope_scaling():
config = {
MODEL_TYPE: MODEL_LLM,
BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM",
INPUT_FEATURES: [text_feature(name="input", encoder={"type": "passthrough"})],
OUTPUT_FEATURES: [text_feature(name="output")],
TRAINER: {
TYPE: "finetune",
BATCH_SIZE: 8,
EPOCHS: 2,
},
"model_parameters": {
"rope_scaling": {
"type": "dynamic",
"factor": 2.0,
}
},
}
config_obj = ModelConfig.from_dict(config)
model = LLM(config_obj)

assert model.model.config.rope_scaling
assert model.model.config.rope_scaling["type"] == "dynamic"
assert model.model.config.rope_scaling["factor"] == 2.0


def _compare_models(model_1: torch.nn.Module, model_2: torch.nn.Module) -> bool:
# Source: https://discuss.pytorch.org/t/check-if-models-have-same-weights/4351/6
for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()):
Expand Down
41 changes: 40 additions & 1 deletion tests/ludwig/schema/test_model_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from tempfile import TemporaryDirectory
from typing import Optional
from typing import Any, Dict, Optional, Union

import pytest
import yaml
Expand Down Expand Up @@ -855,3 +855,42 @@ def test_llm_quantization_config(bits: Optional[int], expected_qconfig: Optional
config_obj = ModelConfig.from_dict(config)

assert config_obj.quantization == expected_qconfig


@pytest.mark.parametrize(
"rope_scaling_config",
[
({"type": "linear"}),
({"factor": 2.0}),
({"type": "linear", "factor": 1.0}),
],
)
def test_llm_rope_scaling_failure_modes(
rope_scaling_config: Union[None, Dict[str, Any]],
):
config = {
MODEL_TYPE: MODEL_LLM,
BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM",
INPUT_FEATURES: [{NAME: "text_input", TYPE: "text"}],
OUTPUT_FEATURES: [{NAME: "text_output", TYPE: "text"}],
"model_parameters": {
"rope_scaling": rope_scaling_config,
},
}

with pytest.raises(ConfigValidationError):
ModelConfig.from_dict(config)


def test_llm_model_parameters_no_rope_scaling():
config = {
MODEL_TYPE: MODEL_LLM,
BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM",
INPUT_FEATURES: [{NAME: "text_input", TYPE: "text"}],
OUTPUT_FEATURES: [{NAME: "text_output", TYPE: "text"}],
"model_parameters": {},
}

config_obj = ModelConfig.from_dict(config)
assert config_obj.model_parameters.rope_scaling is None
assert config_obj.model_parameters.to_dict() == {}

0 comments on commit f8708a3

Please sign in to comment.