Skip to content

Commit

Permalink
feat: openai adapter support config file & model mapping (#475)
Browse files Browse the repository at this point in the history
* support config file & model mapping

* fix merge
  • Loading branch information
ZingLix authored Apr 26, 2024
1 parent 0797adf commit 49dbfe6
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 39 deletions.
29 changes: 29 additions & 0 deletions docs/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -428,4 +428,33 @@ export OPENAI_BASE_URL='http://127.0.0.1:8001/v1' # 模拟 OpenAI 接口的地
* `--port / -p`: 绑定的端口,默认为 [default:8001]
* `--detach / -d`: 后台运行
* `--log-file`: 日志文件路径,默认不输出至文件
* `--ignore-system / --no-ignore-system`: 是否忽略消息中的 system 字段,默认忽略
* `--config-file / -c`: 配置文件路径,具体文件格式请参考下方配置文件说明
* `--help`: 展示帮助信息

#### 配置文件

配置文件为 YAML 格式,支持传递如模型映射等更为复杂的信息,配置格式和默认参数如下:

> ⚠️ 命令行的参数会覆盖配置文件中的参数
```yaml
openai_adapter:
# 绑定的 host
host: 0.0.0.0
# 运行端口
port: 8001
# 是否后台运行
detach: false
# 日志文件路径,不设置则不输出至文件
log_file: null
# 模型映射,将模型名称从 OpenAI 的模型映射至千帆模型
# 当所有映射都失败时,会使用原始名称
# 映射的 key 与 value 均支持正则表达式
model_mapping:
gpt-3.5.*: ERNIE-3.5-8K
gpt-4.*: ERNIE-4.0-8K
text-embedding.*: Embedding-V1
# 支持正则替换,如下仅作示意
# gpt-3.5(.*): ERNIE-3.5\1
```
75 changes: 69 additions & 6 deletions python/qianfan/common/client/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@
from qianfan.common.client.plugin import plugin_entry
from qianfan.common.client.trainer import trainer_app
from qianfan.common.client.txt2img import txt2img_entry
from qianfan.common.client.utils import credential_required, print_error_msg
from qianfan.common.client.utils import (
credential_required,
print_error_msg,
print_info_msg,
)
from qianfan.config import encoding
from qianfan.utils.utils import check_dependency

app = typer.Typer(
Expand All @@ -52,23 +57,81 @@
@app.command(name="openai")
@credential_required
def openai(
host: str = typer.Option("0.0.0.0", "--host", "-h", help="Host to bind."),
port: int = typer.Option(8001, "--port", "-p", help="Port of the server."),
detach: bool = typer.Option(
False,
host: Optional[str] = typer.Option(
None,
"--host",
"-h",
help="Host to bind. [dim]\[default: 0.0.0.0][/]",
show_default=False,
),
port: Optional[int] = typer.Option(
None,
"--port",
"-p",
help="Port of the server. [dim]\[default: 8001][/]",
show_default=False,
),
detach: Optional[bool] = typer.Option(
None,
"--detach",
"-d",
help="Run the server in background.",
),
ignore_system: Optional[bool] = typer.Option(
None,
help="Ignore system messages in input. [dim]\[default: True][/]",
show_default=False,
),
log_file: Optional[str] = typer.Option(None, help="Log file path."),
config_file: Optional[str] = typer.Option(
None, help="Config file path.", show_default=False
),
) -> None:
"""
Create an openai wrapper server.
"""
check_dependency("openai", ["fastapi", "uvicorn"])
from qianfan.common.client.openai_adapter import entry as openai_entry

openai_entry(host=host, port=port, detach=detach, log_file=log_file)
default_config = {
"host": "0.0.0.0",
"port": 8001,
"detach": False,
"ignore_system": True,
"log_file": None,
"model_mapping": None,
}
adapter_config = {}
if config_file is not None:
import yaml

with open(config_file, "r", encoding=encoding()) as f:
config = yaml.safe_load(f)
adapter_config = config.get("openai_adapter")
if adapter_config is None:
raise ValueError("Config file should contain a key named `openai_adapter`.")
if ignore_system is None and adapter_config.get("ignore_system") is None:
print_info_msg(
"`--no-ignore-system` is not set. System messages will be ignored by"
" default since most system messages for openai is not suitable for ERNIE"
" model."
)
print()

merged_config = {**default_config, **adapter_config}

openai_entry(
host=host if host is not None else merged_config["host"],
port=port if port is not None else merged_config["port"],
detach=detach if detach is not None else merged_config["detach"],
log_file=log_file if log_file is not None else merged_config["log_file"],
ignore_system=(
ignore_system
if ignore_system is not None
else merged_config["ignore_system"]
),
model_mapping=merged_config["model_mapping"],
)


@app.command(name="proxy")
Expand Down
18 changes: 15 additions & 3 deletions python/qianfan/common/client/openai_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import json
from typing import Any, AsyncIterator, Optional
from typing import Any, AsyncIterator, Dict, Optional

from fastapi import FastAPI, Request, Response
from fastapi.responses import JSONResponse
Expand All @@ -23,7 +23,14 @@
from qianfan.utils.utils import get_ip_address


def entry(host: str, port: int, detach: bool, log_file: Optional[str]) -> None:
def entry(
host: str,
port: int,
detach: bool,
log_file: Optional[str],
ignore_system: bool,
model_mapping: Dict[str, str],
) -> None:
import rich
import uvicorn
import uvicorn.config
Expand All @@ -50,7 +57,8 @@ def entry(host: str, port: int, detach: bool, log_file: Optional[str]) -> None:
display_host = host
if display_host == "0.0.0.0":
display_host = get_ip_address()
messages.append(f"- http://{display_host}:{port}")
if display_host != "127.0.0.1":
messages.append(f"- http://{display_host}:{port}")

messages.append("\nRemember to set the environment variables:")
messages.append(f"""```shell
Expand All @@ -67,6 +75,10 @@ def start_server() -> None:

adapter = OpenAIApdater()

if model_mapping is not None:
adapter._model_mapping = model_mapping
adapter._ignore_system = ignore_system

async def stream(resp: AsyncIterator[Any]) -> AsyncIterator[str]:
"""
Convert an async iterator to a stream.
Expand Down
65 changes: 36 additions & 29 deletions python/qianfan/extensions/openai/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import asyncio
import json
import re
from typing import Any, AsyncIterator, Dict, List, Optional, Tuple, TypeVar, Union

import qianfan
Expand All @@ -28,19 +29,6 @@
QianfanResponse = Dict[str, Any]


def _convert_model(model: str) -> str:
"""
Convert OpenAI model name to Qianfan model name.
"""
if model.lower().startswith("gpt-3.5"):
return "ERNIE-3.5-8K"
elif model.lower().startswith("gpt-4"):
return "ERNIE-4.0-8K"
elif model.lower().startswith("text-embedding"):
return "Embedding-V1"
return model


def merge_async_iters(*aiters: AsyncIterator[_T]) -> AsyncIterator[_T]:
"""
Merge multiple async iterators into one.
Expand Down Expand Up @@ -100,14 +88,35 @@ class OpenAIApdater(object):
This value is used to split OpenAI requests into multiple Qianfan requests.
"""

def __init__(self) -> None:
def __init__(
self,
ignore_system: bool = True,
model_mapping: Dict[str, str] = {
r"gpt-3.5.*": "ERNIE-3.5-8K",
r"gpt-4.*": "ERNIE-4.0-8K",
r"text-embedding.*": "Embedding-V1",
},
) -> None:
self._chat_client = qianfan.ChatCompletion()
self._comp_client = qianfan.Completion()
self._embed_client = qianfan.Embedding()

@classmethod
self._ignore_system = ignore_system
self._model_mapping = model_mapping

def _convert_model(self, model: str) -> str:
"""
Convert OpenAI model name to Qianfan model name.
"""
for pattern, qianfan_model in self._model_mapping.items():
new_model, n = re.subn(pattern, qianfan_model, model)
if n != 0:
return new_model

return model

def openai_base_request_to_qianfan(
cls, openai_request: OpenAIRequest
self, openai_request: OpenAIRequest
) -> QianfanRequest:
"""
Convert general arguments in OpenAI request to Qianfan request.
Expand All @@ -129,7 +138,7 @@ def add_if_exist(openai_key: str, qianfan_key: Optional[str] = None) -> None:
add_if_exist("tool_choice")

model = openai_request["model"]
qianfan_request["model"] = _convert_model(model)
qianfan_request["model"] = self._convert_model(model)

if "presence_penalty" in openai_request:
penalty = openai_request["presence_penalty"]
Expand Down Expand Up @@ -166,17 +175,17 @@ def add_if_exist(openai_key: str, qianfan_key: Optional[str] = None) -> None:
qianfan_request["response_format"] = response_format
return qianfan_request

@classmethod
def openai_chat_request_to_qianfan(
cls, openai_request: OpenAIRequest
self, openai_request: OpenAIRequest
) -> QianfanRequest:
"""
Convert chat request in OpenAI to Qianfan request.
"""
qianfan_request = cls.openai_base_request_to_qianfan(openai_request)
qianfan_request = self.openai_base_request_to_qianfan(openai_request)
messages = openai_request["messages"]
if messages[0]["role"] == "system":
qianfan_request["system"] = messages[0]["content"]
if not self._ignore_system:
qianfan_request["system"] = messages[0]["content"]
messages = messages[1:]

for item in messages:
Expand All @@ -194,31 +203,29 @@ def openai_chat_request_to_qianfan(
qianfan_request["messages"] = messages
return qianfan_request

@classmethod
def openai_completion_request_to_qianfan(
cls, openai_request: OpenAIRequest
self, openai_request: OpenAIRequest
) -> QianfanRequest:
"""
Convert completion request in OpenAI to Qianfan request.
"""
qianfan_request = cls.openai_base_request_to_qianfan(openai_request)
qianfan_request = self.openai_base_request_to_qianfan(openai_request)
prompt = openai_request["prompt"]
if isinstance(prompt, list):
prompt = "".join(prompt)
qianfan_request["prompt"] = prompt
return qianfan_request

@classmethod
def convert_openai_embedding_request(
cls, openai_request: OpenAIRequest
self, openai_request: OpenAIRequest
) -> List[QianfanRequest]:
"""
Converts embedding request in OpenAI to multiple Qianfan requests.
Since Qianfan has limits on the count of texts in one request, we need to
split the OpenAI request to multiple Qianfan requests.
"""
qianfan_request = cls.openai_base_request_to_qianfan(openai_request)
qianfan_request = self.openai_base_request_to_qianfan(openai_request)
input = openai_request["input"]
if isinstance(input, str):
input = [input]
Expand All @@ -227,11 +234,11 @@ def convert_openai_embedding_request(
while i < len(input):
request_list.append(
{
"texts": input[i : min(i + cls.EmbeddingBatchSize, len(input))],
"texts": input[i : min(i + self.EmbeddingBatchSize, len(input))],
**qianfan_request,
},
)
i += cls.EmbeddingBatchSize
i += self.EmbeddingBatchSize

return request_list

Expand Down
2 changes: 1 addition & 1 deletion python/qianfan/resources/requestor/openapi_requestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ async def async_llm(
"""
llm related api request
"""
log_info(f"async requesting llm api endpoint: {endpoint}")
log_debug(f"async requesting llm api endpoint: {endpoint}")

@self._async_retry_if_token_expired
async def _helper() -> Union[QfResponse, AsyncIterator[QfResponse]]:
Expand Down

0 comments on commit 49dbfe6

Please sign in to comment.