diff --git a/src/devices/cuda/fastllm-cuda.cu b/src/devices/cuda/fastllm-cuda.cu index f89152d7..ed020e97 100644 --- a/src/devices/cuda/fastllm-cuda.cu +++ b/src/devices/cuda/fastllm-cuda.cu @@ -21,6 +21,39 @@ void showError(cudaError_t result, char const* const message, const char* const } } + +#define FETCH_FLOAT4(pointer) (reinterpret_cast(&(pointer))[0]) + +typedef union __align__(16) { + uint2 in; + uint8_t out[8]; +} union_char8; + +typedef union __align__(16) { + uint32_t in; + uint8_t out[4]; +} union_char4; + +typedef union __align__(16) _union_half_4 { + uint2 in; + half out[4]; + half2 out2[2]; + __device__ _union_half_4() { + // Do nothing + } +} union_half4; + +typedef union __align__(16) _union_half_8 { + uint4 in; + half out[8]; + half2 out2[4]; + __device__ _union_half_8() { + // Do nothing + } +} union_half8; + +const size_t ST128_FP16_COUNT = 8; + static std::map s_fastllmCublasHandleMap; cublasHandle_t getFastllmCublasHandle() { int id = -1; @@ -59,10 +92,40 @@ __global__ void FastllmCudaFloat2HalfKernel(float* a, half *b, int len) { } __global__ void FastllmCudaInt82HalfKernel(uint8_t* a, float *scales, uint8_t *zeros, half *b, int len, int per) { +#ifdef CUDA_NO_TENSOR_CORE + float scalesBuffer[2]; + uint8_t zerosBuffer[2]; + int threshold = ST128_FP16_COUNT; + int index = (threadIdx.x + blockIdx.x * blockDim.x) * ST128_FP16_COUNT; + for (int idx = index; idx < len; idx += (gridDim.x * blockDim.x) * ST128_FP16_COUNT) { + int startIdx = idx / per; + int endIdx = (idx + ST128_FP16_COUNT - 1) / per; + scalesBuffer[1] = scalesBuffer[0] = scales[startIdx]; + zerosBuffer[1] = zerosBuffer[0] = zeros[startIdx]; + if (endIdx > startIdx) { + threshold = (idx + ST128_FP16_COUNT - 1) % per; + scalesBuffer[1] = scales[endIdx]; + zerosBuffer[1] = zeros[endIdx]; + } + // 读取 + union_char8 aBuffer[2]; + half bBuffer[ST128_FP16_COUNT]; + aBuffer[0].in = *reinterpret_cast(a + idx); + // 处理 + for (int i=0; i(b)[idx / ST128_FP16_COUNT] = *reinterpret_cast(bBuffer); + } +#else int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < len) { b[idx] = __float2half(scales[idx / per] * ((float)a[idx] - zeros[idx / per])); } +#endif } __global__ void FastllmCudaInt4Group2HalfKernel(uint8_t* a, float *scales, float *mins, half *b, int len, int per, @@ -80,6 +143,38 @@ __global__ void FastllmCudaInt4Group2HalfKernel(uint8_t* a, float *scales, float __global__ void FastllmCudaInt42HalfKernel(uint8_t* a, float *scales, float *mins, half *b, int len, int per) { int idx = threadIdx.x + blockIdx.x * blockDim.x; +#ifdef CUDA_NO_TENSOR_CORE + float2 scalesBuffer; + float2 minBuffer; + int threshold = ST128_FP16_COUNT; + for (int index = idx * ST128_FP16_COUNT; index < len; index += (gridDim.x * blockDim.x) * ST128_FP16_COUNT) { + int startIdx = index / per; + int endIdx = (index + ST128_FP16_COUNT - 1) / per; + scalesBuffer.x = scalesBuffer.y = __ldg(scales + startIdx); + minBuffer.x = minBuffer.y = __ldg(mins + startIdx); + if (endIdx > startIdx) { + threshold = (idx + ST128_FP16_COUNT - 1) % per; + scalesBuffer.y = __ldg(scales + endIdx); + minBuffer.y = __ldg(mins + endIdx); + } + // 读取 + union_char4 aBuffer; + union_half8 bBuffer; + aBuffer.in = *reinterpret_cast(a + index / 2); + // 处理 + for (int i = 0; i < ST128_FP16_COUNT / 2; i++) { + if (index + i * 2 + 1 < len) { + float scale = i * 2 < threshold ? scalesBuffer.x : scalesBuffer.y; + float min = i * 2 < threshold ? minBuffer.x : minBuffer.y; + bBuffer.out[i * 2] = __float2half(scale * (aBuffer.out[i] >> 4) + min); + bBuffer.out[i * 2 + 1] = __float2half(scale * (aBuffer.out[i] & 0xF) + min); + } + // if (a[index + i] != aBuffer.out[i] && index < 100) + // printf("%d - %d : %d\n", index + i, a[index + i], aBuffer.out[i]); + } + reinterpret_cast(b)[idx] = bBuffer.in; + } +#else if (idx < len) { if (idx % 2 == 1) { b[idx] = __float2half(scales[idx / per] * (a[idx / 2] & 0xF) + mins[idx / per]); @@ -87,6 +182,7 @@ __global__ void FastllmCudaInt42HalfKernel(uint8_t* a, float *scales, float *min b[idx] = __float2half(scales[idx / per] * (a[idx / 2] >> 4) + mins[idx / per]); } } +#endif } __global__ void FastllmCudaHalf2FlotaKernel(half* a, float *b, int len) { @@ -806,25 +902,51 @@ template __global__ void FastllmGemvFp32Fp16Kernel2(float *A, half *B, float *C, float *bias, int m, int k) { __shared__ float sdata[THREAD_PER_BLOCK]; unsigned int tid = threadIdx.x; + const half zero = __float2half_rn(0.0); + float4 regA; + union_half4 regB; // 1. 计算 int st = blockIdx.x * PART; int end = st + PART; for (int p = st; p < end; p++) { sdata[tid] = 0; + const half *baseB = B + p * m; +#ifdef CUDA_NO_TENSOR_CORE +#pragma unroll + for (int i = tid*4; i < m; i += THREAD_PER_BLOCK*4) { + regA = FETCH_FLOAT4(A[i]); + regB.in = *reinterpret_cast(baseB + i); + float sum = 0.0f; + if (i < m) + sum += regA.x * __low2float(regB.out2[0]); + if (i + 1 < m) + sum += regA.y * __high2float(regB.out2[0]); + if (i + 2 < m) + sum += regA.z * __low2float(regB.out2[1]); + if (i + 3 < m) + sum += regA.w * __high2float(regB.out2[1]); + sdata[tid] += sum; + } +#else for (int i = tid; i < m; i += THREAD_PER_BLOCK) { sdata[tid] += A[i] * (float)B[p * m + i]; } +#endif __syncthreads(); - for (unsigned int s = 1; s < THREAD_PER_BLOCK; s *= 2) { - if ((tid & (2 * s - 1)) == 0) { - sdata[tid] += sdata[tid + s]; + float diff = 0.0f; + for (unsigned int s = THREAD_PER_BLOCK/2; s > 0; s >>= 1) { + if (tid < s) { + float other = sdata[tid + s] - diff; + float sumTmp = sdata[tid] + other; + diff = (sumTmp - sdata[tid]) - other; + sdata[tid] = sumTmp; } __syncthreads(); } if (tid == 0) { - C[p] = sdata[0] + bias[p]; + C[p] = sdata[0] + __ldg(bias + p); } __syncthreads(); } @@ -843,25 +965,51 @@ __global__ void FastllmGemvInt8Kernel2(float *A, uint8_t *B, float *C, } __syncthreads();*/ + float4 regA; + union_char4 regB; + // 2. 计算 int st = blockIdx.x * PART; int end = st + PART; for (int p = st; p < end; p++) { sdata[tid] = 0; uint8_t zero = zeros[p]; + const uint8_t *baseB = B + p * m; +#ifdef CUDA_NO_TENSOR_CORE +#pragma unroll + for (int i = tid*4; i < m; i += THREAD_PER_BLOCK*4) { + regA = FETCH_FLOAT4(A[i]); + regB.in = *reinterpret_cast(baseB + i); + float sum = 0.0f; + if (i < m) + sum += regA.x * (float)(regB.out[0] - zero); + if (i + 1 < m) + sum += regA.y * (float)(regB.out[1] - zero); + if (i + 2 < m) + sum += regA.z * (float)(regB.out[2] - zero); + if (i + 3 < m) + sum += regA.w * (float)(regB.out[3] - zero); + sdata[tid] += sum; + } +#else for (int i = tid; i < m; i += THREAD_PER_BLOCK) { sdata[tid] += A[i] * (B[p * m + i] - zero); } +#endif __syncthreads(); - for (unsigned int s = 1; s < THREAD_PER_BLOCK; s *= 2) { - if ((tid & (2 * s - 1)) == 0) { - sdata[tid] += sdata[tid + s]; + float diff = 0.0f; + for (unsigned int s = THREAD_PER_BLOCK/2; s > 0; s >>= 1) { + if (tid < s) { + float other = sdata[tid + s] - diff; + float sumTmp = sdata[tid] + other; + diff = (sumTmp - sdata[tid]) - other; + sdata[tid] = sumTmp; } __syncthreads(); } if (tid == 0) { - C[p] = sdata[0] * scales[p] + bias[p]; + C[p] = sdata[0] * __ldg(scales + p) + __ldg(bias + p); } __syncthreads(); } @@ -1020,6 +1168,47 @@ __global__ void FastllmGemvInt4NoZeroKernel2(float *A, uint8_t *B, float *C, } } +template +__global__ void FastllmGemvInt4NoZeroKernel1(float *A, uint8_t *B, float *C, + float *bias, float *scales, float *mins, + int m, int k) { + __shared__ float sdata[THREAD_PER_BLOCK]; + unsigned int tid = threadIdx.x; + + // 1. 计算 + int st = blockIdx.x * PART; + int end = st + PART; + for (int p = st; p < end; p++) { + sdata[tid] = 0; + const uint8_t *baseB = B + p * m / 2; + float minv = __ldg(mins + p) / __ldg(scales + p); + for (int i = tid * 2; i < m / 2; i += THREAD_PER_BLOCK * 2) { + float4 aBuffer = FETCH_FLOAT4(A[i * 2]); + uint16_t bBuffer = *reinterpret_cast(baseB + i); + sdata[tid] += aBuffer.x * (minv + ((bBuffer >> 4) & 15)) + aBuffer.y * (minv + (bBuffer & 15)); + sdata[tid] += aBuffer.z * (minv + (bBuffer >> 12)) + aBuffer.w * (minv + ((bBuffer >> 8) & 15)); + } + __syncthreads(); + + float diff = 0.0f; + for (unsigned int s = THREAD_PER_BLOCK/2; s > 0; s >>= 1) { + if (tid < s) { + float other = sdata[tid + s] - diff; + float sumTmp = sdata[tid] + other; + diff = (sumTmp - sdata[tid]) - other; + sdata[tid] = sumTmp; + } + __syncthreads(); + } + //if (tid <= 32) + //warpReduce(sdata, tid); + if (tid == 0) { + C[p] = sdata[0] * scales[p] + bias[p]; + } + __syncthreads(); + } +} + template __global__ void FastllmSplitBatchKernel(uint8_t *input, uint8_t **outputs, int outer, int channels, int inner) { int bid = blockIdx.x / outer, oid = blockIdx.x % outer; @@ -1416,12 +1605,13 @@ bool FastllmCudaMatMulFloatInt8(const fastllm::Data &input, fastllm::Data &weigh FastllmCudaFloat2HalfKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>(cudaInput, cudaFp16Input, len); len = k * m; - FastllmCudaInt82HalfKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>((uint8_t*)weight.cudaData, - cudaScales, - cudaZeropoints, - cudaFp16Weight, len, m); - #ifdef CUDA_NO_TENSOR_CORE + int gridSize = (len - 1) / (threadPerBlock * ST128_FP16_COUNT) + 1; + FastllmCudaInt82HalfKernel <<< gridSize, threadPerBlock>>>((uint8_t*)weight.cudaData, + cudaScales, + cudaZeropoints, + cudaFp16Weight, len, m); + status = cublasGemmEx(fastllmCublasHandle, CUBLAS_OP_T, CUBLAS_OP_N, k, n, m, @@ -1431,6 +1621,11 @@ bool FastllmCudaMatMulFloatInt8(const fastllm::Data &input, fastllm::Data &weigh cudaOutput, CType, k, ComputeType, static_cast(CUBLAS_GEMM_DEFAULT)); #else + FastllmCudaInt82HalfKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>((uint8_t*)weight.cudaData, + cudaScales, + cudaZeropoints, + cudaFp16Weight, len, m); + status = cublasGemmEx(fastllmCublasHandle, CUBLAS_OP_T, CUBLAS_OP_N, k, n, m, @@ -1684,12 +1879,12 @@ bool FastllmCudaMatMulFloatInt4NoZero(const fastllm::Data &input, fastllm::Data len); len = k * m; - FastllmCudaInt42HalfKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>((uint8_t *) weight.cudaData, - cudaScales, - cudaMins, - cudaFp16Weight, len, m); - #ifdef CUDA_NO_TENSOR_CORE + int gridSize = (len - 1) / (threadPerBlock * 4) + 1; + FastllmCudaInt42HalfKernel <<< gridSize, threadPerBlock>>>((uint8_t *) weight.cudaData, + cudaScales, cudaMins, + cudaFp16Weight, len, m); + status = cublasGemmEx(fastllmCublasHandle, CUBLAS_OP_T, CUBLAS_OP_N, k, n, m, @@ -1699,6 +1894,11 @@ bool FastllmCudaMatMulFloatInt4NoZero(const fastllm::Data &input, fastllm::Data cudaOutput, CType, k, ComputeType, static_cast(CUBLAS_GEMM_DEFAULT)); #else + FastllmCudaInt42HalfKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>((uint8_t *) weight.cudaData, + cudaScales, + cudaMins, + cudaFp16Weight, len, m); + status = cublasGemmEx(fastllmCublasHandle, CUBLAS_OP_T, CUBLAS_OP_N, k, n, m, @@ -1730,7 +1930,7 @@ bool FastllmCudaMatMulFloatInt4NoZero(const fastllm::Data &input, fastllm::Data #endif } else { for (int i = 0; i < n; i++) { - FastllmGemvInt4NoZeroKernel2<256, 1> <<< k, 256 >>>(cudaInput + i * m, + FastllmGemvInt4NoZeroKernel1<256, 1> <<< k, 256 >>>(cudaInput + i * m, (uint8_t *) weight.cudaData, cudaOutput + i * k, cudaBiasData,