From ca880a0a1d2742e65d67b686f0b4955a6900b9c3 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Mon, 24 Feb 2025 20:19:09 -0500 Subject: [PATCH] Automatic computation of mla metadata --- candle-flash-mla/hkernel/flash_api.cu | 6 +- candle-flash-mla/src/ffi.rs | 2 - candle-flash-mla/src/lib.rs | 107 +++++++++++++++----------- 3 files changed, 67 insertions(+), 48 deletions(-) diff --git a/candle-flash-mla/hkernel/flash_api.cu b/candle-flash-mla/hkernel/flash_api.cu index 34061da90..2287b2698 100644 --- a/candle-flash-mla/hkernel/flash_api.cu +++ b/candle-flash-mla/hkernel/flash_api.cu @@ -8,18 +8,18 @@ #include #include +#define CEIL_DIV(x, y) (((x) + (y) - 1) / (y)) + void get_mla_metadata( int32_t* seqlens_k_ptr, int32_t* tile_scheduler_metadata_ptr, // [num_sm_parts, TileSchedulerMetaDataSize] int32_t* num_splits_ptr, // [batch_size + 1] - const int num_heads_per_head_k, - const int num_heads_k, const int batch_size, const int num_sm_parts, const cudaStream_t stream ) { // This should match the logic in the MLA kernel. - static constexpr int block_size_m = 64; + // static constexpr int block_size_m = 64; MOVED TO lib.rs static constexpr int block_size_n = 64; static constexpr int fixed_overhead_num_blocks = 5; diff --git a/candle-flash-mla/src/ffi.rs b/candle-flash-mla/src/ffi.rs index f1a7350e8..f18e1ec81 100644 --- a/candle-flash-mla/src/ffi.rs +++ b/candle-flash-mla/src/ffi.rs @@ -54,8 +54,6 @@ extern "C" { seqlens_k_ptr: *mut c_int, tile_scheduler_metadata_ptr: *mut c_int, num_splits_ptr: *mut c_int, - num_heads_per_head_k: c_int, - num_heads_k: c_int, batch_size: c_int, num_sm_parts: c_int, stream: CUstream, diff --git a/candle-flash-mla/src/lib.rs b/candle-flash-mla/src/lib.rs index 30350644d..1079ade83 100644 --- a/candle-flash-mla/src/lib.rs +++ b/candle-flash-mla/src/lib.rs @@ -3,6 +3,7 @@ mod ffi; use std::f32; use candle::backend::BackendStorage; +use candle::cuda::cudarc; use candle::cuda_backend::cudarc::driver::DevicePtr; use candle::cuda_backend::WrapErr; use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor}; @@ -11,9 +12,7 @@ use half::bf16; pub struct FlashAttn { pub softmax_scale: f32, pub block_table: Tensor, - pub tile_scheduler_metadata: Tensor, - pub num_splits: Tensor, - pub seqlens_k: Tensor, + pub cache_seqlens: Tensor, } impl FlashAttn { @@ -89,29 +88,6 @@ impl FlashAttn { candle::bail!("number of k/v heads {num_heads_k} must divide number of heads in query {num_heads}") } - if self.tile_scheduler_metadata.dim(1)? != ffi::TILE_SCHEDULER_METADATA_SIZE { - candle::bail!( - "Expected tile scheduler metadata to match {}: {}", - ffi::TILE_SCHEDULER_METADATA_SIZE, - self.tile_scheduler_metadata.dim(1)? - ) - } - let candle::Storage::Cuda(tile_scheduler_metadata) = - &*self.tile_scheduler_metadata.storage_and_layout().0 - else { - candle::bail!("tile_scheduler_metadata must be CUDA") - }; - let tile_scheduler_metadata = tile_scheduler_metadata - .as_cuda_slice::()? - .slice(self.tile_scheduler_metadata.layout().start_offset()..); - - let candle::Storage::Cuda(num_splits) = &*self.num_splits.storage_and_layout().0 else { - candle::bail!("num_splits must be CUDA") - }; - let num_splits = num_splits - .as_cuda_slice::()? - .slice(self.num_splits.layout().start_offset()..); - let candle::Storage::Cuda(block_table) = &*self.block_table.storage_and_layout().0 else { candle::bail!("block_table must be CUDA") }; @@ -119,12 +95,13 @@ impl FlashAttn { .as_cuda_slice::()? .slice(self.block_table.layout().start_offset()..); - let candle::Storage::Cuda(seqlens_k) = &*self.seqlens_k.storage_and_layout().0 else { - candle::bail!("seqlens_k must be CUDA") + let candle::Storage::Cuda(cache_seqlens) = &*self.cache_seqlens.storage_and_layout().0 + else { + candle::bail!("cache_seqlens must be CUDA") }; - let seqlens_k = seqlens_k + let cache_seqlens = cache_seqlens .as_cuda_slice::()? - .slice(self.seqlens_k.layout().start_offset()..); + .slice(self.cache_seqlens.layout().start_offset()..); let is_causal = if seqlen_q == 1 { false } else { true }; @@ -133,17 +110,48 @@ impl FlashAttn { let num_heads = num_heads_k; let head_size_k = head_size_q; - let dst = unsafe { dev.alloc::(b_sz * seqlen_q * num_heads * head_size_v) }.w()?; - let softmax_lse = dev.alloc_zeros::(b_sz * num_heads * seqlen_q).w()?; + let num_heads_per_head_k = num_heads / num_heads_k; + + // This should match the logic in the MLA kernel. + let block_size_m = 64usize; + let sm_count = dev + .attribute( + cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, + ) + .w()? as usize; + let num_sm_parts = sm_count / num_heads_k / num_heads_per_head_k.div_ceil(block_size_m); + + let tile_scheduler_metadata = + unsafe { dev.alloc::(num_sm_parts * ffi::TILE_SCHEDULER_METADATA_SIZE) }.w()?; + let num_splits = dev.alloc_zeros::(b_sz + 1).w()?; + + unsafe { + ffi::get_mla_metadata( + (*cache_seqlens.device_ptr()) as *mut core::ffi::c_int, + (*tile_scheduler_metadata.device_ptr()) as *mut core::ffi::c_int, + (*num_splits.device_ptr()) as *mut core::ffi::c_int, + b_sz as i32, + num_sm_parts as i32, + *dev.cu_stream(), + ); + } + + let dst = unsafe { + dev.alloc::((b_sz + num_sm_parts) * seqlen_q * num_heads * head_size_v) + } + .w()?; + let softmax_lse = dev + .alloc_zeros::((b_sz + num_sm_parts) * num_heads * seqlen_q) + .w()?; assert_eq!(head_size_q, 576); let params = ffi::FlashFwdMlaParams { b: b_sz as i32, seqlen_q: seqlen_q as i32, - cu_seqlens_k: (*seqlens_k.device_ptr()) as *mut core::ffi::c_int, + cu_seqlens_k: (*cache_seqlens.device_ptr()) as *mut core::ffi::c_int, h: num_heads as i32, - h_h_k_ratio: num_heads as i32 / num_heads_k as i32, + h_h_k_ratio: num_heads_per_head_k as i32, ngroups: ngroups as i32, is_causal, d: head_size_q as i32, @@ -172,7 +180,7 @@ impl FlashAttn { page_block_size: page_block_size as i32, tile_scheduler_metadata_ptr: (*tile_scheduler_metadata.device_ptr()) as *mut core::ffi::c_int, - num_sm_parts: self.tile_scheduler_metadata.dim(0)? as i32, + num_sm_parts: num_sm_parts as i32, num_splits_ptr: (*num_splits.device_ptr()) as *mut core::ffi::c_int, oaccum_ptr: (*dst.device_ptr()) as *mut core::ffi::c_void, softmax_lseaccum_ptr: (*softmax_lse.device_ptr()) as *mut core::ffi::c_void, @@ -228,10 +236,9 @@ impl candle::CustomOp3 for FlashAttn { /// /// * `q: (batch_size, seq_len_q, num_heads_q, head_dim). /// * `k_cache`: (num_blocks, page_block_size, num_heads_k, head_dim). +/// * `v_cache`: (num_blocks, page_block_size, num_heads_k, head_dim_v). /// * `block_table`: (batch_size, max_num_blocks_per_seq), i32. /// * `cache_seqlens`: (batch_size), i32 -/// * `tile_scheduler_metadata`: (num_sm_parts, TileSchedulerMetaDataSize), i32, returned by get_mla_metadata. -/// * `num_splits`: (batch_size + 1), i32, returned by get_mla_metadata. /// * `softmax_scale: The scale of QK^T before applying softmax. /// /// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size_v)`. @@ -241,16 +248,30 @@ pub fn flash_attn_mla( v_cache: &Tensor, block_table: Tensor, cache_seqlens: Tensor, - tile_scheduler_metadata: Tensor, - num_splits: Tensor, softmax_scale: f32, ) -> Result { let op = FlashAttn { softmax_scale, block_table, - tile_scheduler_metadata, - num_splits, - seqlens_k: cache_seqlens, + cache_seqlens, }; - q.apply_op3(k_cache, v_cache, op) + + let (b_sz, seqlen_q_ori, num_heads, head_size) = q.shape().dims4()?; + let (_, _, _, head_size_v) = v_cache.shape().dims4()?; + + let num_heads_k = k_cache.dim(2)?; + let ngroups = num_heads / num_heads_k; + + let seqlen_q = seqlen_q_ori * ngroups; + + let q = q + .reshape((b_sz, seqlen_q_ori, num_heads_k, ngroups, head_size))? + .transpose(2, 3)? + .reshape((b_sz, seqlen_q, num_heads, head_size))?; + + let out = q.apply_op3(k_cache, v_cache, op)?; + + out.reshape((b_sz, seqlen_q_ori, ngroups, num_heads_k, head_size_v))? + .transpose(2, 3)? + .reshape((b_sz, seqlen_q_ori, num_heads, head_size_v)) }