Skip to content

Commit

Permalink
Maybe some progress...
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Mar 4, 2025
1 parent bb391ee commit 3a683a5
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 15 deletions.
23 changes: 13 additions & 10 deletions candle-flash-mla/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pub struct FlashAttn {
pub head_size_v: usize,
pub seqlen_q_ori: usize,
pub ngroups: usize,
pub num_heads_per_head_k: usize,
}

impl FlashAttn {
Expand Down Expand Up @@ -53,7 +54,7 @@ impl FlashAttn {

if q_rank != 4 || k_rank != 4 || v_rank != 4 {
candle::bail!(
"flash-attn expects input tensors of rank 4 (q: {q_rank}, k: {k_rank}, v: {v_rank}"
"flash-attn expects input tensors of rank 4 (q: {q_rank}, k: {k_rank}, v: {v_rank})"
)
}
if q_stride[q_rank - 1] != 1 {
Expand Down Expand Up @@ -138,20 +139,20 @@ impl FlashAttn {
);
}

let num_heads_per_head_k = num_heads / num_heads_k;

// This should match the logic in the MLA kernel.
let block_size_m = 64usize;
let sm_count = dev
.attribute(
cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT,
)
.w()? as usize;
let num_sm_parts = sm_count / num_heads_k / num_heads_per_head_k.div_ceil(block_size_m);
let num_sm_parts = sm_count
/ num_heads_k
/ (self.seqlen_q_ori * self.num_heads_per_head_k).div_ceil(block_size_m);

let tile_scheduler_metadata =
unsafe { dev.alloc::<i32>(num_sm_parts * ffi::TILE_SCHEDULER_METADATA_SIZE) }.w()?;
let num_splits = dev.alloc_zeros::<i32>(b_sz + 1).w()?;
let num_splits = unsafe { dev.alloc::<i32>(b_sz + 1) }.w()?;

unsafe {
ffi::get_mla_metadata(
Expand All @@ -165,15 +166,14 @@ impl FlashAttn {
}

let dst = unsafe { dev.alloc::<T>(b_sz * seqlen_q * num_heads * self.head_size_v) }.w()?;
let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?;
let softmax_lse = unsafe { dev.alloc::<f32>(b_sz * num_heads * seqlen_q) }.w()?;

let dst_accum = unsafe {
dev.alloc::<f32>((b_sz + num_sm_parts) * seqlen_q * num_heads * self.head_size_v)
}
.w()?;
let softmax_lse_accum = dev
.alloc_zeros::<f32>((b_sz + num_sm_parts) * num_heads * seqlen_q)
.w()?;
let softmax_lse_accum =
unsafe { dev.alloc::<f32>((b_sz + num_sm_parts) * num_heads * seqlen_q) }.w()?;

// Expect:
if head_size_q != 576 {
Expand All @@ -191,7 +191,7 @@ impl FlashAttn {
seqlen_q: seqlen_q as i32,
cu_seqlens_k: (*cache_seqlens.device_ptr()) as *mut core::ffi::c_int,
h: num_heads as i32,
h_h_k_ratio: num_heads_per_head_k as i32,
h_h_k_ratio: (num_heads / num_heads_k) as i32,
ngroups: self.ngroups as i32,
is_causal,
d: head_size_q as i32,
Expand Down Expand Up @@ -294,6 +294,8 @@ pub fn flash_attn_mla(
let ngroups = num_heads / num_heads_k;

let seqlen_q = seqlen_q_ori * ngroups;
let num_heads_per_head_k = num_heads / num_heads_k;
dbg!(num_heads_per_head_k);

let q = q
.reshape((b_sz, seqlen_q_ori, num_heads_k, ngroups, head_size))?
Expand All @@ -307,6 +309,7 @@ pub fn flash_attn_mla(
head_size_v,
seqlen_q_ori,
ngroups,
num_heads_per_head_k,
};

let out = q.apply_op2(k_c_k_pe_cache, op)?;
Expand Down
14 changes: 9 additions & 5 deletions candle-flash-mla/tests/flash_mla_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,14 @@ fn sdpa(
}

#[rstest(
b => [128],
s_k => [4096, 8192],
h_q => [16, 32, 64, 128], // TP = 8, 4, 2, 1
s_q => [1, 2], // MTP = 1, 2
// b => [128],
b => [1],
// s_k => [4096, 8192],
s_k => [4096],
// h_q => [16, 32, 64, 128], // TP = 8, 4, 2, 1
h_q => [128],
// s_q => [1, 2], // MTP = 1, 2
s_q => [1], // MTP = 1, 2
)]
fn flash_mla_param(b: usize, s_k: usize, h_q: usize, s_q: usize) -> Result<()> {
dbg!(b, s_k, h_q, s_q);
Expand Down Expand Up @@ -119,7 +123,7 @@ fn flash_mla_param(b: usize, s_k: usize, h_q: usize, s_q: usize) -> Result<()> {
.sum_all()?
.to_scalar::<f32>()?
.max(1e-12);
assert!(cos_diff < 1e-5, "{cos_diff}");
assert!(cos_diff < 1e-5, "{cos_diff} > {}", 1e-5);

Ok(())
}

0 comments on commit 3a683a5

Please sign in to comment.