Skip to content

Commit

Permalink
out-accum should be f32
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Feb 26, 2025
1 parent d8101c2 commit 294dcfd
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions candle-flash-mla/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ impl FlashAttn {
let out_shape = Shape::from_dims(&[b_sz, seqlen_q, num_heads, self.head_size_v]);
let out_l = Layout::contiguous(&out_shape);

let q = q.as_cuda_slice::<bf16>()?;
let k_c_k_pe_cache = k_c_k_pe_cache.as_cuda_slice::<bf16>()?;
let q = q.as_cuda_slice::<T>()?;
let k_c_k_pe_cache = k_c_k_pe_cache.as_cuda_slice::<T>()?;
let q = q.slice(q_l.start_offset()..);
let k_c_k_pe_cache = k_c_k_pe_cache.slice(k_c_k_pe_cache_l.start_offset()..);

Expand Down Expand Up @@ -165,11 +165,11 @@ impl FlashAttn {
}

let dst =
unsafe { dev.alloc::<bf16>(b_sz * seqlen_q * num_heads * self.head_size_v) }.w()?;
unsafe { dev.alloc::<T>(b_sz * seqlen_q * num_heads * self.head_size_v) }.w()?;
let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?;

let dst_accum = unsafe {
dev.alloc::<bf16>((b_sz + num_sm_parts) * seqlen_q * num_heads * self.head_size_v)
dev.alloc::<f32>((b_sz + num_sm_parts) * seqlen_q * num_heads * self.head_size_v)
}
.w()?;
let softmax_lse_accum = dev
Expand Down Expand Up @@ -260,7 +260,7 @@ impl candle::CustomOp2 for FlashAttn {
candle::DType::BF16 => {
self.cuda_fwd_t::<bf16>(q, q_l, k_c_k_pe_cache, k_c_k_pe_cache_l)
}
dt => candle::bail!("flash-attn is only supported for f16/bf16 ({dt:?})"),
dt => candle::bail!("flash-mla is only supported for bf16 ({dt:?})"),
}
}
}
Expand Down

0 comments on commit 294dcfd

Please sign in to comment.