Skip to content

Commit

Permalink
Move predict config update from model loading to config validation (#163
Browse files Browse the repository at this point in the history
)
  • Loading branch information
francoishernandez authored Dec 18, 2024
1 parent 828a658 commit 7a0fd50
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 62 deletions.
6 changes: 3 additions & 3 deletions eole/config/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
field_validator,
model_validator,
computed_field,
) # , TypeAdapter
TypeAdapter,
)

import eole
from eole.constants import PositionEncodingType, ActivationFunction, ModelType
Expand Down Expand Up @@ -786,5 +787,4 @@ def _validate_transformer(self):
Field(discriminator="architecture", default_factory=RnnModelConfig), # noqa: F821
]

# Not used anymore, keeping for reference
# build_model_config = TypeAdapter(ModelConfig).validate_python
build_model_config = TypeAdapter(ModelConfig).validate_python
44 changes: 43 additions & 1 deletion eole/config/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from eole.config.training import TrainingConfig
from eole.config.inference import InferenceConfig
from eole.config.common import MiscConfig, LoggingConfig
from eole.config.models import ModelConfig
from eole.config.models import ModelConfig, build_model_config
from eole.config.data import (
DataConfig,
BaseVocabConfig,
Expand Down Expand Up @@ -148,6 +148,48 @@ def _update_with_model_config(self):
t for t in transforms if transforms_cls[t].type != TransformType.Train
]

if os.path.exists(config_path):
# logic from models.BaseModel.inference_logic
model_config = build_model_config(config_dict.get("model", {}))
training_config = TrainingConfig(
**config_dict.get("training", {}), dummy_load=True
)
training_config.world_size = self.world_size
training_config.gpu_ranks = self.gpu_ranks
# retrieve share_vocab from checkpoint config
self.__dict__["share_vocab"] = config_dict.get("share_vocab", False)
# retrieve precision from checkpoint config if not explicitly set
if "compute_dtype" not in self.model_fields_set:
self.compute_dtype = training_config.compute_dtype
# quant logic, might be better elsewhere
if hasattr(
training_config, "quant_type"
) and training_config.quant_type in [
"awq_gemm",
"awq_gemv",
]:
if (
hasattr(self, "quant_type")
and self.quant_type != ""
and self.quant_type != training_config.quant_type
):
raise ValueError(
"Model is a awq quantized model, cannot overwrite with another quant method"
)
self.update(quant_type=training_config.quant_type)
elif self.quant_type == "" and training_config.quant_type != "":
self.update(
quant_layers=training_config.quant_layers,
quant_type=training_config.quant_type,
)

model_config._validate_model_config()
# training_config._validate_running_config() # not sure it's needed

self.update(
model=model_config,
)

if "transforms" not in self.model_fields_set:
self.transforms = self._all_transform = transforms
if "transforms_configs" not in self.model_fields_set:
Expand Down
59 changes: 1 addition & 58 deletions eole/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,49 +299,7 @@ def training_logic(self, running_config, vocabs, checkpoint, device_id):

@classmethod
def inference_logic(self, checkpoint, running_config, vocabs, device_id=None):
model_config = checkpoint["config"].model
# here we need a running config updated in the same way
training_config = checkpoint["config"].training
# override gpu_ranks/world_size to prevent warnings
training_config.update(
world_size=running_config.world_size, gpu_ranks=running_config.gpu_ranks
)
# retrieve share_vocab flag from checkpoint config
running_config.share_vocab = checkpoint["config"].share_vocab
# retrieve precision from checkpoint config if not explicitly set
if "compute_dtype" not in running_config.model_fields_set:
running_config.compute_dtype = training_config.compute_dtype
# in fine we might have some nested Lora/QuantizeConfig that are updated from checkpoint values # noqa: E501
# should quant type be in model config or running config ?
if hasattr(training_config, "quant_type") and training_config.quant_type in [
"awq_gemm",
"awq_gemv",
]: # if the loaded model is a awq quantized one, inference config cannot overwrite this
if (
hasattr(running_config, "quant_type")
and running_config.quant_type != ""
and running_config.quant_type != training_config.quant_type
):
raise ValueError(
"Model is a awq quantized model, cannot overwrite with another quant method"
)
# below we are updating training_config with opt (inference_config), though we might want to do the opposite # noqa: E501
elif hasattr(
running_config, "quant_type"
) and running_config.quant_type not in [
"awq_gemm",
"awq_gemv",
]: # we still want to be able to load fp16/32 models
# with bnb 4bit to minimize ram footprint
# this is probably not useful anymore as running config will already have the info we need, and the opposite case is handled above # noqa: E501
training_config.quant_layers = running_config.quant_layers
training_config.quant_type = running_config.quant_type
training_config.lora_layers = []
else:
# new case, we might want to retrieve quant stuff from training_config
running_config.quant_layers = training_config.quant_layers
running_config.quant_type = training_config.quant_type

model_config = running_config.model # loaded in PredictConfig validation
if (
running_config.world_size > 1
and running_config.parallel_mode == "tensor_parallel"
Expand All @@ -360,24 +318,9 @@ def inference_logic(self, checkpoint, running_config, vocabs, device_id=None):
else:
device = torch.device("cpu")
offset = 0
# not sure about this one either, do we want to retrieve the value from training sometimes?
# if hasattr(running_config, "self_attn_type"):
# training_config.self_attn_type = running_config.self_attn_type

model_config._validate_model_config()
training_config._validate_running_config() # not sure it's needed
vocabs = dict_to_vocabs(checkpoint["vocab"])

# Avoid functionality on inference
# not sure this will be needed anymore here, though we might need to reconcile train/inference config at some point # noqa: E501
training_config.update(
update_vocab=False,
dropout_steps=[0],
dropout=[0.0],
attention_dropout=[0.0],
)
# required to force no dropout at inference with flash

return (
model_config,
running_config,
Expand Down

0 comments on commit 7a0fd50

Please sign in to comment.