diff --git a/Cargo.toml b/Cargo.toml index 40bbeb7ed6..1dfda2ed5e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,11 +10,11 @@ members = [ "candle-wasm-examples/*", "candle-wasm-tests", "tensor-tools", + "candle-flash-mla", ] exclude = [ "candle-flash-attn", "candle-flash-attn-v3", - "candle-flash-mla", "candle-kernels", "candle-metal-kernels", "candle-onnx", diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index c9f00c0d34..0c1219d760 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -16,7 +16,6 @@ candle-datasets = { workspace = true, optional = true } candle-nn = { workspace = true } candle-transformers = { workspace = true } candle-flash-attn = { workspace = true, optional = true } -candle-flash-mla = { workspace = true } candle-onnx = { workspace = true, optional = true } csv = "1.3.0" diff --git a/candle-flash-mla/src/lib.rs b/candle-flash-mla/src/lib.rs index 68aac30593..f558f717fc 100644 --- a/candle-flash-mla/src/lib.rs +++ b/candle-flash-mla/src/lib.rs @@ -14,6 +14,8 @@ pub struct FlashAttn { pub block_table: Tensor, pub cache_seqlens: Tensor, pub head_size_v: usize, + pub seqlen_q_ori: usize, + pub ngroups: usize, } impl FlashAttn { @@ -102,10 +104,8 @@ impl FlashAttn { .as_cuda_slice::()? .slice(self.cache_seqlens.layout().start_offset()..); - let is_causal = if seqlen_q == 1 { false } else { true }; + let is_causal = if self.seqlen_q_ori == 1 { false } else { true }; - let ngroups = num_heads / num_heads_k; - let seqlen_q = seqlen_q * ngroups; let num_heads = num_heads_k; let head_size_k = head_size_q; @@ -192,7 +192,7 @@ impl FlashAttn { cu_seqlens_k: (*cache_seqlens.device_ptr()) as *mut core::ffi::c_int, h: num_heads as i32, h_h_k_ratio: num_heads_per_head_k as i32, - ngroups: ngroups as i32, + ngroups: self.ngroups as i32, is_causal, d: head_size_q as i32, d_v: self.head_size_v as i32, @@ -288,13 +288,6 @@ pub fn flash_attn_mla( softmax_scale: f32, head_size_v: usize, ) -> Result { - let op = FlashAttn { - softmax_scale, - block_table, - cache_seqlens, - head_size_v, - }; - let (b_sz, seqlen_q_ori, num_heads, head_size) = q.shape().dims4()?; let num_heads_k = k_c_k_pe_cache.dim(2)?; @@ -307,6 +300,15 @@ pub fn flash_attn_mla( .transpose(2, 3)? .reshape((b_sz, seqlen_q, num_heads_k, head_size))?; + let op = FlashAttn { + softmax_scale, + block_table, + cache_seqlens, + head_size_v, + seqlen_q_ori, + ngroups + }; + let out = q.apply_op2(k_c_k_pe_cache, op)?; out.reshape((b_sz, seqlen_q_ori, ngroups, num_heads_k, head_size_v))? diff --git a/candle-flash-mla/tests/flash_mla_tests.rs b/candle-flash-mla/tests/flash_mla_tests.rs index fd62177415..ba3e24fb57 100644 --- a/candle-flash-mla/tests/flash_mla_tests.rs +++ b/candle-flash-mla/tests/flash_mla_tests.rs @@ -3,20 +3,6 @@ use candle::{DType, Device, IndexOp, Tensor, D}; use candle_flash_mla; use rstest::rstest; -fn to_vec3_round(t: Tensor, digits: i32) -> Result>>> { - let b = 10f32.powi(digits); - let t = t.to_vec3::()?; - let t = t - .iter() - .map(|t| { - t.iter() - .map(|t| t.iter().map(|t| f32::round(t * b) / b).collect()) - .collect() - }) - .collect(); - Ok(t) -} - fn sdpa( q: &Tensor, k: &Tensor, @@ -47,6 +33,7 @@ fn sdpa( s_q => [1, 2], // MTP = 1, 2 )] fn flash_mla_param(b: usize, s_k: usize, h_q: usize, s_q: usize) -> Result<()> { + dbg!(b, s_k, h_q, s_q); let device = Device::new_cuda(0)?; let h_kv = 1; @@ -55,7 +42,6 @@ fn flash_mla_param(b: usize, s_k: usize, h_q: usize, s_q: usize) -> Result<()> { let cache_seqlens_vec = vec![s_k as i32; b]; let cache_seqlens = Tensor::new(cache_seqlens_vec.clone(), &device)?; - let total_seqlens = cache_seqlens.sum_all()?.to_scalar::()? as usize; let max_seqlen = cache_seqlens.max(0)?.to_scalar::()? as usize; let max_seqlen_pad = max_seqlen.div_ceil(256) * 256; @@ -107,14 +93,16 @@ fn flash_mla_param(b: usize, s_k: usize, h_q: usize, s_q: usize) -> Result<()> { }; assert_eq!(out_flash.dims(), truth.dims()); - println!( - "MLA {}; TRUTH {}", - out_flash - .to_dtype(DType::F32)? - .mean_all()? - .to_scalar::()?, - truth.to_dtype(DType::F32)?.mean_all()?.to_scalar::()? - ); + + let cos_diff = 1. + - 2. * (out_flash.to_dtype(DType::F32)? * truth.to_dtype(DType::F32)?)? + .sum_all()? + .to_scalar::()? + / (out_flash.sqr()?.to_dtype(DType::F32)? + truth.sqr()?.to_dtype(DType::F32)?)? + .sum_all()? + .to_scalar::()? + .max(1e-12); + assert!(cos_diff < 1e-5, "{cos_diff}"); Ok(()) }