Skip to content

Commit

Permalink
优化internlm2
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Sep 5, 2024
1 parent 0bce71b commit e69c439
Showing 1 changed file with 40 additions and 0 deletions.
40 changes: 40 additions & 0 deletions src/models/internlm2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,46 @@ namespace fastllm {
const fastllm::Data &positionIds, std::vector<std::pair<Data, Data>> &pastKeyValues,
const GenerationConfig &generationConfig, const LastTokensManager &lastTokens,
std::vector <std::vector <float>*> *retLogits) {
if (!mergeSwiglu) {
bool canMerge = true;
for (int i = 0; i < block_cnt; i++) {
std::string w1WeightName = "model.layers." + std::to_string(i) + ".feed_forward.w1.weight";
std::string w3WeightName = "model.layers." + std::to_string(i) + ".feed_forward.w3.weight";
std::string swigluWeightName = "model.layers." + std::to_string(i) + ".mlp.gateup_proj.weight";

if (weight.weight.find(swigluWeightName) != weight.weight.end()) {
mergeQKV = true;
break;
}
Data &w1 = weight.weight[w1WeightName], &w3 = weight.weight[w3WeightName];
if ((w1.dataType == DataType::INT4_GROUP && w1.dims[1] % w1.groupCnt != 0) ||
(w3.dataType == DataType::INT4_GROUP && w3.dims[1] % w3.groupCnt != 0)) {
canMerge = false;
break;
}

weight.weight[swigluWeightName] = Data(w1.dataType, {w1.dims[0] + w3.dims[0], w1.dims[1]});
Data &swiglu = weight.weight[swigluWeightName];
swiglu.name = swigluWeightName;
swiglu.Allocate();
memcpy(swiglu.cpuData, w1.cpuData, w1.GetBytes());
memcpy(swiglu.cpuData + w1.GetBytes(), w3.cpuData, w3.GetBytes());

swiglu.perChannelAxis = w1.perChannelAxis;
swiglu.group = w1.group;
swiglu.groupCnt = w1.groupCnt;
swiglu.perChannelsConfigs = AppendVector(w1.perChannelsConfigs, w3.perChannelsConfigs);
swiglu.zeros = AppendVector(w1.zeros, w3.zeros);
swiglu.scales = AppendVector(w1.scales, w3.scales);
swiglu.mins = AppendVector(w1.mins, w3.mins);

weight.weight.erase(w1WeightName);
weight.weight.erase(w3WeightName);
}

this->mergeSwiglu = canMerge;
}

int maxLen = inputIds.dims[1];
Data hiddenStates;
Data attenInput;
Expand Down

0 comments on commit e69c439

Please sign in to comment.