diff --git a/src/models/internlm2.cpp b/src/models/internlm2.cpp index e0234a1..354f1ee 100644 --- a/src/models/internlm2.cpp +++ b/src/models/internlm2.cpp @@ -47,6 +47,46 @@ namespace fastllm { const fastllm::Data &positionIds, std::vector> &pastKeyValues, const GenerationConfig &generationConfig, const LastTokensManager &lastTokens, std::vector *> *retLogits) { + if (!mergeSwiglu) { + bool canMerge = true; + for (int i = 0; i < block_cnt; i++) { + std::string w1WeightName = "model.layers." + std::to_string(i) + ".feed_forward.w1.weight"; + std::string w3WeightName = "model.layers." + std::to_string(i) + ".feed_forward.w3.weight"; + std::string swigluWeightName = "model.layers." + std::to_string(i) + ".mlp.gateup_proj.weight"; + + if (weight.weight.find(swigluWeightName) != weight.weight.end()) { + mergeQKV = true; + break; + } + Data &w1 = weight.weight[w1WeightName], &w3 = weight.weight[w3WeightName]; + if ((w1.dataType == DataType::INT4_GROUP && w1.dims[1] % w1.groupCnt != 0) || + (w3.dataType == DataType::INT4_GROUP && w3.dims[1] % w3.groupCnt != 0)) { + canMerge = false; + break; + } + + weight.weight[swigluWeightName] = Data(w1.dataType, {w1.dims[0] + w3.dims[0], w1.dims[1]}); + Data &swiglu = weight.weight[swigluWeightName]; + swiglu.name = swigluWeightName; + swiglu.Allocate(); + memcpy(swiglu.cpuData, w1.cpuData, w1.GetBytes()); + memcpy(swiglu.cpuData + w1.GetBytes(), w3.cpuData, w3.GetBytes()); + + swiglu.perChannelAxis = w1.perChannelAxis; + swiglu.group = w1.group; + swiglu.groupCnt = w1.groupCnt; + swiglu.perChannelsConfigs = AppendVector(w1.perChannelsConfigs, w3.perChannelsConfigs); + swiglu.zeros = AppendVector(w1.zeros, w3.zeros); + swiglu.scales = AppendVector(w1.scales, w3.scales); + swiglu.mins = AppendVector(w1.mins, w3.mins); + + weight.weight.erase(w1WeightName); + weight.weight.erase(w3WeightName); + } + + this->mergeSwiglu = canMerge; + } + int maxLen = inputIds.dims[1]; Data hiddenStates; Data attenInput;