Skip to content

Commit

Permalink
fix config, #1191
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Oct 15, 2023
1 parent 0d63584 commit a6a04be
Show file tree
Hide file tree
Showing 9 changed files with 40 additions and 57 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ fire
jieba
rouge-chinese
nltk
gradio>=3.36.0
gradio==3.38.0
uvicorn
pydantic==1.10.11
fastapi==0.95.1
Expand Down
2 changes: 1 addition & 1 deletion src/llmtuner/tuner/core/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def get_train_args(

# postprocess model_args
model_args.compute_dtype = (
torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else torch.float32)
torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None)
)
model_args.model_max_length = data_args.cutoff_len

Expand Down
23 changes: 11 additions & 12 deletions src/llmtuner/webui/common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import json
import gradio as gr
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional
from transformers.utils import (
WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
Expand All @@ -27,7 +27,6 @@
ADAPTER_WEIGHTS_NAME,
ADAPTER_SAFE_WEIGHTS_NAME
]
CONFIG_CLASS = Dict[str, Union[str, Dict[str, str]]]


def get_save_dir(*args) -> os.PathLike:
Expand All @@ -38,28 +37,28 @@ def get_config_path() -> os.PathLike:
return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)


def load_config() -> CONFIG_CLASS:
def load_config() -> Dict[str, Any]:
try:
with open(get_config_path(), "r", encoding="utf-8") as f:
return json.load(f)
except:
return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}


def save_config(
config: CONFIG_CLASS, lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None
) -> None:
def save_config(lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None) -> None:
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
config["lang"] = lang or config["lang"]
user_config = load_config()
user_config["lang"] = lang or user_config["lang"]
if model_name:
config["last_model"] = model_name
config["path_dict"][model_name] = model_path
user_config["last_model"] = model_name
user_config["path_dict"][model_name] = model_path
with open(get_config_path(), "w", encoding="utf-8") as f:
json.dump(config, f, indent=2, ensure_ascii=False)
json.dump(user_config, f, indent=2, ensure_ascii=False)


def get_model_path(config: Dict[str, Any], model_name: str) -> str:
return config["path_dict"].get(model_name, None) or SUPPORTED_MODELS.get(model_name, "")
def get_model_path(model_name: str) -> str:
user_config = load_config()
return user_config["path_dict"].get(model_name, None) or SUPPORTED_MODELS.get(model_name, "")


def get_module(model_name: str) -> str:
Expand Down
5 changes: 1 addition & 4 deletions src/llmtuner/webui/components/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
unload_btn = gr.Button()

info_box = gr.Textbox(show_label=False, interactive=False)

elem_dict.update(dict(
info_box=info_box, load_btn=load_btn, unload_btn=unload_btn
))
elem_dict.update(dict(load_btn=load_btn, unload_btn=unload_btn, info_box=info_box))

chat_box, chatbot, history, chat_elems = create_chat_box(engine, visible=False)
elem_dict.update(dict(chat_box=chat_box, **chat_elems))
Expand Down
12 changes: 5 additions & 7 deletions src/llmtuner/webui/components/top.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
from llmtuner.extras.template import templates
from llmtuner.webui.common import get_model_path, get_template, list_checkpoint, load_config, save_config
from llmtuner.webui.common import get_model_path, get_template, list_checkpoint, save_config
from llmtuner.webui.utils import can_quantize

if TYPE_CHECKING:
Expand All @@ -12,7 +12,6 @@

def create_top() -> Dict[str, "Component"]:
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
config = gr.State(value=load_config())

with gr.Row():
lang = gr.Dropdown(choices=["en", "zh"], scale=1)
Expand All @@ -39,25 +38,24 @@ def create_top() -> Dict[str, "Component"]:
model_name.change(
list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False
).then(
get_model_path, [config, model_name], [model_path], queue=False
get_model_path, [model_name], [model_path], queue=False
).then(
get_template, [model_name], [template], queue=False
) # do not save config since the below line will save

model_path.change(save_config, inputs=[config, lang, model_name, model_path])
model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False)

finetuning_type.change(
list_checkpoint, [model_name, finetuning_type], [checkpoints]
list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False
).then(
can_quantize, [finetuning_type], [quantization_bit]
can_quantize, [finetuning_type], [quantization_bit], queue=False
)

refresh_btn.click(
list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False
)

return dict(
config=config,
lang=lang,
model_name=model_name,
model_path=model_path,
Expand Down
13 changes: 7 additions & 6 deletions src/llmtuner/webui/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, Dict, Generator, Optional

from llmtuner.webui.chatter import WebChatModel
from llmtuner.webui.common import get_model_path, list_dataset, CONFIG_CLASS
from llmtuner.webui.common import get_model_path, list_dataset, load_config
from llmtuner.webui.locales import LOCALES
from llmtuner.webui.manager import Manager
from llmtuner.webui.runner import Runner
Expand All @@ -21,8 +21,9 @@ def __init__(self, pure_chat: Optional[bool] = False) -> None:
def _form_dict(self, resume_dict: Dict[str, Dict[str, Any]]):
return {self.manager.get_elem(k): gr.update(**v) for k, v in resume_dict.items()}

def resume(self, config: CONFIG_CLASS) -> Generator[Dict[Component, Dict[str, Any]], None, None]:
lang = config.get("lang", None) or "en"
def resume(self) -> Generator[Dict[Component, Dict[str, Any]], None, None]:
user_config = load_config()
lang = user_config.get("lang", None) or "en"

resume_dict = {
"top.lang": {"value": lang},
Expand All @@ -33,9 +34,9 @@ def resume(self, config: CONFIG_CLASS) -> Generator[Dict[Component, Dict[str, An
resume_dict["train.dataset"] = {"choices": list_dataset()["choices"]}
resume_dict["eval.dataset"] = {"choices": list_dataset()["choices"]}

if config.get("last_model", None):
resume_dict["top.model_name"] = {"value": config["last_model"]}
resume_dict["top.model_path"] = {"value": get_model_path(config, config["last_model"])}
if user_config.get("last_model", None):
resume_dict["top.model_name"] = {"value": user_config["last_model"]}
resume_dict["top.model_path"] = {"value": get_model_path(user_config["last_model"])}

yield self._form_dict(resume_dict)

Expand Down
31 changes: 9 additions & 22 deletions src/llmtuner/webui/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
create_export_tab,
create_chat_box
)
from llmtuner.webui.common import load_config, save_config
from llmtuner.webui.common import save_config
from llmtuner.webui.css import CSS
from llmtuner.webui.engine import Engine


require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0")
require_version("gradio==3.38.0", "To fix: pip install gradio==3.38.0")


def create_ui() -> gr.Blocks:
Expand All @@ -23,9 +23,6 @@ def create_ui() -> gr.Blocks:
with gr.Blocks(title="Web Tuner", css=CSS) as demo:
engine.manager.all_elems["top"] = create_top()
lang: "gr.Dropdown" = engine.manager.get_elem("top.lang")
config = engine.manager.get_elem("top.config")
model_name = engine.manager.get_elem("top.model_name")
model_path = engine.manager.get_elem("top.model_path")

with gr.Tab("Train"):
engine.manager.all_elems["train"] = create_train_tab(engine)
Expand All @@ -39,13 +36,9 @@ def create_ui() -> gr.Blocks:
with gr.Tab("Export"):
engine.manager.all_elems["export"] = create_export_tab(engine)

demo.load(engine.resume, [config], engine.manager.list_elems())

lang.change(
engine.change_lang, [lang], engine.manager.list_elems(), queue=False
).then(
save_config, inputs=[config, lang, model_name, model_path]
)
demo.load(engine.resume, outputs=engine.manager.list_elems())
lang.change(engine.change_lang, [lang], engine.manager.list_elems(), queue=False)
lang.input(save_config, inputs=[lang], queue=False)

return demo

Expand All @@ -54,21 +47,15 @@ def create_web_demo() -> gr.Blocks:
engine = Engine(pure_chat=True)

with gr.Blocks(title="Web Demo", css=CSS) as demo:
config = gr.State(value=load_config())
lang = gr.Dropdown(choices=["en", "zh"])

engine.manager.all_elems["top"] = dict(config=config, lang=lang)
engine.manager.all_elems["top"] = dict(lang=lang)

chat_box, _, _, chat_elems = create_chat_box(engine, visible=True)
engine.manager.all_elems["infer"] = dict(chat_box=chat_box, **chat_elems)

demo.load(engine.resume, [config], engine.manager.list_elems())

lang.change(
engine.change_lang, [lang], engine.manager.list_elems(), queue=False
).then(
save_config, inputs=[config, lang]
)
demo.load(engine.resume, outputs=engine.manager.list_elems())
lang.change(engine.change_lang, [lang], engine.manager.list_elems(), queue=False)
lang.input(save_config, inputs=[lang], queue=False)

return demo

Expand Down
1 change: 0 additions & 1 deletion src/llmtuner/webui/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def get_elem(self, name: str) -> "Component":

def get_base_elems(self):
return {
self.all_elems["top"]["config"],
self.all_elems["top"]["lang"],
self.all_elems["top"]["model_name"],
self.all_elems["top"]["model_path"],
Expand Down
8 changes: 5 additions & 3 deletions src/llmtuner/webui/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from llmtuner.extras.logging import LoggerHandler
from llmtuner.extras.misc import torch_gc
from llmtuner.tuner import run_exp
from llmtuner.webui.common import get_module, get_save_dir
from llmtuner.webui.common import get_module, get_save_dir, load_config
from llmtuner.webui.locales import ALERTS
from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar

Expand Down Expand Up @@ -74,6 +74,7 @@ def _finalize(self, lang: str, finish_info: str) -> str:

def _parse_train_args(self, data: Dict[Component, Any]) -> Tuple[str, str, str, List[str], str, Dict[str, Any]]:
get = lambda name: data[self.manager.get_elem(name)]
user_config = load_config()

if get("top.checkpoints"):
checkpoint_dir = ",".join([
Expand All @@ -89,7 +90,7 @@ def _parse_train_args(self, data: Dict[Component, Any]) -> Tuple[str, str, str,
model_name_or_path=get("top.model_path"),
do_train=True,
overwrite_cache=False,
cache_dir=get("top.config").get("cache_dir", None),
cache_dir=user_config.get("cache_dir", None),
checkpoint_dir=checkpoint_dir,
finetuning_type=get("top.finetuning_type"),
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
Expand Down Expand Up @@ -142,6 +143,7 @@ def _parse_train_args(self, data: Dict[Component, Any]) -> Tuple[str, str, str,

def _parse_eval_args(self, data: Dict[Component, Any]) -> Tuple[str, str, str, List[str], str, Dict[str, Any]]:
get = lambda name: data[self.manager.get_elem(name)]
user_config = load_config()

if get("top.checkpoints"):
checkpoint_dir = ",".join([
Expand All @@ -160,7 +162,7 @@ def _parse_eval_args(self, data: Dict[Component, Any]) -> Tuple[str, str, str, L
do_eval=True,
overwrite_cache=False,
predict_with_generate=True,
cache_dir=get("top.config").get("cache_dir", None),
cache_dir=user_config.get("cache_dir", None),
checkpoint_dir=checkpoint_dir,
finetuning_type=get("top.finetuning_type"),
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
Expand Down

0 comments on commit a6a04be

Please sign in to comment.