Skip to content

Commit

Permalink
[MNN::Bugfix] Some bugfix sync.
Browse files Browse the repository at this point in the history
  • Loading branch information
wangzhaode committed Dec 12, 2023
1 parent 1ea55f4 commit 72fa060
Show file tree
Hide file tree
Showing 38 changed files with 1,268 additions and 431 deletions.
64 changes: 64 additions & 0 deletions llm/cli_demo.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
//
// cli_demo.cpp
//
// Created by MNN on 2023/03/24.
// ZhaodeWang
//

#include "llm.hpp"
#include <fstream>
#include <stdlib.h>

void benchmark(Llm* llm, std::string prompt_file) {
std::cout << "prompt file is " << prompt_file << std::endl;
std::ifstream prompt_fs(prompt_file);
std::vector<std::string> prompts;
std::string prompt;
while (std::getline(prompt_fs, prompt)) {
// prompt start with '#' will be ignored
if (prompt.substr(0, 1) == "#") {
continue;
}
prompts.push_back(prompt);
}
int prompt_len = 0;
int decode_len = 0;
int64_t prefill_time = 0;
int64_t decode_time = 0;
// llm->warmup();
for (int i = 0; i < prompts.size(); i++) {
llm->response(prompts[i]);
prompt_len += llm->prompt_len_;
decode_len += llm->gen_seq_len_;
prefill_time += llm->prefill_us_;
decode_time += llm->decode_us_;
llm->reset();
}
float prefill_s = prefill_time / 1e6;
float decode_s = decode_time / 1e6;
printf("\n#################################\n");
printf("prompt tokens num = %d\n", prompt_len);
printf("decode tokens num = %d\n", decode_len);
printf("prefill time = %.2f s\n", prefill_s);
printf(" decode time = %.2f s\n", decode_s);
printf("prefill speed = %.2f tok/s\n", prompt_len / prefill_s);
printf(" decode speed = %.2f tok/s\n", decode_len / decode_s);
printf("##################################\n");
}

int main(int argc, const char* argv[]) {
if (argc < 2) {
std::cout << "Usage: " << argv[0] << " model_dir <prompt.txt>" << std::endl;
return 0;
}
std::string model_dir = argv[1];
std::cout << "model path is " << model_dir << std::endl;
std::unique_ptr<Llm> llm(Llm::createLLM(model_dir));
llm->load(model_dir);
if (argc < 3) {
llm->chat();
}
std::string prompt_file = argv[2];
benchmark(llm.get(), prompt_file);
return 0;
}
39 changes: 30 additions & 9 deletions llm/include/llm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,37 @@

using namespace MNN;
using namespace Express;
class Tokenizer;

class MNN_PUBLIC Llm {
public:
Llm() {
// default tokenier is senrencepiece
tokenizer_.reset(new Sentencepiece);
}
static Llm* createLLM(const std::string& path);
VARP gen_embedding(const std::vector<int>& input_ids);
virtual ~Llm() = default;
static Llm* createLLM(const std::string& path, std::string model_type = "auto");
VARP disk_embedding(const std::vector<int>& input_ids);
void load(const std::string& model_dir);
int forward(const std::vector<int>& input_ids);
std::vector<int> tokenizer_encode(const std::string& input_str);
std::string decode(int id);
std::string response(const std::string& input_str, std::ostream* os = &std::cout);
void chat();
void warmup();
std::string response(const std::string& input_str, std::ostream* os = &std::cout, const char* end_with = nullptr);
float load_progress() { return load_progress_; }
void reset();
void print_speed();
public:
std::vector<int> history_;
// forward info
int max_seq_len_ = 1024;
int prompt_len_ = 0;
int gen_seq_len_ = 0;
int all_seq_len_ = 0;
// time
int64_t prefill_us_ = 0;
int64_t decode_us_ = 0;
private:
virtual std::vector<int> tokenizer(const std::string& query) = 0;
virtual VARP gen_attention_mask(int seq_len) = 0;
Expand All @@ -52,9 +67,6 @@ class MNN_PUBLIC Llm {
std::vector<int> key_value_shape_ = {};
std::string model_name_ = "";
// gen info
int gen_seq_len_ = 0;
int all_seq_len_ = 0;
int max_seq_len_ = 256;
float load_progress_ = 0.f;
// tokenizer
std::unique_ptr<Tokenizer> tokenizer_;
Expand All @@ -65,9 +77,6 @@ class MNN_PUBLIC Llm {
std::vector<VARP> past_key_values_;
// model dir
std::string model_dir_;
// tokenizer
std::vector<std::string> word_decoder_;
std::unordered_map<std::string, int> word_encoder_;
};

// some llm models
Expand Down Expand Up @@ -107,6 +116,7 @@ class Qwen_7b : public Llm {
model_name_ = "Qwen_7b";
layer_nums_ = 32;
key_value_shape_ = {2, 1, 0, 32, 128};
hidden_size_ = 4096;
tokenizer_.reset(new Tiktoken);
}
private:
Expand All @@ -116,6 +126,17 @@ class Qwen_7b : public Llm {
virtual bool is_stop(int token_id) override;
};

class Qwen_1_8b : public Qwen_7b {
public:
Qwen_1_8b() {
model_name_ = "Qwen_1.8b";
layer_nums_ = 24;
key_value_shape_ = {2, 1, 0, 16, 128};
hidden_size_ = 2048;
tokenizer_.reset(new Tiktoken);
}
};

class Llama2_7b : public Llm {
public:
Llama2_7b() {
Expand Down
1 change: 1 addition & 0 deletions llm/include/tokenizer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
class Tokenizer {
public:
Tokenizer() = default;
virtual ~Tokenizer() = default;
virtual bool load(const std::string& filename) = 0;
virtual std::vector<int> encode(const std::string& str) = 0;
virtual std::string decode(int id) = 0;
Expand Down
50 changes: 46 additions & 4 deletions llm/llm_demo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,59 @@
//

#include "llm.hpp"
#include <iostream>
#include <fstream>
#include <stdlib.h>

void benchmark(Llm* llm, std::string prompt_file) {
std::cout << "prompt file is " << prompt_file << std::endl;
std::ifstream prompt_fs(prompt_file);
std::vector<std::string> prompts;
std::string prompt;
while (std::getline(prompt_fs, prompt)) {
// prompt start with '#' will be ignored
if (prompt.substr(0, 1) == "#") {
continue;
}
prompts.push_back(prompt);
}
int prompt_len = 0;
int decode_len = 0;
int64_t prefill_time = 0;
int64_t decode_time = 0;
// llm->warmup();
for (int i = 0; i < prompts.size(); i++) {
llm->response(prompts[i]);
prompt_len += llm->prompt_len_;
decode_len += llm->gen_seq_len_;
prefill_time += llm->prefill_us_;
decode_time += llm->decode_us_;
llm->reset();
}
float prefill_s = prefill_time / 1e6;
float decode_s = decode_time / 1e6;
printf("\n#################################\n");
printf("prompt tokens num = %d\n", prompt_len);
printf("decode tokens num = %d\n", decode_len);
printf("prefill time = %.2f s\n", prefill_s);
printf(" decode time = %.2f s\n", decode_s);
printf("prefill speed = %.2f tok/s\n", prompt_len / prefill_s);
printf(" decode speed = %.2f tok/s\n", decode_len / decode_s);
printf("##################################\n");
}

int main(int argc, const char* argv[]) {
if (argc < 2) {
std::cout << "Usage: ./llm_demo.out <model_path>" << std::endl;
std::cout << "Usage: " << argv[0] << " model_dir <prompt.txt>" << std::endl;
return 0;
}
std::string model_dir = argv[1];
std::cout << "model path is " << model_dir << std::endl;
std::unique_ptr<Llm> llm(Llm::createLLM(model_dir));
llm->load(model_dir);
llm->response("你好");
if (argc < 3) {
llm->chat();
}
std::string prompt_file = argv[2];
benchmark(llm.get(), prompt_file);
return 0;
}
}
Loading

0 comments on commit 72fa060

Please sign in to comment.