Skip to content

Commit

Permalink
支持Relu
Browse files Browse the repository at this point in the history
  • Loading branch information
huangsheng-tf committed Jan 13, 2025
1 parent df71273 commit f05c92f
Show file tree
Hide file tree
Showing 8 changed files with 99 additions and 4 deletions.
4 changes: 4 additions & 0 deletions include/devices/cpu/cpudevice.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ namespace fastllm {
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
};

class CpuReluOp : BaseOperator {
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
};

class CpuGeluOp : BaseOperator {
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
};
Expand Down
4 changes: 4 additions & 0 deletions include/devices/cuda/cudadevice.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ namespace fastllm {
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
};

class CudaReluOp : BaseOperator {
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
};

class CudaGeluOp : BaseOperator {
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
};
Expand Down
1 change: 1 addition & 0 deletions include/devices/cuda/fastllm-cuda.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ bool FastllmCudaAttention(const fastllm::Data &q, const fastllm::Data &k, const
const fastllm::Data &mask, const fastllm::Data &output, int group, float scale, int maskType);
bool FastllmCudaGeluNew(const fastllm::Data &input, fastllm::Data &output);\
bool FastllmCudaGelu(const fastllm::Data &input, fastllm::Data &output);
bool FastllmCudaRelu(const fastllm::Data &input, fastllm::Data &output);
bool FastllmCudaSilu(const fastllm::Data &input, fastllm::Data &output);
bool FastllmCudaSwiglu(const fastllm::Data &input, fastllm::Data &output);
bool FastllmCudaAdd(const fastllm::Data &input, float v, fastllm::Data &output);
Expand Down
4 changes: 4 additions & 0 deletions include/fastllm.h
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,8 @@ namespace fastllm {

Data (DataType type, const std::vector <int> &dims); // 构造函数

Data (DataType type, const std::vector <int> &dims, DataDevice device, void *ptr); // 构造函数,使用已有数据地址的Fake data

// 构造函数,创建好之后从data复制数据
// data中是原始数据,如果type不是float那么需要量化
Data (DataType type, const std::vector <int> &dims, const std::vector <float> &data);
Expand Down Expand Up @@ -569,6 +571,8 @@ namespace fastllm {

void TanH(const Data &input, Data &output);

void Relu(const Data &input, Data &output);

void Gelu(const Data &input, Data &output);

void GeluNew(const Data &input, Data &output);
Expand Down
22 changes: 21 additions & 1 deletion src/devices/cpu/cpudevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ namespace fastllm {
this->ops["SoftMax"] = (BaseOperator*)(new CpuSoftMaxOp());
this->ops["Silu"] = (BaseOperator*)(new CpuSiluOp());
this->ops["TanH"] = (BaseOperator*)(new CpuTanHOp());
this->ops["Relu"] = (BaseOperator*)(new CpuReluOp());
this->ops["Gelu"] = (BaseOperator*)(new CpuGeluOp());
this->ops["GeluNew"] = (BaseOperator*)(new CpuGeluNewOp());
this->ops["Swiglu"] = (BaseOperator*)(new CpuSwigluOp());
Expand Down Expand Up @@ -2170,6 +2171,8 @@ namespace fastllm {
Data &weight = *(datas.find("weight")->second);
Data &bias = *(datas.find("bias")->second);

AssertInFastLLM(bias.dataType == DataType::FLOAT32, "Linear's bias' type should be float32.\n");

output.Allocate(0.0f);
int n = input.Count(0) / input.dims.back();
int m = input.dims.back();
Expand Down Expand Up @@ -3419,6 +3422,23 @@ namespace fastllm {
return r;
}

void CpuReluOp::Run(const std::string &opType, const fastllm::DataDict &datas,
const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) {
Data &input = *(datas.find("input")->second);
Data &output = *(datas.find("output")->second);
output.Allocate();
AssertInFastLLM(input.dataType == DataType::FLOAT32, "Relu error: Data's type should be float32.\n");

float *inputData = (float*)input.cpuData;
float *outputData = (float*)output.cpuData;
int len = input.Count(0);
int i = 0;
for (; i < len; i++) {
float x = inputData[i];
outputData[i] = x > 0 ? x : 0;
}
}

void CpuGeluOp::Run(const std::string &opType, const fastllm::DataDict &datas,
const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) {
Data &input = *(datas.find("input")->second);
Expand Down Expand Up @@ -3715,7 +3735,7 @@ namespace fastllm {
Data &input1 = *(datas.find("input1")->second);
float alpha = floatParams.find("alpha") != floatParams.end() ? floatParams.find("alpha")->second : 1.0;

AssertInFastLLM(input0.dataType == DataType::FLOAT32 || input1.dataType == DataType::FLOAT16,
AssertInFastLLM(input0.dataType == DataType::FLOAT32 || input0.dataType == DataType::FLOAT16,
"AddTo error: Data's type should be float32 or float16.\n");
AssertInFastLLM(input0.dims == input1.dims, "AddTo error: input's shape should be same.\n");

Expand Down
14 changes: 13 additions & 1 deletion src/devices/cuda/cudadevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ namespace fastllm {
this->ops["MatMul"] = (BaseOperator*)(new CudaMatMulOp());
this->ops["MatMulTransB"] = (BaseOperator*)(new CudaMatMulTransBOp());
this->ops["SoftMax"] = (BaseOperator*)(new CudaSoftMaxOp());
this->ops["Relu"] = (BaseOperator*)(new CudaReluOp());
this->ops["Gelu"] = (BaseOperator*)(new CudaGeluOp());
this->ops["GeluNew"] = (BaseOperator*)(new CudaGeluNewOp());
this->ops["Silu"] = (BaseOperator*)(new CudaSiluOp());
Expand Down Expand Up @@ -322,7 +323,9 @@ namespace fastllm {
int n = input.Count(0) / input.dims.back();
int m = input.dims.back();
int k = output.dims.back();
if (input.dataType == DataType::FLOAT16) {
if (bias.dataType != DataType::FLOAT32) {
ErrorInFastLLM("Linear error: unsupport bias' dataType.\n");
} else if (input.dataType == DataType::FLOAT16) {
if (weight.dataType == DataType::FLOAT16) {
FastllmCudaHalfMatMulFloat16(input, weight, bias, output, n, m, k);
} else if (weight.dataType == DataType::INT8) {
Expand Down Expand Up @@ -668,6 +671,15 @@ namespace fastllm {
FastllmCudaGelu(input, output);
}

void CudaReluOp::Run(const std::string &opType, const fastllm::DataDict &datas,
const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) {
Data &input = *(datas.find("input")->second);
Data &output = *(datas.find("output")->second);
output.Allocate();
AssertInFastLLM(input.dataType == DataType::FLOAT32, "Relu error: Data's type should be float32\n");
FastllmCudaRelu(input, output);
}

void CudaSwigluOp::Reshape(const std::string &opType, const fastllm::DataDict &datas,
const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) {
Data &input = *(datas.find("input")->second);
Expand Down
29 changes: 28 additions & 1 deletion src/devices/cuda/fastllm-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,14 @@ __global__ void FastllmCudaBiasKernel(half *a, half *bias, int k) {
}
}

__global__ void FastllmReluKernel(float* a, float *b, int len) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < len) {
float x = a[idx];
b[idx] = x > 0 ? x : 0;
}
}

__global__ void FastllmGeluKernel(float* a, float *b, int len) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < len) {
Expand Down Expand Up @@ -3226,6 +3234,22 @@ void FastllmCudaMemcpy2DDeviceToDeviceBatch(void ** dsts, size_t * dpitchs, voi
DeviceSync();
}

bool FastllmCudaRelu(const fastllm::Data &input, fastllm::Data &output) {
int len = input.Count(0);
float *cudaInput = (float *) FastllmCudaPrepareInput(input);
float *cudaOutput = (float *) FastllmCudaPrepareOutput(output);
int threadPerBlock = std::min(256, len);
if (input.dataType == fastllm::DataType::FLOAT32) {
FastllmReluKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>(cudaInput, cudaOutput, len);
} else {
printf("Relu datatype error.\n");
exit(0);
}
FastllmCudaFinishInput(input, cudaInput);
FastllmCudaFinishOutput(output, cudaOutput);
return true;
}

bool FastllmCudaGelu(const fastllm::Data &input, fastllm::Data &output) {
int len = input.Count(0);
float *cudaInput = (float *) FastllmCudaPrepareInput(input);
Expand Down Expand Up @@ -3502,7 +3526,10 @@ bool FastllmCudaLayerNorm(const fastllm::Data &input, fastllm::Data &gamma, fast
int inner = input.strides[axis];

if (inner == 1) {
if (input.dataType == fastllm::DataType::FLOAT32) {
if (gamma.dataType != fastllm::DataType::FLOAT32 || beta.dataType != fastllm::DataType::FLOAT32) {
printf("layernorm datatype error.\n");
exit(0);
} else if (input.dataType == fastllm::DataType::FLOAT32) {
if (channels < 64) {
FastllmLayerNormKernelInner1<1> <<< outer, 1 >>>(cudaInput, (float *) gamma.cudaData,
(float *) beta.cudaData, cudaOutput,
Expand Down
25 changes: 24 additions & 1 deletion src/fastllm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,23 @@ namespace fastllm {
Resize(dims);
}

Data::Data (DataType type, const std::vector <int> &dims, DataDevice device, void *ptr): Data::Data(type, dims) {
this->isFake = true;
this->expansionSize = this->Count(0);
this->UpdateUnitSize();
this->dataDevice = device;
if (device == DataDevice::CPU) {
this->cpuData = (uint8_t*)ptr;
} else if (this->dataDevice == DataDevice::CUDA) {
#ifdef USE_CUDA
this->cudaData = ptr;
this->dataDeviceIds = {0}; // todo 支持多卡
#else
ErrorInFastLLM("Error: cuda is not supported.\n");
#endif
}
}

Data::Data(fastllm::DataType type, const std::vector<int> &dims, const std::vector<float> &data) : Data::Data(type, dims) {
// std::cout<<"调用数值构造"<<std::endl;
this->Allocate();
Expand Down Expand Up @@ -2590,7 +2607,7 @@ namespace fastllm {
{"input", (Data*)&input}
}, {}, {});
} else {
ErrorInFastLLM("ToDataDevice: Unsupport data type.\n");
ErrorInFastLLM("ToDataType: Unsupport data type.\n");
}
}

Expand Down Expand Up @@ -2726,6 +2743,12 @@ namespace fastllm {
}, {}, {});
}

void Relu(const fastllm::Data &input, fastllm::Data &output) {
curExecutor->Run("Relu", {
{"input", (Data*)&input}, {"output", &output}
}, {}, {});
}

void Gelu(const fastllm::Data &input, fastllm::Data &output) {
curExecutor->Run("Gelu", {
{"input", (Data*)&input}, {"output", &output}
Expand Down

0 comments on commit f05c92f

Please sign in to comment.