Skip to content

Commit

Permalink
Merge pull request #469 from TylunasLi/develop
Browse files Browse the repository at this point in the history
直接读取Llama3,Qwen2的HF模型,apiserver webui benchmark使用ChatTemplate
  • Loading branch information
ztxz16 authored Jun 20, 2024
2 parents 47739b7 + 6b05906 commit b5a3902
Show file tree
Hide file tree
Showing 9 changed files with 115 additions and 35 deletions.
69 changes: 57 additions & 12 deletions example/Win32Demo/Win32Demo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,16 @@
#include "model.h"
#include <shellapi.h>

std::map <std::string, fastllm::DataType> dataTypeDict = {
{"float32", fastllm::DataType::FLOAT32},
{"half", fastllm::DataType::FLOAT16},
{"float16", fastllm::DataType::FLOAT16},
{"int8", fastllm::DataType::INT8},
{"int4", fastllm::DataType::INT4_NOZERO},
{"int4z", fastllm::DataType::INT4},
{"int4g", fastllm::DataType::INT4_GROUP}
};

enum RUN_TYPE {
RUN_TYPE_CONSOLE = 0,
RUN_TYPE_WEBUI = 1,
Expand All @@ -19,16 +29,20 @@ static int modeltype = 0;
static RUN_TYPE runType = RUN_TYPE_CONSOLE;
static std::unique_ptr<fastllm::basellm> model;
static fastllm::GenerationConfig* generationConfig;
static int sRound = 0;
static std::string modelType;
static std::string history;
static fastllm::ChatMessages* messages;
static std::string currentContent = "";


struct RunConfig {
std::string path = "chatglm-6b-int4.bin"; // 模型文件路径
std::string systemPrompt = "";
std::set <std::string> eosToken;
int threads = 4; // 使用的线程数
bool lowMemMode = false; // 是否使用低内存模式
fastllm::DataType dtype = fastllm::DataType::FLOAT16;
fastllm::DataType kvtype = fastllm::DataType::FLOAT32;
int groupCnt = -1;
bool webuiType = false; // false 控制台运行 true webui
};

Expand All @@ -38,6 +52,10 @@ void Usage() {
std::cout << "<-p|--path> <args>: 模型文件的路径" << std::endl;
std::cout << "<-t|--threads> <args>: 使用的线程数量" << std::endl;
std::cout << "<-l|--low>: 使用低内存模式" << std::endl;
std::cout << "<--system> <args>: 设置系统提示词(system prompt)" << std::endl;
std::cout << "<--eos_token> <args>:: 设置eos token" << std::endl;
std::cout << "<--dtype> <args>: 设置权重类型(读取hf文件时生效)" << std::endl;
std::cout << "<--kvtype> <args>: 设置推理使用的数据类型(float32/float16)" << std::endl;
std::cout << "<--top_p> <args>: 采样参数top_p" << std::endl;
std::cout << "<--top_k> <args>: 采样参数top_k" << std::endl;
std::cout << "<--temperature> <args>: 采样参数温度,越高结果越不固定" << std::endl;
Expand Down Expand Up @@ -70,6 +88,24 @@ void ParseArgs(int argc, char **argv, RunConfig &config, fastllm::GenerationConf
generationConfig.temperature = atof(sargv[++i].c_str());
} else if (sargv[i] == "--repeat_penalty") {
generationConfig.repeat_penalty = atof(sargv[++i].c_str());
} else if (sargv[i] == "--system") {
config.systemPrompt = sargv[++i];
} else if (sargv[i] == "--eos_token") {
config.eosToken.insert(sargv[++i]);
} else if (sargv[i] == "--dtype") {
std::string dtypeStr = sargv[++i];
if (dtypeStr.size() > 5 && dtypeStr.substr(0, 5) == "int4g") {
config.groupCnt = atoi(dtypeStr.substr(5).c_str());
dtypeStr = dtypeStr.substr(0, 5);
}
fastllm::AssertInFastLLM(dataTypeDict.find(dtypeStr) != dataTypeDict.end(),
"Unsupport data type: " + dtypeStr);
config.dtype = dataTypeDict[dtypeStr];
} else if (sargv[i] == "--kvtype") {
std::string atypeStr = sargv[++i];
fastllm::AssertInFastLLM(dataTypeDict.find(atypeStr) != dataTypeDict.end(),
"Unsupport act type: " + atypeStr);
config.kvtype = dataTypeDict[atypeStr];
} else if (sargv[i] == "-w" || sargv[i] == "--webui") {
config.webuiType = true;
} else {
Expand All @@ -83,12 +119,22 @@ int initLLMConf(RunConfig config) {
fastllm::PrintInstructionInfo();
fastllm::SetThreads(config.threads);
fastllm::SetLowMemMode(config.lowMemMode);
std::ifstream f(config.path.c_str());
if (!f.good()) {
if (!fastllm::FileExists(config.path)) {
printf("模型文件 %s 不存在!\n", config.path.c_str());
exit(0);
}
model = fastllm::CreateLLMModelFromFile(config.path);
bool isHFDir = fastllm::FileExists(config.path + "/config.json") || fastllm::FileExists(config.path + "config.json");
model = isHFDir ? fastllm::CreateLLMModelFromHF(config.path, config.dtype, config.groupCnt) : fastllm::CreateLLMModelFromFile(config.path);
if (config.kvtype != fastllm::DataType::FLOAT32) {
model->SetDataType(config.kvtype);
}
model->SetSaveHistoryChat(true);
for (auto &it : config.eosToken) {
generationConfig->stop_token_ids.insert(model->weight.tokenizer.GetTokenId(it));
}
std::string systemConfig = config.systemPrompt;
messages = new fastllm::ChatMessages({{"system", systemConfig}});

modelType = model->model_type;
runType = config.webuiType ? RUN_TYPE_WEBUI : RUN_TYPE_CONSOLE;
return 0;
Expand All @@ -101,7 +147,8 @@ int chatllm(const char* prompt, int type) {
if (runType == RUN_TYPE_CONSOLE) {
input = Gb2utf(input);
}
std::string strInput = model->MakeInput(history, sRound, input);
messages->push_back(std::make_pair("user", input));
std::string strInput = model->ApplyChatTemplate(*messages);
ret = model->Response(strInput, [](int index, const char* content) {
if (runType == RUN_TYPE_WEBUI) {
if (index > -1) {
Expand All @@ -115,7 +162,7 @@ int chatllm(const char* prompt, int type) {
printf("%s: ", modelType.c_str());
// printf("%s", result.c_str());
}
if (*content > 0 && *content < 127) {
if (*content > 0 && *content < 127 || (strlen(content) % 3 == 0 && (*content > -32 && *content < -16))) {
std::string result = utf2Gb(currentContent.c_str());
currentContent = "";
printf("%s", result.c_str());
Expand All @@ -134,8 +181,7 @@ int chatllm(const char* prompt, int type) {
}

}, *generationConfig);
history = model->MakeHistory(history, sRound, input, ret);
sRound++;
messages->push_back(std::make_pair("assistant", ret));
return ret.length();
}

Expand All @@ -145,9 +191,8 @@ void runConslusion() {
printf("用户: ");
std::string input;
std::getline(std::cin, input);
if (input == "reset") {
history = "";
sRound = 0;
if (input == "reset" || input.empty()) {
messages->erase(std::next(messages->begin()), messages->end());
continue;
}
if (input == "stop") {
Expand Down
6 changes: 4 additions & 2 deletions example/apiserver/apiserver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,15 +310,17 @@ struct WorkQueue {
}
}
if (node->error != "") {
printf("error body = %s, prompt = %s, error = %s\n", node->request.body.c_str(), node->config["prompt"].string_value().c_str(), node->error.c_str());
printf("error body = %s, prompt = %s, error = %s\n", node->request.body.c_str(), node->config["prompt"].string_value().c_str(), node->error.c_str());
message += node->error;
int ret = write(node->client, message.c_str(), message.length()); //返回error
close(node->client);
return;
}

std::string output = "";
auto prompt = model->MakeInput("", 0, node->config["prompt"].string_value());
fastllm::ChatMessages messages;
messages.push_back({"user", node->config["prompt"].string_value()});
auto prompt = model->ApplyChatTemplate(messages);
auto inputs = model->weight.tokenizer.Encode(prompt);
std::vector<int> tokens;
for (int i = 0; i < inputs.Count(0); i++) {
Expand Down
13 changes: 10 additions & 3 deletions example/benchmark/benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@ const char* GBK_LOCALE_NAME = ".936";
std::string utf8_to_gbk(const std::string& str)
{
std::wstring_convert<std::codecvt_utf8<wchar_t>> conv;
std::wstring tmp_wstr = conv.from_bytes(str);
std::wstring tmp_wstr;
try {
tmp_wstr = conv.from_bytes(str);
} catch (const std::range_error& e) {
return str;
}
std::wstring_convert<std::codecvt_byname<wchar_t, char, mbstate_t>> convert(new std::codecvt_byname<wchar_t, char, mbstate_t>(GBK_LOCALE_NAME));
return convert.to_bytes(tmp_wstr);
}
Expand Down Expand Up @@ -134,7 +139,7 @@ int main(int argc, char **argv) {
}
}
if (inputs.empty()) {
inputs.push_back("Hello");
inputs.push_back("Hello!");
}
if (config.batch <= 0) {
config.batch = inputs.size();
Expand All @@ -148,7 +153,9 @@ int main(int argc, char **argv) {

int promptTokenNum = 0;
for (int i = 0; i < inputs.size(); i++) {
inputs[i] = model->MakeInput("", 0, inputs[i]);
fastllm::ChatMessages messages;
messages.push_back({"user", inputs[i]});
inputs[i] = model->ApplyChatTemplate(messages);
promptTokenNum += model->weight.tokenizer.Encode(inputs[i]).Count(0);
}

Expand Down
17 changes: 11 additions & 6 deletions example/webui/webui.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ std::map <std::string, fastllm::DataType> dataTypeDict = {
struct WebConfig {
std::string path = "chatglm-6b-int4.bin"; // 模型文件路径
std::string webPath = "web"; // 网页文件路径
std::string systemPrompt = "You are a helpful assistant.";
int threads = 4; // 使用的线程数
bool lowMemMode = false; // 是否使用低内存模式
int port = 8081; // 端口号
Expand All @@ -39,6 +40,8 @@ void Usage() {
std::cout << "<--dtype> <args>: 设置权重类型(读取hf文件时生效)" << std::endl;
std::cout << "<-t|--threads> <args>: 使用的线程数量" << std::endl;
std::cout << "<-l|--low>: 使用低内存模式" << std::endl;
std::cout << "<--system> <args>: 设置系统提示词(system prompt)" << std::endl;
std::cout << "<--dtype> <args>: 设置权重类型(读取hf文件时生效)" << std::endl;
std::cout << "<-w|--web> <args>: 网页文件的路径" << std::endl;
std::cout << "<--port> <args>: 网页端口号" << std::endl;
}
Expand All @@ -58,6 +61,8 @@ void ParseArgs(int argc, char **argv, WebConfig &config) {
config.threads = atoi(sargv[++i].c_str());
} else if (sargv[i] == "-l" || sargv[i] == "--low") {
config.lowMemMode = true;
} else if (sargv[i] == "--system") {
config.systemPrompt = sargv[++i];
} else if (sargv[i] == "--dtype") {
std::string dtypeStr = sargv[++i];
if (dtypeStr.size() > 5 && dtypeStr.substr(0, 5) == "int4g") {
Expand All @@ -79,10 +84,9 @@ void ParseArgs(int argc, char **argv, WebConfig &config) {
}

struct ChatSession {
std::string history = "";
fastllm::ChatMessages messages;
std::string input = "";
std::string output = "";
int round = 0;
int status = 0; // 0: 空闲 1: 结果生成好了 2: 已经写回了
};

Expand All @@ -106,12 +110,12 @@ int main(int argc, char** argv) {
httplib::Server svr;
auto chat = [&](ChatSession *session, const std::string input) {
if (input == "reset" || input == "stop") {
session->history = "";
session->round = 0;
session->messages.erase(std::next(session->messages.begin()), session->messages.end());
session->output = "<eop>\n";
session->status = 2;
} else {
auto prompt = model->MakeInput(session->history, session->round, input);
session->messages.push_back(std::make_pair("user", input));
auto prompt = model->ApplyChatTemplate(session->messages);
auto inputs = model->weight.tokenizer.Encode(prompt);

std::vector<int> tokens;
Expand All @@ -134,7 +138,7 @@ int main(int argc, char** argv) {
break;
}
}
session->history = model->MakeHistory(session->history, session->round++, input, session->output);
session->messages.push_back(std::make_pair("assistant", session->output));
session->output += "<eop>\n";
session->status = 2;
}
Expand All @@ -145,6 +149,7 @@ int main(int argc, char** argv) {
locker.lock();
if (sessions.find(uuid) == sessions.end()) {
sessions[uuid] = new ChatSession();
sessions[uuid]->messages.push_back({"system", config.systemPrompt});
}
auto *session = sessions[uuid];
locker.unlock();
Expand Down
2 changes: 1 addition & 1 deletion include/template.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ namespace fastllm {
{"or", JinjaToken::JinjaToKenType::JinjaTokenOr},
};

// 一个Jinja块
// 一个Jinja块
struct JinjaBlock {
enum JinjaBlockType {
JinjaBlockOriginal = 0, JinjaBlockEmpty, JinjaBlockVar, JinjaBlockFor,
Expand Down
3 changes: 2 additions & 1 deletion main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ void Usage() {
std::cout << "<--system> <args>: 设置系统提示词(system prompt)" << std::endl;
std::cout << "<--eos_token> <args>: 设置eos token" << std::endl;
std::cout << "<--dtype> <args>: 设置权重类型(读取hf文件时生效)" << std::endl;
std::cout << "<--atype> <args>: 设置推理使用的数据类型(float32/float16)" << std::endl;
std::cout << "<--top_p> <args>: 采样参数top_p" << std::endl;
std::cout << "<--top_k> <args>: 采样参数top_k" << std::endl;
std::cout << "<--temperature> <args>: 采样参数温度,越高结果越不固定" << std::endl;
Expand Down Expand Up @@ -104,7 +105,7 @@ int main(int argc, char **argv) {
if (config.atype != fastllm::DataType::FLOAT32) {
model->SetDataType(config.atype);
}
model->SetSaveHistoryChat(true);
model->SetSaveHistoryChat(true);

for (auto &it : config.eosToken) {
generationConfig.stop_token_ids.insert(model->weight.tokenizer.GetTokenId(it));
Expand Down
35 changes: 25 additions & 10 deletions src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,13 +402,37 @@ namespace fastllm {
for (auto &it : config.object_items()) {
model->weight.AddDict(it.first, it.second.dump().c_str());
}
// 设置eos_token_id
if (config["eos_token_id"].is_array()) {
for (auto &it : config["eos_token_id"].array_items()) {
model->eos_token_ids.insert(it.int_value());
}
} else {
model->eos_token_id = config["eos_token_id"].int_value();
}

std::string generatetionConfigFile = path + "generation_config.json";
if (FileExists(generatetionConfigFile)) {
auto generation_config = json11::Json::parse(ReadAllFile(generatetionConfigFile), error);
for (auto &it : generation_config.object_items()) {
if ("eos_token_id" == it.first && it.second.type() == json11::Json::ARRAY)
continue;
model->weight.AddDict(it.first, it.second.dump().c_str());
}
// 更新eos_token_id
if (generation_config["eos_token_id"].is_array()) {
for (auto &it : generation_config["eos_token_id"].array_items()) {
model->eos_token_ids.insert(it.int_value());
}
}
}

// 3. 读取分词
std::string tokenizerConfigFile = path + "tokenizer_config.json";
auto tokenizerConfig = json11::Json::parse(ReadAllFile(tokenizerConfigFile), error);
model->weight.tokenizer.SetTokenizerConfig(tokenizerConfig);
std::string tokenizerClass = tokenizerConfig["tokenizer_class"].string_value();
if (tokenizerClass == "PreTrainedTokenizerFast") {
if (tokenizerClass == "PreTrainedTokenizerFast" || tokenizerClass == "Qwen2Tokenizer") {
// PreTrainedTokenizerFast
std::string tokenizerFile = path + "tokenizer.json";
auto tokenizer = json11::Json::parse(ReadAllFile(tokenizerFile), error);
Expand Down Expand Up @@ -445,15 +469,6 @@ namespace fastllm {
((ChatGLMModel*)model)->bos_token_id = model->weight.tokenizer.GetTokenId("<sop>");
((ChatGLMModel*)model)->tokenizerClass = tokenizerClass;

// 设置eos_token_id
if (config["eos_token_id"].is_array()) {
for (auto &it : config["eos_token_id"].array_items()) {
model->eos_token_ids.insert(it.int_value());
}
} else {
model->eos_token_id = config["eos_token_id"].int_value();
}

// ChatGLM采用拼接token的方法,需要强行指定分割词的TokenID
model->pre_prompt = "";
model->user_role = ("<FLM_FIX_TOKEN_" + std::to_string(model->weight.tokenizer.GetTokenId("<|user|>")) + ">\n");
Expand Down
1 change: 1 addition & 0 deletions src/models/basellm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -985,6 +985,7 @@ printf("len = %d, spend = %f s. tokens / s = %f\n", (int)total, spend, (float)to
{"content", message.second}
});
}
ret["add_generation_prompt"] = fastllm::JinjaVar{1};
return ret;
}

Expand Down
4 changes: 4 additions & 0 deletions src/models/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -925,6 +925,10 @@ namespace fastllm {
pastKeyValues.push_back(std::make_pair(Data(DataType::FLOAT32),
Data(DataType::FLOAT32)));
}
if (this->weight.weight.find("lm_head.weight") == this->weight.weight.end()) {
this->weight["lm_head.weight"] = Data();
this->weight["lm_head.weight"].CopyFrom(this->weight["model.embed_tokens.weight"]);
}
Forward(inputIds, attentionMask, positionIds, pastKeyValues);
printf("finish.\n");
}
Expand Down

0 comments on commit b5a3902

Please sign in to comment.