Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[model] support yarn #6693

Merged
merged 1 commit into from
Jan 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/llamafactory/data/processors/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def preprocess_pretrain_dataset(
text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]]

if not data_args.packing:
if data_args.template == "gemma":
if getattr(tokenizer, "add_bos_token", False):
text_examples = [tokenizer.bos_token + example for example in text_examples]

result = tokenizer(text_examples, add_special_tokens=False, truncation=True, max_length=data_args.cutoff_len)
Expand All @@ -47,7 +47,7 @@ def preprocess_pretrain_dataset(
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
if data_args.template == "gemma":
if getattr(tokenizer, "add_bos_token", False):
for i in range(len(result["input_ids"])):
result["input_ids"][i][0] = tokenizer.bos_token_id

Expand Down
2 changes: 1 addition & 1 deletion src/llamafactory/hparams/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
default=True,
metadata={"help": "Whether or not to use memory-efficient model loading."},
)
rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
rope_scaling: Optional[Literal["linear", "dynamic", "yarn", "llama3"]] = field(
default=None,
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
)
Expand Down
15 changes: 1 addition & 14 deletions src/llamafactory/model/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,20 +86,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
except Exception as e:
raise OSError("Failed to load tokenizer.") from e

if model_args.model_max_length is not None and tokenizer.model_max_length != model_args.model_max_length:
tokenizer.model_max_length = model_args.model_max_length

if model_args.new_special_tokens is not None:
num_added_tokens = tokenizer.add_special_tokens(
dict(additional_special_tokens=model_args.new_special_tokens),
replace_additional_special_tokens=False,
)
logger.info_rank0("Add {} to special tokens.".format(",".join(model_args.new_special_tokens)))
if num_added_tokens > 0 and not model_args.resize_vocab:
model_args.resize_vocab = True
logger.warning_rank0("New tokens have been added, changed `resize_vocab` to True.")

patch_tokenizer(tokenizer)
patch_tokenizer(tokenizer, model_args)
try:
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
patch_processor(processor, config, tokenizer, model_args)
Expand Down
18 changes: 13 additions & 5 deletions src/llamafactory/model/model_utils/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
logger.warning_rank0("Current model does not support RoPE scaling.")
return

rope_kwargs = {}
if model_args.model_max_length is not None:
if is_trainable and model_args.rope_scaling == "dynamic":
logger.warning_rank0(
Expand All @@ -50,14 +51,21 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
if current_max_length and model_args.model_max_length > current_max_length:
logger.info_rank0(f"Enlarge max model length from {current_max_length} to {model_args.model_max_length}.")
setattr(config, "max_position_embeddings", model_args.model_max_length)
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
rope_kwargs["factor"] = float(math.ceil(model_args.model_max_length / current_max_length))
else:
logger.warning_rank0("Input length is smaller than max length. Consider increase input length.")
scaling_factor = 1.0
rope_kwargs["factor"] = 1.0

if model_args.rope_scaling == "dynamic":
rope_kwargs["original_max_position_embeddings"] = current_max_length
elif model_args.rope_scaling == "llama3":
rope_kwargs["original_max_position_embeddings"] = current_max_length
rope_kwargs["low_freq_factor"] = 1.0
rope_kwargs["high_freq_factor"] = 4.0
else:
scaling_factor = 2.0
rope_kwargs["factor"] = 2.0

setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
setattr(config, "rope_scaling", {"rope_type": model_args.rope_scaling, **rope_kwargs})
logger.info_rank0(
f"Using {model_args.rope_scaling} scaling strategy and setting scaling factor to {scaling_factor}"
f"Using {model_args.rope_scaling} scaling strategy and setting scaling factor to {rope_kwargs['factor']}."
)
15 changes: 14 additions & 1 deletion src/llamafactory/model/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,23 @@
logger = logging.get_logger(__name__)


def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None:
def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> None:
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)

if model_args.model_max_length is not None and tokenizer.model_max_length != model_args.model_max_length:
tokenizer.model_max_length = model_args.model_max_length

if model_args.new_special_tokens is not None:
num_added_tokens = tokenizer.add_special_tokens(
dict(additional_special_tokens=model_args.new_special_tokens),
replace_additional_special_tokens=False,
)
logger.info_rank0("Add {} to special tokens.".format(",".join(model_args.new_special_tokens)))
if num_added_tokens > 0 and not model_args.resize_vocab:
model_args.resize_vocab = True
logger.warning_rank0("New tokens have been added, changed `resize_vocab` to True.")


def patch_processor(
processor: "ProcessorMixin",
Expand Down
24 changes: 15 additions & 9 deletions src/llamafactory/webui/chatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
import os
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple

from transformers.utils import is_torch_npu_available

from ..chat import ChatModel
from ..data import Role
from ..extras.constants import PEFT_METHODS
from ..extras.misc import torch_gc
from ..extras.packages import is_gradio_available
from .common import QUANTIZATION_BITS, get_save_dir
from .common import get_save_dir, load_config
from .locales import ALERTS


Expand Down Expand Up @@ -59,6 +61,8 @@ def load_model(self, data) -> Generator[str, None, None]:
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
finetuning_type, checkpoint_path = get("top.finetuning_type"), get("top.checkpoint_path")
user_config = load_config()

error = ""
if self.loaded:
error = ALERTS["err_exists"][lang]
Expand All @@ -74,26 +78,22 @@ def load_model(self, data) -> Generator[str, None, None]:
yield error
return

if get("top.quantization_bit") in QUANTIZATION_BITS:
quantization_bit = int(get("top.quantization_bit"))
else:
quantization_bit = None

yield ALERTS["info_loading"][lang]
args = dict(
model_name_or_path=model_path,
cache_dir=user_config.get("cache_dir", None),
finetuning_type=finetuning_type,
quantization_bit=quantization_bit,
quantization_method=get("top.quantization_method"),
template=get("top.template"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") != "none" else None,
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
use_unsloth=(get("top.booster") == "unsloth"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
enable_liger_kernel=(get("top.booster") == "liger_kernel"),
infer_backend=get("infer.infer_backend"),
infer_dtype=get("infer.infer_dtype"),
trust_remote_code=True,
)

# checkpoints
if checkpoint_path:
if finetuning_type in PEFT_METHODS: # list
args["adapter_name_or_path"] = ",".join(
Expand All @@ -102,6 +102,12 @@ def load_model(self, data) -> Generator[str, None, None]:
else: # str
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, checkpoint_path)

# quantization
if get("top.quantization_bit") != "none":
args["quantization_bit"] = int(get("top.quantization_bit"))
args["quantization_method"] = get("top.quantization_method")
args["double_quantization"] = not is_torch_npu_available()

super().__init__(args)
yield ALERTS["info_loaded"][lang]

Expand Down
2 changes: 0 additions & 2 deletions src/llamafactory/webui/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@
DEFAULT_DATA_DIR = "data"
DEFAULT_SAVE_DIR = "saves"
USER_CONFIG = "user_config.yaml"
QUANTIZATION_BITS = ["8", "6", "5", "4", "3", "2", "1"]
GPTQ_BITS = ["8", "4", "3", "2"]


def get_save_dir(*paths: str) -> os.PathLike:
Expand Down
7 changes: 6 additions & 1 deletion src/llamafactory/webui/components/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ...extras.misc import torch_gc
from ...extras.packages import is_gradio_available
from ...train.tuner import export_model
from ..common import GPTQ_BITS, get_save_dir
from ..common import get_save_dir, load_config
from ..locales import ALERTS


Expand All @@ -32,6 +32,9 @@
from ..engine import Engine


GPTQ_BITS = ["8", "4", "3", "2"]


def can_quantize(checkpoint_path: Union[str, List[str]]) -> "gr.Dropdown":
if isinstance(checkpoint_path, list) and len(checkpoint_path) != 0:
return gr.Dropdown(value="none", interactive=False)
Expand All @@ -54,6 +57,7 @@ def save_model(
export_dir: str,
export_hub_model_id: str,
) -> Generator[str, None, None]:
user_config = load_config()
error = ""
if not model_name:
error = ALERTS["err_no_model"][lang]
Expand All @@ -75,6 +79,7 @@ def save_model(

args = dict(
model_name_or_path=model_path,
cache_dir=user_config.get("cache_dir", None),
finetuning_type=finetuning_type,
template=template,
export_dir=export_dir,
Expand Down
10 changes: 5 additions & 5 deletions src/llamafactory/webui/components/top.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ def create_top() -> Dict[str, "Component"]:
checkpoint_path = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=6)

with gr.Row():
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", allow_custom_value=True, scale=2)
quantization_method = gr.Dropdown(choices=["bitsandbytes", "hqq", "eetq"], value="bitsandbytes", scale=2)
template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default", scale=2)
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=3)
booster = gr.Radio(choices=["auto", "flashattn2", "unsloth", "liger_kernel"], value="auto", scale=5)
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", allow_custom_value=True)
quantization_method = gr.Dropdown(choices=["bitsandbytes", "hqq", "eetq"], value="bitsandbytes")
template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default")
rope_scaling = gr.Dropdown(choices=["none", "linear", "dynamic", "yarn", "llama3"], value="none")
booster = gr.Dropdown(choices=["auto", "flashattn2", "unsloth", "liger_kernel"], value="auto")

model_name.change(get_model_info, [model_name], [model_path, template], queue=False).then(
list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False
Expand Down
38 changes: 20 additions & 18 deletions src/llamafactory/webui/locales.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
LOCALES = {
"lang": {
"en": {
"label": "Lang",
"label": "Language",
},
"ru": {
"label": "язык",
"label": "Язык",
},
"zh": {
"label": "语言",
Expand All @@ -30,19 +30,19 @@
"model_name": {
"en": {
"label": "Model name",
"info": "Input the name prefix to search for the model.",
"info": "Input the initial name to search for the model.",
},
"ru": {
"label": "Название модели",
"info": "Введите префикс имени для поиска модели.",
"info": "Введите начальное имя для поиска модели.",
},
"zh": {
"label": "模型名称",
"info": "输入首单词以检索模型。",
},
"ko": {
"label": "모델 이름",
"info": "모델을 검색하기 위해 이름 접두어를 입력하세요.",
"info": "모델을 검색할 초기 이름을 입력하세요.",
},
},
"model_path": {
Expand Down Expand Up @@ -129,48 +129,50 @@
},
"template": {
"en": {
"label": "Prompt template",
"info": "The template used in constructing prompts.",
"label": "Chat template",
"info": "The chat template used in constructing prompts.",
},
"ru": {
"label": "Шаблон запроса",
"info": "Шаблон, используемый при формировании запросов.",
"label": "Шаблон чата",
"info": "Шаблон чата используемый для составления подсказок.",
},
"zh": {
"label": "提示模板",
"label": "对话模板",
"info": "构建提示词时使用的模板。",
},
"ko": {
"label": "프롬프트 템플릿",
"info": "프롬프트 구성에 사용될 템플릿.",
"label": "채팅 템플릿",
"info": "프롬프트 작성에 사용되는 채팅 템플릿.",
},
},
"rope_scaling": {
"en": {
"label": "RoPE scaling",
"info": "RoPE scaling method to use.",
},
"ru": {
"label": "Масштабирование RoPE",
"info": "Метод масштабирования RoPE для использования.",
},
"zh": {
"label": "RoPE 插值方法",
},
"zh": {"label": "RoPE 插值方法", "info": "RoPE 插值时使用的方法。"},
"ko": {
"label": "RoPE 스케일링",
"info": "사용할 RoPE 스케일링 방법.",
},
},
"booster": {
"en": {
"label": "Booster",
"info": "Approach used to boost training speed.",
},
"ru": {
"label": "Ускоритель",
"info": "Подход, используемый для ускорения обучения.",
},
"zh": {
"label": "加速方式",
},
"zh": {"label": "加速方式", "info": "使用的加速方法。"},
"ko": {
"label": "부스터",
"info": "훈련 속도를 향상시키기 위해 사용된 접근 방식.",
},
},
"training_stage": {
Expand Down
11 changes: 6 additions & 5 deletions src/llamafactory/webui/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES
from ..extras.misc import is_gpu_or_npu_available, torch_gc, use_ray
from ..extras.packages import is_gradio_available, is_transformers_version_equal_to_4_46
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, QUANTIZATION_BITS, get_save_dir, load_config
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_save_dir, load_config
from .locales import ALERTS, LOCALES
from .utils import abort_process, gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd

Expand Down Expand Up @@ -120,7 +120,7 @@ def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
preprocessing_num_workers=16,
finetuning_type=finetuning_type,
template=get("top.template"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") != "none" else None,
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
use_unsloth=(get("top.booster") == "unsloth"),
enable_liger_kernel=(get("top.booster") == "liger_kernel"),
Expand Down Expand Up @@ -170,7 +170,7 @@ def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, get("top.checkpoint_path"))

# quantization
if get("top.quantization_bit") in QUANTIZATION_BITS:
if get("top.quantization_bit") != "none":
args["quantization_bit"] = int(get("top.quantization_bit"))
args["quantization_method"] = get("top.quantization_method")
args["double_quantization"] = not is_torch_npu_available()
Expand Down Expand Up @@ -280,7 +280,7 @@ def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
finetuning_type=finetuning_type,
quantization_method=get("top.quantization_method"),
template=get("top.template"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") != "none" else None,
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
use_unsloth=(get("top.booster") == "unsloth"),
dataset_dir=get("eval.dataset_dir"),
Expand Down Expand Up @@ -311,9 +311,10 @@ def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, get("top.checkpoint_path"))

# quantization
if get("top.quantization_bit") in QUANTIZATION_BITS:
if get("top.quantization_bit") != "none":
args["quantization_bit"] = int(get("top.quantization_bit"))
args["quantization_method"] = get("top.quantization_method")
args["double_quantization"] = not is_torch_npu_available()

return args

Expand Down
Loading