diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 7231640081..706e07c1e4 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -17,11 +17,7 @@ const CONV: &str = include_str!("conv.metal"); const FILL: &str = include_str!("fill.metal"); const INDEXING: &str = include_str!("indexing.metal"); // Current source: https://github.com/ivarflakstad/metal-flash-attention/tree/candle -#[cfg(not(target_os = "ios"))] const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib"); -// Current source: https://github.com/philipturner/metal-flash-attention/releases/tag/v1.0.1 -#[cfg(target_os = "ios")] -const MFA: &[u8] = include_bytes!("libMetalFlashAttention.ios.metallib"); const MLX_GEMM: &str = include_str!("mlx_gemm.metal"); const QUANTIZED: &str = include_str!("quantized.metal"); const RANDOM: &str = include_str!("random.metal"); @@ -2017,12 +2013,7 @@ pub fn call_sdpa_vector( alpha }; - let constants = Some(ConstantValues::new(vec![( - 20, - Value::Bool(/* sdpa_vector_has_mask */ false), - )])); - - let pipeline = kernels.load_pipeline_with_constants(device, Source::Sdpa, &name, constants)?; + let pipeline = kernels.load_pipeline(device, Source::Sdpa, &name)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); @@ -2134,13 +2125,7 @@ pub fn call_sdpa_vector_2pass( alpha }; - let constants = Some(ConstantValues::new(vec![( - 20, - Value::Bool(/* sdpa_vector_has_mask */ false), - )])); - - let pipeline = - kernels.load_pipeline_with_constants(device, Source::Sdpa, &name_pass1, constants)?; + let pipeline = kernels.load_pipeline(device, Source::Sdpa, &name_pass1)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); diff --git a/candle-metal-kernels/src/scaled_dot_product_attention.metal b/candle-metal-kernels/src/scaled_dot_product_attention.metal index a8873ad681..6afad8126b 100644 --- a/candle-metal-kernels/src/scaled_dot_product_attention.metal +++ b/candle-metal-kernels/src/scaled_dot_product_attention.metal @@ -299,8 +299,6 @@ struct MLXScaledDotProductAttentionParams { // ============ "mlx/backend/metal/kernels/scaled_dot_product_attention_params.sdpa_vector" -constant bool sdpa_vector_has_mask [[function_constant(20)]]; - template [[kernel]] void sdpa_vector( const device T* queries [[buffer(0)]], @@ -313,16 +311,14 @@ template const constant size_t& v_stride, const constant float& scale, const constant float& softcapping, - const device bool* mask [[function_constant(sdpa_vector_has_mask)]], - const constant int& mask_seq_stride [[function_constant(sdpa_vector_has_mask)]], - const constant int& mask_head_stride [[function_constant(sdpa_vector_has_mask)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { constexpr int BN = 32; constexpr int BD = 32; constexpr int elem_per_thread = D / BD; - constexpr int stride = BN * D; + + const int stride = BN * D; typedef float U; @@ -340,9 +336,6 @@ template queries += head_idx * D + simd_lid * elem_per_thread; keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread; values += kv_head_idx * v_stride + simd_gid * D + simd_lid * elem_per_thread; - if (sdpa_vector_has_mask) { - mask += head_idx * mask_head_stride + simd_gid * mask_seq_stride; - } out += head_idx * D + simd_gid * elem_per_thread; // Read the query and 0 the output accumulator @@ -358,43 +351,38 @@ template // For each key for (int i = simd_gid; i < N; i += BN) { - if (!sdpa_vector_has_mask || mask[0]) { - // Read the key - for (int j = 0; j < elem_per_thread; j++) { - k[j] = keys[j]; - } + // Read the key + for (int i = 0; i < elem_per_thread; i++) { + k[i] = keys[i]; + } - // Compute the i-th score - U score = 0; - for (int j = 0; j < elem_per_thread; j++) { - score += q[j] * k[j]; - } - score = simd_sum(score); - if (softcapping != 1.) { - score = precise::tanh(score); - score = score * softcapping; - } + // Compute the i-th score + U score = 0; + for (int i = 0; i < elem_per_thread; i++) { + score += q[i] * k[i]; + } + score = simd_sum(score); + if (softcapping != 1.) { + score = precise::tanh(score); + score = score * softcapping; + } - // Update the accumulators - U new_max = max(max_score, score); - U factor = fast::exp(max_score - new_max); - U exp_score = fast::exp(score - new_max); + // Update the accumulators + U new_max = max(max_score, score); + U factor = fast::exp(max_score - new_max); + U exp_score = fast::exp(score - new_max); - max_score = new_max; - sum_exp_score = sum_exp_score * factor + exp_score; + max_score = new_max; + sum_exp_score = sum_exp_score * factor + exp_score; - // Update the output accumulator - for (int j = 0; j < elem_per_thread; j++) { - o[j] = o[j] * factor + exp_score * values[j]; - } + // Update the output accumulator + for (int i = 0; i < elem_per_thread; i++) { + o[i] = o[i] * factor + exp_score * values[i]; } // Move the pointers to the next kv keys += stride; values += stride; - if (sdpa_vector_has_mask) { - mask += BN * mask_seq_stride; - } } // Each thread has a partial part of the output so we need to combine them. @@ -440,9 +428,6 @@ template const constant size_t& v_stride, const constant float& scale, const constant float& softcapping, - const device bool* mask [[function_constant(sdpa_vector_has_mask)]], - const constant int& mask_seq_stride [[function_constant(sdpa_vector_has_mask)]], - const constant int& mask_head_stride [[function_constant(sdpa_vector_has_mask)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { @@ -472,10 +457,6 @@ template values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * D + simd_lid * elem_per_thread; out += head_idx * blocks * D + block_idx * D + simd_lid * elem_per_thread; - if (sdpa_vector_has_mask) { - mask += head_idx * mask_head_stride + - (block_idx * BN + simd_gid) * mask_seq_stride; - } sums += head_idx * blocks + block_idx; maxs += head_idx * blocks + block_idx; @@ -492,43 +473,75 @@ template // For each key for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) { - if (!sdpa_vector_has_mask || mask[0]) { - // Read the key - for (int i = 0; i < elem_per_thread; i++) { - k[i] = keys[i]; - } + // Read the key + for (int i = 0; i < elem_per_thread; i++) { + k[i] = keys[i]; + } - // Compute the i-th score - U score = 0; - for (int i = 0; i < elem_per_thread; i++) { - score += q[i] * k[i]; - } - score = simd_sum(score); - if (softcapping != 1.) { - score = precise::tanh(score); - score = score * softcapping; - } + // Compute the i-th score + U score = 0; + for (int i = 0; i < elem_per_thread; i++) { + score += q[i] * k[i]; + } + score = simd_sum(score); + if (softcapping != 1.) { + score = precise::tanh(score); + score = score * softcapping; + } - // Update the accumulators - U new_max = max(max_score, score); - U factor = fast::exp(max_score - new_max); - U exp_score = fast::exp(score - new_max); + // Update the accumulators + U new_max = max(max_score, score); + U factor = fast::exp(max_score - new_max); + U exp_score = fast::exp(score - new_max); - max_score = new_max; - sum_exp_score = sum_exp_score * factor + exp_score; + max_score = new_max; + sum_exp_score = sum_exp_score * factor + exp_score; - // Update the output accumulator - for (int i = 0; i < elem_per_thread; i++) { - o[i] = o[i] * factor + exp_score * values[i]; - } + // Update the output accumulator + for (int i = 0; i < elem_per_thread; i++) { + o[i] = o[i] * factor + exp_score * values[i]; } // Move the pointers to the next kv keys += blocks * stride; values += blocks * stride; - if (sdpa_vector_has_mask) { - mask += BN * blocks * mask_seq_stride; + } + + // Each thread has a partial part of the output so we need to combine them. + + // First let's communicate the max and sum_exp + if (simd_lid == 0) { + max_scores[simd_gid] = max_score; + sum_exp_scores[simd_gid] = sum_exp_score; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + max_score = (simd_lid < BN) ? max_scores[simd_lid] : -1e9; + U new_max = simd_max(max_score); + U factor = fast::exp(max_score - new_max); + sum_exp_score = (simd_lid < BN) ? sum_exp_scores[simd_lid] : 0; + sum_exp_score = simd_sum(sum_exp_score * factor); + + // Write the sum and new max + if (simd_gid == 0) { + sums[0] = sum_exp_score; + maxs[0] = new_max; + } + + // Now we need to aggregate all the outputs + for (int i = 0; i < elem_per_thread; i++) { + outputs[simd_lid * BN + simd_gid] = + o[i] * fast::exp(max_scores[simd_gid] - new_max); + threadgroup_barrier(mem_flags::mem_threadgroup); + + // And write the output + if (simd_gid == 0) { + U output = outputs[simd_lid * BN]; + for (int j = 1; j < BN; j++) { + output += outputs[simd_lid * BN + j]; + } + out[i] = static_cast(output); } + threadgroup_barrier(mem_flags::mem_threadgroup); } } @@ -1656,9 +1669,6 @@ instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 256, 2, 2); const constant size_t& v_stride, \ const constant float& scale, \ const constant float& softcapping, \ - const device bool* mask [[function_constant(sdpa_vector_has_mask)]], \ - const constant int& mask_seq_stride [[function_constant(sdpa_vector_has_mask)]], \ - const constant int& mask_head_stride [[function_constant(sdpa_vector_has_mask)]], \ uint3 tid [[threadgroup_position_in_grid]], \ uint simd_gid [[simdgroup_index_in_threadgroup]], \ uint simd_lid [[thread_index_in_simdgroup]]); \ @@ -1676,9 +1686,6 @@ instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 256, 2, 2); const constant size_t& v_stride, \ const constant float& scale, \ const constant float& softcapping, \ - const device bool* mask [[function_constant(sdpa_vector_has_mask)]], \ - const constant int& mask_seq_stride [[function_constant(sdpa_vector_has_mask)]], \ - const constant int& mask_head_stride [[function_constant(sdpa_vector_has_mask)]], \ uint3 tid [[threadgroup_position_in_grid]], \ uint simd_gid [[simdgroup_index_in_threadgroup]], \ uint simd_lid [[thread_index_in_simdgroup]]); \