Skip to content

Commit

Permalink
add about 10 simple samplers
Browse files Browse the repository at this point in the history
  • Loading branch information
huangzhengxiang committed Sep 5, 2024
1 parent ab9b6ac commit 3bfc3b6
Show file tree
Hide file tree
Showing 7 changed files with 736 additions and 122 deletions.
3 changes: 3 additions & 0 deletions express/NeuralNetWorkOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,9 @@ VARP _Softmax(VARP logits, int axis) {
softmax->main.AsAxis()->axis = axis;
return (Variable::create(Expr::create(softmax.get(), {logits})));
}
VARP _TempratureSoftmax(VARP logits, float temperature, int axis) {
return _Softmax(logits * _Scalar<float>(1.0f / temperature), axis);
}
/*Computes softplus: log(exp(features) + 1).
Args:
features: A variable. Must be Halide_Type_Float.
Expand Down
1 change: 1 addition & 0 deletions include/MNN/expr/NeuralNetWorkOp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ MNN_PUBLIC VARP _Relu(VARP x, float slope = 0.0f);
MNN_PUBLIC VARP _Relu6(VARP x, float minValue = 0.0f, float maxValue = 6.0f);
MNN_PUBLIC VARP _PRelu(VARP x, std::vector<float> &&slopes);
MNN_PUBLIC VARP _Softmax(VARP logits, int axis = -1);
MNN_PUBLIC VARP _TempratureSoftmax(VARP logits, float temperature, int axis = -1);
MNN_PUBLIC VARP _Softplus(VARP features);
MNN_PUBLIC VARP _Softsign(VARP features);
MNN_PUBLIC std::vector<VARP> _Split(VARP value, INTS size_splits, int axis = 0);
Expand Down
18 changes: 11 additions & 7 deletions transformers/llm/engine/include/llm/llm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@
#include <MNN/expr/Module.hpp>
#include <MNN/expr/MathOp.hpp>
#include <MNN/expr/NeuralNetWorkOp.hpp>
#include "sampler/sampler.hpp"

namespace MNN {
namespace Transformer {
class Sampler;
class Tokenizer;
class Pipeline;
class LlmConfig;
Expand Down Expand Up @@ -53,20 +55,21 @@ class MNN_PUBLIC Llm {
Llm(std::shared_ptr<LlmConfig> config) : config_(config) {}
virtual ~Llm();
static Llm* createLLM(const std::string& config_path);
void chat();
void chat(std::ostream* time_log=nullptr);
void reset();
void trace(bool start);
virtual void load();
MNN::Express::VARP forward(const std::vector<int>& input_ids);
int sample(MNN::Express::VARP logits, const std::vector<int>& pre_ids);
MNN::Express::VARP forward(const std::vector<int>& input_ids, bool prefill=true);
std::string decode(int id);
bool is_stop(int token_id);
std::string apply_prompt_template(const std::string& user_content) const;
std::string apply_chat_template(const std::vector<PromptItem>& chat_prompts) const;
std::string response(const std::string& user_content, std::ostream* os = &std::cout, const char* end_with = nullptr);
std::string response(const std::vector<PromptItem>& chat_prompts, std::ostream* os = &std::cout, const char* end_with = nullptr);
std::string response(std::vector<PromptItem>& chat_prompts, std::ostream* os = &std::cout, const char* end_with = nullptr);
void generate_init();
std::string generate(const std::vector<int>& input_ids, std::ostream* os, const char* end_with);
std::vector<int> generate(const std::vector<int>& input_ids, int max_new_tokens = -1);
void print_speed();
void print_speed(std::ostream* os);
// config function
std::string dump_config();
bool set_config(const std::string& content);
Expand All @@ -85,9 +88,11 @@ class MNN_PUBLIC Llm {
// time
int64_t prefill_us_ = 0;
int64_t decode_us_ = 0;
TimePerformance time_perf_;
bool is_single_ = true;
bool attention_fused_ = true;
protected:
std::shared_ptr<Sampler> sampler_;
std::shared_ptr<LlmConfig> config_;
std::shared_ptr<Tokenizer> tokenizer_;
std::vector<int> key_value_shape_ = {};
Expand All @@ -97,9 +102,8 @@ class MNN_PUBLIC Llm {
std::vector<std::shared_ptr<MNN::Express::Module>> modules_;
std::vector<std::shared_ptr<MNN::Express::Module>> prefill_modules_, decode_modules_, current_modules_;
const MNN::Express::Module* base_module_ = nullptr;
void initSampler();
void init_runtime();
std::string decode(int id);
bool is_stop(int token_id);
virtual std::vector<int> tokenizer(const std::string& query);
virtual MNN::Express::VARP embedding(const std::vector<int>& input_ids);
virtual MNN::Express::VARP gen_attention_mask(int seq_len);
Expand Down
133 changes: 133 additions & 0 deletions transformers/llm/engine/include/sampler/sampler.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
#ifndef SAMPLER_hpp
#define SAMPLER_hpp

#include <vector>
#include <memory>
#include <string>
#include <fstream>
#include <sstream>
#include <iostream>
#include <streambuf>
#include <functional>
#include <unordered_map>
#include <utility>

#include <MNN/expr/Expr.hpp>
#include <MNN/expr/Module.hpp>
#include <MNN/expr/MathOp.hpp>
#include <MNN/expr/NeuralNetWorkOp.hpp>

namespace MNN {
namespace Transformer {

#define MICRO_TO_MILLI 1e-3f
#define MILLI_TO_MICRO 1000
#define MICRO_TO_SEC 1e-6f
#define SEC_TO_MICRO 1000000

#define MEGA_TO_GIGA (1/1024.f)
#define GIGA_TO_MEGA 1024.f
#define KILLO_TO_GIGA (1/1024.f/1024.f)
#define GIGA_TO_KILLO (1024.f*1024.f)
#define KILLO_TO_MEGA (1/1024.f)
#define MEGA_TO_KILLO 1024.f

struct PrefillTimePerformance {
size_t prefill_prev_token_ = 0;
size_t prefill_token_ = 0;
size_t prefill_us_ = 0;
};

struct DecodeTimePerformance {
size_t decode_prev_token_ = 0;
size_t decode_us_ = 0;
};

struct TimePerformance {
std::vector<PrefillTimePerformance> prefill_record_;
std::vector<DecodeTimePerformance> decode_record_;
};

void mergePerformance(struct TimePerformance* dst, struct TimePerformance* src);
void clearPerformance(struct TimePerformance* perf);

class Llm;

class MNN_PUBLIC Sampler {
protected:
Llm* mLlm;
std::vector<std::vector<int>> mCandidates;
std::vector<int> mCommonPrefix;
int mMaxNewTokens;
int getGenLength(int candidate, int output_len) const {
return mCandidates[candidate].size() - (mCommonPrefix.size() - output_len);
}
public:
virtual std::string sample(const std::vector<int>& input_ids, std::ostream* os = &std::cout, const char* end_with = nullptr, struct TimePerformance* time_perf = nullptr) = 0;
// prepare for another round of sampling
// in the future, only reset its own.
virtual void reset() {}
};

class MNN_PUBLIC LocalSampler : public Sampler {
public:
struct LocalSamplerConfig {
std::string type = "temperature";
float temperature = 0.8;
int topK = 40;
float topP = 0.9;
float minP = 0.05;
float tfsZ = 1.0;
float typical = 0.95;
float penalty = 1.1;
int ngram = 8;
float ngram_factor = 1.0; // panalize repeated ngram with a multiplied ngram_factor.
float max_penalty = 10.;
};
private:
struct LocalSamplerConfig mConfig;
int randomSelect(float* probs, size_t size);
int argmax(MNN::Express::VARP logits);
int temperature(MNN::Express::VARP logits, float temperature = 1.0);
struct IndexProb {
int index;
float prob;
};
struct IndexProbCmpLess{
bool operator()(IndexProb a, IndexProb b) {
return a.prob < b.prob;
}
};
struct IndexProbCmpGreater{
bool operator()(IndexProb a, IndexProb b) {
return a.prob > b.prob;
}
};
int reSoftmaxSelect(std::vector<int> index, std::vector<float> scores, float temperature);
void topK(MNN::Express::VARP logits, int K, std::vector<int>& topKindex, std::vector<float>& topKprob);
int topK(MNN::Express::VARP logits, int K = 40, float temperature = 1.0);
void topP(MNN::Express::VARP logits, float p, float temperature, std::vector<int>& topPindex, std::vector<float>& topPprob);
int topP(MNN::Express::VARP logits, float p = 0.9, float temperature = 1.0);
void minP(MNN::Express::VARP logits, float p, float temperature, std::vector<int>& minPindex, std::vector<float>& minPprob);
int minP(MNN::Express::VARP logits, float p = 0.1, float temperature = 1.0);
void tfs(MNN::Express::VARP logits, float z, float temperature, std::vector<int>& index, std::vector<float>& tfsprob);
int tfs(MNN::Express::VARP logits, float z = 1.0, float temperature = 1.0);
void typical(MNN::Express::VARP logits, float p, float temperature, std::vector<int>& index, std::vector<float>& minPprob);
int typical(MNN::Express::VARP logits, float p = 1.0, float temperature = 1.0);
void penalty(MNN::Express::VARP logits, float penalty = 1.0, bool penalizeNgram = false, int ngram = 8, float ngram_factor = 1.0);
int penalty(MNN::Express::VARP logits, float penalty = 1.0, int ngram = 8, float ngram_factor = 1.0, float temperature = 1.0);
// int mixed(MNN::Express::VARP logits);
std::string handleToken(int token, std::ostream* os = &std::cout, const char* end_with = nullptr);
public:
LocalSampler(Llm* llm, int max_new_tokens, struct LocalSamplerConfig config);
int algorithm(MNN::Express::VARP logits);
virtual std::string sample(const std::vector<int>& input_ids, std::ostream* os = &std::cout, const char* end_with = nullptr, struct TimePerformance* time_perf = nullptr) override;
virtual void reset() override;
};


} // Transformer
} // MNN


#endif // SAMPLER_hpp
Loading

0 comments on commit 3bfc3b6

Please sign in to comment.