Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix 2pass sdpa kernel #73

Merged
merged 3 commits into from
Mar 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions candle-kernels/build.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
fn main() {
std::env::set_var("NVCC_PREPEND_FLAGS", "-D_USE_MATH_DEFINES");

println!("cargo:rerun-if-changed=build.rs");
println!("cargo:rerun-if-changed=src/compatibility.cuh");
println!("cargo:rerun-if-changed=src/cuda_utils.cuh");
Expand Down
37 changes: 37 additions & 0 deletions candle-metal-kernels/src/scaled_dot_product_attention.metal
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,43 @@ template <typename T, int D>
mask += BN * blocks * mask_seq_stride;
}
}

// Each thread has a partial part of the output so we need to combine them.

// First let's communicate the max and sum_exp
if (simd_lid == 0) {
max_scores[simd_gid] = max_score;
sum_exp_scores[simd_gid] = sum_exp_score;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
max_score = (simd_lid < BN) ? max_scores[simd_lid] : -1e9;
U new_max = simd_max(max_score);
U factor = fast::exp(max_score - new_max);
sum_exp_score = (simd_lid < BN) ? sum_exp_scores[simd_lid] : 0;
sum_exp_score = simd_sum(sum_exp_score * factor);

// Write the sum and new max
if (simd_gid == 0) {
sums[0] = sum_exp_score;
maxs[0] = new_max;
}

// Now we need to aggregate all the outputs
for (int i = 0; i < elem_per_thread; i++) {
outputs[simd_lid * BN + simd_gid] =
o[i] * fast::exp(max_scores[simd_gid] - new_max);
threadgroup_barrier(mem_flags::mem_threadgroup);

// And write the output
if (simd_gid == 0) {
U output = outputs[simd_lid * BN];
for (int j = 1; j < BN; j++) {
output += outputs[simd_lid * BN + j];
}
out[i] = static_cast<T>(output);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
}

template <typename T, int D>
Expand Down
51 changes: 51 additions & 0 deletions candle-nn/tests/sdpa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -390,4 +390,55 @@ mod metal_sdpa_tests {

Ok(())
}

#[test]
fn sdpa_vector_gqa_2pass_no_mask() -> candle::Result<()> {
use candle::{DType, Device, Tensor};
// GQA && Increase seq_len to 1024 in order to cover 2-pass code branch

/// Repeats a key or value tensor for grouped query attention
/// The input tensor should have a shape `(batch, num_kv_heads, seq_len, head_dim)`,
fn repeat_kv(xs: Tensor, n_rep: usize) -> candle::Result<Tensor> {
if n_rep == 1 {
Ok(xs)
} else {
let (b_sz, n_kv_head, seq_len, head_dim) = xs.dims4()?;
// Using cat is faster than a broadcast as it avoids going through a potentially
// strided copy.
// https://github.com/huggingface/candle/pull/2043
Tensor::cat(&vec![&xs; n_rep], 2)?.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))
}
}

const BS: usize = 4;
const R: usize = 1;
const L: usize = 1024;
const DK: usize = 128;
const HQ: usize = 28;
const HKV: usize = 4;

let scale: f64 = f64::from(DK as u32).sqrt().recip();
let device = Device::new_metal(0)?;
let q = Tensor::randn(0f32, 1f32, (BS, HQ, R, DK), &device)?;
let k = Tensor::randn(0f32, 1f32, (BS, HKV, L, DK), &device)?;
let v = Tensor::randn(0f32, 1f32, (BS, HKV, L, DK), &device)?;

let k_aligned = repeat_kv(k.copy().unwrap(), HQ / HKV)?;
let v_aligned = repeat_kv(v.copy().unwrap(), HQ / HKV)?;

let ground_truth = {
let att = (q.clone() * scale)?.matmul(&k_aligned.clone().t()?)?;
let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)?
.to_dtype(q.dtype())?;
att.matmul(&v_aligned.clone())?
};
let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?;
assert_eq!(ground_truth.shape(), sdpa_output.shape());
let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)?
.sum_all()?
.to_scalar()?;
println!("{error}");
assert!(error <= 0.06, "{}", error);
Ok(())
}
}
Loading