Skip to content

Commit

Permalink
Improve pretrained config
Browse files Browse the repository at this point in the history
  • Loading branch information
jlamypoirier committed Mar 6, 2025
1 parent 93d7e7d commit 8766e55
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 33 deletions.
19 changes: 16 additions & 3 deletions fast_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ def __init__(
# Should raise an Exception in case of failure, and return the validated value.
# Run before the default validation (type check).
valid: typing.Optional[typing.Callable[[typing.Any], typing.Any]] = None,
# Option to skip (postpone) instantiation of a `Config` field.
auto_instantiate: bool = True,
default=dataclasses.MISSING,
default_factory=dataclasses.MISSING,
init: bool = True,
Expand Down Expand Up @@ -152,6 +154,7 @@ def __init__(
self.doc = doc
self.hint = hint
self.valid = valid
self.auto_instantiate = auto_instantiate


class FieldUpdate(dict):
Expand Down Expand Up @@ -265,6 +268,10 @@ def wrap(cls):
return wrap(cls)


# A marker to prevent auto instantiation of a config.
NoAutoInstantiate = object()


@dataclasses.dataclass()
class Config:
"""
Expand Down Expand Up @@ -712,10 +719,16 @@ def _from_dict(
continue
if flat:
if isinstance(field.type, type) and issubclass(field.type, Config):
if flat:
out_arg_dict[name] = field.type._from_dict(default, False, True)
assert isinstance(field.default_factory, type) and issubclass(
field.default_factory, field.type
)
if field.auto_instantiate:
if flat:
out_arg_dict[name] = field.default_factory._from_dict(default, False, True)
else:
out_arg_dict[name] = field.default_factory._from_dict(default.pop(name, {}), strict)
else:
out_arg_dict[name] = field.type._from_dict(default.pop(name, {}), strict)
out_arg_dict[name] = default.pop(name, {})
elif name in default:
out_arg_dict[name] = default.pop(name)
else:
Expand Down
41 changes: 12 additions & 29 deletions fast_llm/engine/multi_stage/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,44 +248,22 @@ def get_base_model_config_class(cls) -> type[BaseModelConfig]:
def from_pretrained(
cls,
pretrained: CheckpointLoadMetadataConfig,
default: typing.Self | None = None,
*updates: dict[str | tuple[str, ...], typing.Any] | None,
) -> typing.Self:
# TODO: Add *updates?
assert pretrained.path is not None
metadata = cls.load_metadata(pretrained)
return cls.from_metadata(pretrained, metadata, default)
return cls.from_metadata(cls.load_metadata(pretrained), *updates)

@classmethod
def from_metadata(
cls,
pretrained: CheckpointLoadMetadataConfig,
metadata: "CheckpointMetadata",
default: typing.Self | None = None,
updates: dict[str | tuple[str, ...], typing.Any] | None = None,
*updates: dict[str | tuple[str, ...], typing.Any] | None,
) -> typing.Self:
# TODO: Standardize to *updates?
# TODO v0.3: Update, remove support for older checkpoints.
if metadata.fast_llm_version.major != 0 or metadata.fast_llm_version.minor not in (0, 1, 2):
raise ValueError(f"Invalid checkpoint version: {metadata.fast_llm_version}")
pretrained_config = cls.from_dict(metadata.config)
if not pretrained.load_config.load_architecture:
assert default is not None
config = default.to_copy()
config.base_model.compare_architecture(pretrained_config.base_model, pretrained.compare_log_fn)
elif pretrained.load_config.load_fast_llm:
config = pretrained_config
else:
with NoAutoValidate():
config = cls() if default is None else default.to_copy()
if pretrained.load_config.load_base_model:
config.base_model = pretrained_config.base_model
else:
config.base_model = config.base_model.to_copy(pretrained_config.base_model.get_architecture())
config.validate()

if updates:
config = config.to_copy(updates)
return config
return cls.from_dict(metadata.config, *updates)

@classmethod
def load_metadata(cls, config: CheckpointLoadMetadataConfig) -> "CheckpointMetadata":
Expand Down Expand Up @@ -315,7 +293,10 @@ class PretrainedFastLLMModelConfig(Config):
_abstract = True
# This configs may be overridden with the pretrained config during validation, so we should be careful about accessing them before.
model: FastLLMModelConfig = Field(
default_factory=FastLLMModelConfig, desc="Configuration for the Fast-LLM model.", hint=FieldHint.core
default_factory=FastLLMModelConfig,
desc="Configuration for the Fast-LLM model.",
hint=FieldHint.core,
auto_instantiate=False,
)
pretrained: CheckpointLoadConfig = Field(
default_factory=CheckpointLoadConfig,
Expand All @@ -327,8 +308,10 @@ def _validate(self) -> None:
assert self.model is not None
self.pretrained.setup(self.model)
self.pretrained.validate()
if self.pretrained.path is not None:
self.model = self.model.from_pretrained(self.pretrained, default=self.model)
if self.pretrained.path is None:
self.model = self.get_field("model").default_factory.from_dict(self.model)
else:
self.model = self.model.from_pretrained(self.pretrained, self.model)
self._setup()
super()._validate()

Expand Down
2 changes: 1 addition & 1 deletion fast_llm/models/custom/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def get_huggingface_model_class(cls) -> type["HuggingfaceCustomModelForCausalLM"

@config_class()
class PretrainedCustomModelConfig(PretrainedGPTModelConfig):
model: CustomModelConfig = FieldUpdate(default_factory=CustomModelConfig)
model: CustomModelConfig = FieldUpdate()


@config_class()
Expand Down

0 comments on commit 8766e55

Please sign in to comment.