diff --git a/eole/config/models.py b/eole/config/models.py index d2c4450d..8d564a29 100644 --- a/eole/config/models.py +++ b/eole/config/models.py @@ -4,7 +4,8 @@ field_validator, model_validator, computed_field, -) # , TypeAdapter + TypeAdapter, +) import eole from eole.constants import PositionEncodingType, ActivationFunction, ModelType @@ -786,5 +787,4 @@ def _validate_transformer(self): Field(discriminator="architecture", default_factory=RnnModelConfig), # noqa: F821 ] -# Not used anymore, keeping for reference -# build_model_config = TypeAdapter(ModelConfig).validate_python +build_model_config = TypeAdapter(ModelConfig).validate_python diff --git a/eole/config/run.py b/eole/config/run.py index 925fa04e..d76d5e8c 100644 --- a/eole/config/run.py +++ b/eole/config/run.py @@ -6,7 +6,7 @@ from eole.config.training import TrainingConfig from eole.config.inference import InferenceConfig from eole.config.common import MiscConfig, LoggingConfig -from eole.config.models import ModelConfig +from eole.config.models import ModelConfig, build_model_config from eole.config.data import ( DataConfig, BaseVocabConfig, @@ -148,6 +148,48 @@ def _update_with_model_config(self): t for t in transforms if transforms_cls[t].type != TransformType.Train ] + if os.path.exists(config_path): + # logic from models.BaseModel.inference_logic + model_config = build_model_config(config_dict.get("model", {})) + training_config = TrainingConfig( + **config_dict.get("training", {}), dummy_load=True + ) + training_config.world_size = self.world_size + training_config.gpu_ranks = self.gpu_ranks + # retrieve share_vocab from checkpoint config + self.__dict__["share_vocab"] = config_dict.get("share_vocab", False) + # retrieve precision from checkpoint config if not explicitly set + if "compute_dtype" not in self.model_fields_set: + self.compute_dtype = training_config.compute_dtype + # quant logic, might be better elsewhere + if hasattr( + training_config, "quant_type" + ) and training_config.quant_type in [ + "awq_gemm", + "awq_gemv", + ]: + if ( + hasattr(self, "quant_type") + and self.quant_type != "" + and self.quant_type != training_config.quant_type + ): + raise ValueError( + "Model is a awq quantized model, cannot overwrite with another quant method" + ) + self.update(quant_type=training_config.quant_type) + elif self.quant_type == "" and training_config.quant_type != "": + self.update( + quant_layers=training_config.quant_layers, + quant_type=training_config.quant_type, + ) + + model_config._validate_model_config() + # training_config._validate_running_config() # not sure it's needed + + self.update( + model=model_config, + ) + if "transforms" not in self.model_fields_set: self.transforms = self._all_transform = transforms if "transforms_configs" not in self.model_fields_set: diff --git a/eole/models/model.py b/eole/models/model.py index f0198041..7609307f 100644 --- a/eole/models/model.py +++ b/eole/models/model.py @@ -299,49 +299,7 @@ def training_logic(self, running_config, vocabs, checkpoint, device_id): @classmethod def inference_logic(self, checkpoint, running_config, vocabs, device_id=None): - model_config = checkpoint["config"].model - # here we need a running config updated in the same way - training_config = checkpoint["config"].training - # override gpu_ranks/world_size to prevent warnings - training_config.update( - world_size=running_config.world_size, gpu_ranks=running_config.gpu_ranks - ) - # retrieve share_vocab flag from checkpoint config - running_config.share_vocab = checkpoint["config"].share_vocab - # retrieve precision from checkpoint config if not explicitly set - if "compute_dtype" not in running_config.model_fields_set: - running_config.compute_dtype = training_config.compute_dtype - # in fine we might have some nested Lora/QuantizeConfig that are updated from checkpoint values # noqa: E501 - # should quant type be in model config or running config ? - if hasattr(training_config, "quant_type") and training_config.quant_type in [ - "awq_gemm", - "awq_gemv", - ]: # if the loaded model is a awq quantized one, inference config cannot overwrite this - if ( - hasattr(running_config, "quant_type") - and running_config.quant_type != "" - and running_config.quant_type != training_config.quant_type - ): - raise ValueError( - "Model is a awq quantized model, cannot overwrite with another quant method" - ) - # below we are updating training_config with opt (inference_config), though we might want to do the opposite # noqa: E501 - elif hasattr( - running_config, "quant_type" - ) and running_config.quant_type not in [ - "awq_gemm", - "awq_gemv", - ]: # we still want to be able to load fp16/32 models - # with bnb 4bit to minimize ram footprint - # this is probably not useful anymore as running config will already have the info we need, and the opposite case is handled above # noqa: E501 - training_config.quant_layers = running_config.quant_layers - training_config.quant_type = running_config.quant_type - training_config.lora_layers = [] - else: - # new case, we might want to retrieve quant stuff from training_config - running_config.quant_layers = training_config.quant_layers - running_config.quant_type = training_config.quant_type - + model_config = running_config.model # loaded in PredictConfig validation if ( running_config.world_size > 1 and running_config.parallel_mode == "tensor_parallel" @@ -360,24 +318,9 @@ def inference_logic(self, checkpoint, running_config, vocabs, device_id=None): else: device = torch.device("cpu") offset = 0 - # not sure about this one either, do we want to retrieve the value from training sometimes? - # if hasattr(running_config, "self_attn_type"): - # training_config.self_attn_type = running_config.self_attn_type - model_config._validate_model_config() - training_config._validate_running_config() # not sure it's needed vocabs = dict_to_vocabs(checkpoint["vocab"]) - # Avoid functionality on inference - # not sure this will be needed anymore here, though we might need to reconcile train/inference config at some point # noqa: E501 - training_config.update( - update_vocab=False, - dropout_steps=[0], - dropout=[0.0], - attention_dropout=[0.0], - ) - # required to force no dropout at inference with flash - return ( model_config, running_config,