From 6a85c423ac3be235dc839832cdb7434bed7508e2 Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Tue, 4 Feb 2025 18:09:03 +0800 Subject: [PATCH] [feat] support ffpa-l1 registers double buffers (#70) * Update README.md * Update README.md * Update env.py * Update prefill.cuh * Update ffpa_attn_templates_L1.cuh * Update launch_templates.cuh * Update README.md --- README.md | 4 +- csrc/cuffpa/ffpa_attn_templates_L1.cuh | 228 +++++++++++++++++++------ csrc/cuffpa/launch_templates.cuh | 17 +- env.py | 12 ++ include/cuffpa/prefill.cuh | 9 + 5 files changed, 214 insertions(+), 56 deletions(-) diff --git a/README.md b/README.md index d4fb9af..fd1f44e 100644 --- a/README.md +++ b/README.md @@ -68,9 +68,9 @@ By leveraging this approach, we can achieve better performance for large headdim |📚Feature |📚Feature |📚Feature |📚Feature| |:---:|:---:|:---:|:---:| -|✔️Tensor Cores|✔️Loop over N/D |✔️Tile Block(Br, Bc) |✔️**MMA(m16n8k16)**| +|✔️Tensor Cores |✔️**MMA(m16n8k16)** |✔️Tile Block(Br, Bc) |✔️Tile MMA/Warp | |✔️**Split Q**(FA-2)|✔️Pack LDST(128 bits)|✔️SMEM **Swizzle/Pad** |✔️Copy Async | -|✔️Tile MMA/Warp |✔️QKV Multi-Stages(1~4) |✔️Collective Store(**Shfl**)|✔️**Prefetch QKV** g2s | +|✔️**Reg Double Buffers** |✔️QKV **Multi-Stages(1~4)** |✔️Collective Store(**Shfl**)|✔️**Prefetch QKV** g2s | |✔️**QKV Fine-grained Tiling**|✔️**Shared QKV** SMEM|✔️Mixed MMA Acc|✔️**Persist Q** s2r/g2s| - 📚 case: FFPA `L1` kernel template signature: [ffpa_attn_templates_L1.cuh](csrc/cuffpa/ffpa_attn_templates_L1.cuh) diff --git a/csrc/cuffpa/ffpa_attn_templates_L1.cuh b/csrc/cuffpa/ffpa_attn_templates_L1.cuh index d47f326..4971f86 100644 --- a/csrc/cuffpa/ffpa_attn_templates_L1.cuh +++ b/csrc/cuffpa/ffpa_attn_templates_L1.cuh @@ -25,6 +25,7 @@ template< const int kShareSmemQKV, // QKV share the same shared memory, reuse QK smem for V. const int kPersistQs2r, // Persist load Q s2r for headdim < 320, more registers, but still keep O(1) SRAM. const int kPersistQg2s, // Persist load Q g2s for headdim < 320, more SRAM, but still keep register usage. + const int kRegPipeKV, // Registers Ping pong double buffers for ldmatrix s2r & mma computation overlapping. const int kStageQK, // <= 4, may apply different multi stages policy for QK and V (<=4) const int kStagePV, // <= 4, may apply different multi stages policy for QK and V (<=4) const int kPadQ, // Pad Q/K/V 0,8; 0 -> smem swizzle, > 0 -> padding @@ -48,7 +49,7 @@ ffpa_mma_stages_split_q_L1_large_d_template( kMmaTileSeqLenP, kMmaTileHeadDimV, kWarpTileSeqLenQ, kWarpTileSeqLenK, kWarpTileSeqLenP, kWarpTileHeadDimV, kMmaAccFloat32QK, kMmaAccFloat32PV, kOStorageAccFloat32, kPrefetchQK, kPrefetchPV, kShareSmemQKV, kPersistQs2r, - kPersistQg2s, kStageQK, kStagePV, kPadQ, kPadK, kPadV + kPersistQg2s, kRegPipeKV, kStageQK, kStagePV, kPadQ, kPadK, kPadV >(); constexpr int Br = kMmaAtomM * kMmaTileSeqLenQ * kWarpTileSeqLenQ; constexpr int Bc = kMmaAtomN * kMmaTileSeqLenK * kWarpTileSeqLenK; @@ -119,14 +120,20 @@ ffpa_mma_stages_split_q_L1_large_d_template( // ---------------------- Registers for S=Q@K^T/O=P@V ---------------------------- // e.g, 64, !kPersistQs2r -> [1][4] 4 regs, kPersistQs2r -> [1][4*4] 16 regs. uint32_t R_Q[kWarpTileSeqLenQ][(kPersistQs2r) ? (kHeadDim / kMmaAtomK) : 1][4]; - uint32_t R_K[kWarpTileSeqLenK][2]; // [8][2] - uint32_t R_V[2]; // [2], S=Q@K, only use 2 32bits registers. + // R_K [8][2] w/o registers ping pong buffers, [2][2] w/ registers ping pong buffers. + uint32_t R_K[(kRegPipeKV) ? 2: kWarpTileSeqLenK][2]; // [8][2] or [2][2] + // R_V [2][2] w registers ping pong buffers, [1][2] w/o registers ping pong buffers. + uint32_t R_V[(kRegPipeKV) ? 2: 1][2]; // [1][2], S=Q@K, only use 2 32bits registers. // e.g [1][8][2], MMA Acc fp16; [1][8][4], MMA Acc fp32; uint32_t R_S[kWarpTileSeqLenQ][kWarpTileSeqLenK][(kMmaAccFloat32QK) ? 4 : 2]; uint32_t R_O[(kMmaAccFloat32PV) ? 4 : 2]; // registers for O=PV[Br,d]=P@V, [4 or 2] uint32_t R_D[kWarpTileSeqLenP][kWarpTileHeadDimV][(kOStorageAccFloat32) ? 4 : 2]; utils::fill_3D_regs(R_D, 0); + + // Additional load/store controllers for kRegPipeKV + uint32_t reg_st_idx = 0; + uint32_t reg_ld_idx = 1; // Now, N must be mutliples of Bc(32/64) for KV tiling across seqlen. // : for K^T[d,seqlen] with K^T_tile[d,Bc] @@ -247,7 +254,7 @@ ffpa_mma_stages_split_q_L1_large_d_template( __syncthreads(); } - // QK s2r + // Q s2r static_assert(kWarpTileSeqLenQ == 1); { if constexpr (kPersistQs2r) { @@ -273,15 +280,23 @@ ffpa_mma_stages_split_q_L1_large_d_template( } } - #pragma unroll - for (int j = 0; j < kWarpTileSeqLenK; ++j) { + // K s2r + reg_st_idx = 0; + reg_ld_idx = 1; + if constexpr (!kRegPipeKV) { + #pragma unroll + for (int j = 0; j < kWarpTileSeqLenK; ++j) { + prefill::sync_fetch_qkv_frags_s2r< + 0, 2, K_tile_size, kMmaAtomM, kMmaAtomN, kMmaAtomK, kPadK>( + smem_K_base_ptr, &R_K[j][0], warp_KV, j, 0, smem_sel + ); + } + } else { + // kRegPipeKV is enabled, load first K tile frags from kWarpTileSeqLenK. prefill::sync_fetch_qkv_frags_s2r< 0, 2, K_tile_size, kMmaAtomM, kMmaAtomN, kMmaAtomK, kPadK>( - smem_K_base_ptr, &R_K[j][0], warp_KV, j, 0, smem_sel + smem_K_base_ptr, &R_K[reg_st_idx][0], warp_KV, 0, 0, smem_sel ); - } - if constexpr (kStageQK < 2) { - __syncthreads(); } // kShareSmemQKV: Prefetch V g2s before last Q@K^T iteration. @@ -308,29 +323,47 @@ ffpa_mma_stages_split_q_L1_large_d_template( const int q_offset = (kPersistQs2r) ? (tile_K_d) : 0; // (tile_K_d) #pragma unroll for (int j = 0; j < kWarpTileSeqLenK; ++j) { + reg_st_idx ^= 1; // 0->1 + reg_ld_idx ^= 1; // 1->0 + if constexpr (kRegPipeKV) { + // load next (j+1) K tile frags + if ((j + 1) < kWarpTileSeqLenK) { + prefill::sync_fetch_qkv_frags_s2r< + 0, 2, K_tile_size, kMmaAtomM, kMmaAtomN, kMmaAtomK, kPadK>( + smem_K_base_ptr, &R_K[reg_st_idx][0], warp_KV, (j + 1), + 0, smem_sel + ); + } + } + const int k_offset = (kRegPipeKV) ? reg_ld_idx : j; if constexpr (kMmaAccFloat32QK) { mma::m16n8k16_f16f16f32( - &R_S[0][j][0], &R_S[0][j][1], &R_S[0][j][2], &R_S[0][j][3], - &R_Q[0][q_offset][0], &R_Q[0][q_offset][1], - &R_Q[0][q_offset][2], &R_Q[0][q_offset][3], - &R_K[j][0], &R_K[j][1] + &R_S[0][j][0], &R_S[0][j][1], &R_S[0][j][2], &R_S[0][j][3], + &R_Q[0][q_offset][0], &R_Q[0][q_offset][1], + &R_Q[0][q_offset][2], &R_Q[0][q_offset][3], + &R_K[k_offset][0], &R_K[k_offset][1] ); } else { mma::m16n8k16_f16f16f16( - &R_S[0][j][0], &R_S[0][j][1], - &R_Q[0][q_offset][0], &R_Q[0][q_offset][1], - &R_Q[0][q_offset][2], &R_Q[0][q_offset][3], - &R_K[j][0], &R_K[j][1] + &R_S[0][j][0], &R_S[0][j][1], + &R_Q[0][q_offset][0], &R_Q[0][q_offset][1], + &R_Q[0][q_offset][2], &R_Q[0][q_offset][3], + &R_K[k_offset][0], &R_K[k_offset][1] ); } } } - + if constexpr (kStageQK > 1) { if (tile_K_d < (kHeadDim / kMmaAtomK - 1)) { cp_async::wait_group<(kStageQK - 2)>(); __syncthreads(); } + } + if constexpr (kStageQK < 2) { + // must wait all MMAs ready before next iteration + // if kStageQK == 1 to avoid K smem overwrite. + __syncthreads(); } } // end loop over d, S=Q@K^T __syncthreads(); @@ -423,15 +456,22 @@ ffpa_mma_stages_split_q_L1_large_d_template( } } - utils::fill_1D_regs(R_O, 0); - #pragma unroll - for (int tile_V_Bc = 0; tile_V_Bc < (Bc / kMmaAtomK); ++tile_V_Bc) { + // reinit controllers + reg_st_idx = 0; + reg_ld_idx = 1; + // kRegPipeKV V s2r + if constexpr(kRegPipeKV) { + // load first tile_V_Bc V tile frags from (Bc / kMmaAtomK). prefill::sync_fetch_qkv_frags_s2r< 1, 2, V_tile_size, kMmaAtomM, kMmaAtomN, kMmaAtomK, kPadV>( - smem_V_base_ptr, &R_V[0], warp_KV, (j % 2), tile_V_Bc, + smem_V_base_ptr, &R_V[reg_st_idx][0], warp_KV, (j % 2), 0, smem_sel_v ); + } + utils::fill_1D_regs(R_O, 0); + #pragma unroll + for (int tile_V_Bc = 0; tile_V_Bc < (Bc / kMmaAtomK); ++tile_V_Bc) { // kShareSmemQKV: Prefetch next QK g2s before last P@V iteration. if constexpr ((kShareSmemQKV) && kPrefetchQK && kStageQK > 1) { if (j == (kWarpTileHeadDimV - 1) && tile_V_Bc == (Bc / kMmaAtomK - 1) @@ -460,23 +500,44 @@ ffpa_mma_stages_split_q_L1_large_d_template( } } // end if kPrefetchQKV && kStageQK > 1 + // V s2r + reg_st_idx ^= 1; // 0->1 + reg_ld_idx ^= 1; // 1->0 + if constexpr(!kRegPipeKV) { + prefill::sync_fetch_qkv_frags_s2r< + 1, 2, V_tile_size, kMmaAtomM, kMmaAtomN, kMmaAtomK, kPadV>( + smem_V_base_ptr, &R_V[0][0], warp_KV, (j % 2), tile_V_Bc, + smem_sel_v + ); + } else { + // load next (tile_V_Bc + 1) V tile frags + if ((tile_V_Bc + 1) < (Bc / kMmaAtomK)) { + prefill::sync_fetch_qkv_frags_s2r< + 1, 2, V_tile_size, kMmaAtomM, kMmaAtomN, kMmaAtomK, kPadV>( + smem_V_base_ptr, &R_V[reg_st_idx][0], warp_KV, (j % 2), + (tile_V_Bc + 1), smem_sel_v + ); + } + } + // Compute P[Br,Bc]@V[Bc,d] = O[Br,d] const int p_offset = tile_V_Bc * 2; // MMA(Warp) selected, 0, 2, 4, 6 + const int v_offset = (kRegPipeKV) ? reg_ld_idx : 0; if constexpr (kMmaAccFloat32PV) { // MMA accumulate with F32 dtype for high precision. mma::m16n8k16_f16f16f32( &R_O[0], &R_O[1], &R_O[2], &R_O[3], &R_S[0][p_offset][0], &R_S[0][p_offset][1], &R_S[0][p_offset + 1][0], &R_S[0][p_offset + 1][1], - &R_V[0], &R_V[1] + &R_V[v_offset][0], &R_V[v_offset][1] ); } else { // MMA accumulate with F16 dtype for high throughput. mma::m16n8k16_f16f16f16( &R_O[0], &R_O[1], - &R_S[0][p_offset][0], &R_S[0][p_offset][1], - &R_S[0][p_offset + 1][0], &R_S[0][p_offset + 1][1], - &R_V[0], &R_V[1] + &R_S[0][p_offset][0], &R_S[0][p_offset][1], + &R_S[0][p_offset + 1][0], &R_S[0][p_offset + 1][1], + &R_V[v_offset][0], &R_V[v_offset][1] ); } } // end for V Bc. @@ -553,6 +614,7 @@ template< const int kShareSmemQKV, // QKV share the same shared memory, reuse QK smem for V. const int kPersistQs2r, // Persist load Q s2r for headdim <= 128, more registers. const int kPersistVs2r, // Persist load V s2r for headdim <= 128, more registers. + const int kRegPipeKV, // Registers Ping pong double buffers for ldmatrix s2r & mma computation overlapping. const int kStageQK, // <= 4, may apply different multi stages policy for QK and V (<=4) const int kStagePV, // <= 4, may apply different multi stages policy for QK and V (<=4) const int kPadQ, // Pad Q/K/V 0,8; 0 -> smem swizzle, > 0 -> padding @@ -581,7 +643,7 @@ ffpa_mma_stages_split_q_L1_small_d_template( kMmaTileSeqLenP, kMmaTileHeadDimV, kWarpTileSeqLenQ, kWarpTileSeqLenK, kWarpTileSeqLenP, kWarpTileHeadDimV, kMmaAccFloat32QK, kMmaAccFloat32PV, kOStorageAccFloat32, kPrefetchQK, kPrefetchPV, kShareSmemQKV, kPersistQs2r, - kPersistVs2r, kStageQK, kStagePV, kPadQ, kPadK, kPadV + kPersistVs2r, kRegPipeKV, kStageQK, kStagePV, kPadQ, kPadK, kPadV >(); constexpr int Br = kMmaAtomM * kMmaTileSeqLenQ * kWarpTileSeqLenQ; constexpr int Bc = kMmaAtomN * kMmaTileSeqLenK * kWarpTileSeqLenK; @@ -650,8 +712,10 @@ ffpa_mma_stages_split_q_L1_small_d_template( // ---------------------- Registers for S=Q@K^T/O=P@V ---------------------------- // e.g, 64, !kPersistQs2r -> [1][4] 4 regs, kPersistQs2r -> [1][4][4] 16 regs. uint32_t R_Q[kWarpTileSeqLenQ][(kPersistQs2r) ? (kHeadDim / kMmaAtomK) : 1][4]; - uint32_t R_K[kWarpTileSeqLenK][2]; // [8][2] - uint32_t R_V[kWarpTileSeqLenP][(kPersistVs2r) ? (Bc / kMmaAtomK): 1][2]; // [1][4][2], e.g Bc=64, S=Q@K + // R_K [8][2] w/o registers ping pong buffers, [2][2] w/ registers ping pong buffers. + uint32_t R_K[(kRegPipeKV) ? 2: kWarpTileSeqLenK][2]; // [8][2] or [2][2] + // R_V [2][2] w registers ping pong buffers, [1][2] w/o registers ping pong buffers. + uint32_t R_V[(kPersistVs2r) ? (Bc / kMmaAtomK): ((kRegPipeKV) ? 2: 1)][2]; // [4][2], e.g Bc=64, S=Q@K // e.g [1][8][2], MMA Acc fp16; [1][8][4], MMA Acc fp32; O=PV[Br,d]=P@V, [4 or 2] uint32_t R_S[kWarpTileSeqLenQ][kWarpTileSeqLenK][(kMmaAccFloat32QK) ? 4 : 2]; uint32_t R_O[kWarpTileSeqLenP][kWarpTileHeadDimV][(kMmaAccFloat32PV) ? 4 : 2]; @@ -659,6 +723,10 @@ ffpa_mma_stages_split_q_L1_small_d_template( utils::fill_3D_regs(R_D, 0); + // Additional load/store controllers for kRegPipeKV + uint32_t reg_st_idx = 0; + uint32_t reg_ld_idx = 1; + if constexpr (kPersistQs2r) { cp_async::wait_group<0>(); __syncthreads(); @@ -728,7 +796,7 @@ ffpa_mma_stages_split_q_L1_small_d_template( (kMmaAccFloat32QK) ? 4 : 2>(R_S, 0); #pragma unroll for (int tile_K_d = 0; tile_K_d < (kHeadDim / kMmaAtomK); ++tile_K_d) { - // QK s2r + // Q s2r static_assert(kWarpTileSeqLenQ == 1); { if constexpr (!kPersistQs2r) { @@ -739,33 +807,57 @@ ffpa_mma_stages_split_q_L1_small_d_template( } } - #pragma unroll - for (int j = 0; j < kWarpTileSeqLenK; ++j) { + // K s2r + reg_st_idx = 0; + reg_ld_idx = 1; + if constexpr (!kRegPipeKV) { + #pragma unroll + for (int j = 0; j < kWarpTileSeqLenK; ++j) { + prefill::sync_fetch_qkv_frags_s2r< + 0, 2, K_tile_size, kMmaAtomM, kMmaAtomN, kMmaAtomK, kPadK>( + smem_K_base_ptr, &R_K[j][0], warp_KV, j, 0, tile_K_d + ); + } + } else { + // kRegPipeKV is enabled, load first K tile frags from kWarpTileSeqLenK. prefill::sync_fetch_qkv_frags_s2r< 0, 2, K_tile_size, kMmaAtomM, kMmaAtomN, kMmaAtomK, kPadK>( - smem_K_base_ptr, &R_K[j][0], warp_KV, j, 0, tile_K_d + smem_K_base_ptr, &R_K[reg_st_idx][0], warp_KV, 0, 0, tile_K_d ); - } - + } + // Q@K^T MMA compute static_assert(kWarpTileSeqLenQ == 1); { // kWarpTileSeqLenQ = 1 const int q_offset = (kPersistQs2r) ? (tile_K_d) : 0; // (tile_K_d) #pragma unroll for (int j = 0; j < kWarpTileSeqLenK; ++j) { + reg_st_idx ^= 1; // 0->1 + reg_ld_idx ^= 1; // 1->0 + if constexpr (kRegPipeKV) { + // load next (j+1) K tile frags + if ((j + 1) < kWarpTileSeqLenK) { + prefill::sync_fetch_qkv_frags_s2r< + 0, 2, K_tile_size, kMmaAtomM, kMmaAtomN, kMmaAtomK, kPadK>( + smem_K_base_ptr, &R_K[reg_st_idx][0], warp_KV, (j + 1), + 0, tile_K_d + ); + } + } + const int k_offset = (kRegPipeKV) ? reg_ld_idx : j; if constexpr (kMmaAccFloat32QK) { mma::m16n8k16_f16f16f32( - &R_S[0][j][0], &R_S[0][j][1], &R_S[0][j][2], &R_S[0][j][3], - &R_Q[0][q_offset][0], &R_Q[0][q_offset][1], - &R_Q[0][q_offset][2], &R_Q[0][q_offset][3], - &R_K[j][0], &R_K[j][1] + &R_S[0][j][0], &R_S[0][j][1], &R_S[0][j][2], &R_S[0][j][3], + &R_Q[0][q_offset][0], &R_Q[0][q_offset][1], + &R_Q[0][q_offset][2], &R_Q[0][q_offset][3], + &R_K[k_offset][0], &R_K[k_offset][1] ); } else { mma::m16n8k16_f16f16f16( - &R_S[0][j][0], &R_S[0][j][1], - &R_Q[0][q_offset][0], &R_Q[0][q_offset][1], - &R_Q[0][q_offset][2], &R_Q[0][q_offset][3], - &R_K[j][0], &R_K[j][1] + &R_S[0][j][0], &R_S[0][j][1], + &R_Q[0][q_offset][0], &R_Q[0][q_offset][1], + &R_Q[0][q_offset][2], &R_Q[0][q_offset][3], + &R_K[k_offset][0], &R_K[k_offset][1] ); } } @@ -844,30 +936,60 @@ ffpa_mma_stages_split_q_L1_small_d_template( for (int tile_V_Bc = 0; tile_V_Bc < (Bc / kMmaAtomK); ++tile_V_Bc) { prefill::sync_fetch_qkv_frags_s2r< 1, 2, V_tile_size, kMmaAtomM, kMmaAtomN, kMmaAtomK, kPadV>( - smem_V_base_ptr, &R_V[0][tile_V_Bc][0], warp_KV, (j % 2), + smem_V_base_ptr, &R_V[tile_V_Bc][0], warp_KV, (j % 2), tile_V_Bc, tile_V_d ); } } + // kRegPipeKV and kPersistVs2r can not both enabled. + static_assert((kRegPipeKV & kPersistVs2r) == 0); + // reinit controllers + reg_st_idx = 0; + reg_ld_idx = 1; + // kRegPipeKV V s2r + if constexpr(kRegPipeKV) { + // load first tile_V_Bc V tile frags from (Bc / kMmaAtomK). + prefill::sync_fetch_qkv_frags_s2r< + 1, 2, V_tile_size, kMmaAtomM, kMmaAtomN, kMmaAtomK, kPadV>( + smem_V_base_ptr, &R_V[reg_st_idx][0], warp_KV, (j % 2), 0, + tile_V_d + ); + } + #pragma unroll for (int tile_V_Bc = 0; tile_V_Bc < (Bc / kMmaAtomK); ++tile_V_Bc) { + // V s2r + reg_st_idx ^= 1; // 0->1 + reg_ld_idx ^= 1; // 1->0 if constexpr (!kPersistVs2r) { - prefill::sync_fetch_qkv_frags_s2r< - 1, 2, V_tile_size, kMmaAtomM, kMmaAtomN, kMmaAtomK, kPadV>( - smem_V_base_ptr, &R_V[0][0][0], warp_KV, (j % 2), tile_V_Bc, - tile_V_d - ); + if constexpr(!kRegPipeKV) { + prefill::sync_fetch_qkv_frags_s2r< + 1, 2, V_tile_size, kMmaAtomM, kMmaAtomN, kMmaAtomK, kPadV>( + smem_V_base_ptr, &R_V[0][0], warp_KV, (j % 2), tile_V_Bc, + tile_V_d + ); + } else { + // load next (tile_V_Bc + 1) V tile frags + if ((tile_V_Bc + 1) < (Bc / kMmaAtomK)) { + prefill::sync_fetch_qkv_frags_s2r< + 1, 2, V_tile_size, kMmaAtomM, kMmaAtomN, kMmaAtomK, kPadV>( + smem_V_base_ptr, &R_V[reg_st_idx][0], warp_KV, (j % 2), + (tile_V_Bc + 1), tile_V_d + ); + } + } } // Compute P[Br,Bc]@V[Bc,d] = O[Br,d] const int p_offset = tile_V_Bc * 2; // MMA(Warp) selected, 0, 2, 4, 6 - const int v_offset = (kPersistVs2r) ? tile_V_Bc : 0; + const int v_offset = ((kPersistVs2r) ? tile_V_Bc : + ((kRegPipeKV) ? reg_ld_idx : 0)); if constexpr (kMmaAccFloat32PV) { // MMA accumulate with F32 dtype for high precision. mma::m16n8k16_f16f16f32( &R_O[0][j][0], &R_O[0][j][1], &R_O[0][j][2], &R_O[0][j][3], &R_S[0][p_offset][0], &R_S[0][p_offset][1], &R_S[0][p_offset + 1][0], &R_S[0][p_offset + 1][1], - &R_V[0][v_offset][0], &R_V[0][v_offset][1] + &R_V[v_offset][0], &R_V[v_offset][1] ); } else { // MMA accumulate with F16 dtype for high throughput. @@ -875,7 +997,7 @@ ffpa_mma_stages_split_q_L1_small_d_template( &R_O[0][j][0], &R_O[0][j][1], &R_S[0][p_offset][0], &R_S[0][p_offset][1], &R_S[0][p_offset + 1][0], &R_S[0][p_offset + 1][1], - &R_V[0][v_offset][0], &R_V[0][v_offset][1] + &R_V[v_offset][0], &R_V[v_offset][1] ); } } // end for V Bc. diff --git a/csrc/cuffpa/launch_templates.cuh b/csrc/cuffpa/launch_templates.cuh index 39165e2..2c348c9 100644 --- a/csrc/cuffpa/launch_templates.cuh +++ b/csrc/cuffpa/launch_templates.cuh @@ -128,6 +128,15 @@ static constexpr int getConfigPersistVs2r() { return kPersistVs2r; } +static constexpr int getConfigRegistersPipeKV() { +#ifdef ENABLE_FFPA_REGISTERS_PIPE_KV + constexpr int kRegPipeKV = 1; +#else + constexpr int kRegPipeKV = 0; +#endif + return kRegPipeKV; +} + static constexpr int getConfigPadQ() { #ifdef ENABLE_FFPA_SMEM_SWIZZLE_Q constexpr int kPadQ = 0; @@ -283,11 +292,12 @@ void launch_ffpa_mma_L1_template(torch::Tensor Q, constexpr int kStageQK = kStage; // <= 4 constexpr int kStagePV = kStage; // <= 4 // Prefetch QKV, Persist Q g2s/s2r, Shared QKV smem. + constexpr int kShareSmemQKV = getConfigShareSmemQKV(); constexpr int kPrefetchQK = getConfigPrefetchQKV(); constexpr int kPrefetchPV = getConfigPrefetchQKV(); constexpr int kPersistQs2r = getConfigPersistQs2r(); constexpr int kPersistQg2s = getConfigPersistQg2s(); - constexpr int kShareSmemQKV = getConfigShareSmemQKV(); + constexpr int kRegPipeKV = getConfigRegistersPipeKV(); // QKV smem swizzle, 0 for smem swizzle, !0 for smem padding. constexpr int kPadQ = getConfigPadQ(); constexpr int kPadK = getConfigPadK(); @@ -353,6 +363,9 @@ TEMPLATE_FUNC<<>>( \ kShareSmemQKV, kPersistQs2r, kPersistVs2r, + // Force disable KV registers ping pong buffers + // while V s2r is enabled. + (kPersistVs2r) ? 0 : kRegPipeKV, 1, /*kStageQK unused*/ 1, /*kStagePV unused*/ kPadQ, @@ -386,6 +399,7 @@ TEMPLATE_FUNC<<>>( \ // need too many register, thus, introduce performance drops. (kPersistQg2s || kHeadDim > 256) ? 0 : kPersistQs2r, kPersistQg2s, + kRegPipeKV, kStageQK, kStagePV, kPadQ, @@ -420,6 +434,7 @@ TEMPLATE_FUNC<<>>( \ // need too many register, thus, introduce performance drops. (kPersistQg2s || kHeadDim > 256) ? 0 : kPersistQs2r, kPersistQg2s, + kRegPipeKV, kStageQK, kStagePV, kPadQ, diff --git a/env.py b/env.py index bb24a95..ca543f9 100644 --- a/env.py +++ b/env.py @@ -95,6 +95,11 @@ class ENV(object): int(os.environ.get("ENABLE_FFPA_PERSIST_V_S2R", ENABLE_FFPA_PERSIST_KV_G2S)) ) + # Registers Ping pong double buffers for ldmatrix & mma computation overlapping. + ENABLE_FFPA_REGISTERS_PIPE_KV = bool( + int(os.environ.get("ENABLE_FFPA_REGISTERS_PIPE_KV", 0)) + ) + # if True: grid(N/Br, H, B) else: grid(N/Br, B * H) ENBALE_FFPA_LAUNCH_GRID_DNHB = bool( int(os.environ.get("ENBALE_FFPA_LAUNCH_GRID_DNHB", 0)) @@ -173,6 +178,10 @@ def enable_persist_v_s2r(cls): if cls.enable_persist_kv_g2s(): return cls.ENABLE_FFPA_PERSIST_V_S2R return False + + @classmethod + def enable_registers_pipe_kv(cls): + return cls.ENABLE_FFPA_REGISTERS_PIPE_KV @classmethod def enable_launch_grid_dnhb(cls): @@ -209,6 +218,8 @@ def env_cuda_cflags(cls): extra_env_cflags.append("-DENABLE_FFPA_PERSIST_Q_S2R") if cls.enable_persist_v_s2r(): extra_env_cflags.append("-DENABLE_FFPA_PERSIST_V_S2R") + if cls.enable_registers_pipe_kv(): + extra_env_cflags.append("-DENABLE_FFPA_REGISTERS_PIPE_KV") if cls.enable_launch_grid_dnhb(): extra_env_cflags.append("-DENBALE_FFPA_LAUNCH_GRID_DNHB") @@ -263,6 +274,7 @@ def formatenv(name, value): formatenv("ENABLE_FFPA_SMEM_SWIZZLE_Q", cls.enable_smem_swizzle_q()) formatenv("ENABLE_FFPA_SMEM_SWIZZLE_K", cls.enable_smem_swizzle_k()) formatenv("ENABLE_FFPA_SMEM_SWIZZLE_V", cls.enable_smem_swizzle_v()) + formatenv("ENABLE_FFPA_REGISTERS_PIPE_KV", cls.enable_registers_pipe_kv()) formatenv("ENBALE_FFPA_LAUNCH_GRID_DNHB", cls.enable_launch_grid_dnhb()) pretty_print_line() diff --git a/include/cuffpa/prefill.cuh b/include/cuffpa/prefill.cuh index 21ed073..55c18b8 100755 --- a/include/cuffpa/prefill.cuh +++ b/include/cuffpa/prefill.cuh @@ -30,6 +30,7 @@ template< const int kShareSmemQKV, // QKV share the same shared memory, reuse QK smem for V. const int kPersistQs2r, // Persist load Q s2r for headdim < 320, more registers, but still keep O(1) SRAM. const int kPersistQg2s, // Persist load Q g2s for headdim < 320, more SRAM, but still keep register usage. + const int kRegPipeKV, // Registers Ping pong double buffers for ldmatrix s2r & mma computation overlapping. const int kStageQK, // <= 4, may apply different multi stages policy for QK and V (<=4) const int kStagePV, // <= 4, may apply different multi stages policy for QK and V (<=4) const int kPadQ, // Pad Q/K/V 0,8; 0 -> smem swizzle, > 0 -> padding @@ -63,6 +64,8 @@ __device__ __forceinline__ void check_large_d_compiling_states() { static_assert((kPersistQg2s & kPersistQs2r) == 0); // kPersistQg2s and kShareSmemQKV can not both enabled for large d kernel.. static_assert((kPersistQg2s & kShareSmemQKV) == 0); + // Registers Ping pong double buffers for ldmatrix s2r & mma computation overlapping. + static_assert(kRegPipeKV == 0 || kRegPipeKV == 1); // May apply different multi stages policy for QK and V. static_assert(kStageQK < 5 && kStageQK > 0); // QK (<=4) static_assert(kStagePV < 5 && kStagePV > 0); // V (<=4) @@ -93,6 +96,7 @@ template< const int kShareSmemQKV, // QKV share the same shared memory, reuse QK smem for V. const int kPersistQs2r, // Persist load Q s2r for headdim <= 128, more registers. const int kPersistVs2r, // Persist load V s2r for headdim <= 128, more registers. + const int kRegPipeKV, // Registers Ping pong double buffers for ldmatrix s2r & mma computation overlapping. const int kStageQK, // <= 4, may apply different multi stages policy for QK and V (<=4) const int kStagePV, // <= 4, may apply different multi stages policy for QK and V (<=4) const int kPadQ, // Pad Q/K/V 0,8; 0 -> smem swizzle, > 0 -> padding @@ -126,6 +130,11 @@ __device__ __forceinline__ void check_small_d_compiling_states() { // kPersistQs2r must be enabled is set kShareSmemQKV as 1 static_assert(kPersistQs2r == 1); } + // Registers Ping pong double buffers for ldmatrix s2r & mma + // computation overlapping. + static_assert(kRegPipeKV == 0 || kRegPipeKV == 1); + // kRegPipeKV and kPersistVs2r can not both enabled. + static_assert((kRegPipeKV & kPersistVs2r) == 0); // May apply different multi stages policy for QK and V. static_assert(kStageQK < 5 && kStageQK > 0); // QK (<=4) static_assert(kStagePV < 5 && kStagePV > 0); // V (<=4)