diff --git a/candle-flash-mla/build.rs b/candle-flash-mla/build.rs index 5bd9bb944..f27733893 100644 --- a/candle-flash-mla/build.rs +++ b/candle-flash-mla/build.rs @@ -8,7 +8,6 @@ const CUDA_NVCC_FLAGS: Option<&'static str> = option_env!("CUDA_NVCC_FLAGS"); const KERNEL_FILES: &[&str] = &[ "flash_api.cu", - "flash_fwd_mla_kernel.h", "flash_fwd_mla_bf16_sm90.cu", ]; diff --git a/candle-flash-mla/hkernel/flash_api.cu b/candle-flash-mla/hkernel/flash_api.cu index 133c3c8ae..54c8051a4 100644 --- a/candle-flash-mla/hkernel/flash_api.cu +++ b/candle-flash-mla/hkernel/flash_api.cu @@ -1,5 +1,6 @@ // Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/flash_api.cpp +#include "flash_fwd_mla_kernel.h" #include "flash_mla.h" #include "static_switch.h" @@ -41,5 +42,5 @@ extern "C" void mha_fwd_kvcache_mla( const cudaStream_t stream ) { assert(params.d == 576); - run_mha_fwd_splitkv_mla<__nv_bfloat16, 576>(params, stream); + run_mha_fwd_splitkv_mla(params, stream); }