Skip to content

Commit

Permalink
增加int4 GEMV算子的向量化访存机制,优化单条推理速度
Browse files Browse the repository at this point in the history
  • Loading branch information
cgli authored and TylunasLi committed Apr 19, 2024
1 parent f34c2b0 commit f56a048
Showing 1 changed file with 42 additions and 1 deletion.
43 changes: 42 additions & 1 deletion src/devices/cuda/fastllm-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1126,6 +1126,47 @@ __global__ void FastllmGemvInt4NoZeroKernel2(float *A, uint8_t *B, float *C,
}
}

template <int THREAD_PER_BLOCK, int PART>
__global__ void FastllmGemvInt4NoZeroKernel1(float *A, uint8_t *B, float *C,
float *bias, float *scales, float *mins,
int m, int k) {
__shared__ float sdata[THREAD_PER_BLOCK];
unsigned int tid = threadIdx.x;

// 1. 计算
int st = blockIdx.x * PART;
int end = st + PART;
for (int p = st; p < end; p++) {
sdata[tid] = 0;
const uint8_t *baseB = B + p * m / 2;
float minv = __ldg(mins + p) / __ldg(scales + p);
for (int i = tid * 2; i < m / 2; i += THREAD_PER_BLOCK * 2) {
float4 aBuffer = FETCH_FLOAT4(A[i * 2]);
uint16_t bBuffer = *reinterpret_cast<const uint16_t *>(baseB + i);
sdata[tid] += aBuffer.x * (minv + ((bBuffer >> 4) & 15)) + aBuffer.y * (minv + (bBuffer & 15));
sdata[tid] += aBuffer.z * (minv + (bBuffer >> 12)) + aBuffer.w * (minv + ((bBuffer >> 8) & 15));
}
__syncthreads();

float diff = 0.0f;
for (unsigned int s = THREAD_PER_BLOCK/2; s > 0; s >>= 1) {
if (tid < s) {
float other = sdata[tid + s] - diff;
float sumTmp = sdata[tid] + other;
diff = (sumTmp - sdata[tid]) - other;
sdata[tid] = sumTmp;
}
__syncthreads();
}
//if (tid <= 32)
//warpReduce(sdata, tid);
if (tid == 0) {
C[p] = sdata[0] * scales[p] + bias[p];
}
__syncthreads();
}
}

template <int THREAD_PER_BLOCK>
__global__ void FastllmSplitBatchKernel(uint8_t *input, uint8_t **outputs, int outer, int channels, int inner) {
int bid = blockIdx.x / outer, oid = blockIdx.x % outer;
Expand Down Expand Up @@ -1842,7 +1883,7 @@ bool FastllmCudaMatMulFloatInt4NoZero(const fastllm::Data &input, fastllm::Data
#endif
} else {
for (int i = 0; i < n; i++) {
FastllmGemvInt4NoZeroKernel2<256, 1> <<< k, 256 >>>(cudaInput + i * m,
FastllmGemvInt4NoZeroKernel1<256, 1> <<< k, 256 >>>(cudaInput + i * m,
(uint8_t *) weight.cudaData,
cudaOutput + i * k,
cudaBiasData,
Expand Down

0 comments on commit f56a048

Please sign in to comment.