From fa4fc08cdadbf22ef9fac7a5dab6b2e83e5bb19c Mon Sep 17 00:00:00 2001 From: Sean Lee Date: Mon, 25 Mar 2024 09:44:50 +0800 Subject: [PATCH] remove prepare_model_for_8bit_training --- angle_emb/angle.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/angle_emb/angle.py b/angle_emb/angle.py index 0ce428a..6764059 100644 --- a/angle_emb/angle.py +++ b/angle_emb/angle.py @@ -29,8 +29,7 @@ from transformers.utils import PaddingStrategy from peft import ( get_peft_model, LoraConfig, TaskType, PeftModel, - prepare_model_for_kbit_training, - prepare_model_for_int8_training + prepare_model_for_kbit_training ) from peft.tuners.lora import LoraLayer @@ -1108,13 +1107,13 @@ def __init__(self, lora_config['bias'] = "none" lora_config['task_type'] = TaskType.CAUSAL_LM - if load_kbit == 4: + if load_kbit in [4, 8]: model = MODEL_CLASS.from_pretrained( model_name_or_path, - load_in_4bit=True, config=None, quantization_config=BitsAndBytesConfig( - load_in_4bit=True, + load_in_4bit=load_kbit == 4, + load_in_8bit=load_kbit == 8, llm_int8_threshold=6.0, llm_int8_has_fp16_weight=False, bnb_4bit_compute_dtype=torch.float32, @@ -1149,17 +1148,10 @@ def __init__(self, if train_mode: model = MODEL_CLASS.from_pretrained( model_name_or_path, - load_in_8bit=load_kbit == 8, torch_dtype=torch.float16 if load_kbit == 16 else torch.float32, device_map=device_map, trust_remote_code=True, ) - if load_kbit == 8: - model = prepare_model_for_int8_training(model, **kbit_kwargs) - if 'target_modules' not in lora_config or lora_config.get('target_modules', None) is None: - target_modules = find_all_linear_names(model) - lora_config['target_modules'] = target_modules - logger.info(f'lora target modules={target_modules}') if pretrained_lora_path is not None: print(f'Load lora weight from {pretrained_lora_path}') model = PeftModel.from_pretrained(