Skip to content

Commit

Permalink
增加一些float16的算子cuda实现,使得ChatGLM可以完整用float16计算
Browse files Browse the repository at this point in the history
  • Loading branch information
ztxz16 committed Mar 27, 2024
1 parent 9280dcc commit a778b09
Show file tree
Hide file tree
Showing 6 changed files with 474 additions and 66 deletions.
5 changes: 5 additions & 0 deletions include/devices/cuda/fastllm-cuda.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ bool FastllmCudaBatchMatMulTransBBatch(void **i0s, void **i1s, void **os,
bool FastllmCudaBatchMatMulBatch(void **i0s, void **i1s, void **os,
int *ns, int *ms, int *ks,
int *i0Strides, int *i1Strides, float alpha, int batch);

bool FastllmCudaHalfAttention(const fastllm::Data &q, const fastllm::Data &k, const fastllm::Data &v,
const fastllm::Data &mask, const fastllm::Data &output, int group, float scale);
bool FastllmCudaHalfMatMulFloat16(const fastllm::Data &input, fastllm::Data &weight, const fastllm::Data &bias, fastllm::Data &output, int n, int m, int k);

void FastllmCudaSetDevice(int gpu_id);
#ifdef __cplusplus
}
Expand Down
1 change: 1 addition & 0 deletions include/fastllm.h
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ namespace fastllm {

void *cudaData = nullptr;
std::vector <void*> extraCudaData;
std::vector <void*> extraCudaHalfData;

void *deviceData = nullptr;
std::vector <void*> extraDeviceData;
Expand Down
4 changes: 2 additions & 2 deletions src/devices/cpu/cpudevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1894,7 +1894,7 @@ namespace fastllm {

AssertInFastLLM((input0.dataType == DataType::FLOAT32 && input1.dataType == DataType::FLOAT32) ||
(input0.dataType == DataType::FLOAT16 && input1.dataType == DataType::FLOAT16),
"Cat's input's type should be float32.\n");
"Cat's input's type should be float32 or float16.\n");
AssertInFastLLM(input0.dims.size() == input1.dims.size(), "Cat Error: input's shape's size should be same.");

int dimsLen = input0.dims.size();
Expand Down Expand Up @@ -3070,7 +3070,7 @@ namespace fastllm {
d += m;
}
} else if (data.dataType == DataType::FLOAT16) {
int index = (int) half_to_float(((uint16_t *) positionIds.cpuData)[(b * 2) * positionIds.dims.back() + l]);
int index = (int) ((float *) positionIds.cpuData)[(b * 2) * positionIds.dims.back() + l];
float *sin = ((float*)sinData.cpuData) + stride * index;
float *cos = ((float*)cosData.cpuData) + stride * index;

Expand Down
76 changes: 53 additions & 23 deletions src/devices/cuda/cudadevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ namespace fastllm {

AssertInFastLLM(q.dataType == k.dataType && q.dataType == v.dataType,
"Attention: q, k, v's datatype should be same.\n");
AssertInFastLLM(q.dataType == DataType::FLOAT32, "Attention's input's type should be float32.\n");
AssertInFastLLM(q.dataType == DataType::FLOAT32 ||
q.dataType == DataType::FLOAT16,
"Attention's input's type should be float32 or float16.\n");

std::vector <int> dims = {q.dims[0], q.dims[1], v.dims[2]};
output.dataType = q.dataType;
Expand All @@ -101,7 +103,12 @@ namespace fastllm {
int group = intParams.find("group") != intParams.end() ? intParams.find("group")->second : 1;
float scale = floatParams.find("scale") != floatParams.end() ? floatParams.find("scale")->second : 1.0;
output.Allocate();
FastllmCudaAttention(q, k, v, mask, output, group, scale);

if (q.dataType == DataType::FLOAT32) {
FastllmCudaAttention(q, k, v, mask, output, group, scale);
} else if (q.dataType == DataType::FLOAT16) {
FastllmCudaHalfAttention(q, k, v, mask, output, group, scale);
}
}

void CudaCopyKVCacheOp::Reshape(const std::string &opType, const fastllm::DataDict &datas,
Expand Down Expand Up @@ -139,6 +146,11 @@ namespace fastllm {
Data &input = *(datas.find("input")->second);
Data &weight = *(datas.find("weight")->second);
Data &output = *(datas.find("output")->second);

AssertInFastLLM(input.dataType == DataType::FLOAT32 ||
input.dataType == DataType::FLOAT16,
"RMSNorm error: datatype should be float32 or float16.");

output.Allocate();

float eps = floatParams.find("eps") != floatParams.end() ? floatParams.find("eps")->second : 1e-5;
Expand Down Expand Up @@ -182,7 +194,7 @@ namespace fastllm {
std::vector <int> dims = input.dims;
dims.back() = weight.dims[0];

output.dataType = DataType::FLOAT32;
output.dataType = input.dataType;
output.Resize(dims);
}

Expand All @@ -203,20 +215,30 @@ namespace fastllm {
int m = input.dims.back();
int k = output.dims.back();

if (weight.dataType == DataType::FLOAT32) {
FastllmCudaMatMulFloat32(input, weight, bias, output, n, m, k);
} else if (weight.dataType == DataType::FLOAT16) {
FastllmCudaMatMulFloat16(input, weight, bias, output, n, m, k);
} else if (weight.dataType == DataType::INT8) {
FastllmCudaMatMulFloatInt8(input, weight, bias, output, n, m, k);
} else if (weight.dataType == DataType::INT4) {
FastllmCudaMatMulFloatInt4(input, weight, bias, output, n, m, k);
} else if (weight.dataType == DataType::INT4_NOZERO) {
FastllmCudaMatMulFloatInt4NoZero(input, weight, bias, output, n, m, k);
} else if (weight.dataType == DataType::INT4_GROUP) {
FastllmCudaMatMulFloatInt4Group(input, weight, bias, output, n, m, k);
if (input.dataType == DataType::FLOAT16) {
if (weight.dataType == DataType::FLOAT16) {
FastllmCudaHalfMatMulFloat16(input, weight, bias, output, n, m, k);
} else {
ErrorInFastLLM("Linear error: unsupport weight's dataType.\n");
}
} else if (input.dataType == DataType::FLOAT32) {
if (weight.dataType == DataType::FLOAT32) {
FastllmCudaMatMulFloat32(input, weight, bias, output, n, m, k);
} else if (weight.dataType == DataType::FLOAT16) {
FastllmCudaMatMulFloat16(input, weight, bias, output, n, m, k);
} else if (weight.dataType == DataType::INT8) {
FastllmCudaMatMulFloatInt8(input, weight, bias, output, n, m, k);
} else if (weight.dataType == DataType::INT4) {
FastllmCudaMatMulFloatInt4(input, weight, bias, output, n, m, k);
} else if (weight.dataType == DataType::INT4_NOZERO) {
FastllmCudaMatMulFloatInt4NoZero(input, weight, bias, output, n, m, k);
} else if (weight.dataType == DataType::INT4_GROUP) {
FastllmCudaMatMulFloatInt4Group(input, weight, bias, output, n, m, k);
} else {
ErrorInFastLLM("Linear error: unsupport weight's dataType.\n");
}
} else {
ErrorInFastLLM("Linear error: unsupport weight's dataType.\n");
ErrorInFastLLM("Linear error: unsupport input's dataType.\n");
}
}

Expand Down Expand Up @@ -275,8 +297,9 @@ namespace fastllm {

int axis = intParams.find("axis") != intParams.end() ? intParams.find("axis")->second : -1;

AssertInFastLLM(input0.dataType == DataType::FLOAT32 && input1.dataType == DataType::FLOAT32,
"Cat's input's type should be float32.\n");
AssertInFastLLM((input0.dataType == DataType::FLOAT32 && input1.dataType == DataType::FLOAT32) ||
(input0.dataType == DataType::FLOAT16 && input1.dataType == DataType::FLOAT16),
"Cat's input's type should be float32 or float16.\n");
AssertInFastLLM(input0.dataDevice == input1.dataDevice, "CatDirect error: inputs should use same device.\n");

if (input0.dims.size() == 0) {
Expand Down Expand Up @@ -475,7 +498,9 @@ namespace fastllm {
Data &input = *(datas.find("input")->second);
Data &output = *(datas.find("output")->second);
output.Allocate();
AssertInFastLLM(input.dataType == DataType::FLOAT32, "Swiglu error: Data's type should be float32.\n");
AssertInFastLLM(input.dataType == DataType::FLOAT32 ||
input.dataType == DataType::FLOAT16,
"Swiglu error: Data's type should be float32.\n");
FastllmCudaSwiglu(input, output);
}

Expand All @@ -495,7 +520,9 @@ namespace fastllm {
output.Allocate();

float v = floatParams.find("v") != floatParams.end() ? floatParams.find("v")->second : 1.0;
AssertInFastLLM(input.dataType == DataType::FLOAT32, "Mul error: Data's type should be float32.\n");
AssertInFastLLM(input.dataType == DataType::FLOAT32 ||
input.dataType == DataType::FLOAT16,
"Mul error: Data's type should be float32 or float16.\n");
FastllmCudaMul(input, v, output);
}

Expand All @@ -505,8 +532,9 @@ namespace fastllm {
Data &input1 = *(datas.find("input1")->second);
float alpha = floatParams.find("alpha") != floatParams.end() ? floatParams.find("alpha")->second : 1.0;

AssertInFastLLM(input0.dataType == DataType::FLOAT32 && input1.dataType == DataType::FLOAT32,
"AddTo error: Data's type should be float32.\n");
AssertInFastLLM((input0.dataType == DataType::FLOAT32 && input1.dataType == DataType::FLOAT32) ||
(input0.dataType == DataType::FLOAT16 && input1.dataType == DataType::FLOAT16),
"AddTo error: Data's type should be float32 or float16.\n");
AssertInFastLLM(input0.dims == input1.dims, "AddTo error: input's shape should be same.\n");
FastllmCudaAddTo(input0, input1, alpha);
}
Expand Down Expand Up @@ -583,7 +611,9 @@ namespace fastllm {
axis.push_back(((int32_t *) axisData.cpuData)[i]);
}

AssertInFastLLM(input.dataType == DataType::FLOAT32, "Permute error: datatype should be float32.");
AssertInFastLLM(input.dataType == DataType::FLOAT32 ||
input.dataType == DataType::FLOAT16,
"Permute error: datatype should be float32 or float16.");
AssertInFastLLM(axis.size() == input.dims.size(), "Permute error: axis's size should be equal to data's shape's size.");

bool same = false;
Expand Down
Loading

0 comments on commit a778b09

Please sign in to comment.