Skip to content

Commit

Permalink
Only k_c_k_pe cache, no k/v cache
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Feb 26, 2025
1 parent 117286a commit 782f066
Showing 1 changed file with 25 additions and 29 deletions.
54 changes: 25 additions & 29 deletions candle-flash-mla/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub struct FlashAttn {
pub softmax_scale: f32,
pub block_table: Tensor,
pub cache_seqlens: Tensor,
pub head_size_v: usize,
}

impl FlashAttn {
Expand All @@ -22,24 +23,23 @@ impl FlashAttn {
&self,
q: &candle::CudaStorage,
q_l: &Layout,
k: &candle::CudaStorage,
k_l: &Layout,
v: &candle::CudaStorage,
v_l: &Layout,
k_c_k_pe_cache: &candle::CudaStorage,
k_c_k_pe_cache_l: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
let dev = q.device();
let out_shape = q_l.shape().clone();
let out_l = Layout::contiguous(&out_shape);

let q = q.as_cuda_slice::<bf16>()?;
let k = k.as_cuda_slice::<bf16>()?;
let v = v.as_cuda_slice::<bf16>()?;
let k_c_k_pe_cache = k_c_k_pe_cache.as_cuda_slice::<bf16>()?;
let q = q.slice(q_l.start_offset()..);
let k = k.slice(k_l.start_offset()..);
let v = v.slice(v_l.start_offset()..);
let k_c_k_pe_cache = k_c_k_pe_cache.slice(k_c_k_pe_cache_l.start_offset()..);

let v_l = k_c_k_pe_cache_l;
let v = &k_c_k_pe_cache;

let q_stride = q_l.stride();
let k_stride = k_l.stride();
let k_stride = k_c_k_pe_cache_l.stride();
let v_stride = v_l.stride();
let o_stride = out_l.stride();

Expand Down Expand Up @@ -74,9 +74,9 @@ impl FlashAttn {
let (_, _, _, head_size_v) = v_l.shape().dims4()?;

let max_num_blocks_per_seq = self.block_table.dim(1)?;
let num_blocks = k_l.dim(0)?;
let page_block_size = k_l.dim(1)?;
let num_heads_k = k_l.dim(2)?;
let num_blocks = k_c_k_pe_cache_l.dim(0)?;
let page_block_size = k_c_k_pe_cache_l.dim(1)?;
let num_heads_k = k_c_k_pe_cache_l.dim(2)?;

if head_size_q % 8 != 0 {
candle::bail!("only supports q/k head sizes that are a multiple of 8")
Expand Down Expand Up @@ -117,11 +117,11 @@ impl FlashAttn {
q_l.dims()
);
}
if k_l.dims() != &[num_blocks, page_block_size, num_heads_k, head_size_k] {
if k_c_k_pe_cache_l.dims() != &[num_blocks, page_block_size, num_heads_k, head_size_k] {
candle::bail!(
"Expected k shape {:?}, got {:?} instead.",
[num_blocks, page_block_size, num_heads_k, head_size_k],
k_l.dims()
k_c_k_pe_cache_l.dims()
);
}
if v_l.dims() != &[num_blocks, page_block_size, num_heads_k, head_size_v] {
Expand Down Expand Up @@ -198,7 +198,7 @@ impl FlashAttn {
scale_softmax: self.softmax_scale,
scale_softmax_log2: self.softmax_scale * f32::consts::LOG2_E,
q_ptr: (*q.device_ptr()) as *mut core::ffi::c_void,
k_ptr: (*k.device_ptr()) as *mut core::ffi::c_void,
k_ptr: (*k_c_k_pe_cache.device_ptr()) as *mut core::ffi::c_void,
v_ptr: (*v.device_ptr()) as *mut core::ffi::c_void,
o_ptr: (*dst.device_ptr()) as *mut core::ffi::c_void,
softmax_lse_ptr: (*softmax_lse.device_ptr()) as *mut core::ffi::c_void,
Expand Down Expand Up @@ -232,7 +232,7 @@ impl FlashAttn {
}
}

impl candle::CustomOp3 for FlashAttn {
impl candle::CustomOp2 for FlashAttn {
fn name(&self) -> &'static str {
"flash-attn"
}
Expand All @@ -243,8 +243,6 @@ impl candle::CustomOp3 for FlashAttn {
_: &Layout,
_: &CpuStorage,
_: &Layout,
_: &CpuStorage,
_: &Layout,
) -> Result<(CpuStorage, Shape)> {
candle::bail!("no cpu support for flash-attn")
}
Expand All @@ -253,13 +251,11 @@ impl candle::CustomOp3 for FlashAttn {
&self,
q: &candle::CudaStorage,
q_l: &Layout,
k: &candle::CudaStorage,
k_l: &Layout,
v: &candle::CudaStorage,
v_l: &Layout,
k_c_k_pe_cache: &candle::CudaStorage,
k_c_k_pe_cache_l: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
match q.dtype() {
candle::DType::BF16 => self.cuda_fwd_t::<bf16>(q, q_l, k, k_l, v, v_l),
candle::DType::BF16 => self.cuda_fwd_t::<bf16>(q, q_l, k_c_k_pe_cache, k_c_k_pe_cache_l,),
dt => candle::bail!("flash-attn is only supported for f16/bf16 ({dt:?})"),
}
}
Expand All @@ -275,30 +271,30 @@ impl candle::CustomOp3 for FlashAttn {
///
/// * `q: (batch_size, seq_len_q, num_heads_q, head_dim).
/// * `k_cache`: (num_blocks, page_block_size, num_heads_k, head_dim).
/// * `v_cache`: (num_blocks, page_block_size, num_heads_k, head_dim_v).
/// * `block_table`: (batch_size, max_num_blocks_per_seq), i32.
/// * `cache_seqlens`: (batch_size), i32
/// * `softmax_scale: The scale of QK^T before applying softmax.
/// * `head_size_v`: v_head_dim in the config
///
/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size_v)`.
pub fn flash_attn_mla(
q: &Tensor,
k_cache: &Tensor,
v_cache: &Tensor,
k_c_k_pe_cache: &Tensor,
block_table: Tensor,
cache_seqlens: Tensor,
softmax_scale: f32,
head_size_v: usize,
) -> Result<Tensor> {
let op = FlashAttn {
softmax_scale,
block_table,
cache_seqlens,
head_size_v
};

let (b_sz, seqlen_q_ori, num_heads, head_size) = q.shape().dims4()?;
let (_, _, _, head_size_v) = v_cache.shape().dims4()?;

let num_heads_k = k_cache.dim(2)?;
let num_heads_k = k_c_k_pe_cache.dim(2)?;
let ngroups = num_heads / num_heads_k;

let seqlen_q = seqlen_q_ori * ngroups;
Expand All @@ -308,7 +304,7 @@ pub fn flash_attn_mla(
.transpose(2, 3)?
.reshape((b_sz, seqlen_q, num_heads_k, head_size))?;

let out = q.apply_op3(k_cache, v_cache, op)?;
let out = q.apply_op2(k_c_k_pe_cache, op)?;

out.reshape((b_sz, seqlen_q_ori, ngroups, num_heads_k, head_size_v))?
.transpose(2, 3)?
Expand Down

0 comments on commit 782f066

Please sign in to comment.