diff --git a/src/llamafactory/data/processors/pretrain.py b/src/llamafactory/data/processors/pretrain.py index 2cee40e504..c8fa3b2692 100644 --- a/src/llamafactory/data/processors/pretrain.py +++ b/src/llamafactory/data/processors/pretrain.py @@ -33,7 +33,7 @@ def preprocess_pretrain_dataset( text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]] if not data_args.packing: - if data_args.template == "gemma": + if getattr(tokenizer, "add_bos_token", False): text_examples = [tokenizer.bos_token + example for example in text_examples] result = tokenizer(text_examples, add_special_tokens=False, truncation=True, max_length=data_args.cutoff_len) @@ -47,7 +47,7 @@ def preprocess_pretrain_dataset( k: [t[i : i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items() } - if data_args.template == "gemma": + if getattr(tokenizer, "add_bos_token", False): for i in range(len(result["input_ids"])): result["input_ids"][i][0] = tokenizer.bos_token_id diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index bf3cf7f570..f5812d5116 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -201,7 +201,7 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments, default=True, metadata={"help": "Whether or not to use memory-efficient model loading."}, ) - rope_scaling: Optional[Literal["linear", "dynamic"]] = field( + rope_scaling: Optional[Literal["linear", "dynamic", "yarn", "llama3"]] = field( default=None, metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."}, ) diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index b450e72d38..113ddafa8c 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -86,20 +86,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule": except Exception as e: raise OSError("Failed to load tokenizer.") from e - if model_args.model_max_length is not None and tokenizer.model_max_length != model_args.model_max_length: - tokenizer.model_max_length = model_args.model_max_length - - if model_args.new_special_tokens is not None: - num_added_tokens = tokenizer.add_special_tokens( - dict(additional_special_tokens=model_args.new_special_tokens), - replace_additional_special_tokens=False, - ) - logger.info_rank0("Add {} to special tokens.".format(",".join(model_args.new_special_tokens))) - if num_added_tokens > 0 and not model_args.resize_vocab: - model_args.resize_vocab = True - logger.warning_rank0("New tokens have been added, changed `resize_vocab` to True.") - - patch_tokenizer(tokenizer) + patch_tokenizer(tokenizer, model_args) try: processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs) patch_processor(processor, config, tokenizer, model_args) diff --git a/src/llamafactory/model/model_utils/rope.py b/src/llamafactory/model/model_utils/rope.py index 079c7643ea..b1effca176 100644 --- a/src/llamafactory/model/model_utils/rope.py +++ b/src/llamafactory/model/model_utils/rope.py @@ -39,6 +39,7 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_ logger.warning_rank0("Current model does not support RoPE scaling.") return + rope_kwargs = {} if model_args.model_max_length is not None: if is_trainable and model_args.rope_scaling == "dynamic": logger.warning_rank0( @@ -50,14 +51,21 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_ if current_max_length and model_args.model_max_length > current_max_length: logger.info_rank0(f"Enlarge max model length from {current_max_length} to {model_args.model_max_length}.") setattr(config, "max_position_embeddings", model_args.model_max_length) - scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length)) + rope_kwargs["factor"] = float(math.ceil(model_args.model_max_length / current_max_length)) else: logger.warning_rank0("Input length is smaller than max length. Consider increase input length.") - scaling_factor = 1.0 + rope_kwargs["factor"] = 1.0 + + if model_args.rope_scaling == "dynamic": + rope_kwargs["original_max_position_embeddings"] = current_max_length + elif model_args.rope_scaling == "llama3": + rope_kwargs["original_max_position_embeddings"] = current_max_length + rope_kwargs["low_freq_factor"] = 1.0 + rope_kwargs["high_freq_factor"] = 4.0 else: - scaling_factor = 2.0 + rope_kwargs["factor"] = 2.0 - setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor}) + setattr(config, "rope_scaling", {"rope_type": model_args.rope_scaling, **rope_kwargs}) logger.info_rank0( - f"Using {model_args.rope_scaling} scaling strategy and setting scaling factor to {scaling_factor}" + f"Using {model_args.rope_scaling} scaling strategy and setting scaling factor to {rope_kwargs['factor']}." ) diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index c33527a684..1f628c2f05 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -53,10 +53,23 @@ logger = logging.get_logger(__name__) -def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None: +def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> None: if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__): tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer) + if model_args.model_max_length is not None and tokenizer.model_max_length != model_args.model_max_length: + tokenizer.model_max_length = model_args.model_max_length + + if model_args.new_special_tokens is not None: + num_added_tokens = tokenizer.add_special_tokens( + dict(additional_special_tokens=model_args.new_special_tokens), + replace_additional_special_tokens=False, + ) + logger.info_rank0("Add {} to special tokens.".format(",".join(model_args.new_special_tokens))) + if num_added_tokens > 0 and not model_args.resize_vocab: + model_args.resize_vocab = True + logger.warning_rank0("New tokens have been added, changed `resize_vocab` to True.") + def patch_processor( processor: "ProcessorMixin", diff --git a/src/llamafactory/webui/chatter.py b/src/llamafactory/webui/chatter.py index 7b360be689..e9689df282 100644 --- a/src/llamafactory/webui/chatter.py +++ b/src/llamafactory/webui/chatter.py @@ -16,12 +16,14 @@ import os from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple +from transformers.utils import is_torch_npu_available + from ..chat import ChatModel from ..data import Role from ..extras.constants import PEFT_METHODS from ..extras.misc import torch_gc from ..extras.packages import is_gradio_available -from .common import QUANTIZATION_BITS, get_save_dir +from .common import get_save_dir, load_config from .locales import ALERTS @@ -59,6 +61,8 @@ def load_model(self, data) -> Generator[str, None, None]: get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)] lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path") finetuning_type, checkpoint_path = get("top.finetuning_type"), get("top.checkpoint_path") + user_config = load_config() + error = "" if self.loaded: error = ALERTS["err_exists"][lang] @@ -74,26 +78,22 @@ def load_model(self, data) -> Generator[str, None, None]: yield error return - if get("top.quantization_bit") in QUANTIZATION_BITS: - quantization_bit = int(get("top.quantization_bit")) - else: - quantization_bit = None - yield ALERTS["info_loading"][lang] args = dict( model_name_or_path=model_path, + cache_dir=user_config.get("cache_dir", None), finetuning_type=finetuning_type, - quantization_bit=quantization_bit, - quantization_method=get("top.quantization_method"), template=get("top.template"), + rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") != "none" else None, flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto", use_unsloth=(get("top.booster") == "unsloth"), - rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, + enable_liger_kernel=(get("top.booster") == "liger_kernel"), infer_backend=get("infer.infer_backend"), infer_dtype=get("infer.infer_dtype"), trust_remote_code=True, ) + # checkpoints if checkpoint_path: if finetuning_type in PEFT_METHODS: # list args["adapter_name_or_path"] = ",".join( @@ -102,6 +102,12 @@ def load_model(self, data) -> Generator[str, None, None]: else: # str args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, checkpoint_path) + # quantization + if get("top.quantization_bit") != "none": + args["quantization_bit"] = int(get("top.quantization_bit")) + args["quantization_method"] = get("top.quantization_method") + args["double_quantization"] = not is_torch_npu_available() + super().__init__(args) yield ALERTS["info_loaded"][lang] diff --git a/src/llamafactory/webui/common.py b/src/llamafactory/webui/common.py index bc59ea6150..64684f048b 100644 --- a/src/llamafactory/webui/common.py +++ b/src/llamafactory/webui/common.py @@ -47,8 +47,6 @@ DEFAULT_DATA_DIR = "data" DEFAULT_SAVE_DIR = "saves" USER_CONFIG = "user_config.yaml" -QUANTIZATION_BITS = ["8", "6", "5", "4", "3", "2", "1"] -GPTQ_BITS = ["8", "4", "3", "2"] def get_save_dir(*paths: str) -> os.PathLike: diff --git a/src/llamafactory/webui/components/export.py b/src/llamafactory/webui/components/export.py index 29be2b353a..7f4b46e6b5 100644 --- a/src/llamafactory/webui/components/export.py +++ b/src/llamafactory/webui/components/export.py @@ -18,7 +18,7 @@ from ...extras.misc import torch_gc from ...extras.packages import is_gradio_available from ...train.tuner import export_model -from ..common import GPTQ_BITS, get_save_dir +from ..common import get_save_dir, load_config from ..locales import ALERTS @@ -32,6 +32,9 @@ from ..engine import Engine +GPTQ_BITS = ["8", "4", "3", "2"] + + def can_quantize(checkpoint_path: Union[str, List[str]]) -> "gr.Dropdown": if isinstance(checkpoint_path, list) and len(checkpoint_path) != 0: return gr.Dropdown(value="none", interactive=False) @@ -54,6 +57,7 @@ def save_model( export_dir: str, export_hub_model_id: str, ) -> Generator[str, None, None]: + user_config = load_config() error = "" if not model_name: error = ALERTS["err_no_model"][lang] @@ -75,6 +79,7 @@ def save_model( args = dict( model_name_or_path=model_path, + cache_dir=user_config.get("cache_dir", None), finetuning_type=finetuning_type, template=template, export_dir=export_dir, diff --git a/src/llamafactory/webui/components/top.py b/src/llamafactory/webui/components/top.py index 528ee908ca..5ef9ee80c8 100644 --- a/src/llamafactory/webui/components/top.py +++ b/src/llamafactory/webui/components/top.py @@ -41,11 +41,11 @@ def create_top() -> Dict[str, "Component"]: checkpoint_path = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=6) with gr.Row(): - quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", allow_custom_value=True, scale=2) - quantization_method = gr.Dropdown(choices=["bitsandbytes", "hqq", "eetq"], value="bitsandbytes", scale=2) - template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default", scale=2) - rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=3) - booster = gr.Radio(choices=["auto", "flashattn2", "unsloth", "liger_kernel"], value="auto", scale=5) + quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", allow_custom_value=True) + quantization_method = gr.Dropdown(choices=["bitsandbytes", "hqq", "eetq"], value="bitsandbytes") + template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default") + rope_scaling = gr.Dropdown(choices=["none", "linear", "dynamic", "yarn", "llama3"], value="none") + booster = gr.Dropdown(choices=["auto", "flashattn2", "unsloth", "liger_kernel"], value="auto") model_name.change(get_model_info, [model_name], [model_path, template], queue=False).then( list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False diff --git a/src/llamafactory/webui/locales.py b/src/llamafactory/webui/locales.py index 154f179418..943b8864e4 100644 --- a/src/llamafactory/webui/locales.py +++ b/src/llamafactory/webui/locales.py @@ -15,10 +15,10 @@ LOCALES = { "lang": { "en": { - "label": "Lang", + "label": "Language", }, "ru": { - "label": "язык", + "label": "Язык", }, "zh": { "label": "语言", @@ -30,11 +30,11 @@ "model_name": { "en": { "label": "Model name", - "info": "Input the name prefix to search for the model.", + "info": "Input the initial name to search for the model.", }, "ru": { "label": "Название модели", - "info": "Введите префикс имени для поиска модели.", + "info": "Введите начальное имя для поиска модели.", }, "zh": { "label": "模型名称", @@ -42,7 +42,7 @@ }, "ko": { "label": "모델 이름", - "info": "모델을 검색하기 위해 이름 접두어를 입력하세요.", + "info": "모델을 검색할 초기 이름을 입력하세요.", }, }, "model_path": { @@ -129,48 +129,50 @@ }, "template": { "en": { - "label": "Prompt template", - "info": "The template used in constructing prompts.", + "label": "Chat template", + "info": "The chat template used in constructing prompts.", }, "ru": { - "label": "Шаблон запроса", - "info": "Шаблон, используемый при формировании запросов.", + "label": "Шаблон чата", + "info": "Шаблон чата используемый для составления подсказок.", }, "zh": { - "label": "提示模板", + "label": "对话模板", "info": "构建提示词时使用的模板。", }, "ko": { - "label": "프롬프트 템플릿", - "info": "프롬프트 구성에 사용될 템플릿.", + "label": "채팅 템플릿", + "info": "프롬프트 작성에 사용되는 채팅 템플릿.", }, }, "rope_scaling": { "en": { "label": "RoPE scaling", + "info": "RoPE scaling method to use.", }, "ru": { "label": "Масштабирование RoPE", + "info": "Метод масштабирования RoPE для использования.", }, - "zh": { - "label": "RoPE 插值方法", - }, + "zh": {"label": "RoPE 插值方法", "info": "RoPE 插值时使用的方法。"}, "ko": { "label": "RoPE 스케일링", + "info": "사용할 RoPE 스케일링 방법.", }, }, "booster": { "en": { "label": "Booster", + "info": "Approach used to boost training speed.", }, "ru": { "label": "Ускоритель", + "info": "Подход, используемый для ускорения обучения.", }, - "zh": { - "label": "加速方式", - }, + "zh": {"label": "加速方式", "info": "使用的加速方法。"}, "ko": { "label": "부스터", + "info": "훈련 속도를 향상시키기 위해 사용된 접근 방식.", }, }, "training_stage": { diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index d2195ea4f2..37409fc7c0 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -24,7 +24,7 @@ from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES from ..extras.misc import is_gpu_or_npu_available, torch_gc, use_ray from ..extras.packages import is_gradio_available, is_transformers_version_equal_to_4_46 -from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, QUANTIZATION_BITS, get_save_dir, load_config +from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_save_dir, load_config from .locales import ALERTS, LOCALES from .utils import abort_process, gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd @@ -120,7 +120,7 @@ def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]: preprocessing_num_workers=16, finetuning_type=finetuning_type, template=get("top.template"), - rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, + rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") != "none" else None, flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto", use_unsloth=(get("top.booster") == "unsloth"), enable_liger_kernel=(get("top.booster") == "liger_kernel"), @@ -170,7 +170,7 @@ def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]: args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, get("top.checkpoint_path")) # quantization - if get("top.quantization_bit") in QUANTIZATION_BITS: + if get("top.quantization_bit") != "none": args["quantization_bit"] = int(get("top.quantization_bit")) args["quantization_method"] = get("top.quantization_method") args["double_quantization"] = not is_torch_npu_available() @@ -280,7 +280,7 @@ def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]: finetuning_type=finetuning_type, quantization_method=get("top.quantization_method"), template=get("top.template"), - rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, + rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") != "none" else None, flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto", use_unsloth=(get("top.booster") == "unsloth"), dataset_dir=get("eval.dataset_dir"), @@ -311,9 +311,10 @@ def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]: args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, get("top.checkpoint_path")) # quantization - if get("top.quantization_bit") in QUANTIZATION_BITS: + if get("top.quantization_bit") != "none": args["quantization_bit"] = int(get("top.quantization_bit")) args["quantization_method"] = get("top.quantization_method") + args["double_quantization"] = not is_torch_npu_available() return args