Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix llama conversion, improve parameter conversion #94

Merged
merged 3 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion fast_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
_AUTO_VALIDATE = True

MISSING = Tag("<MISSING>")
DEFAULT = Tag("<DEFAULT>")


class NoAutoValidate:
Expand Down Expand Up @@ -347,6 +348,10 @@ def _validate(self):
if not field.init or field._field_type == dataclasses._FIELD_CLASSVAR: # noqa
continue
value = getattr(self, name)
if value is DEFAULT:
# Replace the value with its default.
# We still need to validate because some fields have invalid defaults.
value = field.default
new_value = self._validate_nested(value, field.type, field.name, field.valid, errors, False)
setattr(self, name, new_value)
for name in getattr(self, "_unknown_fields", {}):
Expand Down Expand Up @@ -603,7 +608,9 @@ def _add_field_to_args(
field_value = field_value.__fast_llm_serialize__()
if isinstance(value, enum.Enum):
field_value = field_value.value
elif not isinstance(value, int | float | bool | str | None):
# Tag is not actually serializable, but needs to be kept as-is for config processing,
# and should be absent for valid configs.
elif not isinstance(value, int | float | bool | str | Tag | None):
field_value = str(field_value)
if format_ == _ConfigDictFormat.tuple:
field_value = {(): field_value}
Expand Down
158 changes: 110 additions & 48 deletions fast_llm/engine/checkpoint/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch

from fast_llm import __version__
from fast_llm.config import MISSING
from fast_llm.engine.base_model.config import BaseModelArchitectureConfig
from fast_llm.engine.checkpoint.config import (
CheckpointLoadConfig,
Expand All @@ -24,65 +25,104 @@
logger = logging.getLogger(__name__)


@dataclasses.dataclass
class ParamConverter:
fast_llm_name: tuple[str, ...] | None
export_name: tuple[str, ...] | str | None
@dataclasses.dataclass(kw_only=True)
class ParamConverter(abc.ABC):
fast_llm_names: tuple[tuple[str, ...], ...] = () # Array of fast-llm names, in nested (tuple) format.
export_names: tuple[tuple[str, ...], ...] = () # Array of export names, in nested (tuple) format.

def export_param(self, fast_llm_value):
return fast_llm_value
@abc.abstractmethod
def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]:
pass

@abc.abstractmethod
def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]:
pass


@dataclasses.dataclass(kw_only=True)
class RenameParamConverter(ParamConverter):

def import_param(self, export_value):
return export_value
def __post_init__(self):
Assert.eq(len(self.fast_llm_names), 1)
Assert.eq(len(self.export_names), 1)

def export_params(self, fast_llm_values):
return fast_llm_values

@dataclasses.dataclass
def import_params(self, export_values):
return export_values


# def __repr__(self):
# return f"RenameParamConverter({'.'.join(self.fast_llm_names[0])} <--> {'.'.join(self.export_names[0])})"


@dataclasses.dataclass(kw_only=True)
class ConstantImportParamConverter(ParamConverter):
fast_llm_value: typing.Any
fast_llm_value: typing.Any = MISSING

def __post_init__(self):
Assert.eq(len(self.fast_llm_names), 1)
Assert.eq(len(self.export_names), 0)

def export_param(self, fast_llm_value):
Assert.eq(fast_llm_value, self.fast_llm_value)
def export_params(self, fast_llm_values):
Assert.eq(fast_llm_values[0], self.fast_llm_value)
return ()

def import_param(self, export_value):
return self.fast_llm_value
def import_params(self, export_values):
return (self.fast_llm_value,)


@dataclasses.dataclass
@dataclasses.dataclass(kw_only=True)
class ConstantExportParamConverter(ParamConverter):
export_value: typing.Any
export_value: typing.Any = MISSING

def export_param(self, fast_llm_value):
return self.export_value
def __post_init__(self):
Assert.eq(len(self.fast_llm_names), 0)
Assert.eq(len(self.export_names), 1)

def import_param(self, export_value):
Assert.eq(export_value, self.export_value)
def export_params(self, fast_llm_values):
return (self.export_value,)

def import_params(self, export_values):
Assert.eq(export_values[0], self.export_value)
return ()


@dataclasses.dataclass
@dataclasses.dataclass(kw_only=True)
class IgnoreImportParamConverter(ParamConverter):
ignore_export_value: typing.Any
ignore_export_value: typing.Any = MISSING

def export_param(self, fast_llm_value):
pass
def __post_init__(self):
Assert.eq(len(self.fast_llm_names), 0)
Assert.eq(len(self.export_names), 1)

def import_param(self, export_value):
if export_value is not self.ignore_export_value:
def export_params(self, fast_llm_values):
return (MISSING,)

def import_params(self, export_values):
if export_values[0] not in (self.ignore_export_value, MISSING):
logger.warning(
f"The configuration parameter `{self.export_name}={export_value}` is ignored during conversion."
f"The configuration parameter `{self.export_names[0]}={export_values[0]}` is ignored during conversion."
f" If you intend to use it in Fast-LLM, make sure to set it explicitly in the model configuration."
)
return ()


@dataclasses.dataclass
@dataclasses.dataclass(kw_only=True)
class MappedConfigParamConverter(ParamConverter):
fast_llm_value: typing.Callable
export_value: typing.Callable
fast_llm_value: typing.Callable = lambda x: x
export_value: typing.Callable = lambda x: x

def __post_init__(self):
Assert.eq(len(self.fast_llm_names), 1)
Assert.eq(len(self.export_names), 1)

def export_param(self, fast_llm_value):
return self.export_value(fast_llm_value)
def export_params(self, fast_llm_values):
return (self.export_value(fast_llm_values[0]),)

def import_param(self, export_value):
return self.fast_llm_value(export_value)
def import_params(self, export_values):
return (self.fast_llm_value(export_values[0]),)


class WeightConverter:
Expand Down Expand Up @@ -197,13 +237,18 @@ def _export_config(cls, config: BaseModelArchitectureConfig) -> dict[str, typing
# TODO v0.3: not used in this class
exported_config = {}
for converter in cls._get_config_converters():
value = converter.export_param(
None
if converter.fast_llm_name is None
else cls._get_fast_llm_attribute(config, converter.fast_llm_name) # Noqa
)
if converter.export_name is not None:
set_nested_dict_value(exported_config, converter.export_name, value)
try:
values = converter.export_params(
tuple(
cls._get_fast_llm_attribute(config, fast_llm_name)
for fast_llm_name in converter.fast_llm_names
)
)
for export_name, value in zip(converter.export_names, values, strict=True):
if value is not MISSING:
set_nested_dict_value(exported_config, export_name, value)
except Exception as e:
raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args)

return exported_config # Noqa

Expand All @@ -214,12 +259,25 @@ def _import_config(
kwargs = {}
for converter in cls._get_config_converters():
try:
value = None if converter.export_name is None else get_nested_dict_value(config, converter.export_name)
except KeyError:
value = None
value = converter.import_param(value)
if converter.fast_llm_name is not None:
kwargs[converter.fast_llm_name] = value
values = ()
for export_name in converter.export_names:
try:
value = get_nested_dict_value(config, export_name)
except KeyError:
value = MISSING
values = values + (value,)
values = converter.import_params(values)
for fast_llm_name, value in zip(converter.fast_llm_names, values, strict=True):
if value is MISSING:
# Missing values need to be handled in dedicated converters,
# because implicit / default values may not match.
# TODO: Different behavior from other uses of MISSING. Use different tag?
raise ValueError(f"Missing converted value for fast-llm parameter {fast_llm_name}")
if fast_llm_name in kwargs:
raise ValueError(f"Duplicate converted value for fast-llm parameter {fast_llm_name}")
kwargs[fast_llm_name] = value
except Exception as e:
raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args)

config_class = cls._model_class.get_base_model_config_class()
if architecture_only:
Expand Down Expand Up @@ -335,7 +393,11 @@ def _get_key(cls, parameter_name: str, shard_name: str) -> str:
@classmethod
@abc.abstractmethod
def _create_config_converters(cls) -> list[ParamConverter]:
return [ConstantExportParamConverter(None, "model_type", cls.get_huggingface_model_type())]
return [
ConstantExportParamConverter(
export_names=(("model_type",),), export_value=cls.get_huggingface_model_type()
)
]

@classmethod
def _load_config(cls, directory: pathlib.Path | str) -> dict:
Expand Down
9 changes: 0 additions & 9 deletions fast_llm/layers/transformer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,6 @@ def complex_format(self):
return self.enabled and not self.triton

def _validate(self):
# These happen during conversion.
if self.scale_factor is None:
self.scale_factor = 8.0
if self.low_frequency_factor is None:
self.low_frequency_factor = 1.0
if self.high_frequency_factor is None:
self.high_frequency_factor = 4.0
if self.original_context_length is None:
self.original_context_length = 8192
super()._validate()
if self.triton and not TritonConfig.TRITON_ENABLED:
warnings.warn("Triton is disabled, but the triton rotary kernel will be used anyway.")
Expand Down
Loading
Loading