Skip to content

Commit

Permalink
[model] support yarn (#6693)
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga authored Jan 18, 2025
1 parent 17b4706 commit 1f47b61
Show file tree
Hide file tree
Showing 11 changed files with 83 additions and 63 deletions.
4 changes: 2 additions & 2 deletions src/llamafactory/data/processors/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def preprocess_pretrain_dataset(
text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]]

if not data_args.packing:
if data_args.template == "gemma":
if getattr(tokenizer, "add_bos_token", False):
text_examples = [tokenizer.bos_token + example for example in text_examples]

result = tokenizer(text_examples, add_special_tokens=False, truncation=True, max_length=data_args.cutoff_len)
Expand All @@ -47,7 +47,7 @@ def preprocess_pretrain_dataset(
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
if data_args.template == "gemma":
if getattr(tokenizer, "add_bos_token", False):
for i in range(len(result["input_ids"])):
result["input_ids"][i][0] = tokenizer.bos_token_id

Expand Down
2 changes: 1 addition & 1 deletion src/llamafactory/hparams/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
default=True,
metadata={"help": "Whether or not to use memory-efficient model loading."},
)
rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
rope_scaling: Optional[Literal["linear", "dynamic", "yarn", "llama3"]] = field(
default=None,
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
)
Expand Down
15 changes: 1 addition & 14 deletions src/llamafactory/model/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,20 +86,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
except Exception as e:
raise OSError("Failed to load tokenizer.") from e

if model_args.model_max_length is not None and tokenizer.model_max_length != model_args.model_max_length:
tokenizer.model_max_length = model_args.model_max_length

if model_args.new_special_tokens is not None:
num_added_tokens = tokenizer.add_special_tokens(
dict(additional_special_tokens=model_args.new_special_tokens),
replace_additional_special_tokens=False,
)
logger.info_rank0("Add {} to special tokens.".format(",".join(model_args.new_special_tokens)))
if num_added_tokens > 0 and not model_args.resize_vocab:
model_args.resize_vocab = True
logger.warning_rank0("New tokens have been added, changed `resize_vocab` to True.")

patch_tokenizer(tokenizer)
patch_tokenizer(tokenizer, model_args)
try:
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
patch_processor(processor, config, tokenizer, model_args)
Expand Down
18 changes: 13 additions & 5 deletions src/llamafactory/model/model_utils/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
logger.warning_rank0("Current model does not support RoPE scaling.")
return

rope_kwargs = {}
if model_args.model_max_length is not None:
if is_trainable and model_args.rope_scaling == "dynamic":
logger.warning_rank0(
Expand All @@ -50,14 +51,21 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
if current_max_length and model_args.model_max_length > current_max_length:
logger.info_rank0(f"Enlarge max model length from {current_max_length} to {model_args.model_max_length}.")
setattr(config, "max_position_embeddings", model_args.model_max_length)
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
rope_kwargs["factor"] = float(math.ceil(model_args.model_max_length / current_max_length))
else:
logger.warning_rank0("Input length is smaller than max length. Consider increase input length.")
scaling_factor = 1.0
rope_kwargs["factor"] = 1.0

if model_args.rope_scaling == "dynamic":
rope_kwargs["original_max_position_embeddings"] = current_max_length
elif model_args.rope_scaling == "llama3":
rope_kwargs["original_max_position_embeddings"] = current_max_length
rope_kwargs["low_freq_factor"] = 1.0
rope_kwargs["high_freq_factor"] = 4.0
else:
scaling_factor = 2.0
rope_kwargs["factor"] = 2.0

setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
setattr(config, "rope_scaling", {"rope_type": model_args.rope_scaling, **rope_kwargs})
logger.info_rank0(
f"Using {model_args.rope_scaling} scaling strategy and setting scaling factor to {scaling_factor}"
f"Using {model_args.rope_scaling} scaling strategy and setting scaling factor to {rope_kwargs['factor']}."
)
15 changes: 14 additions & 1 deletion src/llamafactory/model/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,23 @@
logger = logging.get_logger(__name__)


def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None:
def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> None:
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)

if model_args.model_max_length is not None and tokenizer.model_max_length != model_args.model_max_length:
tokenizer.model_max_length = model_args.model_max_length

if model_args.new_special_tokens is not None:
num_added_tokens = tokenizer.add_special_tokens(
dict(additional_special_tokens=model_args.new_special_tokens),
replace_additional_special_tokens=False,
)
logger.info_rank0("Add {} to special tokens.".format(",".join(model_args.new_special_tokens)))
if num_added_tokens > 0 and not model_args.resize_vocab:
model_args.resize_vocab = True
logger.warning_rank0("New tokens have been added, changed `resize_vocab` to True.")


def patch_processor(
processor: "ProcessorMixin",
Expand Down
24 changes: 15 additions & 9 deletions src/llamafactory/webui/chatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
import os
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple

from transformers.utils import is_torch_npu_available

from ..chat import ChatModel
from ..data import Role
from ..extras.constants import PEFT_METHODS
from ..extras.misc import torch_gc
from ..extras.packages import is_gradio_available
from .common import QUANTIZATION_BITS, get_save_dir
from .common import get_save_dir, load_config
from .locales import ALERTS


Expand Down Expand Up @@ -59,6 +61,8 @@ def load_model(self, data) -> Generator[str, None, None]:
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
finetuning_type, checkpoint_path = get("top.finetuning_type"), get("top.checkpoint_path")
user_config = load_config()

error = ""
if self.loaded:
error = ALERTS["err_exists"][lang]
Expand All @@ -74,26 +78,22 @@ def load_model(self, data) -> Generator[str, None, None]:
yield error
return

if get("top.quantization_bit") in QUANTIZATION_BITS:
quantization_bit = int(get("top.quantization_bit"))
else:
quantization_bit = None

yield ALERTS["info_loading"][lang]
args = dict(
model_name_or_path=model_path,
cache_dir=user_config.get("cache_dir", None),
finetuning_type=finetuning_type,
quantization_bit=quantization_bit,
quantization_method=get("top.quantization_method"),
template=get("top.template"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") != "none" else None,
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
use_unsloth=(get("top.booster") == "unsloth"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
enable_liger_kernel=(get("top.booster") == "liger_kernel"),
infer_backend=get("infer.infer_backend"),
infer_dtype=get("infer.infer_dtype"),
trust_remote_code=True,
)

# checkpoints
if checkpoint_path:
if finetuning_type in PEFT_METHODS: # list
args["adapter_name_or_path"] = ",".join(
Expand All @@ -102,6 +102,12 @@ def load_model(self, data) -> Generator[str, None, None]:
else: # str
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, checkpoint_path)

# quantization
if get("top.quantization_bit") != "none":
args["quantization_bit"] = int(get("top.quantization_bit"))
args["quantization_method"] = get("top.quantization_method")
args["double_quantization"] = not is_torch_npu_available()

super().__init__(args)
yield ALERTS["info_loaded"][lang]

Expand Down
2 changes: 0 additions & 2 deletions src/llamafactory/webui/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@
DEFAULT_DATA_DIR = "data"
DEFAULT_SAVE_DIR = "saves"
USER_CONFIG = "user_config.yaml"
QUANTIZATION_BITS = ["8", "6", "5", "4", "3", "2", "1"]
GPTQ_BITS = ["8", "4", "3", "2"]


def get_save_dir(*paths: str) -> os.PathLike:
Expand Down
7 changes: 6 additions & 1 deletion src/llamafactory/webui/components/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ...extras.misc import torch_gc
from ...extras.packages import is_gradio_available
from ...train.tuner import export_model
from ..common import GPTQ_BITS, get_save_dir
from ..common import get_save_dir, load_config
from ..locales import ALERTS


Expand All @@ -32,6 +32,9 @@
from ..engine import Engine


GPTQ_BITS = ["8", "4", "3", "2"]


def can_quantize(checkpoint_path: Union[str, List[str]]) -> "gr.Dropdown":
if isinstance(checkpoint_path, list) and len(checkpoint_path) != 0:
return gr.Dropdown(value="none", interactive=False)
Expand All @@ -54,6 +57,7 @@ def save_model(
export_dir: str,
export_hub_model_id: str,
) -> Generator[str, None, None]:
user_config = load_config()
error = ""
if not model_name:
error = ALERTS["err_no_model"][lang]
Expand All @@ -75,6 +79,7 @@ def save_model(

args = dict(
model_name_or_path=model_path,
cache_dir=user_config.get("cache_dir", None),
finetuning_type=finetuning_type,
template=template,
export_dir=export_dir,
Expand Down
10 changes: 5 additions & 5 deletions src/llamafactory/webui/components/top.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ def create_top() -> Dict[str, "Component"]:
checkpoint_path = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=6)

with gr.Row():
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", allow_custom_value=True, scale=2)
quantization_method = gr.Dropdown(choices=["bitsandbytes", "hqq", "eetq"], value="bitsandbytes", scale=2)
template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default", scale=2)
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=3)
booster = gr.Radio(choices=["auto", "flashattn2", "unsloth", "liger_kernel"], value="auto", scale=5)
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", allow_custom_value=True)
quantization_method = gr.Dropdown(choices=["bitsandbytes", "hqq", "eetq"], value="bitsandbytes")
template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default")
rope_scaling = gr.Dropdown(choices=["none", "linear", "dynamic", "yarn", "llama3"], value="none")
booster = gr.Dropdown(choices=["auto", "flashattn2", "unsloth", "liger_kernel"], value="auto")

model_name.change(get_model_info, [model_name], [model_path, template], queue=False).then(
list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False
Expand Down
38 changes: 20 additions & 18 deletions src/llamafactory/webui/locales.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
LOCALES = {
"lang": {
"en": {
"label": "Lang",
"label": "Language",
},
"ru": {
"label": "язык",
"label": "Язык",
},
"zh": {
"label": "语言",
Expand All @@ -30,19 +30,19 @@
"model_name": {
"en": {
"label": "Model name",
"info": "Input the name prefix to search for the model.",
"info": "Input the initial name to search for the model.",
},
"ru": {
"label": "Название модели",
"info": "Введите префикс имени для поиска модели.",
"info": "Введите начальное имя для поиска модели.",
},
"zh": {
"label": "模型名称",
"info": "输入首单词以检索模型。",
},
"ko": {
"label": "모델 이름",
"info": "모델을 검색하기 위해 이름 접두어를 입력하세요.",
"info": "모델을 검색할 초기 이름을 입력하세요.",
},
},
"model_path": {
Expand Down Expand Up @@ -129,48 +129,50 @@
},
"template": {
"en": {
"label": "Prompt template",
"info": "The template used in constructing prompts.",
"label": "Chat template",
"info": "The chat template used in constructing prompts.",
},
"ru": {
"label": "Шаблон запроса",
"info": "Шаблон, используемый при формировании запросов.",
"label": "Шаблон чата",
"info": "Шаблон чата используемый для составления подсказок.",
},
"zh": {
"label": "提示模板",
"label": "对话模板",
"info": "构建提示词时使用的模板。",
},
"ko": {
"label": "프롬프트 템플릿",
"info": "프롬프트 구성에 사용될 템플릿.",
"label": "채팅 템플릿",
"info": "프롬프트 작성에 사용되는 채팅 템플릿.",
},
},
"rope_scaling": {
"en": {
"label": "RoPE scaling",
"info": "RoPE scaling method to use.",
},
"ru": {
"label": "Масштабирование RoPE",
"info": "Метод масштабирования RoPE для использования.",
},
"zh": {
"label": "RoPE 插值方法",
},
"zh": {"label": "RoPE 插值方法", "info": "RoPE 插值时使用的方法。"},
"ko": {
"label": "RoPE 스케일링",
"info": "사용할 RoPE 스케일링 방법.",
},
},
"booster": {
"en": {
"label": "Booster",
"info": "Approach used to boost training speed.",
},
"ru": {
"label": "Ускоритель",
"info": "Подход, используемый для ускорения обучения.",
},
"zh": {
"label": "加速方式",
},
"zh": {"label": "加速方式", "info": "使用的加速方法。"},
"ko": {
"label": "부스터",
"info": "훈련 속도를 향상시키기 위해 사용된 접근 방식.",
},
},
"training_stage": {
Expand Down
11 changes: 6 additions & 5 deletions src/llamafactory/webui/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES
from ..extras.misc import is_gpu_or_npu_available, torch_gc, use_ray
from ..extras.packages import is_gradio_available, is_transformers_version_equal_to_4_46
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, QUANTIZATION_BITS, get_save_dir, load_config
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_save_dir, load_config
from .locales import ALERTS, LOCALES
from .utils import abort_process, gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd

Expand Down Expand Up @@ -120,7 +120,7 @@ def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
preprocessing_num_workers=16,
finetuning_type=finetuning_type,
template=get("top.template"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") != "none" else None,
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
use_unsloth=(get("top.booster") == "unsloth"),
enable_liger_kernel=(get("top.booster") == "liger_kernel"),
Expand Down Expand Up @@ -170,7 +170,7 @@ def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, get("top.checkpoint_path"))

# quantization
if get("top.quantization_bit") in QUANTIZATION_BITS:
if get("top.quantization_bit") != "none":
args["quantization_bit"] = int(get("top.quantization_bit"))
args["quantization_method"] = get("top.quantization_method")
args["double_quantization"] = not is_torch_npu_available()
Expand Down Expand Up @@ -280,7 +280,7 @@ def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
finetuning_type=finetuning_type,
quantization_method=get("top.quantization_method"),
template=get("top.template"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") != "none" else None,
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
use_unsloth=(get("top.booster") == "unsloth"),
dataset_dir=get("eval.dataset_dir"),
Expand Down Expand Up @@ -311,9 +311,10 @@ def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, get("top.checkpoint_path"))

# quantization
if get("top.quantization_bit") in QUANTIZATION_BITS:
if get("top.quantization_bit") != "none":
args["quantization_bit"] = int(get("top.quantization_bit"))
args["quantization_method"] = get("top.quantization_method")
args["double_quantization"] = not is_torch_npu_available()

return args

Expand Down

0 comments on commit 1f47b61

Please sign in to comment.