From 45e68b9f092879dda55023ebbcd8cf4660e3045a Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Fri, 31 Jan 2025 00:09:21 +0800 Subject: [PATCH] [webui] improve webui & reasoning mode (#6778) --- README.md | 9 +- README_zh.md | 9 +- src/llamafactory/data/template.py | 25 +- src/llamafactory/extras/constants.py | 21 +- src/llamafactory/webui/chatter.py | 44 ++- src/llamafactory/webui/common.py | 207 ++++++++++--- src/llamafactory/webui/components/chatbot.py | 23 +- src/llamafactory/webui/components/data.py | 6 + src/llamafactory/webui/components/eval.py | 3 +- src/llamafactory/webui/components/top.py | 4 +- src/llamafactory/webui/components/train.py | 4 +- src/llamafactory/webui/control.py | 201 ++++++++++++ src/llamafactory/webui/css.py | 23 ++ src/llamafactory/webui/engine.py | 18 +- src/llamafactory/webui/locales.py | 14 + src/llamafactory/webui/manager.py | 4 + src/llamafactory/webui/runner.py | 60 +++- src/llamafactory/webui/utils.py | 304 ------------------- 18 files changed, 570 insertions(+), 409 deletions(-) create mode 100644 src/llamafactory/webui/control.py delete mode 100644 src/llamafactory/webui/utils.py diff --git a/README.md b/README.md index 0acf68c5f8..adc31754e0 100644 --- a/README.md +++ b/README.md @@ -216,16 +216,15 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ | [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 | | [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere | | [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek | -| [DeepSeek 2.5/3](https://huggingface.co/deepseek-ai) | 236B/685B | deepseek3 | -| [DeepSeek R1](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseekr1 | +| [DeepSeek 2.5/3](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 | +| [DeepSeek R1](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseek3 | | [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon | | [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma | | [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 | | [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - | | [Granite 3.0-3.1](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 | | [Index](https://huggingface.co/IndexTeam) | 1.9B | index | -| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 | -| [InternLM3](https://huggingface.co/internlm) | 8B | intern3 | +| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 | | [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | | [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | | [Llama 3-3.3](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 | @@ -830,7 +829,7 @@ If you have a project that should be incorporated, please contact via email or c This repository is licensed under the [Apache-2.0 License](LICENSE). -Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3/Phi-4](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) +Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3/Phi-4](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) ## Citation diff --git a/README_zh.md b/README_zh.md index a20c6bce5a..61fac00daa 100644 --- a/README_zh.md +++ b/README_zh.md @@ -218,16 +218,15 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272 | [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 | | [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere | | [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek | -| [DeepSeek 2.5/3](https://huggingface.co/deepseek-ai) | 236B/685B | deepseek3 | -| [DeepSeek R1](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseekr1 | +| [DeepSeek 2.5/3](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 | +| [DeepSeek R1](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseek3 | | [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon | | [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma | | [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 | | [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - | | [Granite 3.0-3.1](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 | | [Index](https://huggingface.co/IndexTeam) | 1.9B | index | -| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 | -| [InternLM3](https://huggingface.co/internlm) | 8B | intern3 | +| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 | | [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | | [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | | [Llama 3-3.3](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 | @@ -832,7 +831,7 @@ swanlab_run_name: test_run # 可选 本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。 -使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3/Phi-4](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) +使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3/Phi-4](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) ## 引用 diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 4d7f5eb919..5b775db7c4 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -719,6 +719,13 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: format_assistant=StringFormatter(slots=["{{content}}\n"]), format_system=StringFormatter(slots=["<|System|>:{{content}}\n"]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + default_system=( + "You are an AI assistant whose name is InternLM (书生·浦语).\n" + "- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory " + "(上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n" + "- InternLM (书生·浦语) can understand and communicate fluently in the language " + "chosen by the user such as English and 中文." + ), stop_words=[""], ) @@ -729,17 +736,13 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]), - stop_words=["<|im_end|>"], -) - - -# copied from intern2 template -_register_template( - name="intern3", - format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), - format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), - format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), - format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + default_system=( + "You are an AI assistant whose name is InternLM (书生·浦语).\n" + "- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory " + "(上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n" + "- InternLM (书生·浦语) can understand and communicate fluently in the language " + "chosen by the user such as English and 中文." + ), stop_words=["<|im_end|>"], ) diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index e98aadbd28..c0903db164 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -105,7 +105,7 @@ def register_model_group( ) -> None: for name, path in models.items(): SUPPORTED_MODELS[name] = path - if template is not None and (any(suffix in name for suffix in ("-Chat", "-Instruct")) or vision): + if template is not None and (any(suffix in name for suffix in ("-Chat", "-Distill", "-Instruct")) or vision): DEFAULT_TEMPLATE[name] = template if vision: VISION_MODELS.add(name) @@ -485,11 +485,11 @@ def register_model_group( DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2.5-1210", DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2.5-1210", }, - "DeepSeek-V3-685B-Base": { + "DeepSeek-V3-671B-Base": { DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V3-Base", DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V3-Base", }, - "DeepSeek-V3-685B-Chat": { + "DeepSeek-V3-671B-Chat": { DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V3", DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V3", }, @@ -517,11 +517,11 @@ def register_model_group( DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1-Distill-Llama-70B", DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1-Distill-Llama-70B", }, - "DeepSeek-R1-671B-Zero": { + "DeepSeek-R1-671B-Chat-Zero": { DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1-Zero", DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1-Zero", - }, - "DeepSeek-R1-671B": { + }, + "DeepSeek-R1-671B-Chat": { DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1", DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1", }, @@ -845,20 +845,15 @@ def register_model_group( DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-20b-chat", DownloadSource.OPENMIND: "Intern/internlm2_5-20b-chat", }, - }, - template="intern2", -) - -register_model_group( - models={ "InternLM3-8B-Chat": { DownloadSource.DEFAULT: "internlm/internlm3-8b-instruct", DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm3-8b-instruct", }, }, - template="intern3", + template="intern2", ) + register_model_group( models={ "Jamba-v0.1": { diff --git a/src/llamafactory/webui/chatter.py b/src/llamafactory/webui/chatter.py index 7abdf8b53f..3c598cf98b 100644 --- a/src/llamafactory/webui/chatter.py +++ b/src/llamafactory/webui/chatter.py @@ -36,6 +36,30 @@ import gradio as gr +def _format_response(text: str, lang: str) -> str: + r""" + Post-processes the response text. + + Based on: https://huggingface.co/spaces/Lyte/DeepSeek-R1-Distill-Qwen-1.5B-Demo-GGUF/blob/main/app.py + """ + if "" not in text: + return text + + text = text.replace("", "") + result = text.split("", maxsplit=1) + if len(result) == 1: + summary = ALERTS["info_thinking"][lang] + thought, answer = text, "" + else: + summary = ALERTS["info_thought"][lang] + thought, answer = result + + return ( + f"
{summary}\n\n" + f"
\n{thought}\n
\n
{answer}" + ) + + class WebChatModel(ChatModel): def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None: self.manager = manager @@ -124,19 +148,26 @@ def unload_model(self, data) -> Generator[str, None, None]: torch_gc() yield ALERTS["info_unloaded"][lang] + @staticmethod def append( - self, chatbot: List[Dict[str, str]], messages: List[Dict[str, str]], role: str, query: str, ) -> Tuple[List[Dict[str, str]], List[Dict[str, str]], str]: + r""" + Adds the user input to chatbot. + + Inputs: infer.chatbot, infer.messages, infer.role, infer.query + Output: infer.chatbot, infer.messages + """ return chatbot + [{"role": "user", "content": query}], messages + [{"role": role, "content": query}], "" def stream( self, chatbot: List[Dict[str, str]], messages: List[Dict[str, str]], + lang: str, system: str, tools: str, image: Optional[Any], @@ -145,6 +176,12 @@ def stream( top_p: float, temperature: float, ) -> Generator[Tuple[List[Dict[str, str]], List[Dict[str, str]]], None, None]: + r""" + Generates output text in stream. + + Inputs: infer.chatbot, infer.messages, infer.system, infer.tools, infer.image, infer.video, ... + Output: infer.chatbot, infer.messages + """ chatbot.append({"role": "assistant", "content": ""}) response = "" for new_text in self.stream_chat( @@ -157,7 +194,6 @@ def stream( top_p=top_p, temperature=temperature, ): - new_text = '' if any(t in new_text for t in ('', '')) else new_text response += new_text if tools: result = self.engine.template.extract_tool(response) @@ -166,12 +202,12 @@ def stream( if isinstance(result, list): tool_calls = [{"name": tool.name, "arguments": json.loads(tool.arguments)} for tool in result] - tool_calls = json.dumps(tool_calls, indent=4, ensure_ascii=False) + tool_calls = json.dumps(tool_calls, ensure_ascii=False) output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_calls}] bot_text = "```json\n" + tool_calls + "\n```" else: output_messages = messages + [{"role": Role.ASSISTANT.value, "content": result}] - bot_text = result + bot_text = _format_response(result, lang) chatbot[-1] = {"role": "assistant", "content": bot_text} yield chatbot, output_messages diff --git a/src/llamafactory/webui/common.py b/src/llamafactory/webui/common.py index 64684f048b..e8f1e097c6 100644 --- a/src/llamafactory/webui/common.py +++ b/src/llamafactory/webui/common.py @@ -14,34 +14,28 @@ import json import os +import signal from collections import defaultdict -from typing import Any, Dict, Optional, Tuple +from datetime import datetime +from typing import Any, Dict, Optional, Union +from psutil import Process from yaml import safe_dump, safe_load from ..extras import logging from ..extras.constants import ( - CHECKPOINT_NAMES, DATA_CONFIG, DEFAULT_TEMPLATE, - PEFT_METHODS, - STAGES_USE_PAIR_DATA, SUPPORTED_MODELS, - TRAINING_STAGES, + TRAINING_ARGS, VISION_MODELS, DownloadSource, ) from ..extras.misc import use_modelscope, use_openmind -from ..extras.packages import is_gradio_available - - -if is_gradio_available(): - import gradio as gr logger = logging.get_logger(__name__) - DEFAULT_CACHE_DIR = "cache" DEFAULT_CONFIG_DIR = "config" DEFAULT_DATA_DIR = "data" @@ -49,6 +43,21 @@ USER_CONFIG = "user_config.yaml" +def abort_process(pid: int) -> None: + r""" + Aborts the processes recursively in a bottom-up way. + """ + try: + children = Process(pid).children() + if children: + for child in children: + abort_process(child.pid) + + os.kill(pid, signal.SIGABRT) + except Exception: + pass + + def get_save_dir(*paths: str) -> os.PathLike: r""" Gets the path to saved model checkpoints. @@ -61,19 +70,19 @@ def get_save_dir(*paths: str) -> os.PathLike: return os.path.join(DEFAULT_SAVE_DIR, *paths) -def get_config_path() -> os.PathLike: +def _get_config_path() -> os.PathLike: r""" Gets the path to user config. """ return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG) -def load_config() -> Dict[str, Any]: +def load_config() -> Dict[str, Union[str, Dict[str, Any]]]: r""" Loads user config if exists. """ try: - with open(get_config_path(), encoding="utf-8") as f: + with open(_get_config_path(), encoding="utf-8") as f: return safe_load(f) except Exception: return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None} @@ -92,7 +101,7 @@ def save_config(lang: str, model_name: Optional[str] = None, model_path: Optiona if model_name and model_path: user_config["path_dict"][model_name] = model_path - with open(get_config_path(), "w", encoding="utf-8") as f: + with open(_get_config_path(), "w", encoding="utf-8") as f: safe_dump(user_config, f) @@ -120,20 +129,9 @@ def get_model_path(model_name: str) -> str: return model_path -def get_model_info(model_name: str) -> Tuple[str, str]: - r""" - Gets the necessary information of this model. - - Returns: - model_path (str) - template (str) - """ - return get_model_path(model_name), get_template(model_name) - - def get_template(model_name: str) -> str: r""" - Gets the template name if the model is a chat model. + Gets the template name if the model is a chat/distill/instruct model. """ return DEFAULT_TEMPLATE.get(model_name, "default") @@ -145,24 +143,11 @@ def get_visual(model_name: str) -> bool: return model_name in VISION_MODELS -def list_checkpoints(model_name: str, finetuning_type: str) -> "gr.Dropdown": +def get_time() -> str: r""" - Lists all available checkpoints. + Gets current date and time. """ - checkpoints = [] - if model_name: - save_dir = get_save_dir(model_name, finetuning_type) - if save_dir and os.path.isdir(save_dir): - for checkpoint in os.listdir(save_dir): - if os.path.isdir(os.path.join(save_dir, checkpoint)) and any( - os.path.isfile(os.path.join(save_dir, checkpoint, name)) for name in CHECKPOINT_NAMES - ): - checkpoints.append(checkpoint) - - if finetuning_type in PEFT_METHODS: - return gr.Dropdown(value=[], choices=checkpoints, multiselect=True) - else: - return gr.Dropdown(value=None, choices=checkpoints, multiselect=False) + return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S") def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]: @@ -181,11 +166,135 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]: return {} -def list_datasets(dataset_dir: str = None, training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Dropdown": +def load_args(config_path: str) -> Optional[Dict[str, Any]]: + r""" + Loads the training configuration from config path. + """ + try: + with open(config_path, encoding="utf-8") as f: + return safe_load(f) + except Exception: + return None + + +def save_args(config_path: str, config_dict: Dict[str, Any]) -> None: r""" - Lists all available datasets in the dataset dir for the training stage. + Saves the training configuration to config path. """ - dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR) - ranking = TRAINING_STAGES[training_stage] in STAGES_USE_PAIR_DATA - datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking] - return gr.Dropdown(choices=datasets) + with open(config_path, "w", encoding="utf-8") as f: + safe_dump(config_dict, f) + + +def _clean_cmd(args: Dict[str, Any]) -> Dict[str, Any]: + r""" + Removes args with NoneType or False or empty string value. + """ + no_skip_keys = ["packing"] + return {k: v for k, v in args.items() if (k in no_skip_keys) or (v is not None and v is not False and v != "")} + + +def gen_cmd(args: Dict[str, Any]) -> str: + r""" + Generates CLI commands for previewing. + """ + cmd_lines = ["llamafactory-cli train "] + for k, v in _clean_cmd(args).items(): + if isinstance(v, dict): + cmd_lines.append(f" --{k} {json.dumps(v, ensure_ascii=False)} ") + elif isinstance(v, list): + cmd_lines.append(f" --{k} {' '.join(map(str, v))} ") + else: + cmd_lines.append(f" --{k} {str(v)} ") + + if os.name == "nt": + cmd_text = "`\n".join(cmd_lines) + else: + cmd_text = "\\\n".join(cmd_lines) + + cmd_text = f"```bash\n{cmd_text}\n```" + return cmd_text + + +def save_cmd(args: Dict[str, Any]) -> str: + r""" + Saves CLI commands to launch training. + """ + output_dir = args["output_dir"] + os.makedirs(output_dir, exist_ok=True) + with open(os.path.join(output_dir, TRAINING_ARGS), "w", encoding="utf-8") as f: + safe_dump(_clean_cmd(args), f) + + return os.path.join(output_dir, TRAINING_ARGS) + + +def load_eval_results(path: os.PathLike) -> str: + r""" + Gets scores after evaluation. + """ + with open(path, encoding="utf-8") as f: + result = json.dumps(json.load(f), indent=4) + + return f"```json\n{result}\n```\n" + + +def create_ds_config() -> None: + r""" + Creates deepspeed config in the current directory. + """ + os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True) + ds_config = { + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "zero_allow_untested_optimizer": True, + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 1, + }, + "bf16": {"enabled": "auto"}, + } + offload_config = { + "device": "cpu", + "pin_memory": True, + } + ds_config["zero_optimization"] = { + "stage": 2, + "allgather_partitions": True, + "allgather_bucket_size": 5e8, + "overlap_comm": True, + "reduce_scatter": True, + "reduce_bucket_size": 5e8, + "contiguous_gradients": True, + "round_robin_gradients": True, + } + with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z2_config.json"), "w", encoding="utf-8") as f: + json.dump(ds_config, f, indent=2) + + ds_config["zero_optimization"]["offload_optimizer"] = offload_config + with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z2_offload_config.json"), "w", encoding="utf-8") as f: + json.dump(ds_config, f, indent=2) + + ds_config["zero_optimization"] = { + "stage": 3, + "overlap_comm": True, + "contiguous_gradients": True, + "sub_group_size": 1e9, + "reduce_bucket_size": "auto", + "stage3_prefetch_bucket_size": "auto", + "stage3_param_persistence_threshold": "auto", + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_gather_16bit_weights_on_model_save": True, + } + with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z3_config.json"), "w", encoding="utf-8") as f: + json.dump(ds_config, f, indent=2) + + ds_config["zero_optimization"]["offload_optimizer"] = offload_config + ds_config["zero_optimization"]["offload_param"] = offload_config + with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z3_offload_config.json"), "w", encoding="utf-8") as f: + json.dump(ds_config, f, indent=2) diff --git a/src/llamafactory/webui/components/chatbot.py b/src/llamafactory/webui/components/chatbot.py index 53e41b93cb..840c190d92 100644 --- a/src/llamafactory/webui/components/chatbot.py +++ b/src/llamafactory/webui/components/chatbot.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json from typing import TYPE_CHECKING, Dict, Tuple from ...data import Role from ...extras.packages import is_gradio_available -from ..utils import check_json_schema +from ..locales import ALERTS if is_gradio_available(): @@ -29,9 +30,27 @@ from ..engine import Engine +def check_json_schema(text: str, lang: str) -> None: + r""" + Checks if the json schema is valid. + """ + try: + tools = json.loads(text) + if tools: + assert isinstance(tools, list) + for tool in tools: + if "name" not in tool: + raise NotImplementedError("Name not found.") + except NotImplementedError: + gr.Warning(ALERTS["err_tool_name"][lang]) + except Exception: + gr.Warning(ALERTS["err_json_schema"][lang]) + + def create_chat_box( engine: "Engine", visible: bool = False ) -> Tuple["Component", "Component", Dict[str, "Component"]]: + lang = engine.manager.get_elem_by_id("top.lang") with gr.Column(visible=visible) as chat_box: chatbot = gr.Chatbot(type="messages", show_copy_button=True) messages = gr.State([]) @@ -67,7 +86,7 @@ def create_chat_box( [chatbot, messages, query], ).then( engine.chatter.stream, - [chatbot, messages, system, tools, image, video, max_new_tokens, top_p, temperature], + [chatbot, messages, lang, system, tools, image, video, max_new_tokens, top_p, temperature], [chatbot, messages], ) clear_btn.click(lambda: ([], []), outputs=[chatbot, messages]) diff --git a/src/llamafactory/webui/components/data.py b/src/llamafactory/webui/components/data.py index cc0428fec6..e62e1823df 100644 --- a/src/llamafactory/webui/components/data.py +++ b/src/llamafactory/webui/components/data.py @@ -40,6 +40,9 @@ def next_page(page_index: int, total_num: int) -> int: def can_preview(dataset_dir: str, dataset: list) -> "gr.Button": + r""" + Checks if the dataset is a local dataset. + """ try: with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f: dataset_info = json.load(f) @@ -67,6 +70,9 @@ def _load_data_file(file_path: str) -> List[Any]: def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, "gr.Column"]: + r""" + Gets the preview samples from the dataset. + """ with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f: dataset_info = json.load(f) diff --git a/src/llamafactory/webui/components/eval.py b/src/llamafactory/webui/components/eval.py index 8d26d23203..39a12026b2 100644 --- a/src/llamafactory/webui/components/eval.py +++ b/src/llamafactory/webui/components/eval.py @@ -15,7 +15,8 @@ from typing import TYPE_CHECKING, Dict from ...extras.packages import is_gradio_available -from ..common import DEFAULT_DATA_DIR, list_datasets +from ..common import DEFAULT_DATA_DIR +from ..control import list_datasets from .data import create_preview_box diff --git a/src/llamafactory/webui/components/top.py b/src/llamafactory/webui/components/top.py index b7b70ebfe6..467e31125e 100644 --- a/src/llamafactory/webui/components/top.py +++ b/src/llamafactory/webui/components/top.py @@ -17,8 +17,8 @@ from ...data import TEMPLATES from ...extras.constants import METHODS, SUPPORTED_MODELS from ...extras.packages import is_gradio_available -from ..common import get_model_info, list_checkpoints, save_config -from ..utils import can_quantize, can_quantize_to +from ..common import save_config +from ..control import can_quantize, can_quantize_to, get_model_info, list_checkpoints if is_gradio_available(): diff --git a/src/llamafactory/webui/components/train.py b/src/llamafactory/webui/components/train.py index b4c0bb2aa8..28aa3a8c02 100644 --- a/src/llamafactory/webui/components/train.py +++ b/src/llamafactory/webui/components/train.py @@ -19,8 +19,8 @@ from ...extras.constants import TRAINING_STAGES from ...extras.misc import get_device_count from ...extras.packages import is_gradio_available -from ..common import DEFAULT_DATA_DIR, list_checkpoints, list_datasets -from ..utils import change_stage, list_config_paths, list_output_dirs +from ..common import DEFAULT_DATA_DIR +from ..control import change_stage, list_checkpoints, list_config_paths, list_datasets, list_output_dirs from .data import create_preview_box diff --git a/src/llamafactory/webui/control.py b/src/llamafactory/webui/control.py new file mode 100644 index 0000000000..b8087af67a --- /dev/null +++ b/src/llamafactory/webui/control.py @@ -0,0 +1,201 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from typing import Any, Dict, List, Optional, Tuple + +from transformers.trainer_utils import get_last_checkpoint + +from ..extras.constants import ( + CHECKPOINT_NAMES, + PEFT_METHODS, + RUNNING_LOG, + STAGES_USE_PAIR_DATA, + TRAINER_LOG, + TRAINING_STAGES, +) +from ..extras.packages import is_gradio_available, is_matplotlib_available +from ..extras.ploting import gen_loss_plot +from ..model import QuantizationMethod +from .common import DEFAULT_CONFIG_DIR, DEFAULT_DATA_DIR, get_model_path, get_save_dir, get_template, load_dataset_info + + +if is_gradio_available(): + import gradio as gr + + +def can_quantize(finetuning_type: str) -> "gr.Dropdown": + r""" + Judges if the quantization is available in this finetuning type. + + Inputs: top.finetuning_type + Outputs: top.quantization_bit + """ + if finetuning_type not in PEFT_METHODS: + return gr.Dropdown(value="none", interactive=False) + else: + return gr.Dropdown(interactive=True) + + +def can_quantize_to(quantization_method: str) -> "gr.Dropdown": + r""" + Gets the available quantization bits. + + Inputs: top.quantization_method + Outputs: top.quantization_bit + """ + if quantization_method == QuantizationMethod.BITS_AND_BYTES.value: + available_bits = ["none", "8", "4"] + elif quantization_method == QuantizationMethod.HQQ.value: + available_bits = ["none", "8", "6", "5", "4", "3", "2", "1"] + elif quantization_method == QuantizationMethod.EETQ.value: + available_bits = ["none", "8"] + + return gr.Dropdown(choices=available_bits) + + +def change_stage(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Tuple[List[str], bool]: + r""" + Modifys states after changing the training stage. + + Inputs: train.training_stage + Outputs: train.dataset, train.packing + """ + return [], TRAINING_STAGES[training_stage] == "pt" + + +def get_model_info(model_name: str) -> Tuple[str, str]: + r""" + Gets the necessary information of this model. + + Inputs: top.model_name + Outputs: top.model_path, top.template + """ + return get_model_path(model_name), get_template(model_name) + + +def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr.Slider", Optional["gr.Plot"]]: + r""" + Gets training infomation for monitor. + + If do_train is True: + Inputs: train.output_path + Outputs: train.output_box, train.progress_bar, train.loss_viewer + If do_train is False: + Inputs: eval.output_path + Outputs: eval.output_box, eval.progress_bar, None + """ + running_log = "" + running_progress = gr.Slider(visible=False) + running_loss = None + + running_log_path = os.path.join(output_path, RUNNING_LOG) + if os.path.isfile(running_log_path): + with open(running_log_path, encoding="utf-8") as f: + running_log = f.read() + + trainer_log_path = os.path.join(output_path, TRAINER_LOG) + if os.path.isfile(trainer_log_path): + trainer_log: List[Dict[str, Any]] = [] + with open(trainer_log_path, encoding="utf-8") as f: + for line in f: + trainer_log.append(json.loads(line)) + + if len(trainer_log) != 0: + latest_log = trainer_log[-1] + percentage = latest_log["percentage"] + label = "Running {:d}/{:d}: {} < {}".format( + latest_log["current_steps"], + latest_log["total_steps"], + latest_log["elapsed_time"], + latest_log["remaining_time"], + ) + running_progress = gr.Slider(label=label, value=percentage, visible=True) + + if do_train and is_matplotlib_available(): + running_loss = gr.Plot(gen_loss_plot(trainer_log)) + + return running_log, running_progress, running_loss + + +def list_checkpoints(model_name: str, finetuning_type: str) -> "gr.Dropdown": + r""" + Lists all available checkpoints. + + Inputs: top.model_name, top.finetuning_type + Outputs: top.checkpoint_path + """ + checkpoints = [] + if model_name: + save_dir = get_save_dir(model_name, finetuning_type) + if save_dir and os.path.isdir(save_dir): + for checkpoint in os.listdir(save_dir): + if os.path.isdir(os.path.join(save_dir, checkpoint)) and any( + os.path.isfile(os.path.join(save_dir, checkpoint, name)) for name in CHECKPOINT_NAMES + ): + checkpoints.append(checkpoint) + + if finetuning_type in PEFT_METHODS: + return gr.Dropdown(value=[], choices=checkpoints, multiselect=True) + else: + return gr.Dropdown(value=None, choices=checkpoints, multiselect=False) + + +def list_config_paths(current_time: str) -> "gr.Dropdown": + r""" + Lists all the saved configuration files. + + Inputs: train.current_time + Outputs: train.config_path + """ + config_files = [f"{current_time}.yaml"] + if os.path.isdir(DEFAULT_CONFIG_DIR): + for file_name in os.listdir(DEFAULT_CONFIG_DIR): + if file_name.endswith(".yaml") and file_name not in config_files: + config_files.append(file_name) + + return gr.Dropdown(choices=config_files) + + +def list_datasets(dataset_dir: str = None, training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Dropdown": + r""" + Lists all available datasets in the dataset dir for the training stage. + + Inputs: *.dataset_dir, *.training_stage + Outputs: *.dataset + """ + dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR) + ranking = TRAINING_STAGES[training_stage] in STAGES_USE_PAIR_DATA + datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking] + return gr.Dropdown(choices=datasets) + + +def list_output_dirs(model_name: Optional[str], finetuning_type: str, current_time: str) -> "gr.Dropdown": + r""" + Lists all the directories that can resume from. + + Inputs: top.model_name, top.finetuning_type, train.current_time + Outputs: train.output_dir + """ + output_dirs = [f"train_{current_time}"] + if model_name: + save_dir = get_save_dir(model_name, finetuning_type) + if save_dir and os.path.isdir(save_dir): + for folder in os.listdir(save_dir): + output_dir = os.path.join(save_dir, folder) + if os.path.isdir(output_dir) and get_last_checkpoint(output_dir) is not None: + output_dirs.append(folder) + + return gr.Dropdown(choices=output_dirs) diff --git a/src/llamafactory/webui/css.py b/src/llamafactory/webui/css.py index 539821195f..c4445e8eb9 100644 --- a/src/llamafactory/webui/css.py +++ b/src/llamafactory/webui/css.py @@ -20,6 +20,29 @@ border-radius: 100vh !important; } +.thinking-summary { + padding: 8px !important; +} + +.thinking-summary span { + border: 1px solid #e0e0e0 !important; + border-radius: 4px !important; + padding: 4px !important; + cursor: pointer !important; + font-size: 14px !important; + background: #333333 !important; +} + +.thinking-container { + border-left: 2px solid #a6a6a6 !important; + padding-left: 10px !important; + margin: 4px 0 !important; +} + +.thinking-container p { + color: #a6a6a6 !important; +} + .modal-box { position: fixed !important; top: 50%; diff --git a/src/llamafactory/webui/engine.py b/src/llamafactory/webui/engine.py index deb9dee12a..3b18eeb9ce 100644 --- a/src/llamafactory/webui/engine.py +++ b/src/llamafactory/webui/engine.py @@ -15,11 +15,10 @@ from typing import TYPE_CHECKING, Any, Dict from .chatter import WebChatModel -from .common import load_config +from .common import create_ds_config, get_time, load_config from .locales import LOCALES from .manager import Manager from .runner import Runner -from .utils import create_ds_config, get_time if TYPE_CHECKING: @@ -27,6 +26,10 @@ class Engine: + r""" + A general engine to control the behaviors of Web UI. + """ + def __init__(self, demo_mode: bool = False, pure_chat: bool = False) -> None: self.demo_mode = demo_mode self.pure_chat = pure_chat @@ -38,7 +41,7 @@ def __init__(self, demo_mode: bool = False, pure_chat: bool = False) -> None: def _update_component(self, input_dict: Dict[str, Dict[str, Any]]) -> Dict["Component", "Component"]: r""" - Gets the dict to update the components. + Updates gradio components according to the (elem_id, properties) mapping. """ output_dict: Dict["Component", "Component"] = {} for elem_id, elem_attr in input_dict.items(): @@ -48,9 +51,11 @@ def _update_component(self, input_dict: Dict[str, Dict[str, Any]]) -> Dict["Comp return output_dict def resume(self): - user_config = load_config() if not self.demo_mode else {} + r""" + Gets the initial value of gradio components and restores training status if necessary. + """ + user_config = load_config() if not self.demo_mode else {} # do not use config in demo mode lang = user_config.get("lang", None) or "en" - init_dict = {"top.lang": {"value": lang}, "infer.chat_box": {"visible": self.chatter.loaded}} if not self.pure_chat: @@ -74,6 +79,9 @@ def resume(self): yield self._update_component({"eval.resume_btn": {"value": True}}) def change_lang(self, lang: str): + r""" + Updates the displayed language of gradio components. + """ return { elem: elem.__class__(**LOCALES[elem_name][lang]) for elem_name, elem in self.manager.get_elem_iter() diff --git a/src/llamafactory/webui/locales.py b/src/llamafactory/webui/locales.py index b747d43162..f7846c531c 100644 --- a/src/llamafactory/webui/locales.py +++ b/src/llamafactory/webui/locales.py @@ -2786,6 +2786,20 @@ "ko": "모델이 언로드되었습니다.", "ja": "モデルがアンロードされました。", }, + "info_thinking": { + "en": "🌀 Thinking...", + "ru": "🌀 Думаю...", + "zh": "🌀 思考中...", + "ko": "🌀 생각 중...", + "ja": "🌀 考えています...", + }, + "info_thought": { + "en": "✅ Thought", + "ru": "✅ Думать закончено", + "zh": "✅ 思考完成", + "ko": "✅ 생각이 완료되었습니다", + "ja": "✅ 思考完了", + }, "info_exporting": { "en": "Exporting model...", "ru": "Экспорт модели...", diff --git a/src/llamafactory/webui/manager.py b/src/llamafactory/webui/manager.py index c34d8e5475..18332ac0f9 100644 --- a/src/llamafactory/webui/manager.py +++ b/src/llamafactory/webui/manager.py @@ -20,6 +20,10 @@ class Manager: + r""" + A class to manage all the gradio components in Web UI. + """ + def __init__(self) -> None: self._id_to_elem: Dict[str, "Component"] = {} self._elem_to_id: Dict["Component", str] = {} diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index 37409fc7c0..9716d917b8 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -24,9 +24,20 @@ from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES from ..extras.misc import is_gpu_or_npu_available, torch_gc, use_ray from ..extras.packages import is_gradio_available, is_transformers_version_equal_to_4_46 -from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_save_dir, load_config +from .common import ( + DEFAULT_CACHE_DIR, + DEFAULT_CONFIG_DIR, + abort_process, + gen_cmd, + get_save_dir, + load_args, + load_config, + load_eval_results, + save_args, + save_cmd, +) +from .control import get_trainer_info from .locales import ALERTS, LOCALES -from .utils import abort_process, gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd if is_gradio_available(): @@ -40,6 +51,10 @@ class Runner: + r""" + A class to manage the running status of the trainers. + """ + def __init__(self, manager: "Manager", demo_mode: bool = False) -> None: self.manager = manager self.demo_mode = demo_mode @@ -57,6 +72,9 @@ def set_abort(self) -> None: abort_process(self.trainer.pid) def _initialize(self, data: Dict["Component", Any], do_train: bool, from_preview: bool) -> str: + r""" + Validates the configuration. + """ get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)] lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path") dataset = get("train.dataset") if do_train else get("eval.dataset") @@ -98,6 +116,9 @@ def _initialize(self, data: Dict["Component", Any], do_train: bool, from_preview return "" def _finalize(self, lang: str, finish_info: str) -> str: + r""" + Cleans the cached memory and resets the runner. + """ finish_info = ALERTS["info_aborted"][lang] if self.aborted else finish_info gr.Info(finish_info) self.trainer = None @@ -108,6 +129,9 @@ def _finalize(self, lang: str, finish_info: str) -> str: return finish_info def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]: + r""" + Builds and validates the training arguments. + """ get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)] model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type") user_config = load_config() @@ -268,6 +292,9 @@ def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]: return args def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]: + r""" + Builds and validates the evaluation arguments. + """ get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)] model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type") user_config = load_config() @@ -319,6 +346,9 @@ def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]: return args def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", str], None, None]: + r""" + Previews the training commands. + """ output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval")) error = self._initialize(data, do_train, from_preview=True) if error: @@ -329,6 +359,9 @@ def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Di yield {output_box: gen_cmd(args)} def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", Any], None, None]: + r""" + Starts the training process. + """ output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval")) error = self._initialize(data, do_train, from_preview=False) if error: @@ -339,7 +372,7 @@ def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dic args = self._parse_train_args(data) if do_train else self._parse_eval_args(data) os.makedirs(args["output_dir"], exist_ok=True) - save_args(os.path.join(args["output_dir"], LLAMABOARD_CONFIG), self._form_config_dict(data)) + save_args(os.path.join(args["output_dir"], LLAMABOARD_CONFIG), self._build_config_dict(data)) env = deepcopy(os.environ) env["LLAMABOARD_ENABLED"] = "1" @@ -350,7 +383,10 @@ def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dic self.trainer = Popen(["llamafactory-cli", "train", save_cmd(args)], env=env) yield from self.monitor() - def _form_config_dict(self, data: Dict["Component", Any]) -> Dict[str, Any]: + def _build_config_dict(self, data: Dict["Component", Any]) -> Dict[str, Any]: + r""" + Builds a dictionary containing the current training configuration. + """ config_dict = {} skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path"] for elem, value in data.items(): @@ -373,6 +409,9 @@ def run_eval(self, data): yield from self._launch(data, do_train=False) def monitor(self): + r""" + Monitors the training progress and logs. + """ self.aborted = False self.running = True @@ -416,7 +455,7 @@ def monitor(self): finish_info = ALERTS["err_failed"][lang] else: if os.path.exists(os.path.join(output_path, "all_results.json")) or use_ray(): - finish_info = get_eval_results(os.path.join(output_path, "all_results.json")) + finish_info = load_eval_results(os.path.join(output_path, "all_results.json")) else: finish_info = ALERTS["err_failed"][lang] @@ -427,6 +466,9 @@ def monitor(self): yield return_dict def save_args(self, data): + r""" + Saves the training configuration to config path. + """ output_box = self.manager.get_elem_by_id("train.output_box") error = self._initialize(data, do_train=True, from_preview=True) if error: @@ -438,10 +480,13 @@ def save_args(self, data): os.makedirs(DEFAULT_CONFIG_DIR, exist_ok=True) save_path = os.path.join(DEFAULT_CONFIG_DIR, config_path) - save_args(save_path, self._form_config_dict(data)) + save_args(save_path, self._build_config_dict(data)) return {output_box: ALERTS["info_config_saved"][lang] + save_path} def load_args(self, lang: str, config_path: str): + r""" + Loads the training configuration from config path. + """ output_box = self.manager.get_elem_by_id("train.output_box") config_dict = load_args(os.path.join(DEFAULT_CONFIG_DIR, config_path)) if config_dict is None: @@ -455,6 +500,9 @@ def load_args(self, lang: str, config_path: str): return output_dict def check_output_dir(self, lang: str, model_name: str, finetuning_type: str, output_dir: str): + r""" + Restore the training status if output_dir exists. + """ output_box = self.manager.get_elem_by_id("train.output_box") output_dict: Dict["Component", Any] = {output_box: LOCALES["output_box"][lang]["value"]} if model_name and output_dir and os.path.isdir(get_save_dir(model_name, finetuning_type, output_dir)): diff --git a/src/llamafactory/webui/utils.py b/src/llamafactory/webui/utils.py deleted file mode 100644 index e7b6aa0199..0000000000 --- a/src/llamafactory/webui/utils.py +++ /dev/null @@ -1,304 +0,0 @@ -# Copyright 2024 the LlamaFactory team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import os -import signal -from datetime import datetime -from typing import Any, Dict, List, Optional, Tuple - -import psutil -from transformers.trainer_utils import get_last_checkpoint -from yaml import safe_dump, safe_load - -from ..extras.constants import PEFT_METHODS, RUNNING_LOG, TRAINER_LOG, TRAINING_ARGS, TRAINING_STAGES -from ..extras.packages import is_gradio_available, is_matplotlib_available -from ..extras.ploting import gen_loss_plot -from ..model import QuantizationMethod -from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_save_dir -from .locales import ALERTS - - -if is_gradio_available(): - import gradio as gr - - -def abort_process(pid: int) -> None: - r""" - Aborts the processes recursively in a bottom-up way. - """ - try: - children = psutil.Process(pid).children() - if children: - for child in children: - abort_process(child.pid) - - os.kill(pid, signal.SIGABRT) - except Exception: - pass - - -def can_quantize(finetuning_type: str) -> "gr.Dropdown": - r""" - Judges if the quantization is available in this finetuning type. - """ - if finetuning_type not in PEFT_METHODS: - return gr.Dropdown(value="none", interactive=False) - else: - return gr.Dropdown(interactive=True) - - -def can_quantize_to(quantization_method: str) -> "gr.Dropdown": - r""" - Returns the available quantization bits. - """ - if quantization_method == QuantizationMethod.BITS_AND_BYTES.value: - available_bits = ["none", "8", "4"] - elif quantization_method == QuantizationMethod.HQQ.value: - available_bits = ["none", "8", "6", "5", "4", "3", "2", "1"] - elif quantization_method == QuantizationMethod.EETQ.value: - available_bits = ["none", "8"] - - return gr.Dropdown(choices=available_bits) - - -def change_stage(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Tuple[List[str], bool]: - r""" - Modifys states after changing the training stage. - """ - return [], TRAINING_STAGES[training_stage] == "pt" - - -def check_json_schema(text: str, lang: str) -> None: - r""" - Checks if the json schema is valid. - """ - try: - tools = json.loads(text) - if tools: - assert isinstance(tools, list) - for tool in tools: - if "name" not in tool: - raise NotImplementedError("Name not found.") - except NotImplementedError: - gr.Warning(ALERTS["err_tool_name"][lang]) - except Exception: - gr.Warning(ALERTS["err_json_schema"][lang]) - - -def clean_cmd(args: Dict[str, Any]) -> Dict[str, Any]: - r""" - Removes args with NoneType or False or empty string value. - """ - no_skip_keys = ["packing"] - return {k: v for k, v in args.items() if (k in no_skip_keys) or (v is not None and v is not False and v != "")} - - -def gen_cmd(args: Dict[str, Any]) -> str: - r""" - Generates arguments for previewing. - """ - cmd_lines = ["llamafactory-cli train "] - for k, v in clean_cmd(args).items(): - if isinstance(v, dict): - cmd_lines.append(f" --{k} {json.dumps(v, ensure_ascii=False)} ") - elif isinstance(v, list): - cmd_lines.append(f" --{k} {' '.join(map(str, v))} ") - else: - cmd_lines.append(f" --{k} {str(v)} ") - - if os.name == "nt": - cmd_text = "`\n".join(cmd_lines) - else: - cmd_text = "\\\n".join(cmd_lines) - - cmd_text = f"```bash\n{cmd_text}\n```" - return cmd_text - - -def save_cmd(args: Dict[str, Any]) -> str: - r""" - Saves arguments to launch training. - """ - output_dir = args["output_dir"] - os.makedirs(output_dir, exist_ok=True) - - with open(os.path.join(output_dir, TRAINING_ARGS), "w", encoding="utf-8") as f: - safe_dump(clean_cmd(args), f) - - return os.path.join(output_dir, TRAINING_ARGS) - - -def get_eval_results(path: os.PathLike) -> str: - r""" - Gets scores after evaluation. - """ - with open(path, encoding="utf-8") as f: - result = json.dumps(json.load(f), indent=4) - return f"```json\n{result}\n```\n" - - -def get_time() -> str: - r""" - Gets current date and time. - """ - return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S") - - -def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr.Slider", Optional["gr.Plot"]]: - r""" - Gets training infomation for monitor. - """ - running_log = "" - running_progress = gr.Slider(visible=False) - running_loss = None - - running_log_path = os.path.join(output_path, RUNNING_LOG) - if os.path.isfile(running_log_path): - with open(running_log_path, encoding="utf-8") as f: - running_log = f.read() - - trainer_log_path = os.path.join(output_path, TRAINER_LOG) - if os.path.isfile(trainer_log_path): - trainer_log: List[Dict[str, Any]] = [] - with open(trainer_log_path, encoding="utf-8") as f: - for line in f: - trainer_log.append(json.loads(line)) - - if len(trainer_log) != 0: - latest_log = trainer_log[-1] - percentage = latest_log["percentage"] - label = "Running {:d}/{:d}: {} < {}".format( - latest_log["current_steps"], - latest_log["total_steps"], - latest_log["elapsed_time"], - latest_log["remaining_time"], - ) - running_progress = gr.Slider(label=label, value=percentage, visible=True) - - if do_train and is_matplotlib_available(): - running_loss = gr.Plot(gen_loss_plot(trainer_log)) - - return running_log, running_progress, running_loss - - -def load_args(config_path: str) -> Optional[Dict[str, Any]]: - r""" - Loads saved arguments. - """ - try: - with open(config_path, encoding="utf-8") as f: - return safe_load(f) - except Exception: - return None - - -def save_args(config_path: str, config_dict: Dict[str, Any]): - r""" - Saves arguments. - """ - with open(config_path, "w", encoding="utf-8") as f: - safe_dump(config_dict, f) - - -def list_config_paths(current_time: str) -> "gr.Dropdown": - r""" - Lists all the saved configuration files. - """ - config_files = [f"{current_time}.yaml"] - if os.path.isdir(DEFAULT_CONFIG_DIR): - for file_name in os.listdir(DEFAULT_CONFIG_DIR): - if file_name.endswith(".yaml") and file_name not in config_files: - config_files.append(file_name) - - return gr.Dropdown(choices=config_files) - - -def list_output_dirs(model_name: Optional[str], finetuning_type: str, current_time: str) -> "gr.Dropdown": - r""" - Lists all the directories that can resume from. - """ - output_dirs = [f"train_{current_time}"] - if model_name: - save_dir = get_save_dir(model_name, finetuning_type) - if save_dir and os.path.isdir(save_dir): - for folder in os.listdir(save_dir): - output_dir = os.path.join(save_dir, folder) - if os.path.isdir(output_dir) and get_last_checkpoint(output_dir) is not None: - output_dirs.append(folder) - - return gr.Dropdown(choices=output_dirs) - - -def create_ds_config() -> None: - r""" - Creates deepspeed config. - """ - os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True) - ds_config = { - "train_batch_size": "auto", - "train_micro_batch_size_per_gpu": "auto", - "gradient_accumulation_steps": "auto", - "gradient_clipping": "auto", - "zero_allow_untested_optimizer": True, - "fp16": { - "enabled": "auto", - "loss_scale": 0, - "loss_scale_window": 1000, - "initial_scale_power": 16, - "hysteresis": 2, - "min_loss_scale": 1, - }, - "bf16": {"enabled": "auto"}, - } - offload_config = { - "device": "cpu", - "pin_memory": True, - } - ds_config["zero_optimization"] = { - "stage": 2, - "allgather_partitions": True, - "allgather_bucket_size": 5e8, - "overlap_comm": True, - "reduce_scatter": True, - "reduce_bucket_size": 5e8, - "contiguous_gradients": True, - "round_robin_gradients": True, - } - with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z2_config.json"), "w", encoding="utf-8") as f: - json.dump(ds_config, f, indent=2) - - ds_config["zero_optimization"]["offload_optimizer"] = offload_config - with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z2_offload_config.json"), "w", encoding="utf-8") as f: - json.dump(ds_config, f, indent=2) - - ds_config["zero_optimization"] = { - "stage": 3, - "overlap_comm": True, - "contiguous_gradients": True, - "sub_group_size": 1e9, - "reduce_bucket_size": "auto", - "stage3_prefetch_bucket_size": "auto", - "stage3_param_persistence_threshold": "auto", - "stage3_max_live_parameters": 1e9, - "stage3_max_reuse_distance": 1e9, - "stage3_gather_16bit_weights_on_model_save": True, - } - with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z3_config.json"), "w", encoding="utf-8") as f: - json.dump(ds_config, f, indent=2) - - ds_config["zero_optimization"]["offload_optimizer"] = offload_config - ds_config["zero_optimization"]["offload_param"] = offload_config - with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z3_offload_config.json"), "w", encoding="utf-8") as f: - json.dump(ds_config, f, indent=2)