Skip to content

Commit

Permalink
Fix llama conversion, improve parameter conversion (#94)
Browse files Browse the repository at this point in the history
  • Loading branch information
jlamypoirier authored Dec 17, 2024
1 parent a19d40b commit d8f3390
Show file tree
Hide file tree
Showing 6 changed files with 275 additions and 166 deletions.
9 changes: 8 additions & 1 deletion fast_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
_AUTO_VALIDATE = True

MISSING = Tag("<MISSING>")
DEFAULT = Tag("<DEFAULT>")


class NoAutoValidate:
Expand Down Expand Up @@ -347,6 +348,10 @@ def _validate(self):
if not field.init or field._field_type == dataclasses._FIELD_CLASSVAR: # noqa
continue
value = getattr(self, name)
if value is DEFAULT:
# Replace the value with its default.
# We still need to validate because some fields have invalid defaults.
value = field.default
new_value = self._validate_nested(value, field.type, field.name, field.valid, errors, False)
setattr(self, name, new_value)
for name in getattr(self, "_unknown_fields", {}):
Expand Down Expand Up @@ -603,7 +608,9 @@ def _add_field_to_args(
field_value = field_value.__fast_llm_serialize__()
if isinstance(value, enum.Enum):
field_value = field_value.value
elif not isinstance(value, int | float | bool | str | None):
# Tag is not actually serializable, but needs to be kept as-is for config processing,
# and should be absent for valid configs.
elif not isinstance(value, int | float | bool | str | Tag | None):
field_value = str(field_value)
if format_ == _ConfigDictFormat.tuple:
field_value = {(): field_value}
Expand Down
158 changes: 110 additions & 48 deletions fast_llm/engine/checkpoint/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch

from fast_llm import __version__
from fast_llm.config import MISSING
from fast_llm.engine.base_model.config import BaseModelArchitectureConfig
from fast_llm.engine.checkpoint.config import (
CheckpointLoadConfig,
Expand All @@ -24,65 +25,104 @@
logger = logging.getLogger(__name__)


@dataclasses.dataclass
class ParamConverter:
fast_llm_name: tuple[str, ...] | None
export_name: tuple[str, ...] | str | None
@dataclasses.dataclass(kw_only=True)
class ParamConverter(abc.ABC):
fast_llm_names: tuple[tuple[str, ...], ...] = () # Array of fast-llm names, in nested (tuple) format.
export_names: tuple[tuple[str, ...], ...] = () # Array of export names, in nested (tuple) format.

def export_param(self, fast_llm_value):
return fast_llm_value
@abc.abstractmethod
def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]:
pass

@abc.abstractmethod
def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]:
pass


@dataclasses.dataclass(kw_only=True)
class RenameParamConverter(ParamConverter):

def import_param(self, export_value):
return export_value
def __post_init__(self):
Assert.eq(len(self.fast_llm_names), 1)
Assert.eq(len(self.export_names), 1)

def export_params(self, fast_llm_values):
return fast_llm_values

@dataclasses.dataclass
def import_params(self, export_values):
return export_values


# def __repr__(self):
# return f"RenameParamConverter({'.'.join(self.fast_llm_names[0])} <--> {'.'.join(self.export_names[0])})"


@dataclasses.dataclass(kw_only=True)
class ConstantImportParamConverter(ParamConverter):
fast_llm_value: typing.Any
fast_llm_value: typing.Any = MISSING

def __post_init__(self):
Assert.eq(len(self.fast_llm_names), 1)
Assert.eq(len(self.export_names), 0)

def export_param(self, fast_llm_value):
Assert.eq(fast_llm_value, self.fast_llm_value)
def export_params(self, fast_llm_values):
Assert.eq(fast_llm_values[0], self.fast_llm_value)
return ()

def import_param(self, export_value):
return self.fast_llm_value
def import_params(self, export_values):
return (self.fast_llm_value,)


@dataclasses.dataclass
@dataclasses.dataclass(kw_only=True)
class ConstantExportParamConverter(ParamConverter):
export_value: typing.Any
export_value: typing.Any = MISSING

def export_param(self, fast_llm_value):
return self.export_value
def __post_init__(self):
Assert.eq(len(self.fast_llm_names), 0)
Assert.eq(len(self.export_names), 1)

def import_param(self, export_value):
Assert.eq(export_value, self.export_value)
def export_params(self, fast_llm_values):
return (self.export_value,)

def import_params(self, export_values):
Assert.eq(export_values[0], self.export_value)
return ()


@dataclasses.dataclass
@dataclasses.dataclass(kw_only=True)
class IgnoreImportParamConverter(ParamConverter):
ignore_export_value: typing.Any
ignore_export_value: typing.Any = MISSING

def export_param(self, fast_llm_value):
pass
def __post_init__(self):
Assert.eq(len(self.fast_llm_names), 0)
Assert.eq(len(self.export_names), 1)

def import_param(self, export_value):
if export_value is not self.ignore_export_value:
def export_params(self, fast_llm_values):
return (MISSING,)

def import_params(self, export_values):
if export_values[0] not in (self.ignore_export_value, MISSING):
logger.warning(
f"The configuration parameter `{self.export_name}={export_value}` is ignored during conversion."
f"The configuration parameter `{self.export_names[0]}={export_values[0]}` is ignored during conversion."
f" If you intend to use it in Fast-LLM, make sure to set it explicitly in the model configuration."
)
return ()


@dataclasses.dataclass
@dataclasses.dataclass(kw_only=True)
class MappedConfigParamConverter(ParamConverter):
fast_llm_value: typing.Callable
export_value: typing.Callable
fast_llm_value: typing.Callable = lambda x: x
export_value: typing.Callable = lambda x: x

def __post_init__(self):
Assert.eq(len(self.fast_llm_names), 1)
Assert.eq(len(self.export_names), 1)

def export_param(self, fast_llm_value):
return self.export_value(fast_llm_value)
def export_params(self, fast_llm_values):
return (self.export_value(fast_llm_values[0]),)

def import_param(self, export_value):
return self.fast_llm_value(export_value)
def import_params(self, export_values):
return (self.fast_llm_value(export_values[0]),)


class WeightConverter:
Expand Down Expand Up @@ -197,13 +237,18 @@ def _export_config(cls, config: BaseModelArchitectureConfig) -> dict[str, typing
# TODO v0.3: not used in this class
exported_config = {}
for converter in cls._get_config_converters():
value = converter.export_param(
None
if converter.fast_llm_name is None
else cls._get_fast_llm_attribute(config, converter.fast_llm_name) # Noqa
)
if converter.export_name is not None:
set_nested_dict_value(exported_config, converter.export_name, value)
try:
values = converter.export_params(
tuple(
cls._get_fast_llm_attribute(config, fast_llm_name)
for fast_llm_name in converter.fast_llm_names
)
)
for export_name, value in zip(converter.export_names, values, strict=True):
if value is not MISSING:
set_nested_dict_value(exported_config, export_name, value)
except Exception as e:
raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args)

return exported_config # Noqa

Expand All @@ -214,12 +259,25 @@ def _import_config(
kwargs = {}
for converter in cls._get_config_converters():
try:
value = None if converter.export_name is None else get_nested_dict_value(config, converter.export_name)
except KeyError:
value = None
value = converter.import_param(value)
if converter.fast_llm_name is not None:
kwargs[converter.fast_llm_name] = value
values = ()
for export_name in converter.export_names:
try:
value = get_nested_dict_value(config, export_name)
except KeyError:
value = MISSING
values = values + (value,)
values = converter.import_params(values)
for fast_llm_name, value in zip(converter.fast_llm_names, values, strict=True):
if value is MISSING:
# Missing values need to be handled in dedicated converters,
# because implicit / default values may not match.
# TODO: Different behavior from other uses of MISSING. Use different tag?
raise ValueError(f"Missing converted value for fast-llm parameter {fast_llm_name}")
if fast_llm_name in kwargs:
raise ValueError(f"Duplicate converted value for fast-llm parameter {fast_llm_name}")
kwargs[fast_llm_name] = value
except Exception as e:
raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args)

config_class = cls._model_class.get_base_model_config_class()
if architecture_only:
Expand Down Expand Up @@ -335,7 +393,11 @@ def _get_key(cls, parameter_name: str, shard_name: str) -> str:
@classmethod
@abc.abstractmethod
def _create_config_converters(cls) -> list[ParamConverter]:
return [ConstantExportParamConverter(None, "model_type", cls.get_huggingface_model_type())]
return [
ConstantExportParamConverter(
export_names=(("model_type",),), export_value=cls.get_huggingface_model_type()
)
]

@classmethod
def _load_config(cls, directory: pathlib.Path | str) -> dict:
Expand Down
9 changes: 0 additions & 9 deletions fast_llm/layers/transformer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,6 @@ def complex_format(self):
return self.enabled and not self.triton

def _validate(self):
# These happen during conversion.
if self.scale_factor is None:
self.scale_factor = 8.0
if self.low_frequency_factor is None:
self.low_frequency_factor = 1.0
if self.high_frequency_factor is None:
self.high_frequency_factor = 4.0
if self.original_context_length is None:
self.original_context_length = 8192
super()._validate()
if self.triton and not TritonConfig.TRITON_ENABLED:
warnings.warn("Triton is disabled, but the triton rotary kernel will be used anyway.")
Expand Down
Loading

0 comments on commit d8f3390

Please sign in to comment.