Skip to content

Commit

Permalink
Fix passing head_size_v
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Feb 26, 2025
1 parent 782f066 commit 3b3cd2f
Showing 1 changed file with 21 additions and 10 deletions.
31 changes: 21 additions & 10 deletions candle-flash-mla/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ impl FlashAttn {
}

let (b_sz, seqlen_q, num_heads, head_size_q) = q_l.shape().dims4()?;
let (_, _, _, head_size_v) = v_l.shape().dims4()?;

let max_num_blocks_per_seq = self.block_table.dim(1)?;
let num_blocks = k_c_k_pe_cache_l.dim(0)?;
Expand All @@ -81,7 +80,7 @@ impl FlashAttn {
if head_size_q % 8 != 0 {
candle::bail!("only supports q/k head sizes that are a multiple of 8")
}
if head_size_v % 32 != 0 {
if self.head_size_v % 32 != 0 {
candle::bail!("only supports v head sizes that are a multiple of 32")
}
if num_heads % num_heads_k != 0 {
Expand Down Expand Up @@ -124,10 +123,10 @@ impl FlashAttn {
k_c_k_pe_cache_l.dims()
);
}
if v_l.dims() != &[num_blocks, page_block_size, num_heads_k, head_size_v] {
if v_l.dims() != &[num_blocks, page_block_size, num_heads_k, self.head_size_v] {
candle::bail!(
"Expected k shape {:?}, got {:?} instead.",
[num_blocks, page_block_size, num_heads_k, head_size_v],
[num_blocks, page_block_size, num_heads_k, self.head_size_v],
v_l.dims()
);
}
Expand Down Expand Up @@ -172,18 +171,28 @@ impl FlashAttn {
);
}

let dst = unsafe { dev.alloc::<bf16>(b_sz * seqlen_q * num_heads * head_size_v) }.w()?;
let dst =
unsafe { dev.alloc::<bf16>(b_sz * seqlen_q * num_heads * self.head_size_v) }.w()?;
let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?;

let dst_accum = unsafe {
dev.alloc::<bf16>((b_sz + num_sm_parts) * seqlen_q * num_heads * head_size_v)
dev.alloc::<bf16>((b_sz + num_sm_parts) * seqlen_q * num_heads * self.head_size_v)
}
.w()?;
let softmax_lse_accum = dev
.alloc_zeros::<f32>((b_sz + num_sm_parts) * num_heads * seqlen_q)
.w()?;

assert_eq!(head_size_q, 576);
// Expect:
if head_size_q != 576 {
candle::bail!("Expected head_size_q to be 576, got {head_size_q}");
}
if self.head_size_v != 512 {
candle::bail!("Expected head_size_v to be 512, got {}", self.head_size_v);
}
if page_block_size != 64 {
candle::bail!("Expected page_block_size to be 64, got {page_block_size}");
}

let params = ffi::FlashFwdMlaParams {
b: b_sz as i32,
Expand All @@ -194,7 +203,7 @@ impl FlashAttn {
ngroups: ngroups as i32,
is_causal,
d: head_size_q as i32,
d_v: head_size_v as i32,
d_v: self.head_size_v as i32,
scale_softmax: self.softmax_scale,
scale_softmax_log2: self.softmax_scale * f32::consts::LOG2_E,
q_ptr: (*q.device_ptr()) as *mut core::ffi::c_void,
Expand Down Expand Up @@ -255,7 +264,9 @@ impl candle::CustomOp2 for FlashAttn {
k_c_k_pe_cache_l: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
match q.dtype() {
candle::DType::BF16 => self.cuda_fwd_t::<bf16>(q, q_l, k_c_k_pe_cache, k_c_k_pe_cache_l,),
candle::DType::BF16 => {
self.cuda_fwd_t::<bf16>(q, q_l, k_c_k_pe_cache, k_c_k_pe_cache_l)
}
dt => candle::bail!("flash-attn is only supported for f16/bf16 ({dt:?})"),
}
}
Expand Down Expand Up @@ -289,7 +300,7 @@ pub fn flash_attn_mla(
softmax_scale,
block_table,
cache_seqlens,
head_size_v
head_size_v,
};

let (b_sz, seqlen_q_ori, num_heads, head_size) = q.shape().dims4()?;
Expand Down

0 comments on commit 3b3cd2f

Please sign in to comment.