Skip to content

Commit

Permalink
增加一些float16的算子,使ChatGLM可以纯float16在cpu上执行
Browse files Browse the repository at this point in the history
  • Loading branch information
ztxz16 committed Mar 26, 2024
1 parent 43d7093 commit 9280dcc
Showing 1 changed file with 139 additions and 31 deletions.
170 changes: 139 additions & 31 deletions src/devices/cpu/cpudevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,15 +164,33 @@ namespace fastllm {
}
} fp16tofp32;

void Float16ToFloat32(uint16_t *float16, float *float32, int len) {
for (int i = 0; i < len; i++) {
float32[i] = fp16tofp32.dict[float16[i]];
}
}

void Float32ToFloat16(float *float32, uint16_t *float16, int len) {
for (int i = 0; i < len; i++) {
float16[i] = float_to_half(float32[i]);
}
}

void CpuToFloat16::Run(const std::string &opType, const fastllm::DataDict &datas,
const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) {
Data &data = *(datas.find("input")->second);
if (data.dataType == DataType::FLOAT16) {
return;
}
if (data.dims.size() == 0) {
data.dataType = DataType::FLOAT16;
data.UpdateUnitSize();
return;
}
if (data.dataType == DataType::FLOAT32) {
float *old = (float*)data.cpuData;
data.dataType = DataType::FLOAT16;
data.UpdateUnitSize();
data.cpuData = new uint8_t[data.GetBytes()];
uint16_t *cur = (uint16_t*)data.cpuData;
int len = data.Count(0);
Expand All @@ -191,6 +209,11 @@ namespace fastllm {
if (data.dataType == DataType::FLOAT32) {
return;
}
if (data.dims.size() == 0) {
data.dataType = DataType::FLOAT32;
data.UpdateUnitSize();
return;
}
if (data.dataType == DataType::FLOAT16) {
uint16_t *old = (uint16_t*)data.cpuData;
data.dataType = DataType::FLOAT32;
Expand Down Expand Up @@ -223,7 +246,9 @@ namespace fastllm {

AssertInFastLLM(q.dataType == k.dataType && q.dataType == v.dataType,
"Attention: q, k, v's datatype should be same.\n");
AssertInFastLLM(q.dataType == DataType::FLOAT32, "Attention's input's type should be float32.\n");
AssertInFastLLM(q.dataType == DataType::FLOAT32 ||
q.dataType == DataType::FLOAT16,
"Attention's input's type should be float32.\n");

std::vector <int> dims = {q.dims[0], q.dims[1], v.dims[2]};
output.dataType = q.dataType;
Expand Down Expand Up @@ -295,6 +320,29 @@ namespace fastllm {
delete[] temp;
}

void SingleAttentionFloat16(uint16_t *qd, uint16_t *kd, uint16_t *vd, uint16_t *maskd, uint16_t *od,
float scale, int q1, int q2, int k1, int v2) {
std::vector <float> fqd, fkd, fvd, fmaskd, fod;

fqd.resize(q1 * q2);
fkd.resize(k1 * q2);
fvd.resize(k1 * v2);
fmaskd.resize(maskd ? q1 * k1 : 0);
fod.resize(q1 * v2);

Float16ToFloat32(qd, fqd.data(), (int)fqd.size());
Float16ToFloat32(kd, fkd.data(), (int)fkd.size());
Float16ToFloat32(vd, fvd.data(), (int)fvd.size());
if (maskd) {
Float16ToFloat32(maskd, fmaskd.data(), (int)fmaskd.size());
}

SingleAttention(fqd.data(), fkd.data(), fvd.data(), maskd ? fmaskd.data() : nullptr, fod.data(),
scale, q1, q2, k1, v2);

Float32ToFloat16(fod.data(), od, (int)fod.size());
}

void CpuAttention::Run(const std::string &opType, const fastllm::DataDict &datas,
const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) {
Data &q = *(datas.find("q")->second);
Expand All @@ -306,25 +354,51 @@ namespace fastllm {
float scale = floatParams.find("scale") != floatParams.end() ? floatParams.find("scale")->second : 1.0;
output.Allocate();
int q0 = q.dims[0], q1 = q.dims[1], q2 = q.dims[2], k0 = k.dims[0], k1 = k.dims[1], v2 = v.dims[2];
float *qd = (float*)q.cpuData;
float *kd = (float*)k.cpuData;
float *vd = (float*)v.cpuData;
float *maskd = (datas.find("mask")->second && mask.dims.size() > 0) ? (float*)mask.cpuData : nullptr;
float *od = (float*)output.cpuData;
int batch = (maskd != nullptr && mask.dims.size() == 3) ? mask.dims[0] : 1;
batch = intParams.find("mask___batch") != intParams.end() ? intParams.find("mask___batch")->second : batch;
int maskStride = (maskd != nullptr) ? (mask.dims.size() == 3 ? mask.strides[0] : mask.Count(0)) : 0;
std::fill(od, od + output.Count(0), 0.0f);
auto pool = GetPool();
std::vector<std::future<void> > futures;
for (int o = 0; o < q0; o++) {
futures.push_back(pool->Submit(SingleAttention,
qd + o * q.strides[0], kd + (o / group) * k.strides[0], vd + (o / group) * v.strides[0],
maskd + (o / (q0 / batch)) * maskStride, od + o * output.strides[0], scale,
q1, q2, k1, v2));
}
for (int o = 0; o < futures.size(); o++) {
futures[o].get();

if (q.dataType == DataType::FLOAT32) {
float *qd = (float*)q.cpuData;
float *kd = (float*)k.cpuData;
float *vd = (float*)v.cpuData;
float *maskd = (datas.find("mask")->second && mask.dims.size() > 0) ? (float*)mask.cpuData : nullptr;
float *od = (float*)output.cpuData;
int batch = (maskd != nullptr && mask.dims.size() == 3) ? mask.dims[0] : 1;
batch = intParams.find("mask___batch") != intParams.end() ? intParams.find("mask___batch")->second : batch;
int maskStride = (maskd != nullptr) ? (mask.dims.size() == 3 ? mask.strides[0] : mask.Count(0)) : 0;
std::fill(od, od + output.Count(0), 0.0f);
auto pool = GetPool();
std::vector<std::future<void> > futures;
for (int o = 0; o < q0; o++) {
futures.push_back(pool->Submit(SingleAttention,
qd + o * q.strides[0], kd + (o / group) * k.strides[0], vd + (o / group) * v.strides[0],
maskd + (o / (q0 / batch)) * maskStride, od + o * output.strides[0], scale,
q1, q2, k1, v2));
}
for (int o = 0; o < futures.size(); o++) {
futures[o].get();
}
} else if (q.dataType == DataType::FLOAT16) {
uint16_t *qd = (uint16_t*)q.cpuData;
uint16_t *kd = (uint16_t*)k.cpuData;
uint16_t *vd = (uint16_t*)v.cpuData;
uint16_t *maskd = (datas.find("mask")->second && mask.dims.size() > 0) ? (uint16_t*)mask.cpuData : nullptr;
uint16_t *od = (uint16_t*)output.cpuData;
int batch = (maskd != nullptr && mask.dims.size() == 3) ? mask.dims[0] : 1;
batch = intParams.find("mask___batch") != intParams.end() ? intParams.find("mask___batch")->second : batch;
int maskStride = (maskd != nullptr) ? (mask.dims.size() == 3 ? mask.strides[0] : mask.Count(0)) : 0;
std::fill(od, od + output.Count(0), float_to_half(0.0f));
auto pool = GetPool();
std::vector<std::future<void> > futures;
for (int o = 0; o < q0; o++) {
futures.push_back(pool->Submit(SingleAttentionFloat16,
qd + o * q.strides[0], kd + (o / group) * k.strides[0], vd + (o / group) * v.strides[0],
maskd + (o / (q0 / batch)) * maskStride, od + o * output.strides[0], scale,
q1, q2, k1, v2));
}
for (int o = 0; o < futures.size(); o++) {
futures[o].get();
}
} else {
ErrorInFastLLM("Attention error: unsupport dataType.\n");
}
}

Expand Down Expand Up @@ -361,14 +435,16 @@ namespace fastllm {
AssertInFastLLM(weight.dims.size() == 2, "Embedding's weight's dim should be 2.\n");
AssertInFastLLM(weight.dataType == DataType::FLOAT32 ||
weight.dataType == DataType::BFLOAT16, "Embedding's weight's type should be float32 or bfloat16.\n");
AssertInFastLLM(input.dataType == DataType::FLOAT32, "Embedding's input's type should be float32.\n");
AssertInFastLLM(input.dataType == DataType::FLOAT32 ||
input.dataType == DataType::FLOAT16,
"Embedding's input's type should be float32 or float16.\n");

weight.weightType = WeightType::EMBEDDING;
int vocabSize = weight.dims[0], embSize = weight.dims[1];
std::vector <int> dims = input.dims;
dims.push_back(embSize);

output.dataType = DataType::FLOAT32;
output.dataType = input.dataType;
output.Resize(dims);
}

Expand All @@ -382,12 +458,30 @@ namespace fastllm {

int vocabSize = weight.dims[0], embSize = weight.dims[1];
uint64_t inputLen = input.Count(0);

float *inputData = (float*)input.cpuData;
float *dstOutputData = (float*)output.cpuData;

std::vector <float> tempInputData, tempOutputData;
if (input.dataType != DataType::FLOAT32) {
tempInputData.resize(inputLen);
tempOutputData.resize(inputLen * embSize);
inputData = tempInputData.data();
dstOutputData = tempOutputData.data();

if (input.dataType == DataType::FLOAT16) {
for (int i = 0; i < inputLen; i++) {
inputData[i] = half_to_float(((uint16_t*)input.cpuData)[i]);
}
} else {
ErrorInFastLLM("Embedding error: unsupport dataType.\n");
}
}

if (GetLowMemMode()) {
FILE *fi = fopen(weight.fileName.c_str(), "rb");
if (weight.dataType == DataType::FLOAT32) {
float *outputData = (float *) output.cpuData;
float *outputData = (float *) dstOutputData;
for (int i = 0; i < inputLen; i++) {
int token = (int) (inputData[i] + 1e-9);
#if defined(_WIN32) or defined(_WIN64)
Expand All @@ -398,7 +492,7 @@ namespace fastllm {
int ret = fread(outputData + i * embSize, sizeof(float), embSize, fi);
}
} else {
uint16_t *outputData = (uint16_t *) output.cpuData;
uint16_t *outputData = (uint16_t *) dstOutputData;
uint16_t *weightData = new uint16_t[embSize];
for (int i = 0; i < inputLen; i++) {
int token = (int) (inputData[i] + 1e-9);
Expand All @@ -418,14 +512,14 @@ namespace fastllm {
fclose(fi);
} else {
if (weight.dataType == DataType::FLOAT32) {
float *outputData = (float *) output.cpuData;
float *outputData = (float *) dstOutputData;
float *weightData = (float *) weight.cpuData;
for (int i = 0; i < inputLen; i++) {
int token = (int) (inputData[i] + 1e-9);
memcpy(outputData + i * embSize, weightData + token * embSize, embSize * sizeof(float));
}
} else {
uint16_t *outputData = (uint16_t *) output.cpuData;
uint16_t *outputData = (uint16_t *) dstOutputData;
uint16_t *weightData = (uint16_t *) weight.cpuData;
for (int i = 0; i < inputLen; i++) {
int token = (int) (inputData[i] + 1e-9);
Expand All @@ -436,6 +530,16 @@ namespace fastllm {
}
}
}

if (output.dataType != DataType::FLOAT32) {
if (output.dataType == DataType::FLOAT16) {
for (int i = 0; i < inputLen * embSize; i++) {
((uint16_t*)output.cpuData)[i] = float_to_half(dstOutputData[i]);
}
} else {
ErrorInFastLLM("Embedding error: unsupport dataType.\n");
}
}
}

void CpuLayerNormOp::Run(const std::string &opType, const fastllm::DataDict &datas,
Expand Down Expand Up @@ -1856,7 +1960,7 @@ namespace fastllm {

AssertInFastLLM((input0.dataType == DataType::FLOAT32 && input1.dataType == DataType::FLOAT32) ||
(input0.dataType == DataType::FLOAT16 && input1.dataType == DataType::FLOAT16),
"CatDirect's input's type should be float32.\n");
"CatDirect's input's type should be float32 or float16.\n");
AssertInFastLLM(input0.dataDevice == input1.dataDevice, "CatDirect error: inputs should use same device.\n");

if (input0.dims.size() == 0) {
Expand Down Expand Up @@ -2950,11 +3054,11 @@ namespace fastllm {
int stride = (int)sinData.dims[1];
for (int l = 0; l < len; l++) {
for (int b = 0; b < bs; b++) {
int index = (int) ((float *) positionIds.cpuData)[(b * 2) * positionIds.dims.back() + l];
float *sin = ((float*)sinData.cpuData) + stride * index;
float *cos = ((float*)cosData.cpuData) + stride * index;

if (data.dataType == DataType::FLOAT32) {
int index = (int) ((float *) positionIds.cpuData)[(b * 2) * positionIds.dims.back() + l];
float *sin = ((float*)sinData.cpuData) + stride * index;
float *cos = ((float*)cosData.cpuData) + stride * index;

float *d = (float *) data.cpuData + (l * bs + b) * spatial;
for (int i = 0; i < n; i++) {
int j = 0;
Expand All @@ -2966,6 +3070,10 @@ namespace fastllm {
d += m;
}
} else if (data.dataType == DataType::FLOAT16) {
int index = (int) half_to_float(((uint16_t *) positionIds.cpuData)[(b * 2) * positionIds.dims.back() + l]);
float *sin = ((float*)sinData.cpuData) + stride * index;
float *cos = ((float*)cosData.cpuData) + stride * index;

uint16_t *d = (uint16_t *) data.cpuData + (l * bs + b) * spatial;
for (int i = 0; i < n; i++) {
int j = 0;
Expand Down

0 comments on commit 9280dcc

Please sign in to comment.