From bbc0c14d49f0a9693bdfae6f1c1489251601d458 Mon Sep 17 00:00:00 2001 From: huangyuyang <410644548@qq.com> Date: Fri, 24 May 2024 01:59:25 +0000 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E4=B8=80=E4=BA=9B=E7=AE=97?= =?UTF-8?q?=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/devices/cuda/cudadevicebatch.cpp | 2 +- src/devices/cuda/fastllm-cuda.cu | 89 ++++++++++++++++++++++++++-- 2 files changed, 85 insertions(+), 6 deletions(-) diff --git a/src/devices/cuda/cudadevicebatch.cpp b/src/devices/cuda/cudadevicebatch.cpp index 79317a13..b31af83c 100644 --- a/src/devices/cuda/cudadevicebatch.cpp +++ b/src/devices/cuda/cudadevicebatch.cpp @@ -304,7 +304,7 @@ namespace fastllm { dpitchs.push_back(input0Stride * unitSize); srcs.push_back(input1.cudaData); spitchs.push_back(input1Stride * unitSize); - widths.push_back(input1.dims[axis] * inner * unitSize); + widths.push_back(inner * unitSize); heights.push_back(outer); } diff --git a/src/devices/cuda/fastllm-cuda.cu b/src/devices/cuda/fastllm-cuda.cu index ed020e97..30f2aefc 100644 --- a/src/devices/cuda/fastllm-cuda.cu +++ b/src/devices/cuda/fastllm-cuda.cu @@ -952,6 +952,61 @@ __global__ void FastllmGemvFp32Fp16Kernel2(float *A, half *B, float *C, float *b } } +template +__global__ void FastllmGemvFp32Fp16Kernel2MultiRow(float *A, half *B, float *C, float *bias, int m, int k) { + __shared__ float sdata[PART][THREAD_PER_BLOCK]; + unsigned int tid = threadIdx.x; + const half zero = __float2half_rn(0.0); + float4 regA; + union_half4 regB; + + // 1. 计算 + int st = blockIdx.x; + int p = st; +#pragma unroll + for (int x = 0; x < PART; x++) sdata[x][tid] = 0; + + const half *baseB = B + p * m; +#pragma unroll + for (int i = tid * 4; i < m; i += THREAD_PER_BLOCK * 4) { +#pragma unroll + for (int x = 0; x < PART; x++) { + regA = FETCH_FLOAT4(A[i + x * m]); + regB.in = *reinterpret_cast(baseB + i); + float sum = 0.0f; + if (i < m) + sum += regA.x * __low2float(regB.out2[0]); + if (i + 1 < m) + sum += regA.y * __high2float(regB.out2[0]); + if (i + 2 < m) + sum += regA.z * __low2float(regB.out2[1]); + if (i + 3 < m) + sum += regA.w * __high2float(regB.out2[1]); + sdata[x][tid] += sum; + } + } + __syncthreads(); + float diff = 0.0f; + for (unsigned int s = THREAD_PER_BLOCK/2; s > 0; s >>= 1) { + if (tid < s) { + #pragma unroll + for (int x = 0; x < PART; x++) { + float other = sdata[x][tid + s] - diff; + float sumTmp = sdata[x][tid] + other; + diff = (sumTmp - sdata[x][tid]) - other; + sdata[x][tid] = sumTmp; + } + } + __syncthreads(); + } + + if (tid == 0) { +#pragma unroll + for (int x = 0; x < PART; x++) C[p + k * x] = sdata[x][0] + __ldg(bias + p); + } + __syncthreads(); +} + template __global__ void FastllmGemvInt8Kernel2(float *A, uint8_t *B, float *C, float *bias, float *scales, uint8_t *zeros, @@ -2012,7 +2067,21 @@ bool FastllmCudaMatMulFloat16(const fastllm::Data &input, fastllm::Data &weight, float *cudaInput = (float*)FastllmCudaPrepareInput(input); float *cudaOutput = (float*)FastllmCudaPrepareOutput(output); - if (n > 1) { + if (n == 1) { + FastllmGemvFp32Fp16Kernel2MultiRow<256, 1> <<< k, 256 >>>(cudaInput, (half *) weight.cudaData, cudaOutput, cudaBiasData, m, k); + } else if (n == 2) { + FastllmGemvFp32Fp16Kernel2MultiRow<256, 2> <<< k, 256 >>>(cudaInput, (half *) weight.cudaData, cudaOutput, cudaBiasData, m, k); + } else if (n == 3) { + FastllmGemvFp32Fp16Kernel2MultiRow<256, 3> <<< k, 256 >>>(cudaInput, (half *) weight.cudaData, cudaOutput, cudaBiasData, m, k); + } else if (n == 4) { + FastllmGemvFp32Fp16Kernel2MultiRow<256, 4> <<< k, 256 >>>(cudaInput, (half *) weight.cudaData, cudaOutput, cudaBiasData, m, k); + } else if (n == 5) { + FastllmGemvFp32Fp16Kernel2MultiRow<256, 5> <<< k, 256 >>>(cudaInput, (half *) weight.cudaData, cudaOutput, cudaBiasData, m, k); + } else if (n == 6) { + FastllmGemvFp32Fp16Kernel2MultiRow<256, 6> <<< k, 256 >>>(cudaInput, (half *) weight.cudaData, cudaOutput, cudaBiasData, m, k); + } else if (n == 7) { + FastllmGemvFp32Fp16Kernel2MultiRow<256, 7> <<< k, 256 >>>(cudaInput, (half *) weight.cudaData, cudaOutput, cudaBiasData, m, k); + } else { auto fastllmCublasHandle = getFastllmCublasHandle(); //cudaDeviceSynchronize(); half *cudaFp16Input, *cudaFp16Output; @@ -2073,8 +2142,6 @@ bool FastllmCudaMatMulFloat16(const fastllm::Data &input, fastllm::Data &weight, FastllmCudaFree(cudaFp16Input); FastllmCudaFree(cudaFp16Output); #endif - } else { - FastllmGemvFp32Fp16Kernel2<256, 1> <<< k, 256 >>>(cudaInput, (half *) weight.cudaData, cudaOutput, cudaBiasData, m, k); } FastllmCudaFinishInput(input, cudaInput); @@ -2292,6 +2359,17 @@ void FastllmCudaMemcpy2DDeviceToDevice(void * dst, size_t dpitch, const void * //cudaDeviceSynchronize(); } +template +__global__ void FastllmMemcpy2DKernel (uint8_t * dst, size_t dpitch, uint8_t * src, + size_t spitch, size_t width, size_t height) { + int id = blockIdx.x; + dst += id * dpitch; + src += id * spitch; + for (int i = threadIdx.x; i < width; i += THREAD_PER_BLOCK) { + dst[i] = src[i]; + } +} + template __global__ void FastllmMemcpyBatchKernel (uint8_t** pointer) { int id = blockIdx.x; @@ -2323,7 +2401,7 @@ void FastllmCudaMemcpy2DDeviceToDeviceBatch(void ** dsts, size_t * dpitchs, voi } } cudaMemcpy(pointers, cpuPointers, sizeof(uint8_t*) * total * 3, cudaMemcpyHostToDevice); - FastllmMemcpyBatchKernel <128> <<>> (pointers); + FastllmMemcpyBatchKernel <256> <<>> (pointers); FastllmCudaFree(pointers); delete[] cpuPointers; @@ -3100,7 +3178,7 @@ bool FastllmCudaAttentionBatch(fastllm::Data **q, fastllm::Data **k, fastllm::Da qk[b] = mem + memSum; memSum += s; } - + if (true) { uint8_t ** pointers = (uint8_t**)FastllmCudaMalloc(sizeof(uint8_t*) * batch * k0 * 8); uint8_t ** cpuPointers = new uint8_t*[batch * k0 * 8]; @@ -3121,6 +3199,7 @@ bool FastllmCudaAttentionBatch(fastllm::Data **q, fastllm::Data **k, fastllm::Da FastllmCudaFree(pointers); delete[] cpuPointers; } + if (true) { int total = 0; for (int b = 0; b < batch; b++) {