Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Solve the compatibility problem of SeparatorStyle.CHATML type messes field #3278

Open
wants to merge 38 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
7b80225
Add Llama3 adaptor
Michaelvll Apr 18, 2024
9607517
Solve the compatibility problem of SeparatorStyle.CHATML type messes …
icowan Apr 24, 2024
0e8e22f
Merge pull request #1 from icowan/chatml-sep
icowan Apr 24, 2024
90e4c45
支持 openbuddy-llama3 模版
Apr 25, 2024
7135cce
Merge pull request #2 from lm-sys/main
icowan May 3, 2024
cb00614
Merge pull request #3 from lm-sys/main
icowan May 6, 2024
79d957c
Merge remote-tracking branch 'origin/main'
icowan May 6, 2024
aa3b16e
调整llama3提示词词
icowan May 6, 2024
a0b39eb
调整llama3提示词词
icowan May 11, 2024
fd9ac3d
fix error and support Phi-3 models #3318
icowan May 14, 2024
b69e909
合并最新代码
icowan May 15, 2024
b918043
合并最新代码
icowan May 15, 2024
3c16819
Merge branch 'main' of https://github.com/icowan/FastChat
icowan May 15, 2024
4857a2b
合并最新代码
icowan May 15, 2024
fae41e4
Merge pull request #6 from lm-sys/main
icowan May 27, 2024
cc7ff18
Merge branch 'main' into add-llama3-adaptor
icowan May 27, 2024
ec0eaf6
Merge remote-tracking branch 'origin/main'
icowan May 28, 2024
0590b4c
合并最新代码
icowan May 28, 2024
8af0b97
增加tensorRT推理支持
icowan May 29, 2024
2aded78
Merge pull request #7 from lm-sys/main
icowan Jun 7, 2024
c0a22b1
Merge pull request #8 from lm-sys/main
icowan Jun 21, 2024
a707950
合并了部分pr
icowan Jun 21, 2024
25e85e3
Merge pull request #9 from lm-sys/main
icowan Jul 7, 2024
0eb1cda
增加worker info 接口
icowan Jul 8, 2024
140067d
增加worker info 接口
icowan Jul 8, 2024
dcf070b
vllm 支持加载lora模型
icowan Jul 22, 2024
15c04a2
尝试增加限流功能
icowan Jul 24, 2024
dcd4e6e
修复vllm adapter的问题
icowan Jul 29, 2024
260eff8
修复vllm adapter的问题
icowan Jul 30, 2024
85ba0f3
Merge pull request #10 from lm-sys/main
icowan Jul 31, 2024
423cb07
Merge branch 'main' of https://github.com/icowan/FastChat
icowan Jul 31, 2024
0967c1a
Merge pull request #11 from lm-sys/main
icowan Aug 14, 2024
1e5fc17
Merge pull request #12 from lm-sys/main
icowan Aug 19, 2024
f531d50
Merge remote-tracking branch 'origin/main'
icowan Aug 19, 2024
24287bc
修复vllm adapter的问题
icowan Aug 19, 2024
526ed36
Merge pull request #13 from lm-sys/main
icowan Sep 25, 2024
d37b1ee
skip_special_tokens
icowan Sep 25, 2024
8959971
临时提交
icowan Sep 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
FROM python:3.10.14-alpine

LABEL maintainer="[email protected]"

RUN apk add gcc python3-dev musl-dev linux-headers

RUN pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple

RUN pip3 install --no-cache-dir aiohttp fastapi httpx \
markdown2[all] nh3 numpy prompt_toolkit>=3.0.0 \
pydantic psutil requests rich>=10.0.0 \
shortuuid tiktoken uvicorn

WORKDIR /app

COPY . /app/
RUN pip3 install -e .
RUN pip3 install pydantic

CMD ["python3", "-m", "fastchat.serve.controller", "--host", "0.0.0.0"]
130 changes: 114 additions & 16 deletions fastchat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class SeparatorStyle(IntEnum):
GEMMA = auto()
CLLM = auto()
DEFAULT = auto()
OPENBUDDY_LLAMA3 = auto()


IMAGE_PLACEHOLDER_STR = "$$<image>$$"
Expand Down Expand Up @@ -134,9 +135,9 @@ def get_prompt(self) -> str:
for i, (role, message) in enumerate(self.messages):
if message:
ret += (
role
+ ": "
+ message.replace("\r\n", "\n").replace("\n\n", "\n")
role
+ ": "
+ message.replace("\r\n", "\n").replace("\n\n", "\n")
)
ret += "\n\n"
else:
Expand Down Expand Up @@ -193,12 +194,16 @@ def get_prompt(self) -> str:
ret = "" if system_prompt == "" else system_prompt + self.sep + "\n"
for role, message in self.messages:
if message:
if type(message) is tuple:
message, images = message
message = IMAGE_PLACEHOLDER_STR * len(images) + message
ret += role + "\n" + message + self.sep + "\n"
if isinstance(message, tuple):
message, images = message if len(message) > 1 else (message[0], [])
images = images if images is not None else []
message = (IMAGE_PLACEHOLDER_STR * len(images) if images else "") + (
message if message is not None else "")
else:
message = message if message is not None else ""
ret += f"{role}\n{message}{self.sep}\n"
else:
ret += role + "\n"
ret += f"{role}\n"
return ret
elif self.sep_style == SeparatorStyle.CHATGLM3:
ret = ""
Expand Down Expand Up @@ -321,16 +326,24 @@ def get_prompt(self) -> str:
else:
ret += role + ":"
return ret
elif self.sep_style == SeparatorStyle.OPENBUDDY_LLAMA3:
ret = system_prompt + "\n"
for role, message in self.messages:
if message:
ret += f"<|role|>{role}<|says|>{message}<|end|>\n"
else:
ret += f"<|role|>{role}<|says|>\n"
return ret
else:
raise ValueError(f"Invalid style: {self.sep_style}")

def get_images(self):
images = []
for i, (role, msg) in enumerate(self.messages[self.offset :]):
for i, (role, msg) in enumerate(self.messages[self.offset:]):
if i % 2 == 0:
if type(msg) is tuple:
for image in msg[1]:
images.append(image.base64_str)
images.append(image)

return images

Expand Down Expand Up @@ -361,7 +374,7 @@ def to_gradio_chatbot(self):
from fastchat.serve.vision.image import ImageFormat

ret = []
for i, (role, msg) in enumerate(self.messages[self.offset :]):
for i, (role, msg) in enumerate(self.messages[self.offset:]):
if i % 2 == 0:
if type(msg) is tuple:
msg, images = msg
Expand Down Expand Up @@ -415,14 +428,76 @@ def to_openai_vision_api_messages(self):
)
return ret

def to_openai_image_format(self, image_urls):
import base64

openai_images = []
for image_url in image_urls:
if image_url.startswith("http://") or image_url.startswith(
"https://"
): # input is a url
openai_images.append(image_url)
elif image_url.lower().endswith(
("png", "jpg", "jpeg", "webp", "gif")
): # input is a local image
img_b64_str = self.convert_image_to_base64(image_url)
filetype = image_url.split(".")[-1].lower()
openai_images.append(f"data:image/{filetype};base64,{img_b64_str}")
else:
try:
assert (
base64.b64encode(base64.b64decode(image_url))
== image_url.encode()
), "The image data is not a valid base64 encoded string"
openai_images.append(f"data:image/jpeg;base64,{image_url}")
except:
raise ValueError(
f"This file is not valid or not currently supported by the OpenAI API: {image_url}"
)
return openai_images

def to_openai_vision_api_messages(self):
"""Convert the conversation to OpenAI vision api completion format"""
ret = [
{
"role": "system",
"content": [{"type": "text", "text": self.system_message}],
}
]
for i, (_, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
if type(msg) is tuple:
content_list = [{"type": "text", "text": msg[0]}]

image_urls = self.to_openai_image_format(msg[1])
for image_url in image_urls:
content_list.append(
{"type": "image_url", "image_url": {"url": image_url}}
)

ret.append({"role": "user", "content": content_list})
else:
ret.append(
{"role": "user", "content": [{"type": "text", "text": msg}]}
)
else:
if msg is not None:
ret.append(
{
"role": "assistant",
"content": [{"type": "text", "text": msg}],
}
)
return ret

def to_openai_api_messages(self):
"""Convert the conversation to OpenAI chat completion format."""
if self.system_message == "":
ret = []
else:
ret = [{"role": "system", "content": self.system_message}]

for i, (_, msg) in enumerate(self.messages[self.offset :]):
for i, (_, msg) in enumerate(self.messages[self.offset:]):
if i % 2 == 0:
ret.append({"role": "user", "content": msg})
else:
Expand Down Expand Up @@ -667,7 +742,7 @@ def register_conv_template(template: Conversation, override: bool = False):
"""Register a new conversation template."""
if not override:
assert (
template.name not in conv_templates
template.name not in conv_templates
), f"{template.name} has been registered."

conv_templates[template.name] = template
Expand All @@ -694,7 +769,7 @@ def get_conv_template(name: str) -> Conversation:
Conversation(
name="one_shot",
system_message="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("Human", "Assistant"),
messages=(
(
Expand Down Expand Up @@ -1015,6 +1090,27 @@ def get_conv_template(name: str) -> Conversation:
)
)

# Buddy default template
register_conv_template(
Conversation(
name="openbuddy-llama3",
system_message="""<|role|>system<|says|>You(assistant) are a helpful, respectful and honest INTP-T AI Assistant named Buddy. You are talking to a human(user).
Always answer as helpfully and logically as possible, while being safe. Your answers should not include any harmful, political, religious, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
You cannot access the internet, but you have vast knowledge, cutoff: 2023-04.
You are trained by OpenBuddy team, (https://openbuddy.ai, https://github.com/OpenBuddy/OpenBuddy), not related to GPT or OpenAI.<|end|>
<|role|>user<|says|>History input 1<|end|>
<|role|>assistant<|says|>History output 1<|end|>
<|role|>user<|says|>History input 2<|end|>
<|role|>assistant<|says|>History output 2<|end|>
<|role|>user<|says|>Current input<|end|>
<|role|>assistant<|says|>
""",
roles=("user", "assistant"),
sep_style=SeparatorStyle.OPENBUDDY_LLAMA3,
sep="\n",
)
)

# Phoenix default template
register_conv_template(
Conversation(
Expand Down Expand Up @@ -1437,7 +1533,8 @@ def get_conv_template(name: str) -> Conversation:
sep_style=SeparatorStyle.RWKV,
sep="\n",
sep2="<|endoftext|>",
stop_str="\nUser", # use stop_str to stop generation after stop_token_ids, it will also remove stop_str from the generated text
stop_str="\nUser",
# use stop_str to stop generation after stop_token_ids, it will also remove stop_str from the generated text
stop_token_ids=[
0,
1,
Expand Down Expand Up @@ -1878,7 +1975,8 @@ def get_conv_template(name: str) -> Conversation:
sep_style=SeparatorStyle.FALCON_CHAT,
sep="\n",
sep2="<|endoftext|>",
stop_str="\nUser:", # use stop_str to stop generation after stop_token_ids, it will also remove stop_str from the generated text
stop_str="\nUser:",
# use stop_str to stop generation after stop_token_ids, it will also remove stop_str from the generated text
)
)

Expand Down
14 changes: 14 additions & 0 deletions fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1570,6 +1570,20 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("llama-2")

class Llama3Adapter(BaseModelAdapter):
"""The model adapter for Llama-3 (e.g., meta-llama/Meta-Llama-3-8B-Instruct)"""

def match(self, model_path: str):
return "meta-llama-3" in model_path.lower()

def load_model(self, model_path: str, from_pretrained_kwargs: dict):
model, tokenizer = super().load_model(model_path, from_pretrained_kwargs)
model.config.eos_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
return model, tokenizer

def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("llama-2")

class Llama3Adapter(BaseModelAdapter):
"""The model adapter for Llama-3 (e.g., meta-llama/Meta-Llama-3-8B-Instruct)"""
Expand Down
2 changes: 2 additions & 0 deletions fastchat/protocol/openai_api_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class ChatCompletionRequest(BaseModel):
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
user: Optional[str] = None
seed: Optional[int] = None


class ChatMessage(BaseModel):
Expand Down Expand Up @@ -166,6 +167,7 @@ class CompletionRequest(BaseModel):
user: Optional[str] = None
use_beam_search: Optional[bool] = False
best_of: Optional[int] = None
seed: Optional[int] = None


class CompletionResponseChoice(BaseModel):
Expand Down
29 changes: 22 additions & 7 deletions fastchat/serve/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
)
from fastchat.utils import build_logger


logger = build_logger("controller", "controller.log")


Expand Down Expand Up @@ -73,11 +72,11 @@ def __init__(self, dispatch_method: str):
self.heart_beat_thread.start()

def register_worker(
self,
worker_name: str,
check_heart_beat: bool,
worker_status: dict,
multimodal: bool,
self,
worker_name: str,
check_heart_beat: bool,
worker_status: dict,
multimodal: bool,
):
if worker_name not in self.worker_info:
logger.info(f"Register a new worker: {worker_name}")
Expand Down Expand Up @@ -123,7 +122,7 @@ def refresh_all_workers(self):

for w_name, w_info in old_info.items():
if not self.register_worker(
w_name, w_info.check_heart_beat, None, w_info.multimodal
w_name, w_info.check_heart_beat, None, w_info.multimodal
):
logger.info(f"Remove stale worker: {w_name}")

Expand Down Expand Up @@ -263,6 +262,17 @@ def worker_api_get_status(self):
"queue_length": queue_length,
}

def worker_get_info(self):
worker_info = self.worker_info
for w_name in worker_info:
worker_status = self.get_worker_status(w_name)
if worker_status is not None:
worker_info[w_name].model_names = worker_status["model_names"]
worker_info[w_name].speed = worker_status["speed"]
worker_info[w_name].queue_length = worker_status["queue_length"]

return worker_info

def worker_api_generate_stream(self, params):
worker_addr = self.get_worker_address(params["model"])
if not worker_addr:
Expand Down Expand Up @@ -350,6 +360,11 @@ async def worker_api_get_status(request: Request):
return "success"


@app.get("/worker_get_info")
async def worker_api_get_status(request: Request):
return controller.worker_get_info()


def create_controller():
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
Expand Down
Loading
Loading