Skip to content

Commit

Permalink
Automatic computation of mla metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Feb 25, 2025
1 parent d16c27f commit ca880a0
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 48 deletions.
6 changes: 3 additions & 3 deletions candle-flash-mla/hkernel/flash_api.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,18 @@
#include <cuda.h>
#include <cuda_bf16.h>

#define CEIL_DIV(x, y) (((x) + (y) - 1) / (y))

void get_mla_metadata(
int32_t* seqlens_k_ptr,
int32_t* tile_scheduler_metadata_ptr, // [num_sm_parts, TileSchedulerMetaDataSize]
int32_t* num_splits_ptr, // [batch_size + 1]
const int num_heads_per_head_k,
const int num_heads_k,
const int batch_size,
const int num_sm_parts,
const cudaStream_t stream
) {
// This should match the logic in the MLA kernel.
static constexpr int block_size_m = 64;
// static constexpr int block_size_m = 64; MOVED TO lib.rs
static constexpr int block_size_n = 64;
static constexpr int fixed_overhead_num_blocks = 5;

Expand Down
2 changes: 0 additions & 2 deletions candle-flash-mla/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@ extern "C" {
seqlens_k_ptr: *mut c_int,
tile_scheduler_metadata_ptr: *mut c_int,
num_splits_ptr: *mut c_int,
num_heads_per_head_k: c_int,
num_heads_k: c_int,
batch_size: c_int,
num_sm_parts: c_int,
stream: CUstream,
Expand Down
107 changes: 64 additions & 43 deletions candle-flash-mla/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mod ffi;
use std::f32;

use candle::backend::BackendStorage;
use candle::cuda::cudarc;
use candle::cuda_backend::cudarc::driver::DevicePtr;
use candle::cuda_backend::WrapErr;
use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor};
Expand All @@ -11,9 +12,7 @@ use half::bf16;
pub struct FlashAttn {
pub softmax_scale: f32,
pub block_table: Tensor,
pub tile_scheduler_metadata: Tensor,
pub num_splits: Tensor,
pub seqlens_k: Tensor,
pub cache_seqlens: Tensor,
}

impl FlashAttn {
Expand Down Expand Up @@ -89,42 +88,20 @@ impl FlashAttn {
candle::bail!("number of k/v heads {num_heads_k} must divide number of heads in query {num_heads}")
}

if self.tile_scheduler_metadata.dim(1)? != ffi::TILE_SCHEDULER_METADATA_SIZE {
candle::bail!(
"Expected tile scheduler metadata to match {}: {}",
ffi::TILE_SCHEDULER_METADATA_SIZE,
self.tile_scheduler_metadata.dim(1)?
)
}
let candle::Storage::Cuda(tile_scheduler_metadata) =
&*self.tile_scheduler_metadata.storage_and_layout().0
else {
candle::bail!("tile_scheduler_metadata must be CUDA")
};
let tile_scheduler_metadata = tile_scheduler_metadata
.as_cuda_slice::<i32>()?
.slice(self.tile_scheduler_metadata.layout().start_offset()..);

let candle::Storage::Cuda(num_splits) = &*self.num_splits.storage_and_layout().0 else {
candle::bail!("num_splits must be CUDA")
};
let num_splits = num_splits
.as_cuda_slice::<i32>()?
.slice(self.num_splits.layout().start_offset()..);

let candle::Storage::Cuda(block_table) = &*self.block_table.storage_and_layout().0 else {
candle::bail!("block_table must be CUDA")
};
let block_table = block_table
.as_cuda_slice::<i32>()?
.slice(self.block_table.layout().start_offset()..);

let candle::Storage::Cuda(seqlens_k) = &*self.seqlens_k.storage_and_layout().0 else {
candle::bail!("seqlens_k must be CUDA")
let candle::Storage::Cuda(cache_seqlens) = &*self.cache_seqlens.storage_and_layout().0
else {
candle::bail!("cache_seqlens must be CUDA")
};
let seqlens_k = seqlens_k
let cache_seqlens = cache_seqlens
.as_cuda_slice::<i32>()?
.slice(self.seqlens_k.layout().start_offset()..);
.slice(self.cache_seqlens.layout().start_offset()..);

let is_causal = if seqlen_q == 1 { false } else { true };

Expand All @@ -133,17 +110,48 @@ impl FlashAttn {
let num_heads = num_heads_k;
let head_size_k = head_size_q;

let dst = unsafe { dev.alloc::<bf16>(b_sz * seqlen_q * num_heads * head_size_v) }.w()?;
let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?;
let num_heads_per_head_k = num_heads / num_heads_k;

// This should match the logic in the MLA kernel.
let block_size_m = 64usize;
let sm_count = dev
.attribute(
cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT,
)
.w()? as usize;
let num_sm_parts = sm_count / num_heads_k / num_heads_per_head_k.div_ceil(block_size_m);

let tile_scheduler_metadata =
unsafe { dev.alloc::<i32>(num_sm_parts * ffi::TILE_SCHEDULER_METADATA_SIZE) }.w()?;
let num_splits = dev.alloc_zeros::<i32>(b_sz + 1).w()?;

unsafe {
ffi::get_mla_metadata(
(*cache_seqlens.device_ptr()) as *mut core::ffi::c_int,
(*tile_scheduler_metadata.device_ptr()) as *mut core::ffi::c_int,
(*num_splits.device_ptr()) as *mut core::ffi::c_int,
b_sz as i32,
num_sm_parts as i32,
*dev.cu_stream(),
);
}

let dst = unsafe {
dev.alloc::<bf16>((b_sz + num_sm_parts) * seqlen_q * num_heads * head_size_v)
}
.w()?;
let softmax_lse = dev
.alloc_zeros::<f32>((b_sz + num_sm_parts) * num_heads * seqlen_q)
.w()?;

assert_eq!(head_size_q, 576);

let params = ffi::FlashFwdMlaParams {
b: b_sz as i32,
seqlen_q: seqlen_q as i32,
cu_seqlens_k: (*seqlens_k.device_ptr()) as *mut core::ffi::c_int,
cu_seqlens_k: (*cache_seqlens.device_ptr()) as *mut core::ffi::c_int,
h: num_heads as i32,
h_h_k_ratio: num_heads as i32 / num_heads_k as i32,
h_h_k_ratio: num_heads_per_head_k as i32,
ngroups: ngroups as i32,
is_causal,
d: head_size_q as i32,
Expand Down Expand Up @@ -172,7 +180,7 @@ impl FlashAttn {
page_block_size: page_block_size as i32,
tile_scheduler_metadata_ptr: (*tile_scheduler_metadata.device_ptr())
as *mut core::ffi::c_int,
num_sm_parts: self.tile_scheduler_metadata.dim(0)? as i32,
num_sm_parts: num_sm_parts as i32,
num_splits_ptr: (*num_splits.device_ptr()) as *mut core::ffi::c_int,
oaccum_ptr: (*dst.device_ptr()) as *mut core::ffi::c_void,
softmax_lseaccum_ptr: (*softmax_lse.device_ptr()) as *mut core::ffi::c_void,
Expand Down Expand Up @@ -228,10 +236,9 @@ impl candle::CustomOp3 for FlashAttn {
///
/// * `q: (batch_size, seq_len_q, num_heads_q, head_dim).
/// * `k_cache`: (num_blocks, page_block_size, num_heads_k, head_dim).
/// * `v_cache`: (num_blocks, page_block_size, num_heads_k, head_dim_v).
/// * `block_table`: (batch_size, max_num_blocks_per_seq), i32.
/// * `cache_seqlens`: (batch_size), i32
/// * `tile_scheduler_metadata`: (num_sm_parts, TileSchedulerMetaDataSize), i32, returned by get_mla_metadata.
/// * `num_splits`: (batch_size + 1), i32, returned by get_mla_metadata.
/// * `softmax_scale: The scale of QK^T before applying softmax.
///
/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size_v)`.
Expand All @@ -241,16 +248,30 @@ pub fn flash_attn_mla(
v_cache: &Tensor,
block_table: Tensor,
cache_seqlens: Tensor,
tile_scheduler_metadata: Tensor,
num_splits: Tensor,
softmax_scale: f32,
) -> Result<Tensor> {
let op = FlashAttn {
softmax_scale,
block_table,
tile_scheduler_metadata,
num_splits,
seqlens_k: cache_seqlens,
cache_seqlens,
};
q.apply_op3(k_cache, v_cache, op)

let (b_sz, seqlen_q_ori, num_heads, head_size) = q.shape().dims4()?;
let (_, _, _, head_size_v) = v_cache.shape().dims4()?;

let num_heads_k = k_cache.dim(2)?;
let ngroups = num_heads / num_heads_k;

let seqlen_q = seqlen_q_ori * ngroups;

let q = q
.reshape((b_sz, seqlen_q_ori, num_heads_k, ngroups, head_size))?
.transpose(2, 3)?
.reshape((b_sz, seqlen_q, num_heads, head_size))?;

let out = q.apply_op3(k_cache, v_cache, op)?;

out.reshape((b_sz, seqlen_q_ori, ngroups, num_heads_k, head_size_v))?
.transpose(2, 3)?
.reshape((b_sz, seqlen_q_ori, num_heads, head_size_v))
}

0 comments on commit ca880a0

Please sign in to comment.