Skip to content

Commit

Permalink
Use cute::bfloat16_t
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Feb 25, 2025
1 parent 8b2151b commit dd81c1a
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
1 change: 0 additions & 1 deletion candle-flash-mla/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ const CUDA_NVCC_FLAGS: Option<&'static str> = option_env!("CUDA_NVCC_FLAGS");

const KERNEL_FILES: &[&str] = &[
"flash_api.cu",
"flash_fwd_mla_kernel.h",
"flash_fwd_mla_bf16_sm90.cu",
];

Expand Down
3 changes: 2 additions & 1 deletion candle-flash-mla/hkernel/flash_api.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/flash_api.cpp

#include "flash_fwd_mla_kernel.h"
#include "flash_mla.h"
#include "static_switch.h"

Expand Down Expand Up @@ -41,5 +42,5 @@ extern "C" void mha_fwd_kvcache_mla(
const cudaStream_t stream
) {
assert(params.d == 576);
run_mha_fwd_splitkv_mla<__nv_bfloat16, 576>(params, stream);
run_mha_fwd_splitkv_mla<cute::bfloat16_t, 576>(params, stream);
}

0 comments on commit dd81c1a

Please sign in to comment.