diff --git a/include/devices/cuda/fastllm-cuda.cuh b/include/devices/cuda/fastllm-cuda.cuh index f94c47fc..7b820614 100644 --- a/include/devices/cuda/fastllm-cuda.cuh +++ b/include/devices/cuda/fastllm-cuda.cuh @@ -71,6 +71,11 @@ bool FastllmCudaBatchMatMulTransBBatch(void **i0s, void **i1s, void **os, bool FastllmCudaBatchMatMulBatch(void **i0s, void **i1s, void **os, int *ns, int *ms, int *ks, int *i0Strides, int *i1Strides, float alpha, int batch); + +bool FastllmCudaHalfAttention(const fastllm::Data &q, const fastllm::Data &k, const fastllm::Data &v, + const fastllm::Data &mask, const fastllm::Data &output, int group, float scale); +bool FastllmCudaHalfMatMulFloat16(const fastllm::Data &input, fastllm::Data &weight, const fastllm::Data &bias, fastllm::Data &output, int n, int m, int k); + void FastllmCudaSetDevice(int gpu_id); #ifdef __cplusplus } diff --git a/include/fastllm.h b/include/fastllm.h index e21255b3..c69b776b 100644 --- a/include/fastllm.h +++ b/include/fastllm.h @@ -236,6 +236,7 @@ namespace fastllm { void *cudaData = nullptr; std::vector extraCudaData; + std::vector extraCudaHalfData; void *deviceData = nullptr; std::vector extraDeviceData; diff --git a/src/devices/cpu/cpudevice.cpp b/src/devices/cpu/cpudevice.cpp index 9eb59d96..0f8db2e1 100644 --- a/src/devices/cpu/cpudevice.cpp +++ b/src/devices/cpu/cpudevice.cpp @@ -1894,7 +1894,7 @@ namespace fastllm { AssertInFastLLM((input0.dataType == DataType::FLOAT32 && input1.dataType == DataType::FLOAT32) || (input0.dataType == DataType::FLOAT16 && input1.dataType == DataType::FLOAT16), - "Cat's input's type should be float32.\n"); + "Cat's input's type should be float32 or float16.\n"); AssertInFastLLM(input0.dims.size() == input1.dims.size(), "Cat Error: input's shape's size should be same."); int dimsLen = input0.dims.size(); @@ -3070,7 +3070,7 @@ namespace fastllm { d += m; } } else if (data.dataType == DataType::FLOAT16) { - int index = (int) half_to_float(((uint16_t *) positionIds.cpuData)[(b * 2) * positionIds.dims.back() + l]); + int index = (int) ((float *) positionIds.cpuData)[(b * 2) * positionIds.dims.back() + l]; float *sin = ((float*)sinData.cpuData) + stride * index; float *cos = ((float*)cosData.cpuData) + stride * index; diff --git a/src/devices/cuda/cudadevice.cpp b/src/devices/cuda/cudadevice.cpp index db2dafc6..d6ba08c6 100644 --- a/src/devices/cuda/cudadevice.cpp +++ b/src/devices/cuda/cudadevice.cpp @@ -83,7 +83,9 @@ namespace fastllm { AssertInFastLLM(q.dataType == k.dataType && q.dataType == v.dataType, "Attention: q, k, v's datatype should be same.\n"); - AssertInFastLLM(q.dataType == DataType::FLOAT32, "Attention's input's type should be float32.\n"); + AssertInFastLLM(q.dataType == DataType::FLOAT32 || + q.dataType == DataType::FLOAT16, + "Attention's input's type should be float32 or float16.\n"); std::vector dims = {q.dims[0], q.dims[1], v.dims[2]}; output.dataType = q.dataType; @@ -101,7 +103,12 @@ namespace fastllm { int group = intParams.find("group") != intParams.end() ? intParams.find("group")->second : 1; float scale = floatParams.find("scale") != floatParams.end() ? floatParams.find("scale")->second : 1.0; output.Allocate(); - FastllmCudaAttention(q, k, v, mask, output, group, scale); + + if (q.dataType == DataType::FLOAT32) { + FastllmCudaAttention(q, k, v, mask, output, group, scale); + } else if (q.dataType == DataType::FLOAT16) { + FastllmCudaHalfAttention(q, k, v, mask, output, group, scale); + } } void CudaCopyKVCacheOp::Reshape(const std::string &opType, const fastllm::DataDict &datas, @@ -139,6 +146,11 @@ namespace fastllm { Data &input = *(datas.find("input")->second); Data &weight = *(datas.find("weight")->second); Data &output = *(datas.find("output")->second); + + AssertInFastLLM(input.dataType == DataType::FLOAT32 || + input.dataType == DataType::FLOAT16, + "RMSNorm error: datatype should be float32 or float16."); + output.Allocate(); float eps = floatParams.find("eps") != floatParams.end() ? floatParams.find("eps")->second : 1e-5; @@ -182,7 +194,7 @@ namespace fastllm { std::vector dims = input.dims; dims.back() = weight.dims[0]; - output.dataType = DataType::FLOAT32; + output.dataType = input.dataType; output.Resize(dims); } @@ -203,20 +215,30 @@ namespace fastllm { int m = input.dims.back(); int k = output.dims.back(); - if (weight.dataType == DataType::FLOAT32) { - FastllmCudaMatMulFloat32(input, weight, bias, output, n, m, k); - } else if (weight.dataType == DataType::FLOAT16) { - FastllmCudaMatMulFloat16(input, weight, bias, output, n, m, k); - } else if (weight.dataType == DataType::INT8) { - FastllmCudaMatMulFloatInt8(input, weight, bias, output, n, m, k); - } else if (weight.dataType == DataType::INT4) { - FastllmCudaMatMulFloatInt4(input, weight, bias, output, n, m, k); - } else if (weight.dataType == DataType::INT4_NOZERO) { - FastllmCudaMatMulFloatInt4NoZero(input, weight, bias, output, n, m, k); - } else if (weight.dataType == DataType::INT4_GROUP) { - FastllmCudaMatMulFloatInt4Group(input, weight, bias, output, n, m, k); + if (input.dataType == DataType::FLOAT16) { + if (weight.dataType == DataType::FLOAT16) { + FastllmCudaHalfMatMulFloat16(input, weight, bias, output, n, m, k); + } else { + ErrorInFastLLM("Linear error: unsupport weight's dataType.\n"); + } + } else if (input.dataType == DataType::FLOAT32) { + if (weight.dataType == DataType::FLOAT32) { + FastllmCudaMatMulFloat32(input, weight, bias, output, n, m, k); + } else if (weight.dataType == DataType::FLOAT16) { + FastllmCudaMatMulFloat16(input, weight, bias, output, n, m, k); + } else if (weight.dataType == DataType::INT8) { + FastllmCudaMatMulFloatInt8(input, weight, bias, output, n, m, k); + } else if (weight.dataType == DataType::INT4) { + FastllmCudaMatMulFloatInt4(input, weight, bias, output, n, m, k); + } else if (weight.dataType == DataType::INT4_NOZERO) { + FastllmCudaMatMulFloatInt4NoZero(input, weight, bias, output, n, m, k); + } else if (weight.dataType == DataType::INT4_GROUP) { + FastllmCudaMatMulFloatInt4Group(input, weight, bias, output, n, m, k); + } else { + ErrorInFastLLM("Linear error: unsupport weight's dataType.\n"); + } } else { - ErrorInFastLLM("Linear error: unsupport weight's dataType.\n"); + ErrorInFastLLM("Linear error: unsupport input's dataType.\n"); } } @@ -275,8 +297,9 @@ namespace fastllm { int axis = intParams.find("axis") != intParams.end() ? intParams.find("axis")->second : -1; - AssertInFastLLM(input0.dataType == DataType::FLOAT32 && input1.dataType == DataType::FLOAT32, - "Cat's input's type should be float32.\n"); + AssertInFastLLM((input0.dataType == DataType::FLOAT32 && input1.dataType == DataType::FLOAT32) || + (input0.dataType == DataType::FLOAT16 && input1.dataType == DataType::FLOAT16), + "Cat's input's type should be float32 or float16.\n"); AssertInFastLLM(input0.dataDevice == input1.dataDevice, "CatDirect error: inputs should use same device.\n"); if (input0.dims.size() == 0) { @@ -475,7 +498,9 @@ namespace fastllm { Data &input = *(datas.find("input")->second); Data &output = *(datas.find("output")->second); output.Allocate(); - AssertInFastLLM(input.dataType == DataType::FLOAT32, "Swiglu error: Data's type should be float32.\n"); + AssertInFastLLM(input.dataType == DataType::FLOAT32 || + input.dataType == DataType::FLOAT16, + "Swiglu error: Data's type should be float32.\n"); FastllmCudaSwiglu(input, output); } @@ -495,7 +520,9 @@ namespace fastllm { output.Allocate(); float v = floatParams.find("v") != floatParams.end() ? floatParams.find("v")->second : 1.0; - AssertInFastLLM(input.dataType == DataType::FLOAT32, "Mul error: Data's type should be float32.\n"); + AssertInFastLLM(input.dataType == DataType::FLOAT32 || + input.dataType == DataType::FLOAT16, + "Mul error: Data's type should be float32 or float16.\n"); FastllmCudaMul(input, v, output); } @@ -505,8 +532,9 @@ namespace fastllm { Data &input1 = *(datas.find("input1")->second); float alpha = floatParams.find("alpha") != floatParams.end() ? floatParams.find("alpha")->second : 1.0; - AssertInFastLLM(input0.dataType == DataType::FLOAT32 && input1.dataType == DataType::FLOAT32, - "AddTo error: Data's type should be float32.\n"); + AssertInFastLLM((input0.dataType == DataType::FLOAT32 && input1.dataType == DataType::FLOAT32) || + (input0.dataType == DataType::FLOAT16 && input1.dataType == DataType::FLOAT16), + "AddTo error: Data's type should be float32 or float16.\n"); AssertInFastLLM(input0.dims == input1.dims, "AddTo error: input's shape should be same.\n"); FastllmCudaAddTo(input0, input1, alpha); } @@ -583,7 +611,9 @@ namespace fastllm { axis.push_back(((int32_t *) axisData.cpuData)[i]); } - AssertInFastLLM(input.dataType == DataType::FLOAT32, "Permute error: datatype should be float32."); + AssertInFastLLM(input.dataType == DataType::FLOAT32 || + input.dataType == DataType::FLOAT16, + "Permute error: datatype should be float32 or float16."); AssertInFastLLM(axis.size() == input.dims.size(), "Permute error: axis's size should be equal to data's shape's size."); bool same = false; diff --git a/src/devices/cuda/fastllm-cuda.cu b/src/devices/cuda/fastllm-cuda.cu index 314be891..1a9bd44c 100644 --- a/src/devices/cuda/fastllm-cuda.cu +++ b/src/devices/cuda/fastllm-cuda.cu @@ -104,6 +104,14 @@ __global__ void FastllmCudaBiasKernel(float *a, float *bias, int k) { } } +__global__ void FastllmCudaBiasKernel(half *a, half *bias, int k) { + half *now = a + blockIdx.x * k; + int stride = blockDim.x; + for (int i = threadIdx.x; i < k; i += stride) { + now[i] = __hadd(now[i], bias[i]); + } +} + __global__ void FastllmGeluKernel(float* a, float *b, int len) { int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < len) { @@ -129,6 +137,15 @@ __global__ void FastllmSwigluKernel(float* a, float *b, int len, int spatial, in } } +__global__ void FastllmSwigluKernel(half* a, half *b, int len, int spatial, int mid) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < len) { + int id = idx / mid * spatial + idx % mid; + half x = a[id], y = a[id + mid]; + b[idx] = __hmul(__hdiv(x, __hadd(__float2half(1.0), hexp(-x))), y); + } +} + __global__ void FastllmMulKernel(float* a, float *b, float v, int len) { int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < len) { @@ -136,6 +153,13 @@ __global__ void FastllmMulKernel(float* a, float *b, float v, int len) { } } +__global__ void FastllmMulKernel(half* a, half *b, half v, int len) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < len) { + b[idx] = __hmul(a[idx], v); + } +} + template __global__ void FastllmMulBatchKernel(float** pointer, int batch, float v) { float *input = pointer[blockIdx.x]; @@ -153,6 +177,13 @@ __global__ void FastllmAddToKernel(float* a, float *b, float alpha, int len) { } } +__global__ void FastllmAddToKernel(half* a, half *b, half alpha, int len) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < len) { + a[idx] = __hadd(a[idx], __hmul(b[idx], alpha)); + } +} + __global__ void FastllmMulToKernel(float* a, float *b, float alpha, int len) { int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < len) { @@ -173,6 +204,19 @@ __global__ void FastllmAttentionMaskKernel(float* a, float *b, float maskValue, } } +template +__global__ void FastllmAttentionMaskKernel(half *a, half *b, half maskValue, int n, int m, int spatial) { + int on = blockIdx.x / m; + int om = blockIdx.x % m; + int o = on * m + om; + int idx = threadIdx.x; + for (int i = idx; i < spatial; i += THREAD_PER_BLOCK) { + if (__half2float(b[on * spatial + i]) > 0.99) { + a[o * spatial + i] = maskValue; + } + } +} + template __global__ void SimpleMask(float* a, float *b, float maskValue, int spatial) { int i = threadIdx.x + blockIdx.x * blockDim.x; @@ -183,6 +227,16 @@ __global__ void SimpleMask(float* a, float *b, float maskValue, int spatial) { } } +template +__global__ void SimpleMask(half* a, half *b, half maskValue, int spatial) { + int i = threadIdx.x + blockIdx.x * blockDim.x; + if (i < spatial) { + if (__half2float(b[i]) > 0.99) { + a[i] = maskValue; + } + } +} + template __global__ void FastllmAlibiMaskKernel(float* a, float *b, float maskValue, int n, int m, int spn, int spm, int spatial) { int on = blockIdx.x / m; @@ -225,7 +279,8 @@ __global__ void FastllmTransposeByRowKernel(uint8_t *dst, uint8_t *ori, int n, i } } -__global__ void FastllmPermuteKernel(float *dst, float *ori, int *temp, int axisLen, int len) { +template +__global__ void FastllmPermuteKernel(T *dst, T *ori, int *temp, int axisLen, int len) { int i = threadIdx.x + blockIdx.x * blockDim.x; if (i < len) { int old = 0; @@ -258,29 +313,6 @@ __global__ void FastllmLlamaRotatePosition2DKernel(float *data, float *positionI __global__ void FastllmNearlyRotatePosition2DKernel(float *data, float *positionIds, float *sin, float *cos, int len, int bs, int spatial, int n, int m, int partStride, int sinCosStride, int rotateDim) { -/* - int len = data.dims[0], bs = data.dims[1]; - int spatial = data.Count(2); - int n = data.dims[2], m = data.dims[3]; - int stride = (int)sinData.dims[1]; - for (int l = 0; l < len; l++) { - for (int b = 0; b < bs; b++) { - int index = (int) ((float *) positionIds.cpuData)[(b * 2) * positionIds.dims.back() + l]; - float *sin = ((float*)sinData.cpuData) + stride * index; - float *cos = ((float*)cosData.cpuData) + stride * index; - float *d = (float *) data.cpuData + (l * bs + b) * spatial; - for (int i = 0; i < n; i++) { - int j = 0; - for (; j < rotaryDim; j += 2) { - float a = d[j], b = d[j + 1]; - d[j] = a * cos[j / 2] - b * sin[j / 2]; - d[j + 1] = a * sin[j / 2] + b * cos[j / 2]; - } - d += m; - } - } - } -*/ int o = (blockIdx.x / n); int l = o / bs; int b = o % bs; @@ -296,6 +328,23 @@ __global__ void FastllmNearlyRotatePosition2DKernel(float *data, float *position d[i * m + 1] = va * curSin + vb * curCos; } +__global__ void FastllmNearlyRotatePosition2DKernel(half *data, float *positionIds, float *sin, float *cos, + int len, int bs, int spatial, int n, int m, int partStride, int sinCosStride, int rotateDim) { + int o = (blockIdx.x / n); + int l = o / bs; + int b = o % bs; + int j = threadIdx.x; + int index = (int) (positionIds[b * 2 * partStride + l]); + + float curSin = sin[index * sinCosStride + j]; + float curCos = cos[index * sinCosStride + j]; + half *d = (half *) data + o * spatial + j * 2; + int i = blockIdx.x % n; + float va = __half2float(d[i * m]), vb = __half2float(d[i * m + 1]); + d[i * m] = __float2half(va * curCos - vb * curSin); + d[i * m + 1] = __float2half(va * curSin + vb * curCos); +} + __global__ void FastllmRotatePosition2DKernel(float *data, float *positionIds, float *sin, float *cos, int len, int bs, int spatial, int n, int m, int partStride, int sinCosStride, int rotateDim) { int o = (blockIdx.x / n) / 2; @@ -369,12 +418,73 @@ __device__ void FastllmSoftmaxKernelInner1Func(float *input, float *output, int } } +template +__device__ void FastllmSoftmaxKernelInner1Func(half *input, half *output, int channels) { + __shared__ half sdata[THREAD_PER_BLOCK]; + __shared__ half maxV; + + // 1. 每个线程计算一部分 + unsigned int tid = threadIdx.x; + half maxValue = input[tid]; + for (int i = tid; i < channels; i += THREAD_PER_BLOCK) { + maxValue = __hmax(maxValue, input[i]); + } + sdata[tid] = maxValue; + __syncthreads(); + + // 2. 求max + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (tid < s) { + sdata[tid] = __hmax(sdata[tid], sdata[tid + s]); + } + __syncthreads(); + } + + // 3. 记录max + if (tid == 0) { + maxV = sdata[0]; + } + __syncthreads(); + + // 4. 求和 + half sum = 0; + for (int i = tid; i < channels; i += THREAD_PER_BLOCK) { + output[i] = hexp(__hsub(input[i], maxV)); + sum = __hadd(sum, output[i]); + } + sdata[tid] = sum; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (tid < s) { + sdata[tid] = __hadd(sdata[tid], sdata[tid + s]); + } + __syncthreads(); + } + if (tid == 0) { + if (fabs(__half2float(sdata[0])) < 1e-6) { + sdata[0] = __float2half(0.1); + } + } + __syncthreads(); + + for (int i = tid; i < channels; i += THREAD_PER_BLOCK) { + output[i] = __hdiv(output[i], sdata[0]); + } +} + template __global__ void FastllmSoftmaxKernelInner1(float* input, float *output, int outer, int channels) { int o = blockIdx.x; FastllmSoftmaxKernelInner1Func (input + o * channels, output + o * channels, channels); } +template +__global__ void FastllmSoftmaxKernelInner1(half* input, half *output, int outer, int channels) { + int o = blockIdx.x; + FastllmSoftmaxKernelInner1Func (input + o * channels, output + o * channels, channels); +} + template __global__ void FastllmSoftmaxKernelBatchInner1(uint8_t** pointer) { int o = blockIdx.x; @@ -420,6 +530,44 @@ __global__ void FastllmRMSNormKernelInner1(float *input, float *weight, float *o } } +template +__global__ void FastllmRMSNormKernelInner1(half *input, float *weight, half *output, int outer, int channels, float eps) { + int o = blockIdx.x; + input = input + o * channels; + output = output + o * channels; + + __shared__ float sdata2[THREAD_PER_BLOCK]; + __shared__ float scale; + + // 1. 每个线程计算一部分 + unsigned int tid = threadIdx.x; + float sum2 = 0.0; + for (int i = tid; i < channels; i += THREAD_PER_BLOCK) { + float x = __half2float(input[i]); + sum2 += x * x; + } + sdata2[tid] = sum2; + __syncthreads(); + + // 2. 求和 + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (tid < s) { + sdata2[tid] += sdata2[tid + s]; + } + __syncthreads(); + } + + // 3. 计算参数 + if (tid == 0) { + scale = 1.0 / sqrt(sdata2[0] / channels + eps); + } + __syncthreads(); + + for (int i = tid; i < channels; i += THREAD_PER_BLOCK) { + output[i] = __float2half(__half2float(input[i]) * scale * weight[i]); + } +} + template __global__ void FastllmLayerNormKernelInner1(float *input, float *gamma, float *beta, float *output, int outer, int channels) { int o = blockIdx.x; @@ -1884,7 +2032,11 @@ bool FastllmCudaSwiglu(const fastllm::Data &input, fastllm::Data &output) { int spatial = input.Count(input.dims.size() - 1), mid = spatial / 2; int threadPerBlock = std::min(256, len); - FastllmSwigluKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>(cudaInput, cudaOutput, len, spatial, mid); + if (input.dataType == fastllm::DataType::FLOAT32) { + FastllmSwigluKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>(cudaInput, cudaOutput, len, spatial, mid); + } else if (input.dataType == fastllm::DataType::FLOAT16) { + FastllmSwigluKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>((half*)cudaInput, (half*)cudaOutput, len, spatial, mid); + } FastllmCudaFinishInput(input, cudaInput); FastllmCudaFinishOutput(output, cudaOutput); @@ -1896,7 +2048,13 @@ bool FastllmCudaMul(const fastllm::Data &input, float v, fastllm::Data &output) float *cudaInput = (float *) FastllmCudaPrepareInput(input); float *cudaOutput = (float *) FastllmCudaPrepareOutput(output); int threadPerBlock = std::min(256, len); - FastllmMulKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>(cudaInput, cudaOutput, v, len); + + if (input.dataType == fastllm::DataType::FLOAT32) { + FastllmMulKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>(cudaInput, cudaOutput, v, len); + } else { + FastllmMulKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>((half*)cudaInput, (half*)cudaOutput, __float2half_rn(v), len); + } + FastllmCudaFinishInput(input, cudaInput); FastllmCudaFinishOutput(output, cudaOutput); return true; @@ -1908,7 +2066,12 @@ bool FastllmCudaAddTo(fastllm::Data &input0, const fastllm::Data &input1, float float *input1Data = (float *) FastllmCudaPrepareInput(input1); int threadPerBlock = std::min(256, len); - FastllmAddToKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>(cudaData, input1Data, alpha, len); + if (input0.dataType == fastllm::DataType::FLOAT32) { + FastllmAddToKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>(cudaData, input1Data, alpha, len); + } else if (input0.dataType == fastllm::DataType::FLOAT16) { + FastllmAddToKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>((half*)cudaData, (half*)input1Data, __float2half_rn(alpha), len); + } + FastllmCudaFinishInput(input1, input1Data); FastllmCudaFinishOutput(input0, cudaData); return true; @@ -2039,12 +2202,28 @@ bool FastllmCudaRMSNorm(const fastllm::Data &input, fastllm::Data &weight, fastl int outer = input.Count(0) / input.Count(axis); int channels = input.dims[axis]; - if (channels < 64) { - FastllmRMSNormKernelInner1<1> <<< outer, 1 >>>(cudaInput, (float *) weight.cudaData, cudaOutput, outer, channels, eps); - } else if (channels < 512) { - FastllmRMSNormKernelInner1<64> <<< outer, 64 >>>(cudaInput, (float *) weight.cudaData, cudaOutput, outer, channels, eps); - } else { - FastllmRMSNormKernelInner1<512> <<< outer, 512 >>>(cudaInput, (float *) weight.cudaData, cudaOutput, outer, channels, eps); + if (input.dataType == fastllm::DataType::FLOAT32) { + if (channels < 64) { + FastllmRMSNormKernelInner1<1> <<< outer, 1 >>>(cudaInput, (float *) weight.cudaData, cudaOutput, outer, + channels, eps); + } else if (channels < 512) { + FastllmRMSNormKernelInner1<64> <<< outer, 64 >>>(cudaInput, (float *) weight.cudaData, cudaOutput, outer, + channels, eps); + } else { + FastllmRMSNormKernelInner1<512> <<< outer, 512 >>>(cudaInput, (float *) weight.cudaData, cudaOutput, outer, + channels, eps); + } + } else if (input.dataType == fastllm::DataType::FLOAT16) { + if (channels < 64) { + FastllmRMSNormKernelInner1<1> <<< outer, 1 >>>((half*)cudaInput, (float*) weight.cudaData, (half*)cudaOutput, outer, + channels, eps); + } else if (channels < 512) { + FastllmRMSNormKernelInner1<64> <<< outer, 64 >>>((half*)cudaInput, (float*) weight.cudaData, (half*)cudaOutput, outer, + channels, eps); + } else { + FastllmRMSNormKernelInner1<512> <<< outer, 512 >>>((half*)cudaInput, (float*) weight.cudaData, (half*)cudaOutput, outer, + channels, eps); + } } FastllmCudaFinishInput(input, cudaInput); @@ -2111,8 +2290,8 @@ bool FastllmCudaPermute(fastllm::Data &input, const std::vector &axis) { exit(0); } int len = input.Count(0); - float *tempData = (float *)FastllmCudaMalloc(len * sizeof(float)); - cudaMemcpy(tempData, input.cudaData, len * sizeof(float), cudaMemcpyDeviceToDevice); + uint8_t *tempData = (uint8_t *)FastllmCudaMalloc(len * input.unitSize); + cudaMemcpy(tempData, input.cudaData, len * input.unitSize, cudaMemcpyDeviceToDevice); std::vector new_dims; for (int i = 0; i < axis.size(); i++) { @@ -2149,9 +2328,17 @@ bool FastllmCudaPermute(fastllm::Data &input, const std::vector &axis) { int *cudaTemp = (int *) FastllmCudaMalloc(temp.size() * sizeof(int)); cudaMemcpy(cudaTemp, temp.data(), temp.size() * sizeof(int), cudaMemcpyHostToDevice); int threadPerBlock = std::min(256, len); - FastllmPermuteKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock >>>((float *) input.cudaData, - tempData, cudaTemp, - (int) axis.size(), len); + if (input.unitSize == 4) { + FastllmPermuteKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock >>>( + (float *) input.cudaData,(float *)tempData, cudaTemp,(int) axis.size(), len); + } else if (input.unitSize == 2) { + FastllmPermuteKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock >>>( + (uint16_t *) input.cudaData,(uint16_t *)tempData, cudaTemp,(int) axis.size(), len); + } else if (input.unitSize == 1) { + FastllmPermuteKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock >>>( + (uint8_t *) input.cudaData,(uint8_t *)tempData, cudaTemp,(int) axis.size(), len); + } + FastllmCudaFree(cudaTemp); } @@ -2296,6 +2483,130 @@ bool FastllmCudaAttention(const fastllm::Data &q, const fastllm::Data &k, const return true; } +bool FastllmCudaHalfAttention(const fastllm::Data &q, const fastllm::Data &k, const fastllm::Data &v, + const fastllm::Data &mask, const fastllm::Data &output, int group, float scale) { + int q0 = q.dims[0], q1 = q.dims[1], q2 = q.dims[2], k0 = k.dims[0], k1 = k.dims[1], v2 = v.dims[2]; + half *qd = (half*)q.cudaData; + half *kd = (half*)k.cudaData; + half *vd = (half*)v.cudaData; + half *maskd = mask.dims.size() > 0 ? (half*)mask.cudaData : nullptr; + half *od = (half*)output.cudaData; + int batch = (mask.dims.size() == 3) ? mask.dims[0] : 1; + int maskStride = (mask.dims.size() == 3 ? mask.strides[0] : mask.Count(0)); + + half beta = __float2half_rn(0.0f), one = __float2half_rn(1.0f), hscale = __float2half_rn(scale); + + if (q1 > 1024) { + half *qk = (half *) FastllmCudaMalloc(q1 * k1 * sizeof(half)); + auto fastllmCublasHandle = getFastllmCublasHandle(); + cublasStatus_t status; + for (int i = 0; i < q0; i++) { + status = cublasHgemmStridedBatched(fastllmCublasHandle, + CUBLAS_OP_T, CUBLAS_OP_N, + k1, q1, q2, &hscale, + kd + (i / group) * k.Count(1), k.strides[1], k.Count(1), + qd + i * q.Count(1), q.strides[1], q.Count(1), + &beta, + qk, k1, k1 * q1, 1); + if (status != CUBLAS_STATUS_SUCCESS) { + printf("status = %d\n", (int) status); + printf("Error: cublas error during MatMulTransB in Attention operator.\n"); + throw ("cublas error"); + exit(0); + } + + if (maskd) { + SimpleMask<256> <<< (q1 * k1 / 256) + 1, 256>>>(qk, maskd + (i / (q0 / batch)) * maskStride, __float2half_rn(-10000), q1 * k1); + } + + int outer = q1; + if (k1 < 8) { + FastllmSoftmaxKernelInner1<1> <<< outer, 1 >>>(qk, qk, outer, k1); + } else if (k1 < 64) { + FastllmSoftmaxKernelInner1<8> <<< outer, 8 >>>(qk, qk, outer, k1); + } else if (k1 < 512) { + FastllmSoftmaxKernelInner1<64> <<< outer, 64 >>>(qk, qk, outer, k1); + } else { + FastllmSoftmaxKernelInner1<256> <<< outer, 256 >>>(qk, qk, outer, k1); + } + + status = cublasHgemmStridedBatched(fastllmCublasHandle, + CUBLAS_OP_N, CUBLAS_OP_N, + v2, q1, k1, &one, + vd + (i / group) * v.Count(1), v.strides[1], v.Count(1), + qk, k1, k1 * q1, + &beta, + od + i * v2 * q1, v2, v2 * q1, 1); + if (status != CUBLAS_STATUS_SUCCESS) { + printf("status = %d\n", (int) status); + printf("Error: cublas error during MatMul in Attention operator.\n"); + throw ("cublas error"); + exit(0); + } + } + + FastllmCudaFree(qk); + DeviceSync(); + return true; + } + + if (true) { + half *qk = (half *) FastllmCudaMalloc(q0 * q1 * k1 * sizeof(half)); + half *temp = (half *) FastllmCudaMalloc(q0 * q1 * k1 * sizeof(half)); + auto fastllmCublasHandle = getFastllmCublasHandle(); + cublasStatus_t status; + + status = cublasHgemmStridedBatched(fastllmCublasHandle, + CUBLAS_OP_T, CUBLAS_OP_N, + k1, q1 * group, q2, &hscale, + kd, k.strides[1], k.Count(1), + qd, q.strides[1], q.Count(1) * group, + &beta, + qk, k1, k1 * q1 * group, q0 / group); + if (status != CUBLAS_STATUS_SUCCESS) { + printf("status = %d\n", (int) status); + printf("Error: cublas error during MatMulTransB in Attention operator.\n"); + throw ("cublas error"); + exit(0); + } + + if (maskd) { + int spatial = q1 * k1, n = batch, m = q0 / batch; + FastllmAttentionMaskKernel <256> <<< n * m, 256>>>(qk, maskd, __float2half_rn(-10000), n, m, spatial); + } + + int outer = q0 * q1; + if (k1 < 8) { + FastllmSoftmaxKernelInner1<1> <<< outer, 1 >>>(qk, temp, outer, k1); + } else if (k1 < 64) { + FastllmSoftmaxKernelInner1<8> <<< outer, 8 >>>(qk, temp, outer, k1); + } else if (k1 < 512) { + FastllmSoftmaxKernelInner1<64> <<< outer, 64 >>>(qk, temp, outer, k1); + } else { + FastllmSoftmaxKernelInner1<256> <<< outer, 256 >>>(qk, temp, outer, k1); + } + + status = cublasHgemmStridedBatched(fastllmCublasHandle, + CUBLAS_OP_N, CUBLAS_OP_N, + v2, q1 * group, k1, &one, + vd, v.strides[1], v.Count(1), + temp, k1, k1 * q1 * group, + &beta, + od, v2, v2 * q1 * group, q0 / group); + if (status != CUBLAS_STATUS_SUCCESS) { + printf("status = %d\n", (int) status); + printf("Error: cublas error during MatMul in Attention operator.\n"); + throw ("cublas error"); + exit(0); + } + FastllmCudaFree(qk); + FastllmCudaFree(temp); + DeviceSync(); + return true; + } + return true; +} + bool FastllmCudaBatchMatMul(const fastllm::Data &input0, const fastllm::Data &input1, fastllm::Data &output, int input0Spatial, int input1Spatial, int outputSpatial, int input0Stride, int input1Stride, @@ -2394,9 +2705,16 @@ bool FastllmCudaNearlyRotatePosition2D(fastllm::Data &data, const fastllm::Data int spatial = data.Count(2); int len = data.dims[0], bs = data.dims[1]; int n = data.dims[2], m = data.dims[3]; - FastllmNearlyRotatePosition2DKernel <<< outer * n, std::min(rotaryDim, m / 4) >>> (cudaData, cudaPositionIds, cudaSin, cudaCos, - len, bs, spatial, n, m, - (int)positionIds.dims.back(), (int)sinData.dims[1], rotaryDim); + + if (data.dataType == fastllm::DataType::FLOAT32) { + FastllmNearlyRotatePosition2DKernel <<< outer * n, std::min(rotaryDim, m / 4) >>> (cudaData, cudaPositionIds, cudaSin, cudaCos, + len, bs, spatial, n, m, + (int)positionIds.dims.back(), (int)sinData.dims[1], rotaryDim); + } else if (data.dataType == fastllm::DataType::FLOAT16) { + FastllmNearlyRotatePosition2DKernel <<< outer * n, std::min(rotaryDim, m / 4) >>> ((half*)cudaData, cudaPositionIds, cudaSin, cudaCos, + len, bs, spatial, n, m, + (int)positionIds.dims.back(), (int)sinData.dims[1], rotaryDim); + } FastllmCudaFinishInput(positionIds, cudaPositionIds); FastllmCudaFinishInput(sinData, cudaSin); @@ -2700,6 +3018,58 @@ bool FastllmCudaBatchMatMulBatch(void **i0s, void **i1s, void **os, return true; } +bool FastllmCudaHalfMatMulFloat16(const fastllm::Data &input, fastllm::Data &weight, const fastllm::Data &bias, fastllm::Data &output, int n, int m, int k) { + if (weight.cudaData == nullptr || weight.extraCudaHalfData.size() == 0) { + half *cudaBiasData; + cudaError_t state = cudaSuccess; + state = cudaMalloc(&cudaBiasData, k * sizeof(half)); + if (bias.dims.size() > 0) { + float *tempBiasData; + state = cudaMalloc(&tempBiasData, k * sizeof(float)); + state = cudaMemcpy(tempBiasData, (uint8_t *) bias.cudaData, k * sizeof(float), cudaMemcpyDeviceToDevice); + int threadPerBlock = std::min(256, k); + FastllmCudaFloat2HalfKernel <<< (k - 1) / threadPerBlock + 1, threadPerBlock>>>(tempBiasData, cudaBiasData, k); + state = cudaFree(tempBiasData); + } else { + state = cudaMemset(cudaBiasData, __float2half_rn(0.0), k * sizeof(half)); + } + checkCudaErrors("Error: CUDA error when moving bias to device!", state); + weight.extraCudaHalfData.push_back((void *) cudaBiasData); + } + + half *cudaBiasData = (half *) weight.extraCudaHalfData[0]; + half *cudaInput = (half *) FastllmCudaPrepareInput(input); + half *cudaOutput = (half *) FastllmCudaPrepareOutput(output); + + __half h_alpha = __float2half_rn(1.0), h_beta = __float2half_rn(0.0); + auto fastllmCublasHandle = getFastllmCublasHandle(); + cudaDataType_t AType = CUDA_R_16F, BType = CUDA_R_16F, CType = CUDA_R_16F, ComputeType = CUDA_R_16F; + cublasStatus_t status; + + int len = n * m; + int threadPerBlock = std::min(256, len); + + status = cublasGemmEx(fastllmCublasHandle, + CUBLAS_OP_T, CUBLAS_OP_N, + k, n, m, + &h_alpha, (half *) weight.cudaData, AType, + m, cudaInput, BType, + m, &h_beta, + cudaOutput, CType, + k, ComputeType, static_cast(CUBLAS_GEMM_DEFAULT)); + if (status != CUBLAS_STATUS_SUCCESS) { + printf("Error: cublas error.\n"); + throw ("cublas error"); + exit(0); + } + + FastllmCudaBiasKernel <<< n, 256 >>>(cudaOutput, (half *) weight.extraCudaHalfData[0], k); + + FastllmCudaFinishInput(input, cudaInput); + FastllmCudaFinishOutput(output, cudaOutput); + return true; +} + void FastllmCudaSetDevice(int gpu_id) { cudaSetDevice(gpu_id); } diff --git a/src/fastllm.cpp b/src/fastllm.cpp index b1da3da8..af16fcae 100644 --- a/src/fastllm.cpp +++ b/src/fastllm.cpp @@ -326,6 +326,8 @@ namespace fastllm { this->unitSize = 4; this->unitSizeDiv = 1; } + + this->expansionBytes = (this->expansionSize * this->unitSize - 1) / this->unitSizeDiv + 1; } void Data::Resize(const std::vector &dims) {