Skip to content

Commit

Permalink
增加一个函数判断是否需要生成attentionMask
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Jun 15, 2024
1 parent 9afe60c commit 1399b63
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 8 deletions.
3 changes: 3 additions & 0 deletions include/models/basellm.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ namespace fastllm {
const LastTokensManager &lastTokens = LastTokensManager(),
std::vector <std::vector <float>*> *logits = nullptr);

// 是否需要生成AttentionMask
virtual bool NeedAttentionMask(int qlen, int klen);

// 根据输入的tokens生成LLM推理的输入
virtual void FillLLMInputs(std::vector <std::vector <float> > &inputTokens,
const std::map <std::string, int> &params,
Expand Down
3 changes: 3 additions & 0 deletions include/models/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ namespace fastllm {
const std::vector <GenerationConfig> &generationConfigs,
const LastTokensManager &lastTokens = LastTokensManager(),
std::vector <std::vector <float>*> *logits = nullptr);

// 是否需要生成AttentionMask
virtual bool NeedAttentionMask(int qlen, int klen);

// 根据输入的tokens生成LLM推理的输入
virtual void FillLLMInputsBatch(std::vector <std::vector <float> > &inputTokens,
Expand Down
5 changes: 5 additions & 0 deletions src/devices/cpu/cpudevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,13 +275,18 @@ namespace fastllm {
void Run() {
float *qk = new float[k1];
float *temp = new float[k1];
int base = k1 - q1;
for (int i = 0; i < q1; i++) {
float maxValue = -10000, sum = 0.0;
for (int j = 0; j < k1; j++) {
if (maskd && maskd[i * k1 + j] > 0.99) {
qk[j] = -10000;
continue;
}
if (!maskd && (base + i) < j) {
qk[j] = -10000;
continue;
}
float now = 0.0f;
int l = 0;
#ifdef __aarch64__
Expand Down
26 changes: 19 additions & 7 deletions src/models/basellm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -890,6 +890,10 @@ printf("len = %d, spend = %f s. tokens / s = %f\n", (int)total, spend, (float)to
}
}

bool basellm::NeedAttentionMask(int qlen, int klen) {
return true;
}

// 根据输入的tokens生成LLM推理的输入
void basellm::FillLLMInputs(std::vector <std::vector <float> > &inputTokens,
const std::map <std::string, int> &params,
Expand All @@ -903,18 +907,25 @@ printf("len = %d, spend = %f s. tokens / s = %f\n", (int)total, spend, (float)to

if (inputTokens[0].size() > 1) {
int seqLen = inputTokens[0].size();

std::vector <float> vmask = std::vector <float> (seqLen * promptLen, 0);
std::vector <float> vpids = std::vector <float> (seqLen, 0);
for (int i = 0; i < seqLen; i++) {
vpids[i] = promptLen - seqLen + i;
for (int j = i + 1; j < seqLen; j++) {
vmask[i * promptLen + (promptLen - seqLen + j)] = 1;
}
}
inputIds.CopyFrom(Data(DataType::FLOAT32, {1, seqLen}, inputTokens[0]));
attentionMask.CopyFrom(Data(DataType::FLOAT32, {seqLen, promptLen}, vmask));
positionIds.CopyFrom(Data(DataType::FLOAT32, {1, seqLen}, vpids));

if (NeedAttentionMask(seqLen, promptLen)) {
std::vector <float> vmask = std::vector <float> (seqLen * promptLen, 0);
for (int i = 0; i < seqLen; i++) {
vpids[i] = promptLen - seqLen + i;
for (int j = i + 1; j < seqLen; j++) {
vmask[i * promptLen + (promptLen - seqLen + j)] = 1;
}
}
attentionMask.CopyFrom(Data(DataType::FLOAT32, {seqLen, promptLen}, vmask));
} else {
attentionMask = Data();
}
} else {
inputIds.CopyFrom(Data(DataType::FLOAT32, {1, 1}, inputTokens[0]));
attentionMask = Data();
Expand Down Expand Up @@ -956,7 +967,8 @@ printf("len = %d, spend = %f s. tokens / s = %f\n", (int)total, spend, (float)to
if (dataType == DataType::FLOAT32) {

} else if (dataType == DataType::FLOAT16) {
AssertInFastLLM(this->model_type == "chatglm" || this->model_type == "llama",
AssertInFastLLM(this->model_type == "chatglm" || this->model_type == "llama" ||
this->model_type == "qwen",
this->model_type + " doesn't support float16");
} else {
ErrorInFastLLM("SetDataType Error: datatype should be float32 or float16");
Expand Down
2 changes: 1 addition & 1 deletion src/models/chatglm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -831,7 +831,7 @@ namespace fastllm {
}
}

if (seqLen <= 4096) {
if (seqLen <= 1024) {
std::vector<float> vmask = std::vector<float>(seqLen * seqLen, 0);
for (int i = 0; i < seqLen - 1; i++) {
vmask[i * seqLen + seqLen - 1] = 1;
Expand Down
9 changes: 9 additions & 0 deletions src/models/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,15 @@ namespace fastllm {
return lastRet;
}

bool LlamaModel::NeedAttentionMask(int qlen, int klen) {
return false;
if (this->weight.dicts["use_alibi"] != "1" &&
((qlen == 1) || (qlen >= 1024))) {
return false;
}
return true;
}

void LlamaModel::FillLLMInputsBatch(std::vector<std::vector<float>> &inputTokens,
const std::vector<std::map<std::string, int>> &params,
fastllm::Data &inputIds, fastllm::Data &attentionMask,
Expand Down

0 comments on commit 1399b63

Please sign in to comment.