Skip to content

Commit

Permalink
Merge pull request #442 from TylunasLi/llama
Browse files Browse the repository at this point in the history
Llama支持分组查询注意力,支持书生2模型
  • Loading branch information
ztxz16 authored Apr 1, 2024
2 parents bf8f25f + 57b03e7 commit d3dfc0a
Show file tree
Hide file tree
Showing 18 changed files with 713 additions and 122 deletions.
11 changes: 7 additions & 4 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,18 @@ set(CMAKE_BUILD_TYPE "Release")
if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread --std=c++17 -O2")
elseif(CMAKE_CXX_COMPILER_ID STREQUAL "MSVC")
set(CMAKE_CXX_FLAGS_DEBUG "/MTd /Zi /Ob0 /Od /RTC1")
set(CMAKE_CXX_FLAGS_RELEASE "/MT /O2 /Ob1 /Gy /DNDEBUG")
string(REPLACE "/Ob2" "/Ob1 /Gy" CMAKE_CXX_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE})
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DNOMINMAX /std:c++17 /arch:AVX2 /source-charset:utf-8")
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread --std=c++17 -O2 -march=native")
endif()


message(STATUS "CMAKE_CXX_FLAGS" ${CMAKE_CXX_FLAGS})
message(STATUS "CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
set(FASTLLM_CXX_SOURCES src/fastllm.cpp src/device.cpp src/model.cpp src/executor.cpp
src/devices/cpu/cpudevice.cpp src/devices/cpu/cpudevicebatch.cpp
src/models/chatglm.cpp src/models/moss.cpp src/models/llama.cpp src/models/qwen.cpp src/models/basellm.cpp src/models/glm.cpp src/models/minicpm.cpp)
src/models/chatglm.cpp src/models/moss.cpp src/models/llama.cpp src/models/qwen.cpp src/models/basellm.cpp
src/models/glm.cpp src/models/minicpm.cpp src/models/internlm2.cpp)

include_directories(include)
include_directories(include/utils)
Expand Down Expand Up @@ -69,6 +69,9 @@ if (USE_IVCOREX)
endif()

if (PY_API)
if(POLICY CMP0148)
cmake_policy(SET CMP0148 NEW)
endif()
set(PYBIND third_party/pybind11)
add_subdirectory(${PYBIND})
add_compile_definitions(PY_API)
Expand Down
24 changes: 24 additions & 0 deletions docs/llama_cookbook.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,19 @@ python3 tools/llamalike2flm.py internlm-7b-int4.flm int4 internlm/internlm-chat-
python3 tools/llamalike2flm.py internlm-7b-int4.flm float16 internlm/internlm-chat-7b #导出internlm-chat-7b float16模型
```

* internlm/[internlm2-chat-1_8b](https://huggingface.co/internlm/internlm2-chat-1_8b)
* internlm/[internlm2-chat-7b](https://huggingface.co/internlm/internlm2-chat-7b)
* internlm/[internlm2-chat-20b](https://huggingface.co/internlm/internlm2-chat-20b)

使用`llamalike2flm.py`脚本转换:

``` sh
cd build
python3 tools/llamalike2flm.py internlm2-1.8b-fp16.flm float16 iinternlm/internlm2-chat-1_8b #导出1.8B float16模型
python3 tools/llamalike2flm.py internlm2-7b-fp16.flm float16 internlm/internlm2-chat-7b #导出chat-7b float16模型
python3 tools/llamalike2flm.py internlm2-7b-int8.flm int8 internlm/internlm2-chat-7b #导出chat-7b int8模型
```

### XVERSE

* xverse/[XVERSE-13B-Chat](https://huggingface.co/xverse/XVERSE-13B-Chat)
Expand Down Expand Up @@ -225,6 +238,17 @@ XVERSE-13B-Chat V1 版本需要对输入做NFKC规范化,fastllm暂不支持
user_role="[|Human|]:", bot_role="\n[|AI|]:", history_sep="\n", dtype=dtype)
```

## Yi

* 01-ai/[Yi-6B-Chat](https://huggingface.co/01-ai/Yi-6B-Chat)

* 01-ai/[Yi-34B-Chat](https://huggingface.co/01-ai/Yi-34B-Chat)

```python
torch2flm.tofile(exportPath, model, tokenizer, pre_prompt="",
user_role="<|im_start|>user\n", bot_role="<|im_end|><|im_start|>assistant\n", history_sep="<|im_end|>\n", dtype=dtype)
```

### WizardCoder

* [WizardCoder-Python-7B-V1.0](https://huggingface.co/WizardLM/WizardCoder-Python-7B-V1.0)
Expand Down
2 changes: 2 additions & 0 deletions example/Win32Demo/fastllm-gpu.vcxproj
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@
<ClInclude Include="..\..\include\models\chatglm.h" />
<ClInclude Include="..\..\include\models\factoryllm.h" />
<ClInclude Include="..\..\include\models\glm.h" />
<ClInclude Include="..\..\include\models\internlm2.h" />
<ClInclude Include="..\..\include\models\llama.h" />
<ClInclude Include="..\..\include\models\minicpm.h" />
<ClInclude Include="..\..\include\models\moss.h" />
Expand All @@ -224,6 +225,7 @@
<ClCompile Include="..\..\src\models\basellm.cpp" />
<ClCompile Include="..\..\src\models\chatglm.cpp" />
<ClCompile Include="..\..\src\models\glm.cpp" />
<ClCompile Include="..\..\src\models\internlm2.cpp" />
<ClCompile Include="..\..\src\models\llama.cpp" />
<ClCompile Include="..\..\src\models\minicpm.cpp" />
<ClCompile Include="..\..\src\models\moss.cpp" />
Expand Down
6 changes: 6 additions & 0 deletions example/Win32Demo/fastllm-gpu.vcxproj.filters
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@
<ClInclude Include="..\..\include\models\glm.h">
<Filter>头文件\models</Filter>
</ClInclude>
<ClInclude Include="..\..\include\models\internlm2.h">
<Filter>头文件\models</Filter>
</ClInclude>
<ClInclude Include="..\..\include\models\llama.h">
<Filter>头文件\models</Filter>
</ClInclude>
Expand Down Expand Up @@ -119,6 +122,9 @@
<ClCompile Include="..\..\src\models\glm.cpp">
<Filter>源文件\models</Filter>
</ClCompile>
<ClCompile Include="..\..\src\models\internlm2.cpp">
<Filter>源文件\models</Filter>
</ClCompile>
<ClCompile Include="..\..\src\models\llama.cpp">
<Filter>源文件\models</Filter>
</ClCompile>
Expand Down
2 changes: 2 additions & 0 deletions example/Win32Demo/fastllm.vcxproj
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@
<ClInclude Include="..\..\include\models\chatglm.h" />
<ClInclude Include="..\..\include\models\factoryllm.h" />
<ClInclude Include="..\..\include\models\glm.h" />
<ClInclude Include="..\..\include\models\internlm2.h" />
<ClInclude Include="..\..\include\models\llama.h" />
<ClInclude Include="..\..\include\models\minicpm.h" />
<ClInclude Include="..\..\include\models\moss.h" />
Expand All @@ -198,6 +199,7 @@
<ClCompile Include="..\..\src\models\basellm.cpp" />
<ClCompile Include="..\..\src\models\chatglm.cpp" />
<ClCompile Include="..\..\src\models\glm.cpp" />
<ClCompile Include="..\..\src\models\internlm2.cpp" />
<ClCompile Include="..\..\src\models\llama.cpp" />
<ClCompile Include="..\..\src\models\minicpm.cpp" />
<ClCompile Include="..\..\src\models\moss.cpp" />
Expand Down
6 changes: 6 additions & 0 deletions example/Win32Demo/fastllm.vcxproj.filters
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@
<ClInclude Include="..\..\include\models\glm.h">
<Filter>头文件\models</Filter>
</ClInclude>
<ClInclude Include="..\..\include\models\internlm2.h">
<Filter>头文件\models</Filter>
</ClInclude>
<ClInclude Include="..\..\include\models\llama.h">
<Filter>头文件\models</Filter>
</ClInclude>
Expand Down Expand Up @@ -107,6 +110,9 @@
<ClCompile Include="..\..\src\models\glm.cpp">
<Filter>源文件\models</Filter>
</ClCompile>
<ClCompile Include="..\..\src\models\internlm2.cpp">
<Filter>源文件\models</Filter>
</ClCompile>
<ClCompile Include="..\..\src\models\llama.cpp">
<Filter>源文件\models</Filter>
</ClCompile>
Expand Down
4 changes: 2 additions & 2 deletions include/fastllm.h
Original file line number Diff line number Diff line change
Expand Up @@ -483,9 +483,9 @@ namespace fastllm {

void CatDirect(Data &input0, const Data &input1, int axis); // 直接把input1的数据拷贝到input0后面(需要input0提前扩容了足够的空间)

void MatMul(const Data &input0, const Data &input1, Data &output, float alpha = 1.0);
void MatMul(const Data &input0, const Data &input1, Data &output, float alpha = 1.0, int group = 1);

void MatMulTransB(const Data &input0, const Data &input1, Data &output, float alpha = 1.0);
void MatMulTransB(const Data &input0, const Data &input1, Data &output, float alpha = 1.0, int group = 1);

void Softmax(const Data &input, Data &output, int axis);

Expand Down
54 changes: 54 additions & 0 deletions include/models/internlm2.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
//
// Created by tylunasli on 3/14/24.
//

#ifndef FASTLLM_INTERNLM2_H
#define FASTLLM_INTERNLM2_H

#include "basellm.h"
#include "llama.h"
#include "cmath"

#include <iostream>

namespace fastllm {
class Internlm2Model : public LlamaModel {
public:
Internlm2Model(); // 构造函数

virtual void InitParams(); // 初始化参数信息

// 推理
virtual int Forward(
const Data &inputIds,
const Data &attentionMask,
const Data &positionIds,
std::vector <std::pair <Data, Data> > &pastKeyValues,
const GenerationConfig &generationConfig = GenerationConfig(),
const LastTokensManager &lastTokens = LastTokensManager(),
std::vector <float> *logits = nullptr);

std::vector <int> ForwardBatch(
int batch,
const Data &inputIds,
const Data &attentionMask,
const Data &positionIds,
std::vector <std::pair <Data, Data> > &pastKeyValues,
const GenerationConfig &generationConfig = GenerationConfig(),
const LastTokensManager &lastTokens = LastTokensManager(),
std::vector <std::vector <float>*> *logits = nullptr);

std::vector <int> ForwardBatch(
int batch,
const Data &inputIds,
const std::vector <Data*> &attentionMask,
const std::vector <Data*> &positionIds,
const std::vector <int> &seqLens,
std::vector <std::pair <Data*, Data*> > &pastKeyValues,
const std::vector <GenerationConfig> &generationConfigs,
const LastTokensManager &lastTokens = LastTokensManager(),
std::vector <std::vector <float>*> *logits = nullptr);
};
}

#endif //FASTLLM_INTERNLM2_H
6 changes: 5 additions & 1 deletion include/models/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,18 @@ namespace fastllm {

virtual std::string MakeHistory(const std::string &history, int round, const std::string &input, const std::string &output); // 根据当前回复更新history

std::pair<std::vector<float>, std::vector<float>> UpdateRotaryPosEmb(float base, float factor); // 更新位置编码
std::pair<std::vector<float>, std::vector<float>> UpdateRotaryPosEmb(float base, float factor, int seqLen = 0); // 更新位置编码

protected:
RoPEType rope_type = RoPEType::BASE;

float rope_base = 10000.f;

float rope_factor = 1.f;

int num_key_value_heads = num_attention_heads;

float rms_norm_eps = 1e-6;
};
}

Expand Down
24 changes: 15 additions & 9 deletions src/devices/cpu/cpudevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2148,7 +2148,9 @@ namespace fastllm {
int input1Spatial = input1.Count(input1.dims.size() - 2);
int batch0 = input0.Count(0) / input0Spatial;
int batch1 = input1.Count(0) / input1Spatial;
AssertInFastLLM(batch0 == batch1, "MatMul's shape error.\n");
int group = intParams.find("group") != intParams.end() ? intParams.find("group")->second : 1;
AssertInFastLLM(batch0 == batch1 * group, "MatMul: input0.dims[1] should be equal to input1.dims[0] * group.\n");
// AssertInFastLLM(batch0 == batch1, "MatMul's shape error.\n");

std::vector <int> dims = input0.dims;
dims.back() = input1.dims[input1.dims.size() - 1];
Expand All @@ -2165,18 +2167,19 @@ namespace fastllm {

output.Allocate();

float alpha = floatParams.find("alpha") != floatParams.end() ? floatParams.find("alpha")->second : 1.0;
int input0Spatial = input0.Count(input0.dims.size() - 2);
float alpha = floatParams.find("alpha") != floatParams.end() ? floatParams.find("alpha")->second : 1.0f;
int group = intParams.find("group") != intParams.end() ? intParams.find("group")->second : 1;
int input0Spatial = input0.Count(input0.dims.size() - 2) * group;
int input1Spatial = input1.Count(input1.dims.size() - 2);
int input0Stride = input0.strides[input0.dims.size() - 2];
int input1Stride = input1.strides[input1.dims.size() - 2];
int n = input0.dims[input0.dims.size() - 2];
int n = input0.dims[input0.dims.size() - 2] * group;
int m = input0.dims.back();
int k = input1.dims[input1.dims.size() - 1];
int batch0 = input0.Count(0) / input0Spatial;
int batch1 = input1.Count(0) / input1Spatial;

int outputSpatial = output.Count(output.dims.size() - 2);
int outputSpatial = output.Count(output.dims.size() - 2) * group;
int threadNum = GetThreads();
#ifdef _WIN64
threadNum = 1;
Expand Down Expand Up @@ -2241,7 +2244,9 @@ namespace fastllm {
int input1Spatial = input1.Count(input1.dims.size() - 2);
int batch0 = input0.Count(0) / input0Spatial;
int batch1 = input1.Count(0) / input1Spatial;
AssertInFastLLM(batch0 == batch1, "MatMulTransB's shape error.\n");
int group = intParams.find("group") != intParams.end() ? intParams.find("group")->second : 1;
AssertInFastLLM(batch0 == batch1 * group, "MatMulTransB: input0.dims[0] should be equal to input1.dims[0] * group.\n");
// AssertInFastLLM(batch0 == batch1, "MatMulTransB's shape error.\n");

std::vector <int> dims = input0.dims;
dims.back() = input1.dims[input1.dims.size() - 2];
Expand All @@ -2258,17 +2263,18 @@ namespace fastllm {
output.Allocate();

float alpha = floatParams.find("alpha") != floatParams.end() ? floatParams.find("alpha")->second : 1.0;
int input0Spatial = input0.Count(input0.dims.size() - 2);
int group = intParams.find("group") != intParams.end() ? intParams.find("group")->second : 1;
int input0Spatial = input0.Count(input0.dims.size() - 2) * group;
int input1Spatial = input1.Count(input1.dims.size() - 2);
int input0Stride = input0.strides[input0.dims.size() - 2];
int input1Stride = input1.strides[input1.dims.size() - 2];
int n = input0.dims[input0.dims.size() - 2];
int n = input0.dims[input0.dims.size() - 2] * group;
int m = input0.dims.back();
int k = input1.dims[input1.dims.size() - 2];
int batch0 = input0.Count(0) / input0Spatial;
int batch1 = input1.Count(0) / input1Spatial;

int outputSpatial = output.Count(output.dims.size() - 2);
int outputSpatial = output.Count(output.dims.size() - 2) * group;
int threadNum = GetThreads();
#ifdef _WIN64
threadNum = 1;
Expand Down
34 changes: 20 additions & 14 deletions src/devices/cuda/cudadevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,9 @@ namespace fastllm {
int input1Spatial = input1.Count(input1.dims.size() - 2);
int batch0 = input0.Count(0) / input0Spatial;
int batch1 = input1.Count(0) / input1Spatial;
AssertInFastLLM(batch0 == batch1, "MatMul's shape error.\n");
int group = intParams.find("group") != intParams.end() ? intParams.find("group")->second : 1;
AssertInFastLLM(batch0 == batch1 * group, "MatMul: input0.dims[1] should be equal to input1.dims[0] * group.\n");
// AssertInFastLLM(batch0 == batch1, "MatMul's shape error.\n");

std::vector <int> dims = input0.dims;
dims.back() = input1.dims[input1.dims.size() - 1];
Expand All @@ -379,21 +381,22 @@ namespace fastllm {

output.Allocate();

float alpha = floatParams.find("alpha") != floatParams.end() ? floatParams.find("alpha")->second : -1;
int input0Spatial = input0.Count(input0.dims.size() - 2);
float alpha = floatParams.find("alpha") != floatParams.end() ? floatParams.find("alpha")->second : 1.0f;
int group = intParams.find("group") != intParams.end() ? intParams.find("group")->second : 1;
int input0Spatial = input0.Count(input0.dims.size() - 2) * group;
int input1Spatial = input1.Count(input1.dims.size() - 2);
int input0Stride = input0.strides[input0.dims.size() - 2];
int input1Stride = input1.strides[input1.dims.size() - 2];
int n = input0.dims[input0.dims.size() - 2];
int n = input0.dims[input0.dims.size() - 2] * group;
int m = input0.dims.back();
int k = input1.dims[input1.dims.size() - 1];
int batch0 = input0.Count(0) / input0Spatial;
int batch1 = input1.Count(0) / input1Spatial;

int outputSpatial = output.Count(output.dims.size() - 2);
int outputSpatial = output.Count(output.dims.size() - 2) * group;
FastllmCudaBatchMatMul(input0, input1, output,
input0Spatial, input1Spatial, outputSpatial, input0Stride, input1Stride,
batch0, n, m, k, alpha);
input0Spatial, input1Spatial, outputSpatial, input0Stride, input1Stride,
batch1, n, m, k, alpha);
}

void CudaMatMulTransBOp::Reshape(const std::string &opType, const fastllm::DataDict &datas,
Expand All @@ -413,7 +416,9 @@ namespace fastllm {
int input1Spatial = input1.Count(input1.dims.size() - 2);
int batch0 = input0.Count(0) / input0Spatial;
int batch1 = input1.Count(0) / input1Spatial;
AssertInFastLLM(batch0 == batch1, "MatMulTransB's shape error.\n");
int group = intParams.find("group") != intParams.end() ? intParams.find("group")->second : 1;
AssertInFastLLM(batch0 == batch1 * group, "MatMulTransB: input0.dims[0] should be equal to input1.dims[0] * group.\n");
// AssertInFastLLM(batch0 == batch1, "MatMulTransB's shape error.\n");

std::vector <int> dims = input0.dims;
dims.back() = input1.dims[input1.dims.size() - 2];
Expand All @@ -429,21 +434,22 @@ namespace fastllm {

output.Allocate();

float alpha = floatParams.find("alpha") != floatParams.end() ? floatParams.find("alpha")->second : -1;
int input0Spatial = input0.Count(input0.dims.size() - 2);
float alpha = floatParams.find("alpha") != floatParams.end() ? floatParams.find("alpha")->second : 1.0f;
int group = intParams.find("group") != intParams.end() ? intParams.find("group")->second : 1;
int input0Spatial = input0.Count(input0.dims.size() - 2) * group;
int input1Spatial = input1.Count(input1.dims.size() - 2);
int input0Stride = input0.strides[input0.dims.size() - 2];
int input1Stride = input1.strides[input1.dims.size() - 2];
int n = input0.dims[input0.dims.size() - 2];
int n = input0.dims[input0.dims.size() - 2] * group;
int m = input0.dims.back();
int k = input1.dims[input1.dims.size() - 2];
int batch0 = input0.Count(0) / input0Spatial;
int batch1 = input1.Count(0) / input1Spatial;

int outputSpatial = output.Count(output.dims.size() - 2);
int outputSpatial = output.Count(output.dims.size() - 2) * group;
FastllmCudaBatchMatMulTransB(input0, input1, output,
input0Spatial, input1Spatial, outputSpatial, input0Stride, input1Stride,
batch0, n, m, k, alpha);
input0Spatial, input1Spatial, outputSpatial, input0Stride, input1Stride,
batch1, n, m, k, alpha);
}

bool CudaSoftMaxOp::CanRun(const std::string &opType, const fastllm::DataDict &datas,
Expand Down
8 changes: 4 additions & 4 deletions src/fastllm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2135,16 +2135,16 @@ namespace fastllm {
}, {}, {{"axis", axis}});
}

void MatMul(const Data &input0, const Data &input1, Data &output, float alpha) {
void MatMul(const Data &input0, const Data &input1, Data &output, float alpha, int group) {
curExecutor->Run("MatMul", {
{"input0", (Data*)&input0}, {"input1", (Data*)&input1}, {"output", &output}
}, {{"alpha", alpha}}, {});
}, {{"alpha", alpha}}, {{"group", group}});
}

void MatMulTransB(const Data &input0, const Data &input1, Data &output, float alpha) {
void MatMulTransB(const Data &input0, const Data &input1, Data &output, float alpha, int group) {
curExecutor->Run("MatMulTransB", {
{"input0", (Data*)&input0}, {"input1", (Data*)&input1}, {"output", &output}
}, {{"alpha", alpha}}, {});
}, {{"alpha", alpha}}, {{"group", group}});
}

void Softmax(const Data &input, Data &output, int axis) {
Expand Down
Loading

0 comments on commit d3dfc0a

Please sign in to comment.