Skip to content

Commit

Permalink
Some fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Mar 4, 2025
1 parent 66e7c17 commit 72cb652
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 36 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ members = [
"candle-wasm-examples/*",
"candle-wasm-tests",
"tensor-tools",
"candle-flash-mla",
]
exclude = [
"candle-flash-attn",
"candle-flash-attn-v3",
"candle-flash-mla",
"candle-kernels",
"candle-metal-kernels",
"candle-onnx",
Expand Down
1 change: 0 additions & 1 deletion candle-examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ candle-datasets = { workspace = true, optional = true }
candle-nn = { workspace = true }
candle-transformers = { workspace = true }
candle-flash-attn = { workspace = true, optional = true }
candle-flash-mla = { workspace = true }
candle-onnx = { workspace = true, optional = true }

csv = "1.3.0"
Expand Down
24 changes: 13 additions & 11 deletions candle-flash-mla/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ pub struct FlashAttn {
pub block_table: Tensor,
pub cache_seqlens: Tensor,
pub head_size_v: usize,
pub seqlen_q_ori: usize,
pub ngroups: usize,
}

impl FlashAttn {
Expand Down Expand Up @@ -102,10 +104,8 @@ impl FlashAttn {
.as_cuda_slice::<i32>()?
.slice(self.cache_seqlens.layout().start_offset()..);

let is_causal = if seqlen_q == 1 { false } else { true };
let is_causal = if self.seqlen_q_ori == 1 { false } else { true };

let ngroups = num_heads / num_heads_k;
let seqlen_q = seqlen_q * ngroups;
let num_heads = num_heads_k;
let head_size_k = head_size_q;

Expand Down Expand Up @@ -192,7 +192,7 @@ impl FlashAttn {
cu_seqlens_k: (*cache_seqlens.device_ptr()) as *mut core::ffi::c_int,
h: num_heads as i32,
h_h_k_ratio: num_heads_per_head_k as i32,
ngroups: ngroups as i32,
ngroups: self.ngroups as i32,
is_causal,
d: head_size_q as i32,
d_v: self.head_size_v as i32,
Expand Down Expand Up @@ -288,13 +288,6 @@ pub fn flash_attn_mla(
softmax_scale: f32,
head_size_v: usize,
) -> Result<Tensor> {
let op = FlashAttn {
softmax_scale,
block_table,
cache_seqlens,
head_size_v,
};

let (b_sz, seqlen_q_ori, num_heads, head_size) = q.shape().dims4()?;

let num_heads_k = k_c_k_pe_cache.dim(2)?;
Expand All @@ -307,6 +300,15 @@ pub fn flash_attn_mla(
.transpose(2, 3)?
.reshape((b_sz, seqlen_q, num_heads_k, head_size))?;

let op = FlashAttn {
softmax_scale,
block_table,
cache_seqlens,
head_size_v,
seqlen_q_ori,
ngroups
};

let out = q.apply_op2(k_c_k_pe_cache, op)?;

out.reshape((b_sz, seqlen_q_ori, ngroups, num_heads_k, head_size_v))?
Expand Down
34 changes: 11 additions & 23 deletions candle-flash-mla/tests/flash_mla_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,6 @@ use candle::{DType, Device, IndexOp, Tensor, D};
use candle_flash_mla;
use rstest::rstest;

fn to_vec3_round(t: Tensor, digits: i32) -> Result<Vec<Vec<Vec<f32>>>> {
let b = 10f32.powi(digits);
let t = t.to_vec3::<f32>()?;
let t = t
.iter()
.map(|t| {
t.iter()
.map(|t| t.iter().map(|t| f32::round(t * b) / b).collect())
.collect()
})
.collect();
Ok(t)
}

fn sdpa(
q: &Tensor,
k: &Tensor,
Expand Down Expand Up @@ -47,6 +33,7 @@ fn sdpa(
s_q => [1, 2], // MTP = 1, 2
)]
fn flash_mla_param(b: usize, s_k: usize, h_q: usize, s_q: usize) -> Result<()> {
dbg!(b, s_k, h_q, s_q);
let device = Device::new_cuda(0)?;

let h_kv = 1;
Expand All @@ -55,7 +42,6 @@ fn flash_mla_param(b: usize, s_k: usize, h_q: usize, s_q: usize) -> Result<()> {

let cache_seqlens_vec = vec![s_k as i32; b];
let cache_seqlens = Tensor::new(cache_seqlens_vec.clone(), &device)?;
let total_seqlens = cache_seqlens.sum_all()?.to_scalar::<i32>()? as usize;
let max_seqlen = cache_seqlens.max(0)?.to_scalar::<i32>()? as usize;
let max_seqlen_pad = max_seqlen.div_ceil(256) * 256;

Expand Down Expand Up @@ -107,14 +93,16 @@ fn flash_mla_param(b: usize, s_k: usize, h_q: usize, s_q: usize) -> Result<()> {
};

assert_eq!(out_flash.dims(), truth.dims());
println!(
"MLA {}; TRUTH {}",
out_flash
.to_dtype(DType::F32)?
.mean_all()?
.to_scalar::<f32>()?,
truth.to_dtype(DType::F32)?.mean_all()?.to_scalar::<f32>()?
);

let cos_diff = 1.
- 2. * (out_flash.to_dtype(DType::F32)? * truth.to_dtype(DType::F32)?)?
.sum_all()?
.to_scalar::<f32>()?
/ (out_flash.sqr()?.to_dtype(DType::F32)? + truth.sqr()?.to_dtype(DType::F32)?)?
.sum_all()?
.to_scalar::<f32>()?
.max(1e-12);
assert!(cos_diff < 1e-5, "{cos_diff}");

Ok(())
}

0 comments on commit 72cb652

Please sign in to comment.