Skip to content

Commit

Permalink
Fix reshape
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Feb 25, 2025
1 parent 45c9163 commit 117286a
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
5 changes: 1 addition & 4 deletions candle-flash-mla/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@ use std::str::FromStr;

const CUDA_NVCC_FLAGS: Option<&'static str> = option_env!("CUDA_NVCC_FLAGS");

const KERNEL_FILES: &[&str] = &[
"flash_api.cu",
"flash_fwd_mla_bf16_sm90.cu",
];
const KERNEL_FILES: &[&str] = &["flash_api.cu", "flash_fwd_mla_bf16_sm90.cu"];

fn main() -> Result<()> {
// Use RAYON_NUM_THREADS or else default to the number of physical CPUs
Expand Down
13 changes: 8 additions & 5 deletions candle-flash-mla/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,14 @@ impl FlashAttn {
);
}

let dst = unsafe {
let dst = unsafe { dev.alloc::<bf16>(b_sz * seqlen_q * num_heads * head_size_v) }.w()?;
let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?;

let dst_accum = unsafe {
dev.alloc::<bf16>((b_sz + num_sm_parts) * seqlen_q * num_heads * head_size_v)
}
.w()?;
let softmax_lse = dev
let softmax_lse_accum = dev
.alloc_zeros::<f32>((b_sz + num_sm_parts) * num_heads * seqlen_q)
.w()?;

Expand Down Expand Up @@ -218,8 +221,8 @@ impl FlashAttn {
as *mut core::ffi::c_int,
num_sm_parts: num_sm_parts as i32,
num_splits_ptr: (*num_splits.device_ptr()) as *mut core::ffi::c_int,
oaccum_ptr: (*dst.device_ptr()) as *mut core::ffi::c_void,
softmax_lseaccum_ptr: (*softmax_lse.device_ptr()) as *mut core::ffi::c_void,
oaccum_ptr: (*dst_accum.device_ptr()) as *mut core::ffi::c_void,
softmax_lseaccum_ptr: (*softmax_lse_accum.device_ptr()) as *mut core::ffi::c_void,
};

unsafe { ffi::mha_fwd_kvcache_mla(params, *dev.cu_stream()) }
Expand Down Expand Up @@ -303,7 +306,7 @@ pub fn flash_attn_mla(
let q = q
.reshape((b_sz, seqlen_q_ori, num_heads_k, ngroups, head_size))?
.transpose(2, 3)?
.reshape((b_sz, seqlen_q, num_heads, head_size))?;
.reshape((b_sz, seqlen_q, num_heads_k, head_size))?;

let out = q.apply_op3(k_cache, v_cache, op)?;

Expand Down

0 comments on commit 117286a

Please sign in to comment.