Skip to content

Commit

Permalink
优化MiniCPM算子,提速大约0.4%
Browse files Browse the repository at this point in the history
  • Loading branch information
cgli authored and TylunasLi committed Mar 19, 2024
1 parent 1eb74ed commit 7373729
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions src/models/minicpm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,17 +180,17 @@ namespace fastllm {
attenOutput.Reshape({bsz, seqlen, -1});

Linear(attenOutput, weight[oWeightName], Data(), attenLastOutput);
Mul(attenLastOutput, this->attention_scale, attenLastOutput);
AddTo(hiddenStates, attenLastOutput);
// Mul(attenLastOutput, this->attention_scale, attenLastOutput);
AddTo(hiddenStates, attenLastOutput, this->attention_scale);
// 2. mlp
RMSNorm(hiddenStates, this->weight["model.layers." + std::to_string(i) + ".post_attention_layernorm.weight"], 1e-5, attenInput);
Linear(attenInput, weight["model.layers." + std::to_string(i) + ".mlp.gate_proj.weight"], Data(), w1);
Linear(attenInput, weight["model.layers." + std::to_string(i) + ".mlp.up_proj.weight"], Data(), w3);
Silu(w1, w1);
MulTo(w1, w3);
Linear(w1, weight["model.layers." + std::to_string(i) + ".mlp.down_proj.weight"], Data(), w2);
Mul(w2, this->attention_scale, w2);
AddTo(hiddenStates, w2);
// Mul(w2, this->attention_scale, w2);
AddTo(hiddenStates, w2, this->attention_scale);
}
Data logits, topk;
Data tempHiddenStates;
Expand Down Expand Up @@ -338,17 +338,17 @@ namespace fastllm {
PermuteSelf(attenOutput, {1, 0, 2});

Linear(attenOutput, weight[oWeightName], Data(), attenLastOutput);
Mul(attenLastOutput, this->attention_scale, attenLastOutput);
AddTo(hiddenStates, attenLastOutput);
// Mul(attenLastOutput, this->attention_scale, attenLastOutput);
AddTo(hiddenStates, attenLastOutput, this->attention_scale);
// 2. mlp
RMSNorm(hiddenStates, this->weight["model.layers." + std::to_string(i) + ".post_attention_layernorm.weight"], 1e-5, attenInput);
Linear(attenInput, weight["model.layers." + std::to_string(i) + ".mlp.gate_proj.weight"], Data(), w1);
Linear(attenInput, weight["model.layers." + std::to_string(i) + ".mlp.up_proj.weight"], Data(), w3);
Silu(w1, w1);
MulTo(w1, w3);
Linear(w1, weight["model.layers." + std::to_string(i) + ".mlp.down_proj.weight"], Data(), w2);
Mul(w2, this->attention_scale, w2);
AddTo(hiddenStates, w2);
// Mul(w2, this->attention_scale, w2);
AddTo(hiddenStates, w2, this->attention_scale);
}

Data logits, topk;
Expand Down Expand Up @@ -525,17 +525,17 @@ namespace fastllm {
}

Linear(attenOutput, weight[oWeightName], Data(), attenLastOutput);
Mul(attenLastOutput, this->attention_scale, attenLastOutput);
AddTo(hiddenStates, attenLastOutput);
// Mul(attenLastOutput, this->attention_scale, attenLastOutput);
AddTo(hiddenStates, attenLastOutput, this->attention_scale);
// 2. mlp
RMSNorm(hiddenStates, this->weight["model.layers." + std::to_string(i) + ".post_attention_layernorm.weight"], 1e-5, attenInput);
Linear(attenInput, weight["model.layers." + std::to_string(i) + ".mlp.gate_proj.weight"], Data(), w1);
Linear(attenInput, weight["model.layers." + std::to_string(i) + ".mlp.up_proj.weight"], Data(), w3);
Silu(w1, w1);
MulTo(w1, w3);
Linear(w1, weight["model.layers." + std::to_string(i) + ".mlp.down_proj.weight"], Data(), w2);
Mul(w2, this->attention_scale, w2);
AddTo(hiddenStates, w2);
// Mul(w2, this->attention_scale, w2);
AddTo(hiddenStates, w2, this->attention_scale);
}

Data logits, curLogit;
Expand Down

0 comments on commit 7373729

Please sign in to comment.