diff --git a/CMakeLists.txt b/CMakeLists.txt
index 3a28abe47..49b6eb94e 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -44,6 +44,7 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON)
 set(CMAKE_CXX_EXTENSIONS OFF)
 include(CheckCXXCompilerFlag)
 include(CheckLanguage)
+include(CMakeDependentOption)
 
 set(_ORTX_STANDALONE_PROJECT OFF)
 if (CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_SOURCE_DIR)
@@ -299,6 +300,7 @@ endmacro()
 
 if(OCOS_USE_CUDA)
   include(ext_cuda)
+  include(cutlass)
 endif()
 
 #######################################################################################################################
@@ -363,12 +365,10 @@ if(OCOS_ENABLE_MATH)
   list(APPEND TARGET_SRC ${TARGET_SRC_MATH} ${TARGET_SRC_DLIB} ${TARGET_SRC_INVERSE})
 endif()
 
-file(GLOB TARGET_SRC_CONTRIB "operators/contrib/*.cc" "operators/contrib/*.h*")
 if (OCOS_USE_CUDA)
-  file(GLOB TARGET_SRC_CONTRIB_CUDA "operators/contrib/cuda/*.*")
-  list(APPEND TARGET_SRC_CONTRIB ${TARGET_SRC_CONTRIB_CUDA})
+  file(GLOB_RECURSE TARGET_SRC_CUDA "operators/cuda/*.*")
+  list(APPEND TARGET_SRC ${TARGET_SRC_CUDA})
 endif()
-list(APPEND TARGET_SRC ${TARGET_SRC_CONTRIB})
 
 # enable the opencv dependency if we have ops that require it
 if(OCOS_ENABLE_CV2 OR OCOS_ENABLE_VISION)
@@ -583,6 +583,10 @@ target_include_directories(ocos_operators PUBLIC
   ${PROJECT_SOURCE_DIR}/base
   ${PROJECT_SOURCE_DIR}/operators)
 
+if (OCOS_USE_CUDA) 
+  target_include_directories(ocos_operators PUBLIC ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples)
+endif()
+
 set(ocos_libraries)
 set(OCOS_COMPILE_DEFINITIONS)
 
diff --git a/cmake/ext_cuda.cmake b/cmake/ext_cuda.cmake
index 15e66ff99..aa7d3282c 100644
--- a/cmake/ext_cuda.cmake
+++ b/cmake/ext_cuda.cmake
@@ -6,6 +6,13 @@ enable_language(CUDA)
 
 set(CMAKE_CUDA_RUNTIME_LIBRARY Shared)
 set(CMAKE_CUDA_STANDARD 17)
+cmake_dependent_option(OCOS_USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "NOT WIN32" OFF)
+option(OCOS_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON)
+if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6)
+  message(STATUS "Turn off flash attention and memory efficient attention since CUDA compiler version < 11.6")
+  set(OCOS_USE_FLASH_ATTENTION OFF)
+  set(OCOS_USE_MEMORY_EFFICIENT_ATTENTION OFF)
+endif()
 
 set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
 if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 11)
@@ -22,3 +29,14 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcudafe \"--diag_suppress=unsigned_co
 set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcudafe \"--diag_suppress=expr_has_no_effect\"")
 
 add_compile_definitions(USE_CUDA)
+
+set(OCOS_USE_MEMORY_EFFICIENT_ATTENTION OFF) # turn off for the build time. Turn them on when these 2 libs are really in use
+set(OCOS_USE_FLASH_ATTENTION OFF)
+if (OCOS_USE_FLASH_ATTENTION)
+  message(STATUS "Enable flash attention")
+  add_compile_definitions(OCOS_USE_FLASH_ATTENTION)
+endif()
+if (OCOS_USE_MEMORY_EFFICIENT_ATTENTION)
+  message(STATUS "Enable memory efficient attention")
+  add_compile_definitions(OCOS_USE_MEMORY_EFFICIENT_ATTENTION)
+endif()
diff --git a/cmake/externals/cutlass.cmake b/cmake/externals/cutlass.cmake
new file mode 100644
index 000000000..e36d42c89
--- /dev/null
+++ b/cmake/externals/cutlass.cmake
@@ -0,0 +1,10 @@
+FetchContent_Declare(
+        cutlass
+        GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git
+        GIT_TAG v3.1.0
+)
+
+FetchContent_GetProperties(cutlass)
+if(NOT cutlass_POPULATED)
+  FetchContent_Populate(cutlass)
+endif()
diff --git a/operators/cuda/attention_lib/cutlass_fmha/fmha_launch_template.h b/operators/cuda/attention_lib/cutlass_fmha/fmha_launch_template.h
new file mode 100644
index 000000000..7a120ff16
--- /dev/null
+++ b/operators/cuda/attention_lib/cutlass_fmha/fmha_launch_template.h
@@ -0,0 +1,276 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#if OCOS_USE_MEMORY_EFFICIENT_ATTENTION
+
+#if defined(__GNUC__)
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wunused-parameter"
+#pragma GCC diagnostic ignored "-Wstrict-aliasing"
+#endif
+
+#include "memory_efficient_attention.h"
+#include "41_fused_multi_head_attention/kernel_forward.h"
+
+namespace ort_extensions {
+namespace cuda {
+
+template <typename AttentionKernel, int kQueriesPerBlock>
+struct RightPaddingBatchHook {
+  using scalar_t = typename AttentionKernel::scalar_t;
+  using accum_t = typename AttentionKernel::accum_t;
+  using lse_scalar_t = typename AttentionKernel::lse_scalar_t;
+  using output_t = typename AttentionKernel::output_t;
+  using output_accum_t = typename AttentionKernel::output_accum_t;
+
+  static constexpr bool kSupportsDropout = AttentionKernel::kSupportsDropout;
+  static constexpr bool kSupportsBias = AttentionKernel::kSupportsBias;
+  static constexpr int kKeysPerBlock = AttentionKernel::kKeysPerBlock;
+  static constexpr bool kIsAligned = AttentionKernel::kIsAligned;
+  static constexpr bool kSingleValueIteration = AttentionKernel::kSingleValueIteration;
+  static constexpr int32_t kAlignLSE = AttentionKernel::kAlignLSE;  // block size of backward
+  static constexpr bool kPreloadV = AttentionKernel::kPreloadV;
+  static constexpr bool kKeepOutputInRF = AttentionKernel::kKeepOutputInRF;
+  static constexpr bool kNeedsOutputAccumulatorBuffer = AttentionKernel::kNeedsOutputAccumulatorBuffer;
+
+  template <typename Params>
+  static CUTLASS_DEVICE bool AdvanceToBlockForGQA(Params& p) {
+    auto batch_id = blockIdx.z;
+    auto head_id = blockIdx.y;
+    auto query_start = blockIdx.x * kQueriesPerBlock;
+
+    auto lse_dim = ceil_div((int32_t)(p.num_queries), kAlignLSE) * kAlignLSE;
+
+    // Advance to current batch - in case of different sequence lengths
+    if (p.seqlen_k_ptr) {
+      p.num_keys = p.seqlen_k_ptr[batch_id];
+    }
+
+    if (query_start >= p.num_queries) {
+      return false;
+    }
+
+    // Advance to the current batch / head / query_start
+    p.query_ptr += batch_id * p.q_strideB + query_start * p.q_strideM + head_id * p.q_strideH;
+    p.key_ptr += batch_id * p.k_strideB + head_id * p.k_strideH;
+    p.value_ptr += batch_id * p.v_strideB + head_id * p.v_strideH;
+    p.output_ptr += int64_t(batch_id * p.num_queries) * p.o_strideM + int64_t(query_start) * p.o_strideM + head_id * p.head_dim_value;
+
+    if (kSupportsBias && p.attn_bias_ptr != nullptr) {
+      p.attn_bias_ptr += (batch_id * p.bias_strideB) + (head_id * p.bias_strideH);
+    }
+    if (p.output_accum_ptr != nullptr) {
+      p.output_accum_ptr += int64_t(batch_id * p.num_queries) * (p.head_dim_value * p.num_heads) +
+                            int64_t(query_start) * (p.head_dim_value * p.num_heads) +
+                            head_id * p.head_dim_value;
+    } else {
+      // Accumulate directly in the destination buffer (eg for f32)
+      p.output_accum_ptr = (accum_t*)(p.output_ptr);
+    }
+
+    if (p.logsumexp_ptr != nullptr) {
+      // lse[batch_id, head_id, query_start]
+      p.logsumexp_ptr +=
+          batch_id * lse_dim * p.num_heads + head_id * lse_dim + query_start;
+    }
+
+    // Custom masking
+    if (p.causal_diagonal_ptr) {
+      p.causal_diagonal_offset = p.causal_diagonal_ptr[batch_id];
+    }
+    if (p.custom_mask_type == AttentionKernel::CausalFromBottomRight) {
+      p.causal_diagonal_offset += p.num_keys - p.num_queries;
+    }
+    if (p.custom_mask_type == AttentionKernel::CausalFromTopLeft ||
+        p.custom_mask_type == AttentionKernel::CausalFromBottomRight) {
+      // the bottom row of the current block is query_start + kQueriesPerBlock
+      // the last active key is then query_start + causal_diagonal_offset +
+      // kQueriesPerBlock so num_keys is the min between actual num_keys and
+      // this to avoid extra computations
+      p.num_keys = cutlass::fast_min(
+          int32_t(query_start + p.causal_diagonal_offset + kQueriesPerBlock),
+          p.num_keys);
+    }
+
+    p.num_queries -= query_start;
+    p.num_batches = 0;  // no longer used after
+
+    // If num_queries == 1, and there is only one key head we're wasting
+    // 15/16th of tensor core compute In that case :
+    //  - we only launch kernels for head_id % kQueriesPerBlock == 0
+    //  - we iterate over heads instead of queries (strideM = strideH)
+    if (p.num_queries == 1 && p.k_strideH == 0 && p.v_strideH == 0) {
+      if (head_id % kQueriesPerBlock != 0)
+        return false;
+      p.q_strideM = p.q_strideH;
+      p.num_queries = p.num_heads;
+      p.num_heads = 1;  // unused but here for intent
+      // remove causal since n_query = 1
+      // otherwise, offset would change with head !
+      p.custom_mask_type = AttentionKernel::NoCustomMask;
+      p.o_strideM = p.head_dim_value;
+    }
+
+    // Make sure the compiler knows these variables are the same on all
+    // the threads of the warp.
+    p.query_ptr = warp_uniform(p.query_ptr);
+    p.key_ptr = warp_uniform(p.key_ptr);
+    p.value_ptr = warp_uniform(p.value_ptr);
+    if (kSupportsBias) {
+      p.attn_bias_ptr = warp_uniform(p.attn_bias_ptr);
+    }
+    p.output_ptr = warp_uniform(p.output_ptr);
+    p.output_accum_ptr = warp_uniform(p.output_accum_ptr);
+    p.logsumexp_ptr = warp_uniform(p.logsumexp_ptr);
+    p.num_queries = warp_uniform(p.num_queries);
+    p.num_keys = warp_uniform(p.num_keys);
+    p.num_heads = warp_uniform(p.num_heads);
+    p.head_dim = warp_uniform(p.head_dim);
+    p.head_dim_value = warp_uniform(p.head_dim_value);
+    p.o_strideM = warp_uniform(p.o_strideM);
+    p.custom_mask_type = warp_uniform(p.custom_mask_type);
+    return true;
+  }
+};
+
+template <typename AK, int kQueriesPerBlock>
+__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
+    attention_kernel_batched_impl_right_padding(typename AK::Params p) {
+  if (!RightPaddingBatchHook<AK, kQueriesPerBlock>::AdvanceToBlockForGQA(p)) {
+    return;
+  }
+  AK::attention_kernel(p);
+}
+
+template <typename T, typename ArchTag, bool is_aligned, int queries_per_block, int keys_per_block, bool single_value_iteration>
+void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) {
+  using Attention = AttentionKernel<T, ArchTag, is_aligned, queries_per_block, keys_per_block, single_value_iteration>;
+  typename Attention::Params p;
+  {  // set parameters
+    p.query_ptr = const_cast<T*>(reinterpret_cast<const T*>(params.query));
+    p.key_ptr = const_cast<T*>(reinterpret_cast<const T*>(params.key));
+    p.value_ptr = const_cast<T*>(reinterpret_cast<const T*>(params.value));
+    p.attn_bias_ptr = const_cast<T*>(reinterpret_cast<const T*>(params.attn_bias));
+    p.seqstart_q_ptr = params.seqstart_q_ptr;
+    p.seqstart_k_ptr = params.seqstart_k_ptr;
+    p.seqlen_k_ptr = params.seqlen_k_ptr;
+
+    p.logsumexp_ptr = nullptr;  // [num_heads, num_queries] for backward or nullptr for forward
+    p.output_ptr = reinterpret_cast<T*>(params.output);
+    if (Attention::kNeedsOutputAccumulatorBuffer) {
+      using Acc = typename Attention::accum_t;
+      // workspace size: batch_size * sequence_length * num_heads * v_head_size * sizeof(float)
+      // TODO: ORT_ENFORCE(params.workspace != nullptr, "Need output accumulator buffer but no workspace provided");
+      p.output_accum_ptr = reinterpret_cast<Acc*>(params.workspace);
+    } else {
+      p.output_accum_ptr = nullptr;
+    }
+    p.num_heads = params.num_heads;
+    p.num_batches = params.batch_size;
+    p.head_dim = params.qk_head_size;
+    p.head_dim_value = params.v_head_size;
+
+    p.scale = params.scale;
+
+    // When params.cu_seqlens_q is provided, num_queries is max_seq_q and num_keys will be set inside the kernel
+    p.num_queries = params.sequence_length;
+    p.num_keys = params.kv_sequence_length;
+
+    if (params.causal) {
+      p.custom_mask_type = Attention::CausalFromBottomRight;
+    }
+
+    // We use max_sequence_length to calculate KV stride
+    if (params.is_kv_bsnh) {
+      // Input Q, K, V format is BxSxNxH, output is BxSxNxH
+      p.q_strideH = params.qk_head_size;
+      p.k_strideH = params.qk_head_size;
+      p.v_strideH = params.v_head_size;
+      p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys;
+
+      p.q_strideM = params.num_heads * params.qk_head_size;
+      p.k_strideM = params.num_heads * params.qk_head_size;
+      p.v_strideM = params.num_heads * params.v_head_size;
+      p.o_strideM = params.num_heads * params.v_head_size;
+      p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys;
+
+      p.q_strideB = static_cast<int64_t>(p.q_strideM) * params.sequence_length;
+      p.k_strideB = static_cast<int64_t>(p.k_strideM) * params.max_sequence_length;
+      p.v_strideB = static_cast<int64_t>(p.v_strideM) * params.max_sequence_length;
+      p.bias_strideB = params.is_attn_bias_batched ? static_cast<int64_t>(p.bias_strideH) * params.num_heads : 0;
+    } else {
+      // Input K, V format is BxNxSxH, Input Q is BxSxNxH, output is BxSxNxH
+      p.q_strideH = params.qk_head_size;
+      p.k_strideH = params.max_sequence_length * params.qk_head_size;
+      p.v_strideH = params.max_sequence_length * params.v_head_size;
+      p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys;
+
+      p.q_strideM = params.num_heads * params.qk_head_size;
+      p.k_strideM = params.qk_head_size;
+      p.v_strideM = params.v_head_size;
+      p.o_strideM = params.num_heads * params.v_head_size;
+      p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys;
+
+      p.q_strideB = params.num_heads * params.qk_head_size * params.sequence_length;
+      p.k_strideB = params.num_heads * params.qk_head_size * params.max_sequence_length;
+      p.v_strideB = params.num_heads * params.v_head_size * params.max_sequence_length;
+      p.bias_strideB = params.is_attn_bias_batched ? static_cast<int64_t>(p.bias_strideH) * params.num_heads : 0;
+    }
+  }
+
+  auto kernel_fn = attention_kernel_batched_impl<Attention>;
+  if (params.has_custom_right_padding) {
+    kernel_fn = attention_kernel_batched_impl_right_padding<Attention, queries_per_block>;
+  }
+
+  int smem_bytes = sizeof(typename Attention::SharedStorage);
+  if (smem_bytes > 0xc000) {
+    // TODO: ORT_ENFORCE(params.sm >= 70, "This kernel requires too much shared memory on this machine!");
+    static bool once = [&]() {
+      cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
+      return true;
+    }();
+  }
+
+  // TODO: ORT_ENFORCE(Attention::check_supported(p));
+  kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes, params.stream>>>(p);
+}
+
+template <typename T, typename ArchTag, int queries_per_block, int keys_per_block, bool single_value_iteration>
+void DispatchIsAligned(const MemoryEfficientAttentionParams& params) {
+  using AlignedAK = AttentionKernel<T, ArchTag, true, queries_per_block, keys_per_block, single_value_iteration>;
+#if defined(_MSC_VER) && !defined(__clang__)
+#pragma warning(push)
+#pragma warning(disable : 6287)
+#endif
+  // Run a more efficient kernel with `isAligned=True` when memory is correctly aligned.
+  bool is_aligned = params.qk_head_size % AlignedAK::kAlignmentQ == 0 &&
+                    params.qk_head_size % AlignedAK::kAlignmentK == 0 &&
+                    params.v_head_size % AlignedAK::kAlignmentV == 0;
+#if defined(_MSC_VER) && !defined(__clang__)
+#pragma warning(pop)
+#endif
+  DISPATCH_BOOL(is_aligned, kIsAligned, ([&]() {
+                  LaunchCutlassFmha<T, ArchTag, kIsAligned, queries_per_block, keys_per_block, single_value_iteration>(params);
+                }));
+}
+
+template <typename T, typename ArchTag>
+void DispatchBlockSize(const MemoryEfficientAttentionParams& params) {
+  if (params.v_head_size <= 64) {
+    DispatchIsAligned<T, ArchTag, 64, 64, true>(params);
+  } else if (params.v_head_size <= 128) {
+    DispatchIsAligned<T, ArchTag, 32, 128, true>(params);
+  } else {
+    DispatchIsAligned<T, ArchTag, 32, 128, false>(params);
+  }
+}
+
+}  // namespace cuda
+}  // namespace ort_extensions
+
+#if defined(__GNUC__)
+#pragma GCC diagnostic pop
+#endif
+
+#endif  // OCOS_USE_MEMORY_EFFICIENT_ATTENTION
diff --git a/operators/cuda/attention_lib/cutlass_fmha/fmha_sm50.cu b/operators/cuda/attention_lib/cutlass_fmha/fmha_sm50.cu
new file mode 100644
index 000000000..669164ae7
--- /dev/null
+++ b/operators/cuda/attention_lib/cutlass_fmha/fmha_sm50.cu
@@ -0,0 +1,22 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#if OCOS_USE_MEMORY_EFFICIENT_ATTENTION
+
+#include "fmha_launch_template.h"
+
+namespace ort_extensions {
+namespace cuda {
+
+void run_memory_efficient_attention_sm50(const MemoryEfficientAttentionParams& params) {
+  if (params.is_half) {
+    DispatchBlockSize<cutlass::half_t, cutlass::arch::Sm50>(params);
+  } else {
+    DispatchBlockSize<float, cutlass::arch::Sm50>(params);
+  }
+}
+
+}  // namespace cuda
+}  // namespace ort_extensions
+
+#endif  // OCOS_USE_MEMORY_EFFICIENT_ATTENTION
diff --git a/operators/cuda/attention_lib/cutlass_fmha/fmha_sm70.cu b/operators/cuda/attention_lib/cutlass_fmha/fmha_sm70.cu
new file mode 100644
index 000000000..f561b8fa6
--- /dev/null
+++ b/operators/cuda/attention_lib/cutlass_fmha/fmha_sm70.cu
@@ -0,0 +1,22 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#if OCOS_USE_MEMORY_EFFICIENT_ATTENTION
+
+#include "fmha_launch_template.h"
+
+namespace ort_extensions {
+namespace cuda {
+
+void run_memory_efficient_attention_sm70(const MemoryEfficientAttentionParams& params) {
+  if (params.is_half) {
+    DispatchBlockSize<cutlass::half_t, cutlass::arch::Sm70>(params);
+  } else {
+    DispatchBlockSize<float, cutlass::arch::Sm70>(params);
+  }
+}
+
+}  // namespace cuda
+}  // namespace ort_extensions
+
+#endif  // OCOS_USE_MEMORY_EFFICIENT_ATTENTION
diff --git a/operators/cuda/attention_lib/cutlass_fmha/fmha_sm75.cu b/operators/cuda/attention_lib/cutlass_fmha/fmha_sm75.cu
new file mode 100644
index 000000000..e55fe4269
--- /dev/null
+++ b/operators/cuda/attention_lib/cutlass_fmha/fmha_sm75.cu
@@ -0,0 +1,22 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#if OCOS_USE_MEMORY_EFFICIENT_ATTENTION
+
+#include "fmha_launch_template.h"
+
+namespace ort_extensions {
+namespace cuda {
+
+void run_memory_efficient_attention_sm75(const MemoryEfficientAttentionParams& params) {
+  if (params.is_half) {
+    DispatchBlockSize<cutlass::half_t, cutlass::arch::Sm75>(params);
+  } else {
+    DispatchBlockSize<float, cutlass::arch::Sm75>(params);
+  }
+}
+
+}  // namespace cuda
+}  // namespace ort_extensions
+
+#endif  // OCOS_USE_MEMORY_EFFICIENT_ATTENTION
diff --git a/operators/cuda/attention_lib/cutlass_fmha/fmha_sm80.cu b/operators/cuda/attention_lib/cutlass_fmha/fmha_sm80.cu
new file mode 100644
index 000000000..76d12de99
--- /dev/null
+++ b/operators/cuda/attention_lib/cutlass_fmha/fmha_sm80.cu
@@ -0,0 +1,22 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#if OCOS_USE_MEMORY_EFFICIENT_ATTENTION
+
+#include "fmha_launch_template.h"
+
+namespace ort_extensions {
+namespace cuda {
+
+void run_memory_efficient_attention_sm80(const MemoryEfficientAttentionParams& params) {
+  if (params.is_half) {
+    DispatchBlockSize<cutlass::half_t, cutlass::arch::Sm80>(params);
+  } else {
+    DispatchBlockSize<float, cutlass::arch::Sm80>(params);
+  }
+}
+
+}  // namespace cuda
+}  // namespace ort_extensions
+
+#endif  // OCOS_USE_MEMORY_EFFICIENT_ATTENTION
diff --git a/operators/cuda/attention_lib/cutlass_fmha/memory_efficient_attention.cu b/operators/cuda/attention_lib/cutlass_fmha/memory_efficient_attention.cu
new file mode 100644
index 000000000..c50a5b46d
--- /dev/null
+++ b/operators/cuda/attention_lib/cutlass_fmha/memory_efficient_attention.cu
@@ -0,0 +1,29 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+#if OCOS_USE_MEMORY_EFFICIENT_ATTENTION
+
+#include "memory_efficient_attention.h"
+#include <cassert>
+
+namespace ort_extensions {
+namespace cuda {
+
+void run_memory_efficient_attention(const MemoryEfficientAttentionParams& params) {
+  const int32_t& sm = params.sm;
+  if (sm >= 80) {
+    run_memory_efficient_attention_sm80(params);
+  } else if (sm >= 75) {
+    run_memory_efficient_attention_sm75(params);
+  } else if (sm >= 70) {
+    run_memory_efficient_attention_sm70(params);
+  } else if (sm >= 50) {
+    run_memory_efficient_attention_sm50(params);
+  } else {
+    assert(false);  // shall not reach here.
+  }
+}
+
+}  // namespace cuda
+}  // namespace ort_extensions
+
+#endif  // OCOS_USE_MEMORY_EFFICIENT_ATTENTION
diff --git a/operators/cuda/attention_lib/cutlass_fmha/memory_efficient_attention.h b/operators/cuda/attention_lib/cutlass_fmha/memory_efficient_attention.h
new file mode 100644
index 000000000..082740ccb
--- /dev/null
+++ b/operators/cuda/attention_lib/cutlass_fmha/memory_efficient_attention.h
@@ -0,0 +1,59 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+#pragma once
+#if OCOS_USE_MEMORY_EFFICIENT_ATTENTION
+#include <cstdint>
+
+namespace ort_extensions {
+namespace cuda {
+
+struct MemoryEfficientAttentionParams {
+  int32_t sm;
+  bool is_half;
+  bool is_kv_bsnh = true;
+  int32_t batch_size;
+  int32_t num_heads;
+  int32_t sequence_length;
+  int32_t kv_sequence_length;
+  int32_t max_sequence_length;
+  int32_t qk_head_size;
+  int32_t v_head_size;
+  bool causal;
+  // The default shape of attn_bias is [1, N, S, S*]. Sometimes we need to use [B, N, S, S*] in custom models.
+  bool is_attn_bias_batched;
+
+  float scale;
+
+  int32_t* seqstart_q_ptr;
+  int32_t* seqstart_k_ptr;
+  int32_t* seqlen_k_ptr;
+
+  const void* query;      // [B, S, N, H]
+  const void* key;        // [B, L, N, H], where L is kv_sequence_length
+  const void* value;      // [B, L, N, H_v]
+  const void* attn_bias;  // [N, S, S*] or null
+  void* output;           // [B, S, N, H_v]
+  void* workspace;        // [B, S, N, H_v] when kNeedsOutputAccumulatorBuffer, nullptr otherwise
+  cudaStream_t stream;
+
+  static bool need_workspace(size_t v_head_size, bool is_float) {
+    return (v_head_size > 128 && !is_float);
+  }
+
+  bool has_custom_right_padding = false;
+};
+
+void run_memory_efficient_attention(const MemoryEfficientAttentionParams& params);
+
+inline bool has_memory_efficient_attention(int32_t sm, bool is_half) {
+  return sm >= (is_half ? 53 : 50);
+}
+
+void run_memory_efficient_attention_sm80(const MemoryEfficientAttentionParams& params);
+void run_memory_efficient_attention_sm75(const MemoryEfficientAttentionParams& params);
+void run_memory_efficient_attention_sm70(const MemoryEfficientAttentionParams& params);
+void run_memory_efficient_attention_sm50(const MemoryEfficientAttentionParams& params);
+
+}
+}
+#endif
diff --git a/operators/cuda/attention_lib/flash_attention/block_info.h b/operators/cuda/attention_lib/flash_attention/block_info.h
new file mode 100644
index 000000000..1ec632658
--- /dev/null
+++ b/operators/cuda/attention_lib/flash_attention/block_info.h
@@ -0,0 +1,44 @@
+/******************************************************************************
+ * Copyright (c) 2023, Tri Dao.
+ ******************************************************************************/
+#pragma once
+
+namespace flash {
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <bool Varlen = true>
+struct BlockInfo {
+  template <typename Params>
+  __device__ BlockInfo(const Params& params, const int bidb)
+      : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]),
+        sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb]),
+        actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
+        // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
+        // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
+        ,
+        seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])),
+        actual_seqlen_k(seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) {
+  }
+
+  template <typename index_t>
+  inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
+    return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
+  }
+
+  template <typename index_t>
+  inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
+    return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
+  }
+
+  const int sum_s_q;
+  const int sum_s_k;
+  const int actual_seqlen_q;
+  // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
+  const int seqlen_k_cache;
+  const int actual_seqlen_k;
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+}  // namespace flash
diff --git a/operators/cuda/attention_lib/flash_attention/flash.h b/operators/cuda/attention_lib/flash_attention/flash.h
new file mode 100644
index 000000000..603a6e068
--- /dev/null
+++ b/operators/cuda/attention_lib/flash_attention/flash.h
@@ -0,0 +1,114 @@
+#pragma once
+#include <cuda.h>
+
+namespace flash {
+struct Qkv_params {
+  using index_t = uint32_t;
+  // The QKV matrices.
+  void* __restrict__ q_ptr = nullptr;
+  void* __restrict__ k_ptr = nullptr;
+  void* __restrict__ v_ptr = nullptr;
+
+  // The stride between rows of the Q, K and V matrices.
+  index_t q_batch_stride = 0;
+  index_t k_batch_stride = 0;
+  index_t v_batch_stride = 0;
+  index_t q_row_stride = 0;
+  index_t k_row_stride = 0;
+  index_t v_row_stride = 0;
+  index_t q_head_stride = 0;
+  index_t k_head_stride = 0;
+  index_t v_head_stride = 0;
+
+  // The number of heads.
+  int h = 0;
+  int h_k = 0;
+  // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be
+  // different from nheads (query).
+  int h_h_k_ratio = 0;  // precompute h / h_k,
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+struct Flash_fwd_params : public Qkv_params {
+  // The O matrix (output).
+  void* __restrict__ o_ptr = nullptr;
+  void* __restrict__ oaccum_ptr = nullptr;
+
+  // The stride between rows of O.
+  index_t o_batch_stride = 0;
+  index_t o_row_stride = 0;
+  index_t o_head_stride = 0;
+
+  // The pointer to the P matrix.
+  void* __restrict__ p_ptr = nullptr;
+
+  // The pointer to the softmax sum.
+  void* __restrict__ softmax_lse_ptr = nullptr;
+  void* __restrict__ softmax_lseaccum_ptr = nullptr;
+
+  // The dimensions.
+  int b = 0;
+  int seqlen_q = 0;
+  int seqlen_k = 0;
+  int seqlen_knew = 0;
+  int d = 0;
+  int seqlen_q_rounded = 0;
+  int seqlen_k_rounded = 0;
+  int d_rounded = 0;
+  int rotary_dim = 0;
+
+  // The scaling factors for the kernel.
+  float scale_softmax = 0.0;
+  float scale_softmax_log2 = 0.0;
+
+  // array of length b+1 holding starting offset of each sequence.
+  int* __restrict__ cu_seqlens_q = nullptr;
+  int* __restrict__ cu_seqlens_k = nullptr;
+
+  int* __restrict__ blockmask = nullptr;
+
+  // The K_new and V_new matrices.
+  void* __restrict__ knew_ptr = nullptr;
+  void* __restrict__ vnew_ptr = nullptr;
+
+  // The stride between rows of the Q, K and V matrices.
+  index_t knew_batch_stride = 0;
+  index_t vnew_batch_stride = 0;
+  index_t knew_row_stride = 0;
+  index_t vnew_row_stride = 0;
+  index_t knew_head_stride = 0;
+  index_t vnew_head_stride = 0;
+
+  // The cos and sin matrices for rotary embedding.
+  void* __restrict__ rotary_cos_ptr = nullptr;
+  void* __restrict__ rotary_sin_ptr = nullptr;
+
+  // The indices to index into the KV cache.
+  int* __restrict__ cache_batch_idx = nullptr;
+
+  // Local window size
+  int window_size_left = -1;
+  int window_size_right = -1;
+
+  bool is_bf16 = false;
+  bool is_causal = false;
+
+  // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
+  // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
+  bool is_seqlens_k_cumulative = true;
+
+  bool is_rotary_interleaved = false;
+
+  int num_splits = 0;  // For split-KV version
+
+  const cudaDeviceProp* dprops = nullptr;
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <typename T, int Headdim>
+void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream);
+template <typename T, int Headdim>
+void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream);
+}
\ No newline at end of file
diff --git a/operators/cuda/attention_lib/flash_attention/flash_api.cc b/operators/cuda/attention_lib/flash_attention/flash_api.cc
new file mode 100644
index 000000000..46812b560
--- /dev/null
+++ b/operators/cuda/attention_lib/flash_attention/flash_api.cc
@@ -0,0 +1,465 @@
+#if OCOS_USE_FLASH_ATTENTION
+
+#include "flash_api.h"
+#include "flash.h"
+#include "static_switch.h"
+#include <cutlass/numeric_types.h>
+
+namespace flash {
+
+void set_params_fprop(Flash_fwd_params& params,
+                      // sizes
+                      size_t batch_size,
+                      size_t seqlen_q,
+                      size_t seqlen_k,
+                      size_t seqlen_q_rounded,
+                      size_t seqlen_k_rounded,
+                      size_t num_heads,
+                      size_t num_heads_k,
+                      size_t head_size,
+                      size_t head_size_rounded,
+                      // device pointers
+                      void* q,
+                      void* k,
+                      void* v,
+                      void* out,
+                      void* cu_seqlens_q_d,
+                      void* cu_seqlens_k_d,
+                      void* p_d,
+                      void* softmax_lse_d,
+                      float softmax_scale,
+                      bool is_causal,
+                      bool is_bf16,
+                      bool kv_bsnh = true,
+                      int window_size_left = -1,
+                      int window_size_right = -1) {
+  // Set the pointers and strides.
+  params.q_ptr = q;
+  params.k_ptr = k;
+  params.v_ptr = v;
+  params.o_ptr = out;
+
+  params.is_bf16 = is_bf16;
+
+  // All stride are in elements, not bytes.
+  if (kv_bsnh) {
+    params.q_row_stride = num_heads * head_size;
+    params.k_row_stride = num_heads_k * head_size;
+    params.v_row_stride = num_heads_k * head_size;
+    params.q_head_stride = head_size;
+    params.k_head_stride = head_size;
+    params.v_head_stride = head_size;
+    params.o_row_stride = num_heads * head_size;
+    params.o_head_stride = head_size;
+  } else {
+    params.q_row_stride = num_heads * head_size;
+    params.k_row_stride = head_size;
+    params.v_row_stride = head_size;
+    params.q_head_stride = head_size;
+    params.k_head_stride = seqlen_k * head_size;
+    params.v_head_stride = seqlen_k * head_size;
+    params.o_row_stride = num_heads * head_size;
+    params.o_head_stride = head_size;
+  }
+
+  if (cu_seqlens_q_d == nullptr) {
+    params.q_batch_stride = seqlen_q * num_heads * head_size;    // stride(0)
+    params.k_batch_stride = seqlen_k * num_heads_k * head_size;  // stride(0)
+    params.v_batch_stride = seqlen_k * num_heads_k * head_size;  // stride(0)
+    params.o_batch_stride = seqlen_q * num_heads * head_size;    // stride(0)
+  } else {
+    params.q_batch_stride = 0;
+    params.k_batch_stride = 0;
+    params.v_batch_stride = 0;
+    params.o_batch_stride = 0;
+  }
+
+  params.cu_seqlens_q = static_cast<int*>(cu_seqlens_q_d);
+  params.cu_seqlens_k = static_cast<int*>(cu_seqlens_k_d);
+
+  // P = softmax(QK^T)
+  params.p_ptr = p_d;
+
+  // Softmax sum
+  params.softmax_lse_ptr = softmax_lse_d;
+
+  // Set the dimensions.
+  params.b = batch_size;
+  params.h = num_heads;
+  params.h_k = num_heads_k;
+  params.h_h_k_ratio = num_heads / num_heads_k;
+  params.seqlen_q = seqlen_q;
+  params.seqlen_k = seqlen_k;
+  params.seqlen_q_rounded = seqlen_q_rounded;
+  params.seqlen_k_rounded = seqlen_k_rounded;
+  params.d = head_size;
+  params.d_rounded = head_size_rounded;
+
+  // Set the different scale values.
+  params.scale_softmax = softmax_scale;
+  params.scale_softmax_log2 = softmax_scale * M_LOG2E;
+
+  // In our API, causal/unidirectional determines if we only look at prior tokens. However, the flash API seperates
+  // local and causal, meaning when we have local window size
+  params.is_causal = is_causal;
+  if (is_causal && (window_size_left >= 0 || window_size_right != 0)) {
+    params.is_causal = false;
+  }
+  if (window_size_left < 0 && window_size_right >= 0) {
+    window_size_left = seqlen_k;
+  }
+  if (window_size_left >= 0 && window_size_right < 0) {
+    window_size_right = seqlen_k;
+  }
+  params.window_size_left = window_size_left;
+  params.window_size_right = window_size_right;
+
+  params.is_seqlens_k_cumulative = true;
+}
+
+size_t get_softmax_lse_size(int seqlen, int batch_size, int num_heads) {
+  size_t bytes = sizeof(float) * batch_size * num_heads * seqlen;
+  return bytes;
+}
+
+size_t get_softmax_lse_accum_size(int num_splits, int batch_size, int num_heads, int seqlen_q) {
+  size_t bytes = sizeof(float) * num_splits * batch_size * seqlen_q * num_heads;
+  return bytes;
+}
+
+size_t get_out_accum_size(int num_splits, int batch_size, int num_heads, int seqlen_q, int head_size_rounded) {
+  size_t bytes = sizeof(float) * num_splits * batch_size * seqlen_q * num_heads * head_size_rounded;
+  return bytes;
+}
+
+void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream, bool force_split_kernel = false) {
+  FP16_SWITCH(!params.is_bf16, [&] {
+    FWD_HEADDIM_SWITCH(params.d, [&] {
+      if (params.num_splits <= 1 && !force_split_kernel) {  // If we don't set it num_splits == 0
+        run_mha_fwd_<elem_type, kHeadDim>(params, stream);
+      } else {
+        run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim>(params, stream);
+      }
+    });
+  });
+}
+
+// Find the number of splits that maximizes the occupancy. For example, if we have
+// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is
+// better than having 3 splits (efficiency = 0.67). However, we also don't want too many
+// splits as that would incur more HBM reads/writes.
+// So we find the best efficiency, then find the smallest number of splits that gets 85%
+// of the best efficiency.
+int num_splits_heuristic(int batch_size, int seqlen_q, int seqlen_k, int num_heads, int head_size, int num_SMs,
+                         int max_splits) {
+  // This needs to match with run_mha_fwd_splitkv_dispatch
+  const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);
+  const int num_n_blocks = (seqlen_k + block_n - 1) / block_n;
+  // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
+  // In any case we don't expect seqlen_q to be larger than 64 for inference.
+  const int num_m_blocks = (seqlen_q + 64 - 1) / 64;
+  int batch_nheads_mblocks = batch_size * num_heads * num_m_blocks;
+  // If we have enough to almost fill the SMs, then just use 1 split
+  if (batch_nheads_mblocks >= 0.8f * num_SMs) {
+    return 1;
+  }
+  max_splits = std::min({max_splits, num_SMs, num_n_blocks});
+  float max_efficiency = 0.f;
+  std::vector<float> efficiency;
+  efficiency.reserve(max_splits);
+  auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
+  // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
+  // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
+  // (i.e. it's 11 splits anyway).
+  // So we check if the number of blocks per split is the same as the previous num_splits.
+  auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {
+    return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1);
+  };
+  for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
+    if (!is_split_eligible(num_splits)) {
+      efficiency.push_back(0.f);
+    } else {
+      float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs;
+      float eff = n_waves / ceil(n_waves);
+      // printf("num_splits = %d, eff = %f\n", num_splits, eff);
+      if (eff > max_efficiency) {
+        max_efficiency = eff;
+      }
+      efficiency.push_back(eff);
+    }
+  }
+  for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
+    if (!is_split_eligible(num_splits)) {
+      continue;
+    }
+    if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) {
+      // printf("num_splits chosen = %d\n", num_splits);
+      return num_splits;
+    }
+  }
+  return 1;
+}
+
+// Returns (num_splits, softmax_lse_accum bytes, out_accum bytes)
+std::tuple<int, int, int> get_num_splits_and_buffer_sizes(int batch_size, int seqlen_q, int seqlen_k, int num_heads,
+                                                          int head_size, int num_SMs) {
+  int max_splits = 128;
+  // split kv buffers
+  int num_splits = num_splits_heuristic(batch_size, seqlen_q, seqlen_k, num_heads, head_size,
+                                        num_SMs, max_splits);
+  if (num_splits > 1) {
+    // softmax_lse_accum buffer
+    int softmax_lse_accum_bytes = get_softmax_lse_accum_size(num_splits, batch_size, num_heads, seqlen_q);
+    // out_accum buffer
+    auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
+    const int head_size_rounded = round_multiple(head_size, 32);
+    int out_accum_bytes = get_out_accum_size(num_splits, batch_size, num_heads, seqlen_q, head_size_rounded);
+    return {num_splits, softmax_lse_accum_bytes, out_accum_bytes};
+  } else {
+    return {0, 0, 0};
+  }
+}
+
+OrtStatusPtr mha_fwd(const cudaDeviceProp& dprops,
+               cudaStream_t stream,
+               void* q,            // batch_size x seqlen_q x num_heads x head_size
+               void* k,            // batch_size x seqlen_k x num_heads_k x head_size
+               void* v,            // batch_size x seqlen_k x num_heads_k x head_size
+               void* out,          // batch_size x seqlen_q x num_heads x head_size
+               void* softmax_lse,  // batch_size x num_heads x seqlen_q
+               int batch_size,
+               int num_heads,
+               int num_heads_k,
+               int head_size,
+               int seqlen_q,
+               int seqlen_k,
+               float softmax_scale,
+               bool is_causal,
+               bool is_bf16,
+               int num_splits,
+               void* softmax_lse_accum,  // num_splits x batch_size x seqlen_q x num_heads
+               void* out_accum,          // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
+               bool kv_bsnh,
+               int local_window_size) {
+  auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
+  const int head_size_rounded = round_multiple(head_size, 32);
+  const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
+  const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
+
+  Flash_fwd_params params;
+  set_params_fprop(params,
+                   batch_size,
+                   seqlen_q, seqlen_k,
+                   seqlen_q_rounded, seqlen_k_rounded,
+                   num_heads, num_heads_k,
+                   head_size, head_size_rounded,
+                   q, k, v, out,
+                   /*cu_seqlens_q*/ nullptr,
+                   /*cu_seqlens_k*/ nullptr,
+                   nullptr,
+                   softmax_lse,
+                   softmax_scale,
+                   is_causal,
+                   is_bf16,
+                   kv_bsnh,
+                   local_window_size,
+                   is_causal ? 0 : -1);
+  params.dprops = &dprops;
+  params.knew_ptr = nullptr;
+  params.vnew_ptr = nullptr;
+  params.knew_batch_stride = 0;
+  params.vnew_batch_stride = 0;
+  params.knew_row_stride = 0;
+  params.vnew_row_stride = 0;
+  params.knew_head_stride = 0;
+  params.vnew_head_stride = 0;
+
+  params.num_splits = num_splits;
+  if (params.num_splits > 1 && softmax_lse_accum != nullptr && out_accum != nullptr) {
+    params.softmax_lseaccum_ptr = softmax_lse_accum;
+    params.oaccum_ptr = out_accum;
+  } else {
+    params.softmax_lseaccum_ptr = nullptr;
+    params.oaccum_ptr = nullptr;
+  }
+
+  run_mha_fwd(params, stream);
+  return nullptr;
+}
+
+OrtStatusPtr mha_varlen_fwd(const cudaDeviceProp& dprops,
+                      cudaStream_t stream,
+                      void* q,            // half (total_q, num_heads, head_size)
+                      void* k,            // half (total_k, num_heads, head_size)
+                      void* v,            // half (total_k, num_heads, head_size)
+                      void* out,          // half (total_q, num_heads, head_size)
+                      int* cu_seqlens_q,  // int (batch_size + 1)
+                      int* cu_seqlens_k,  // int (batch_size + 1)
+                      void* softmax_lse,  // float (batch_size, num_heads, max_seqlen_q)
+                      int batch_size,
+                      int num_heads,
+                      int num_heads_k,
+                      int head_size,
+                      int max_seqlen_q,
+                      int max_seqlen_k,
+                      float softmax_scale,
+                      bool is_causal,
+                      bool is_bf16) {
+  auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
+  const int head_size_rounded = round_multiple(head_size, 32);
+  const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
+  const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
+
+  Flash_fwd_params params;
+  set_params_fprop(params,
+                   batch_size,
+                   max_seqlen_q, max_seqlen_k,
+                   seqlen_q_rounded, seqlen_k_rounded,
+                   num_heads, num_heads_k,
+                   head_size, head_size_rounded,
+                   q, k, v, out,
+                   cu_seqlens_q,
+                   cu_seqlens_k,
+                   nullptr,
+                   softmax_lse,
+                   softmax_scale,
+                   is_causal,
+                   is_bf16,
+                   true,
+                   -1,
+                   is_causal ? 0 : -1);
+  params.dprops = &dprops;
+  params.num_splits = 0;
+  params.softmax_lseaccum_ptr = nullptr;
+  params.oaccum_ptr = nullptr;
+  params.knew_ptr = nullptr;
+  params.vnew_ptr = nullptr;
+  run_mha_fwd(params, stream);
+  return nullptr;
+}
+
+bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, int num_heads_k) {
+  bool is_sm8x = dprops.major == 8 && dprops.minor >= 0;
+  bool is_sm90 = dprops.major == 9 && dprops.minor == 0;
+  return (is_sm8x || is_sm90) && (head_size % 8 == 0) && (head_size <= 256) && (num_heads % num_heads_k == 0);
+}
+
+// This API is used when past key and value are present... since cached, these are assumed to have sequence length
+// of max_sequence_length, so seqlen_k == max_sequence_length. The actual past sequence length is held in seqlens_k_.
+OrtStatusPtr mha_fwd_kvcache(const cudaDeviceProp& dprops,
+                       cudaStream_t stream,
+                       void* q,            // batch_size x seqlen_q x num_heads x head_size
+                       void* kcache,       // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size
+                       void* vcache,       // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size
+                       void* k_new,        // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
+                       void* v_new,        // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
+                       void* out,          // batch_size x seqlen_q x num_heads x head_size
+                       void* softmax_lse,  // batch_size x num_heads x seqlen_q
+                       void* seqlens_k_,   // batch_size
+                       void* rotary_cos,   // seqlen_ro x (rotary_dim / 2)
+                       void* rotary_sin,   // seqlen_ro x (rotary_dim / 2)
+                       int batch_size,
+                       int num_heads,
+                       int num_heads_k,
+                       int head_size,
+                       int seqlen_q,
+                       int seqlen_k,
+                       int seqlen_k_new,
+                       const float softmax_scale,
+                       bool is_causal,
+                       bool is_bf16,
+                       bool past_bsnh,  // otherwise bnsh
+                       int num_splits,
+                       void* softmax_lse_accum,  // num_splits x batch_size x seqlen_q x num_heads
+                       void* out_accum,          // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
+                       int local_window_size,
+                       bool is_rotary_interleaved,
+                       bool is_packed_qkv) {
+  auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
+  const int head_size_rounded = round_multiple(head_size, 32);
+  const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
+  const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
+
+  // In kv-cache case, seqlen_k_max as kv sequence length
+  Flash_fwd_params params;
+  set_params_fprop(params,
+                   batch_size,
+                   seqlen_q, seqlen_k,
+                   seqlen_q_rounded, seqlen_k_rounded,
+                   num_heads, num_heads_k,
+                   head_size, head_size_rounded,
+                   q, kcache, vcache, out,
+                   /*cu_seqlens_q_d=*/nullptr,
+                   /*cu_seqlens_k_d=*/nullptr,
+                   /*p_ptr=*/nullptr,
+                   softmax_lse,
+                   softmax_scale,
+                   is_causal,
+                   is_bf16,
+                   past_bsnh,
+                   local_window_size,
+                   is_causal ? 0 : -1);
+  params.dprops = &dprops;
+
+  if (k_new != nullptr && v_new != nullptr) {
+    params.seqlen_knew = seqlen_k_new;
+    params.knew_ptr = k_new;
+    params.vnew_ptr = v_new;
+    // All stride are in elements, not bytes.
+    if (is_packed_qkv) {
+      params.q_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size);
+      params.q_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size);
+      params.knew_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size);
+      params.vnew_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size);
+      params.knew_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size);
+      params.vnew_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size);
+    } else {
+      params.knew_batch_stride = seqlen_k_new * num_heads_k * head_size;
+      params.vnew_batch_stride = seqlen_k_new * num_heads_k * head_size;
+      params.knew_row_stride = num_heads_k * head_size;
+      params.vnew_row_stride = num_heads_k * head_size;
+    }
+    params.knew_head_stride = head_size;
+    params.vnew_head_stride = head_size;
+  } else {
+    params.seqlen_knew = 0;
+    params.knew_ptr = nullptr;
+    params.vnew_ptr = nullptr;
+    params.knew_batch_stride = 0;
+    params.vnew_batch_stride = 0;
+    params.knew_row_stride = 0;
+    params.vnew_row_stride = 0;
+    params.knew_head_stride = 0;
+    params.vnew_head_stride = 0;
+  }
+
+  params.is_seqlens_k_cumulative = seqlens_k_ == nullptr;
+  if (seqlens_k_ != nullptr) {
+    params.cu_seqlens_k = static_cast<int*>(seqlens_k_);
+  }
+
+  if (rotary_cos != nullptr) {
+    params.rotary_cos_ptr = rotary_cos;
+    params.rotary_sin_ptr = rotary_sin;
+    params.is_rotary_interleaved = is_rotary_interleaved;
+    params.rotary_dim = (head_size / 16) * 16;
+  }
+
+  params.num_splits = num_splits;
+  if (params.num_splits > 1 && softmax_lse_accum != nullptr && out_accum != nullptr) {
+    params.softmax_lseaccum_ptr = softmax_lse_accum;
+    params.oaccum_ptr = out_accum;
+  } else {
+    params.softmax_lseaccum_ptr = nullptr;
+    params.oaccum_ptr = nullptr;
+  }
+
+  // Only split kernel supports appending to KV cache
+  run_mha_fwd(params, stream, /*force_split_kernel=*/k_new != nullptr);
+
+  return nullptr;
+}
+
+}  // namespace flash
+
+#endif  // OCOS_USE_FLASH_ATTENTION
diff --git a/operators/cuda/attention_lib/flash_attention/flash_api.h b/operators/cuda/attention_lib/flash_attention/flash_api.h
new file mode 100644
index 000000000..4ad1b76e1
--- /dev/null
+++ b/operators/cuda/attention_lib/flash_attention/flash_api.h
@@ -0,0 +1,92 @@
+#pragma once
+
+#if OCOS_USE_FLASH_ATTENTION
+
+#include <tuple>
+#include <cuda.h>
+#include <cuda_runtime.h>
+#include "onnxruntime_c_api.h"
+
+namespace flash {
+
+OrtStatusPtr mha_fwd(const cudaDeviceProp& dprops,
+               cudaStream_t stream,
+               void* q,            // batch_size x seqlen_q x num_heads x head_size
+               void* k,            // batch_size x seqlen_k x num_heads_k x head_size
+               void* v,            // batch_size x seqlen_k x num_heads_k x head_size
+               void* out,          // batch_size x seqlen_q x num_heads x head_size
+               void* softmax_lse,  // batch_size x num_heads x seqlen_q
+               int batch_size,
+               int num_heads,
+               int num_heads_k,
+               int head_size,
+               int seqlen_q,
+               int seqlen_k,
+               float softmax_scale,
+               bool is_causal,
+               bool is_bf16,
+               int num_splits = 0,
+               void* softmax_lse_accum = nullptr,  // num_splits x batch_size x seqlen_q x num_heads
+               void* out_accum = nullptr,          // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
+               bool kv_bsnh = true,
+               int local_window_size = -1);
+
+OrtStatusPtr mha_varlen_fwd(const cudaDeviceProp& dprops,
+                      cudaStream_t stream,
+                      void* q,            // half (total_q, num_heads, head_size)
+                      void* k,            // half (total_k, num_heads, head_size)
+                      void* v,            // half (total_k, num_heads, v_head_size)
+                      void* out,          // half (total_q, num_heads, v_head_size)
+                      int* cu_seqlens_q,  // int (batch_size + 1)
+                      int* cu_seqlens_k,  // int (batch_size + 1)
+                      void* softmax_lse,  // float (batch_size, num_heads, max_seqlen_q)
+                      int batch_size,
+                      int num_heads,
+                      int num_heads_k,
+                      int head_size,
+                      int max_seqlen_q,
+                      int max_seqlen_k,
+                      float softmax_scale,
+                      bool is_causal,
+                      bool is_bf16);
+
+OrtStatusPtr mha_fwd_kvcache(const cudaDeviceProp& dprops,
+                       cudaStream_t stream,
+                       void* q,            // batch_size x seqlen_q x num_heads x head_size
+                       void* kcache,       // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size
+                       void* vcache,       // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size
+                       void* k,            // batch_size x seqlen_k_new x num_heads_k x head_size
+                       void* v,            // batch_size x seqlen_k_new x num_heads_k x head_size
+                       void* out,          // batch_size x seqlen_q x num_heads x head_size
+                       void* softmax_lse,  // batch_size x num_heads x seqlen_q
+                       void* seqlens_k_,   // batch_size
+                       void* rotary_sin,   // seqlen_ro x (rotary_dim / 2)
+                       void* rotary_cos,   // seqlen_ro x (rotary_dim / 2)
+                       int batch_size,
+                       int num_heads,
+                       int num_heads_k,
+                       int head_size,
+                       int seqlen_q,
+                       int seqlen_k,
+                       int seqlen_k_new,
+                       const float softmax_scale,
+                       bool is_causal,
+                       bool is_bf16,
+                       bool past_bsnh,  // otherwise bnsh
+                       int num_splits = 0,
+                       void* softmax_lse_accum = nullptr,  // num_splits x batch_size x seqlen_q x num_heads
+                       void* out_accum = nullptr,          // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
+                       int local_window_size = -1,
+                       bool is_rotary_interleaved = false,
+                       bool is_packed_qkv = false);
+
+size_t get_softmax_lse_size(int max_seqlen_q, int batch_size, int num_heads);
+
+std::tuple<int, int, int> get_num_splits_and_buffer_sizes(int batch_size, int seqlen_q, int seqlen_k, int num_heads,
+                                                          int head_size, int num_SMs);
+
+bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, int num_heads_k);
+
+}  // namespace flash
+
+#endif  //  OCOS_USE_FLASH_ATTENTION
diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim128_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim128_bf16_sm80.cu
new file mode 100644
index 000000000..8a9b32ee8
--- /dev/null
+++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim128_bf16_sm80.cu
@@ -0,0 +1,16 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+#if OCOS_USE_FLASH_ATTENTION
+
+#include "flash_fwd_launch_template.h"
+
+namespace flash {
+
+template<>
+void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params& params, cudaStream_t stream) {
+    run_mha_fwd_hdim128<cutlass::bfloat16_t>(params, stream);
+}
+
+}  // namespace flash
+#endif
diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim128_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim128_fp16_sm80.cu
new file mode 100644
index 000000000..e643d97ec
--- /dev/null
+++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim128_fp16_sm80.cu
@@ -0,0 +1,16 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+#if OCOS_USE_FLASH_ATTENTION
+
+#include "flash_fwd_launch_template.h"
+
+namespace flash {
+
+template <>
+void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params& params, cudaStream_t stream) {
+  run_mha_fwd_hdim128<cutlass::half_t>(params, stream);
+}
+
+}  // namespace flash
+#endif
diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim160_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim160_bf16_sm80.cu
new file mode 100644
index 000000000..bd716155d
--- /dev/null
+++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim160_bf16_sm80.cu
@@ -0,0 +1,16 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+#if OCOS_USE_FLASH_ATTENTION
+
+#include "flash_fwd_launch_template.h"
+
+namespace flash {
+
+template<>
+void run_mha_fwd_<cutlass::bfloat16_t, 160>(Flash_fwd_params& params, cudaStream_t stream) {
+    run_mha_fwd_hdim160<cutlass::bfloat16_t>(params, stream);
+}
+
+}  // namespace flash
+#endif
diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim160_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim160_fp16_sm80.cu
new file mode 100644
index 000000000..2b61a318a
--- /dev/null
+++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim160_fp16_sm80.cu
@@ -0,0 +1,16 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+#if OCOS_USE_FLASH_ATTENTION
+
+#include "flash_fwd_launch_template.h"
+
+namespace flash {
+
+template <>
+void run_mha_fwd_<cutlass::half_t, 160>(Flash_fwd_params& params, cudaStream_t stream) {
+  run_mha_fwd_hdim160<cutlass::half_t>(params, stream);
+}
+
+}  // namespace flash
+#endif
diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim192_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim192_bf16_sm80.cu
new file mode 100644
index 000000000..a08bc39d0
--- /dev/null
+++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim192_bf16_sm80.cu
@@ -0,0 +1,16 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+#if OCOS_USE_FLASH_ATTENTION
+
+#include "flash_fwd_launch_template.h"
+
+namespace flash {
+
+template<>
+void run_mha_fwd_<cutlass::bfloat16_t, 192>(Flash_fwd_params& params, cudaStream_t stream) {
+    run_mha_fwd_hdim192<cutlass::bfloat16_t>(params, stream);
+}
+
+}  // namespace flash
+#endif
diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim192_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim192_fp16_sm80.cu
new file mode 100644
index 000000000..d9fab2d60
--- /dev/null
+++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim192_fp16_sm80.cu
@@ -0,0 +1,16 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+#if OCOS_USE_FLASH_ATTENTION
+
+#include "flash_fwd_launch_template.h"
+
+namespace flash {
+
+template <>
+void run_mha_fwd_<cutlass::half_t, 192>(Flash_fwd_params& params, cudaStream_t stream) {
+  run_mha_fwd_hdim192<cutlass::half_t>(params, stream);
+}
+
+}  // namespace flash
+#endif
diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim224_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim224_bf16_sm80.cu
new file mode 100644
index 000000000..7d6d69378
--- /dev/null
+++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim224_bf16_sm80.cu
@@ -0,0 +1,16 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+#if OCOS_USE_FLASH_ATTENTION
+
+#include "flash_fwd_launch_template.h"
+
+namespace flash {
+
+template<>
+void run_mha_fwd_<cutlass::bfloat16_t, 224>(Flash_fwd_params& params, cudaStream_t stream) {
+    run_mha_fwd_hdim224<cutlass::bfloat16_t>(params, stream);
+}
+
+}  // namespace flash
+#endif
diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim224_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim224_fp16_sm80.cu
new file mode 100644
index 000000000..83b77523b
--- /dev/null
+++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim224_fp16_sm80.cu
@@ -0,0 +1,16 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+#if OCOS_USE_FLASH_ATTENTION
+
+#include "flash_fwd_launch_template.h"
+
+namespace flash {
+
+template <>
+void run_mha_fwd_<cutlass::half_t, 224>(Flash_fwd_params& params, cudaStream_t stream) {
+  run_mha_fwd_hdim224<cutlass::half_t>(params, stream);
+}
+
+}  // namespace flash
+#endif
diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim256_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim256_bf16_sm80.cu
new file mode 100644
index 000000000..954d752fb
--- /dev/null
+++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim256_bf16_sm80.cu
@@ -0,0 +1,16 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+#if OCOS_USE_FLASH_ATTENTION
+
+#include "flash_fwd_launch_template.h"
+
+namespace flash {
+
+template<>
+void run_mha_fwd_<cutlass::bfloat16_t, 256>(Flash_fwd_params& params, cudaStream_t stream) {
+    run_mha_fwd_hdim256<cutlass::bfloat16_t>(params, stream);
+}
+
+}  // namespace flash
+#endif
diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim256_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim256_fp16_sm80.cu
new file mode 100644
index 000000000..80045c2f9
--- /dev/null
+++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim256_fp16_sm80.cu
@@ -0,0 +1,16 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+#if OCOS_USE_FLASH_ATTENTION
+
+#include "flash_fwd_launch_template.h"
+
+namespace flash {
+
+template <>
+void run_mha_fwd_<cutlass::half_t, 256>(Flash_fwd_params& params, cudaStream_t stream) {
+  run_mha_fwd_hdim256<cutlass::half_t>(params, stream);
+}
+
+}  // namespace flash
+#endif
diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim32_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim32_bf16_sm80.cu
new file mode 100644
index 000000000..e27f7907e
--- /dev/null
+++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim32_bf16_sm80.cu
@@ -0,0 +1,16 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+#if OCOS_USE_FLASH_ATTENTION
+
+#include "flash_fwd_launch_template.h"
+
+namespace flash {
+
+template<>
+void run_mha_fwd_<cutlass::bfloat16_t, 32>(Flash_fwd_params& params, cudaStream_t stream) {
+    run_mha_fwd_hdim32<cutlass::bfloat16_t>(params, stream);
+}
+
+}  // namespace flash
+#endif
diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim32_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim32_fp16_sm80.cu
new file mode 100644
index 000000000..187db09a6
--- /dev/null
+++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim32_fp16_sm80.cu
@@ -0,0 +1,16 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+#if OCOS_USE_FLASH_ATTENTION
+
+#include "flash_fwd_launch_template.h"
+
+namespace flash {
+
+template <>
+void run_mha_fwd_<cutlass::half_t, 32>(Flash_fwd_params& params, cudaStream_t stream) {
+  run_mha_fwd_hdim32<cutlass::half_t>(params, stream);
+}
+
+}  // namespace flash
+#endif
diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim64_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim64_bf16_sm80.cu
new file mode 100644
index 000000000..34728f342
--- /dev/null
+++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim64_bf16_sm80.cu
@@ -0,0 +1,16 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+#if OCOS_USE_FLASH_ATTENTION
+
+#include "flash_fwd_launch_template.h"
+
+namespace flash {
+
+template<>
+void run_mha_fwd_<cutlass::bfloat16_t, 64>(Flash_fwd_params& params, cudaStream_t stream) {
+    run_mha_fwd_hdim64<cutlass::bfloat16_t>(params, stream);
+}
+
+}  // namespace flash
+#endif
diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim64_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim64_fp16_sm80.cu
new file mode 100644
index 000000000..c62f1342d
--- /dev/null
+++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim64_fp16_sm80.cu
@@ -0,0 +1,16 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+#if OCOS_USE_FLASH_ATTENTION
+
+#include "flash_fwd_launch_template.h"
+
+namespace flash {
+
+template <>
+void run_mha_fwd_<cutlass::half_t, 64>(Flash_fwd_params& params, cudaStream_t stream) {
+  run_mha_fwd_hdim64<cutlass::half_t>(params, stream);
+}
+
+}  // namespace flash
+#endif
diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim96_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim96_bf16_sm80.cu
new file mode 100644
index 000000000..6a7fb413e
--- /dev/null
+++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim96_bf16_sm80.cu
@@ -0,0 +1,16 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+#if OCOS_USE_FLASH_ATTENTION
+
+#include "flash_fwd_launch_template.h"
+
+namespace flash {
+
+template<>
+void run_mha_fwd_<cutlass::bfloat16_t, 96>(Flash_fwd_params& params, cudaStream_t stream) {
+    run_mha_fwd_hdim96<cutlass::bfloat16_t>(params, stream);
+}
+
+}  // namespace flash
+#endif
diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim96_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim96_fp16_sm80.cu
new file mode 100644
index 000000000..68b751b87
--- /dev/null
+++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim96_fp16_sm80.cu
@@ -0,0 +1,16 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+#if OCOS_USE_FLASH_ATTENTION
+
+#include "flash_fwd_launch_template.h"
+
+namespace flash {
+
+template <>
+void run_mha_fwd_<cutlass::half_t, 96>(Flash_fwd_params& params, cudaStream_t stream) {
+  run_mha_fwd_hdim96<cutlass::half_t>(params, stream);
+}
+
+}  // namespace flash
+#endif
diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_kernel.h b/operators/cuda/attention_lib/flash_attention/flash_fwd_kernel.h
new file mode 100644
index 000000000..c44a470f6
--- /dev/null
+++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_kernel.h
@@ -0,0 +1,1259 @@
+/******************************************************************************
+ * Copyright (c) 2023, Tri Dao.
+ ******************************************************************************/
+#pragma once
+
+#if defined(__GNUC__)
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wunused-variable"
+#pragma GCC diagnostic ignored "-Wunused-but-set-variable"
+#endif
+
+#include <cmath>
+#include <cute/algorithm/copy.hpp>
+#include <cute/algorithm/gemm.hpp>
+
+#include <cutlass/cutlass.h>
+#include <cutlass/array.h>
+#include <cutlass/numeric_types.h>
+#include <cutlass/numeric_conversion.h>
+
+#include "block_info.h"
+#include "kernel_traits.h"
+#include "utils.h"
+#include "softmax.h"
+
+namespace flash {
+using namespace cute;
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <bool Is_first, bool Check_inf = false, typename Tensor0, typename Tensor1, typename Tensor2>
+inline __device__ void softmax_rescale_o(Tensor0& scores, Tensor1& scores_max, Tensor1& scores_sum,
+                                         Tensor2& acc_o, float softmax_scale_log2) {
+  if (Is_first) {
+    flash::template reduce_max</*zero_init=*/true>(scores, scores_max);
+    flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2);
+    flash::reduce_sum(scores, scores_sum);
+  } else {
+    cute::Tensor scores_max_prev = make_fragment_like(scores_max);
+    cute::copy(scores_max, scores_max_prev);
+    flash::template reduce_max</*zero_init=*/false>(scores, scores_max);
+    // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
+    cute::Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
+#pragma unroll
+    for (int mi = 0; mi < cute::size(scores_max); ++mi) {
+      float scores_max_cur = !Check_inf
+                                 ? scores_max(mi)
+                                 : (scores_max(mi) == -INFINITY ? 0.0f : scores_max(mi));
+      float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
+      scores_sum(mi) *= scores_scale;
+#pragma unroll
+      for (int ni = 0; ni < cute::size<1>(acc_o_rowcol); ++ni) {
+        acc_o_rowcol(mi, ni) *= scores_scale;
+      }
+    }
+    flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2);
+    cute::Tensor scores_sum_cur = make_fragment_like(scores_sum);
+    flash::reduce_sum(scores, scores_sum_cur);
+#pragma unroll
+    for (int mi = 0; mi < cute::size(scores_sum); ++mi) {
+      scores_sum(mi) += scores_sum_cur(mi);
+    }
+  }
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename TiledCopy>
+inline __device__ void write_softmax_to_gmem(
+    cute::Tensor<Engine0, Layout0> const& tOrP, cute::Tensor<Engine1, Layout1>& tPgP, TiledCopy gmem_tiled_copy_P) {
+  // Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N)
+  cute::Layout l = tOrP.layout();
+  cute::Tensor tPrP = make_tensor(tOrP.data(), make_layout(get<0>(l), make_layout(get<1>(l), get<2>(l))));
+  CUTE_STATIC_ASSERT_V(cute::size<2>(tPgP) == _1{});
+  CUTE_STATIC_ASSERT_V(cute::size<1>(tPrP) == cute::size<1>(tPgP));
+#pragma unroll
+  for (int mi = 0; mi < cute::size<1>(tPrP); ++mi) {
+    cute::copy(gmem_tiled_copy_P, tPrP(_, mi), tPgP(_, mi, 0));
+  }
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <typename Kernel_traits, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
+inline __device__ void compute_attn_1rowblock(const Params& params, const int bidb, const int bidh, const int m_block) {
+  using Element = typename Kernel_traits::Element;
+  using ElementAccum = typename Kernel_traits::ElementAccum;
+  using index_t = typename Kernel_traits::index_t;
+
+  // Shared memory.
+  extern __shared__ char smem_[];
+
+  // The thread index.
+  const int tidx = threadIdx.x;
+
+  constexpr int kBlockM = Kernel_traits::kBlockM;
+  constexpr int kBlockN = Kernel_traits::kBlockN;
+  constexpr int kHeadDim = Kernel_traits::kHeadDim;
+  constexpr int kNWarps = Kernel_traits::kNWarps;
+  constexpr int MMA_M = kBlockM / decltype(cute::size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value;
+
+  const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
+  if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return;
+
+  const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN);
+  int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN);
+  if (Is_causal || Is_local) {
+    n_block_max = std::min(n_block_max,
+                           cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN));
+    // We exit early and write 0 to gO and gLSE.
+    // Otherwise we might read OOB elements from gK and gV.
+    if (n_block_max <= n_block_min) {
+      const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
+      const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
+      Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.o_ptr) + row_offset_o),
+                              Shape<Int<kBlockM>, Int<kHeadDim>>{},
+                              make_stride(params.o_row_stride, _1{}));
+      Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr) + row_offset_lse),
+                                Shape<Int<kBlockM>>{}, Stride<_1>{});
+
+      typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
+      auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
+      Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
+      Tensor tOrO = make_tensor<Element>(shape(tOgO));
+      clear(tOrO);
+      // Construct identity layout for sO
+      Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO)));  // (BLK_M,BLK_K) -> (blk_m,blk_k)
+      // Repeat the partitioning with identity layouts
+      Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
+      Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
+      if (!Is_even_K) {
+#pragma unroll
+        for (int k = 0; k < size(tOpO); ++k) {
+          tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d;
+        }
+      }
+      // Clear_OOB_K must be false since we don't want to write zeros to gmem
+      flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
+          gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM);
+#pragma unroll
+      for (int m = 0; m < size<1>(tOgO); ++m) {
+        const int row = get<0>(tOcO(0, m, 0));
+        if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) {
+          gLSE(row) = INFINITY;
+        }
+      }
+      return;
+    }
+  }
+
+  // We iterate over the blocks in reverse order. This is because the last block is the only one
+  // that needs masking when we read K and V from global memory. Moreover, iterating in reverse
+  // might save us 1 register (we just need n_block instead of both n_block and n_block_max).
+
+  const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
+  // We move K and V to the last block.
+  const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
+  const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
+  const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN;
+  cute::Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr) + row_offset_q),
+                                cute::Shape<cute::Int<kBlockM>, cute::Int<kHeadDim>>{},
+                                make_stride(params.q_row_stride, _1{}));
+  cute::Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.k_ptr) + row_offset_k),
+                                cute::Shape<cute::Int<kBlockN>, cute::Int<kHeadDim>>{},
+                                make_stride(params.k_row_stride, _1{}));
+  cute::Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.v_ptr) + row_offset_v),
+                                cute::Shape<cute::Int<kBlockN>, cute::Int<kHeadDim>>{},
+                                make_stride(params.v_row_stride, _1{}));
+  cute::Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.p_ptr) + row_offset_p),
+                                cute::Shape<cute::Int<kBlockM>, cute::Int<kBlockN>>{},
+                                make_stride(params.seqlen_k_rounded, _1{}));
+
+  cute::Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element*>(smem_)),
+                                typename Kernel_traits::SmemLayoutQ{});
+  // Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem;
+  cute::Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : cute::size(sQ)),
+                                typename Kernel_traits::SmemLayoutKV{});
+  cute::Tensor sV = make_tensor(sK.data() + cute::size(sK), typename Kernel_traits::SmemLayoutKV{});
+  cute::Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
+  cute::Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
+
+  typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
+  auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
+  typename Kernel_traits::GmemTiledCopyP gmem_tiled_copy_P;
+  auto gmem_thr_copy_P = gmem_tiled_copy_P.get_thread_slice(tidx);
+
+  cute::Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
+  cute::Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
+  cute::Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK);  // (KCPY, KCPY_N, KCPY_K)
+  cute::Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
+  cute::Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV);  // (VCPY, VCPY_N, VCPY_K)
+  cute::Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
+  cute::Tensor tPgP = gmem_thr_copy_P.partition_D(gP);
+
+  typename Kernel_traits::TiledMma tiled_mma;
+  auto thr_mma = tiled_mma.get_thread_slice(tidx);
+  cute::Tensor tSrQ = thr_mma.partition_fragment_A(sQ);             // (MMA,MMA_M,MMA_K)
+  cute::Tensor tSrK = thr_mma.partition_fragment_B(sK);             // (MMA,MMA_N,MMA_K)
+  cute::Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle);  // (MMA, MMA_K,MMA_N)
+
+  cute::Tensor acc_o = partition_fragment_C(tiled_mma, cute::Shape<cute::Int<kBlockM>, cute::Int<kHeadDim>>{});  // MMA, MMA_M, MMA_K
+
+  //
+  // Copy Atom retiling
+  //
+
+  auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
+  auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
+  cute::Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
+
+  auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
+  auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
+  cute::Tensor tSsK = smem_thr_copy_K.partition_S(sK);
+
+  auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
+  auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
+  cute::Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
+
+  // TODO: this might need to change if we change the mma instruction in SM70
+  cute::Tensor scores_max = make_tensor<ElementAccum>(cute::Shape<cute::Int<2 * cute::size<1>(acc_o)>>{});
+  cute::Tensor scores_sum = make_fragment_like(scores_max);
+
+  //
+  // PREDICATES
+  //
+
+  // Construct identity layout for sQ and sK
+  cute::Tensor cQ = make_identity_tensor(make_shape(cute::size<0>(sQ), cute::size<1>(sQ)));   // (BLK_M,BLK_K) -> (blk_m,blk_k)
+  cute::Tensor cKV = make_identity_tensor(make_shape(cute::size<0>(sK), cute::size<1>(sK)));  // (BLK_N,BLK_K) -> (blk_n,blk_k)
+
+  // Repeat the partitioning with identity layouts
+  cute::Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ);     // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
+  cute::Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV);  // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
+
+  // Allocate predicate tensors for k
+  cute::Tensor tQpQ = make_tensor<bool>(make_shape(cute::size<2>(tQsQ)));
+  cute::Tensor tKVpKV = make_tensor<bool>(make_shape(cute::size<2>(tKsK)));
+
+  // Set predicates for k bounds
+  if (!Is_even_K) {
+#pragma unroll
+    for (int k = 0; k < cute::size(tQpQ); ++k) {
+      tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d;
+    }
+#pragma unroll
+    for (int k = 0; k < cute::size(tKVpKV); ++k) {
+      tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d;
+    }
+  }
+
+  // Prologue
+
+  cute::Tensor tQrQ = make_fragment_like(tQgQ);
+  // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
+  flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
+                                     binfo.actual_seqlen_q - m_block * kBlockM);
+  if (Kernel_traits::Is_Q_in_regs) {
+    cute::cp_async_fence();
+  }
+
+  if (Kernel_traits::Share_Q_K_smem) {
+    flash::cp_async_wait<0>();
+    __syncthreads();
+    cute::Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
+    CUTE_STATIC_ASSERT_V(cute::size<1>(tSsQ) == cute::size<1>(tSrQ_copy_view));  // M
+    cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
+    __syncthreads();
+  }
+
+  int n_block = n_block_max - 1;
+  // We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
+  flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
+                                     binfo.actual_seqlen_k - n_block * kBlockN);
+  cute::cp_async_fence();
+
+  if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) {
+    flash::cp_async_wait<1>();
+    __syncthreads();
+    cute::Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
+    CUTE_STATIC_ASSERT_V(cute::size<1>(tSsQ) == cute::size<1>(tSrQ_copy_view));  // M
+    cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
+  }
+
+  clear(acc_o);
+
+  // For performance reason, we separate out two kinds of iterations:
+  // those that need masking on S, and those that don't.
+  // We need masking on S for the very last block when K and V has length not multiple of kBlockN.
+  // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
+  // We will have at least 1 "masking" iteration.
+
+  // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
+  // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
+  constexpr int n_masking_steps = (!Is_causal && !Is_local)
+                                      ? 1
+                                      : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1);
+#pragma unroll
+  for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) {
+    cute::Tensor acc_s = partition_fragment_C(tiled_mma, cute::Shape<cute::Int<kBlockM>, cute::Int<kBlockN>>{});  // (MMA=4, MMA_M, MMA_N)
+    clear(acc_s);
+    flash::cp_async_wait<0>();
+    __syncthreads();
+
+    // Advance gV
+    if (masking_step > 0) {
+      tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
+      flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
+    } else {
+      // Clear the smem tiles to account for predicated off loads
+      flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
+          gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN);
+    }
+    cute::cp_async_fence();
+
+    flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
+        acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
+        smem_thr_copy_Q, smem_thr_copy_K);
+    // if (cute::thread0()) { print(acc_s); }
+
+    // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
+    cute::Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
+
+    // We don't put the masking before the matmul S = Q K^T because we don't clear sK
+    // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
+    // can produce Inf / NaN.
+    if (!Is_causal && !Is_local) {
+      if (!Is_even_MN) {
+        flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN);
+      }
+    } else {
+      // I can't get the stride from idx_row
+      flash::apply_mask_local</*HasWSLeft=*/Is_local>(scores, n_block * kBlockN, binfo.actual_seqlen_k,
+                                                      // m_block * kBlockM + get<0>(idx_row(0)),
+                                                      m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
+                                                      binfo.actual_seqlen_q, kNWarps * 16,
+                                                      params.window_size_left, params.window_size_right);
+    }
+
+    flash::cp_async_wait<0>();
+    __syncthreads();
+    if (n_block > n_block_min) {
+      // Advance gK
+      tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
+      flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
+      // This cp_async_fence needs to be in the if block, otherwise the synchronization
+      // isn't right and we get race conditions.
+      cute::cp_async_fence();
+    }
+
+    // TODO: when we have key_padding_mask we'll need to Check_inf
+    masking_step == 0
+        ? softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2)
+        : softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
+
+    // Convert scores from fp32 to fp16/bf16
+    cute::Tensor rP = flash::convert_type<Element>(scores);
+    // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
+    // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
+    cute::Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
+    // if (Return_softmax) {
+    //   cute::Tensor tOrP_copy = make_fragment_like(tOrP);
+    //   copy(tOrP, tOrP_copy);
+    //   flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P);
+    //   tPgP.data() = tPgP.data() + (-kBlockN);
+    // }
+
+    flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
+
+    // This check is at the end of the loop since we always have at least 1 iteration
+    if (n_masking_steps > 1 && n_block <= n_block_min) {
+      --n_block;
+      break;
+    }
+  }
+
+  // These are the iterations where we don't need masking on S
+  for (; n_block >= n_block_min; --n_block) {
+    cute::Tensor acc_s = partition_fragment_C(tiled_mma, cute::Shape<cute::Int<kBlockM>, cute::Int<kBlockN>>{});  // (MMA=4, MMA_M, MMA_N)
+    clear(acc_s);
+    flash::cp_async_wait<0>();
+    __syncthreads();
+    // Advance gV
+    tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
+    flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
+    cute::cp_async_fence();
+
+    flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
+        acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
+        smem_thr_copy_Q, smem_thr_copy_K);
+
+    flash::cp_async_wait<0>();
+    __syncthreads();
+    if (n_block > n_block_min) {
+      // Advance gK
+      tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
+      flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
+      // This cp_async_fence needs to be in the if block, otherwise the synchronization
+      // isn't right and we get race conditions.
+      cute::cp_async_fence();
+    }
+
+    // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
+    Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
+    if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) {
+      flash::apply_mask_local(
+          scores, n_block * kBlockN, binfo.actual_seqlen_k,
+          m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
+          binfo.actual_seqlen_q, kNWarps * 16,
+          params.window_size_left, params.window_size_right);
+    }
+    softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
+
+    cute::Tensor rP = flash::convert_type<Element>(scores);
+    // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
+    // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
+    cute::Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
+    // if (Return_softmax) {
+    //   cute::Tensor tOrP_copy = make_fragment_like(tOrP);
+    //   copy(tOrP, tOrP_copy);
+    //   flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P);
+    //   tPgP.data() = tPgP.data() + (-kBlockN);
+    // }
+
+    flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
+  }
+
+  // Epilogue
+
+  // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
+  cute::Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
+  cute::Tensor lse = make_fragment_like(scores_sum);
+#pragma unroll
+  for (int mi = 0; mi < cute::size<0>(acc_o_rowcol); ++mi) {
+    float sum = scores_sum(mi);
+    float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
+    lse(mi) = (sum == 0.f || sum != sum) ? INFINITY : scores_max(mi) * params.scale_softmax + __logf(sum);
+    float scale = inv_sum;
+#pragma unroll
+    for (int ni = 0; ni < cute::size<1>(acc_o_rowcol); ++ni) {
+      acc_o_rowcol(mi, ni) *= scale;
+    }
+  }
+
+  // Convert acc_o from fp32 to fp16/bf16
+  cute::Tensor rO = flash::convert_type<Element>(acc_o);
+  cute::Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{});  // (SMEM_M,SMEM_N)
+  // Partition sO to match the accumulator partitioning
+  auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma);
+  auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx);  // auto smem_thr_copy_O = make_tiled_copy_C_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx);
+  cute::Tensor taccOrO = smem_thr_copy_O.retile_S(rO);              // ((Atom,AtomNum), MMA_M, MMA_N)
+  cute::Tensor taccOsO = smem_thr_copy_O.partition_D(sO);           // ((Atom,AtomNum),PIPE_M,PIPE_N)
+
+  // sO has the same size as sQ, so we don't need to sync here.
+  if (Kernel_traits::Share_Q_K_smem) {
+    __syncthreads();
+  }
+
+  cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
+
+  const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
+  const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
+  cute::Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.o_ptr) + row_offset_o),
+                                cute::Shape<cute::Int<kBlockM>, cute::Int<kHeadDim>>{},
+                                make_stride(params.o_row_stride, _1{}));
+  cute::Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr) + row_offset_lse),
+                                  cute::Shape<cute::Int<kBlockM>>{}, cute::Stride<_1>{});
+
+  typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
+  auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
+  cute::Tensor tOsO = gmem_thr_copy_O.partition_S(sO);  // ((Atom,AtomNum),ATOM_M,ATOM_N)
+  cute::Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
+
+  __syncthreads();
+
+  cute::Tensor tOrO = make_tensor<Element>(cute::shape(tOgO));
+  cute::copy(gmem_tiled_copy_O, tOsO, tOrO);
+
+  cute::Tensor caccO = make_identity_tensor(cute::Shape<cute::Int<kBlockM>, cute::Int<kHeadDim>>{});  // (BLK_M,BLK_K) -> (blk_m,blk_k)
+  cute::Tensor taccOcO = thr_mma.partition_C(caccO);                                                  // (MMA,MMA_M,MMA_K)
+  static_assert(decltype(cute::size<0>(taccOcO))::value == 4);
+  // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices.
+  cute::Tensor taccOcO_row = logical_divide(taccOcO, cute::Shape<_2>{})(make_coord(0, _), _, 0);
+  CUTE_STATIC_ASSERT_V(cute::size(lse) == cute::size(taccOcO_row));  // MMA_M
+  if (get<1>(taccOcO_row(0)) == 0) {
+#pragma unroll
+    for (int mi = 0; mi < cute::size(lse); ++mi) {
+      const int row = get<0>(taccOcO_row(mi));
+      if (row < binfo.actual_seqlen_q - m_block * kBlockM) {
+        gLSE(row) = lse(mi);
+      }
+    }
+  }
+
+  // Construct identity layout for sO
+  cute::Tensor cO = make_identity_tensor(make_shape(cute::size<0>(sO), cute::size<1>(sO)));  // (BLK_M,BLK_K) -> (blk_m,blk_k)
+  // Repeat the partitioning with identity layouts
+  cute::Tensor tOcO = gmem_thr_copy_O.partition_D(cO);  // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
+  cute::Tensor tOpO = make_tensor<bool>(make_shape(cute::size<2>(tOgO)));
+  if (!Is_even_K) {
+#pragma unroll
+    for (int k = 0; k < cute::size(tOpO); ++k) {
+      tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d;
+    }
+  }
+  // Clear_OOB_K must be false since we don't want to write zeros to gmem
+  flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
+      gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <typename Kernel_traits, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV, typename Params>
+inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) {
+  using Element = typename Kernel_traits::Element;
+  using ElementAccum = typename Kernel_traits::ElementAccum;
+  using index_t = typename Kernel_traits::index_t;
+
+  // Shared memory.
+  extern __shared__ char smem_[];
+
+  // The thread index.
+  const int tidx = threadIdx.x;
+
+  constexpr int kBlockM = Kernel_traits::kBlockM;
+  constexpr int kBlockN = Kernel_traits::kBlockN;
+  constexpr int kHeadDim = Kernel_traits::kHeadDim;
+  constexpr int kNWarps = Kernel_traits::kNWarps;
+
+  using GmemTiledCopyO = std::conditional_t<
+      !Split,
+      typename Kernel_traits::GmemTiledCopyOaccum,
+      typename Kernel_traits::GmemTiledCopyO>;
+  using ElementO = std::conditional_t<!Split, Element, ElementAccum>;
+
+  const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
+  // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d, actual_seqlen_k = %d\n", Is_even_MN, params.is_seqlens_k_cumulative, binfo.seqlen_k_cache, binfo.actual_seqlen_k); }
+  // if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); }
+  if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
+
+  const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits;
+  const int n_block_min = !Is_local
+                              ? n_split_idx * n_blocks_per_split
+                              : std::max(n_split_idx * n_blocks_per_split, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN);
+  int n_block_max = std::min(cute::ceil_div(binfo.actual_seqlen_k, kBlockN), (n_split_idx + 1) * n_blocks_per_split);
+  if (Is_causal || Is_local) {
+    n_block_max = std::min(n_block_max,
+                           cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN));
+  }
+  if (n_block_min >= n_block_max) {  // This also covers the case where n_block_max <= 0
+    // We exit early and write 0 to gOaccum and -inf to gLSEaccum.
+    // Otherwise we might read OOB elements from gK and gV,
+    // or get wrong results when we combine gOaccum from different blocks.
+    const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
+    const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_rounded;
+    const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
+    Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO*>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),
+                                 Shape<Int<kBlockM>, Int<kHeadDim>>{},
+                                 make_stride(Split ? kHeadDim : params.o_row_stride, _1{}));
+    Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum),
+                                   Shape<Int<kBlockM>>{}, Stride<_1>{});
+
+    GmemTiledCopyO gmem_tiled_copy_Oaccum;
+    auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
+    Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);
+    Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum));
+    clear(tOrOaccum);
+    // Construct identity layout for sO
+    Tensor cO = make_identity_tensor(make_shape(size<0>(gOaccum), size<1>(gOaccum)));  // (BLK_M,BLK_K) -> (blk_m,blk_k)
+    // Repeat the partitioning with identity layouts
+    Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO);
+    Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
+    if (!Is_even_K) {
+#pragma unroll
+      for (int k = 0; k < size(tOpO); ++k) {
+        tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d;
+      }
+    }
+    // Clear_OOB_K must be false since we don't want to write zeros to gmem
+    flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
+        gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM);
+#pragma unroll
+    for (int m = 0; m < size<1>(tOgOaccum); ++m) {
+      const int row = get<0>(tOcO(0, m, 0));
+      if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) {
+        gLSEaccum(row) = Split ? -INFINITY : INFINITY;
+      }
+    }
+    return;
+  }
+
+  // We iterate over the blocks in reverse order. This is because the last block is the only one
+  // that needs masking when we read K and V from global memory. Moreover, iterating in reverse
+  // might save us 1 register (we just need n_block instead of both n_block and n_block_max).
+
+  const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
+  // We move K and V to the last block.
+  const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb];
+  const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
+  const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
+
+  Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr) + row_offset_q),
+                          Shape<Int<kBlockM>, Int<kHeadDim>>{},
+                          make_stride(params.q_row_stride, _1{}));
+  Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.k_ptr) + row_offset_k),
+                          Shape<Int<kBlockN>, Int<kHeadDim>>{},
+                          make_stride(params.k_row_stride, _1{}));
+  // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k, gK.data()); }
+  Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.v_ptr) + row_offset_v),
+                          Shape<Int<kBlockN>, Int<kHeadDim>>{},
+                          make_stride(params.v_row_stride, _1{}));
+
+  Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element*>(smem_)),
+                          typename Kernel_traits::SmemLayoutQ{});
+  Tensor sK = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutKV{});
+  Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});
+  Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
+  Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
+
+  typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
+  auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
+
+  Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
+  Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
+  Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK);  // (KCPY, KCPY_N, KCPY_K)
+  Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
+  Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV);  // (VCPY, VCPY_N, VCPY_K)
+  Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
+
+  typename Kernel_traits::TiledMma tiled_mma;
+  auto thr_mma = tiled_mma.get_thread_slice(tidx);
+  Tensor tSrQ = thr_mma.partition_fragment_A(sQ);             // (MMA,MMA_M,MMA_K)
+  Tensor tSrK = thr_mma.partition_fragment_B(sK);             // (MMA,MMA_N,MMA_K)
+  Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle);  // (MMA, MMA_K,MMA_N)
+
+  Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{});  // MMA, MMA_M, MMA_K
+
+  //
+  // Copy Atom retiling
+  //
+
+  auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
+  auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
+  Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
+
+  auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
+  auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
+  Tensor tSsK = smem_thr_copy_K.partition_S(sK);
+
+  auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
+  auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
+  Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
+
+  // TODO: this might need to change if we change the mma instruction in SM70
+  Tensor scores_max = make_tensor<ElementAccum>(Shape<Int<2 * size<1>(acc_o)>>{});
+  Tensor scores_sum = make_fragment_like(scores_max);
+
+  //
+  // PREDICATES
+  //
+
+  // // Allocate predicate tensors for m and n
+  // Tensor tQpQ = make_tensor<bool>(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{});
+  // Tensor tKVpKV = make_tensor<bool>(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{});
+
+  // Construct identity layout for sQ and sK
+  Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ)));   // (BLK_M,BLK_K) -> (blk_m,blk_k)
+  Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK)));  // (BLK_N,BLK_K) -> (blk_n,blk_k)
+
+  // Repeat the partitioning with identity layouts
+  Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ);     // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
+  Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV);  // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
+
+  // Allocate predicate tensors for k
+  Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
+  Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK)));
+
+  // Set predicates for k bounds
+  if (!Is_even_K) {
+#pragma unroll
+    for (int k = 0; k < size(tQpQ); ++k) {
+      tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d;
+    }
+#pragma unroll
+    for (int k = 0; k < size(tKVpKV); ++k) {
+      tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d;
+    }
+  }
+
+  // Prologue
+  // Copy from Knew to K, optionally apply rotary embedding.
+  typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary;
+  auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx);
+  typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont;
+  auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx);
+  if constexpr (Append_KV) {
+    // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to
+    // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe.
+    // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache.
+    const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2);
+    Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.rotary_cos_ptr) + row_offset_cossin),
+                              Shape<Int<kBlockN>, Int<kHeadDim / 2>>{},
+                              make_stride(params.rotary_dim / 2, _1{}));
+    Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.rotary_sin_ptr) + row_offset_cossin),
+                              Shape<Int<kBlockN>, Int<kHeadDim / 2>>{},
+                              make_stride(params.rotary_dim / 2, _1{}));
+    Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.rotary_cos_ptr) + row_offset_cossin),
+                                  Shape<Int<kBlockN>, Int<kHeadDim>>{},
+                                  make_stride(params.rotary_dim / 2, _1{}));
+    Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.rotary_sin_ptr) + row_offset_cossin),
+                                  Shape<Int<kBlockN>, Int<kHeadDim>>{},
+                                  make_stride(params.rotary_dim / 2, _1{}));
+    Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos);
+    Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin);
+    Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont);
+    Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont);
+    // if (cute::thread(0, 0)) { printf("rotary_cos_ptr = %p, gCos.data() = %p, tRgCos.data() = %p, rotary_dim = %d\n", params.rotary_cos_ptr, gCos.data(), tRgCos.data(), params.rotary_dim); }
+    // if (cute::thread(8, 0)) { print_tensor(gCos); }
+    // if (cute::thread(0, 0)) { print_tensor(tRgCos); }
+
+    const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride;
+    const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride;
+    // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them,
+    // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64].
+    // This maps to accessing the first 64 rows of knew_ptr.
+    Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.knew_ptr) + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride),
+                               Shape<Int<kBlockN>, Int<kHeadDim>>{},
+                               make_stride(params.knew_row_stride, _1{}));
+    // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); }
+    Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.vnew_ptr) + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride),
+                               Shape<Int<kBlockN>, Int<kHeadDim>>{},
+                               make_stride(params.vnew_row_stride, _1{}));
+    Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew);  // (KCPY, KCPY_N, KCPY_K)
+    Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew);  // (VCPY, VCPY_N, VCPY_K)
+
+    const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN);
+    for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) {
+      flash::copy_w_min_idx<Is_even_K>(
+          tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN);
+      tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
+      tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride));
+      if (params.rotary_dim == 0) {
+        flash::copy_w_min_idx<Is_even_K>(
+            tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN);
+      } else {
+        if (params.is_rotary_interleaved) {
+          // Don't clear OOB_K because we're writing to global memory
+          flash::copy_rotary_interleaved<Is_even_K, /*Clear_OOB_K=*/false>(
+              tKgKnew, tKgK, tRgCos, tRgSin, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN,
+              binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim);
+          tRgCos.data() = tRgCos.data() + (-int(kBlockN * params.rotary_dim / 2));
+          tRgSin.data() = tRgSin.data() + (-int(kBlockN * params.rotary_dim / 2));
+        } else {
+          // Don't clear OOB_K because we're writing to global memory
+          flash::copy_rotary_contiguous<Is_even_K, /*Clear_OOB_K=*/false>(
+              tKgKnew, tKgK, tRgCosCont, tRgSinCont, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN,
+              binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim);
+          tRgCosCont.data() = tRgCosCont.data() + (-int(kBlockN * params.rotary_dim / 2));
+          tRgSinCont.data() = tRgSinCont.data() + (-int(kBlockN * params.rotary_dim / 2));
+        }
+      }
+      tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
+      tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride));
+    }
+    // Need this before we can read in K again, so that we'll see the updated K values.
+    __syncthreads();
+    if (n_block_max > n_block_copy_min) {
+      tKgK.data() = tKgK.data() + (n_block_max - n_block_copy_min) * kBlockN * params.k_row_stride;
+      tVgV.data() = tVgV.data() + (n_block_max - n_block_copy_min) * kBlockN * params.v_row_stride;
+    }
+  }
+
+  // Read Q from gmem to smem, optionally apply rotary embedding.
+  Tensor tQrQ = make_fragment_like(tQgQ);
+  if (!Append_KV || params.rotary_dim == 0) {
+    // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
+    flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
+                                       binfo.actual_seqlen_q - m_block * kBlockM);
+  } else {
+    const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2);
+    // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache.
+    // We do this by setting the row stride of gCos / gSin to 0.
+    Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.rotary_cos_ptr) + row_offset_cossin),
+                              Shape<Int<kBlockM>, Int<kHeadDim / 2>>{},
+                              make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
+    Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.rotary_sin_ptr) + row_offset_cossin),
+                              Shape<Int<kBlockM>, Int<kHeadDim / 2>>{},
+                              make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
+    Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.rotary_cos_ptr) + row_offset_cossin),
+                                  Shape<Int<kBlockM>, Int<kHeadDim>>{},
+                                  make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
+    Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.rotary_sin_ptr) + row_offset_cossin),
+                                  Shape<Int<kBlockM>, Int<kHeadDim>>{},
+                                  make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
+    Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos);
+    Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin);
+    Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont);
+    Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont);
+    if (params.is_rotary_interleaved) {
+      flash::copy_rotary_interleaved<Is_even_K>(
+          tQgQ, tQsQ, tRgCos, tRgSin, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM,
+          0, params.d, params.rotary_dim);
+    } else {
+      flash::copy_rotary_contiguous<Is_even_K>(
+          tQgQ, tQsQ, tRgCosCont, tRgSinCont, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM,
+          0, params.d, params.rotary_dim);
+    }
+  }
+
+  int n_block = n_block_max - 1;
+  // We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
+  flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
+                                     binfo.actual_seqlen_k - n_block * kBlockN);
+  cute::cp_async_fence();
+
+  // flash::cp_async_wait<0>();
+  // __syncthreads();
+  // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tKsK); }
+  // __syncthreads();
+
+  clear(acc_o);
+
+  // For performance reason, we separate out two kinds of iterations:
+  // those that need masking on S, and those that don't.
+  // We need masking on S for the very last block when K and V has length not multiple of kBlockN.
+  // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
+  // We will have at least 1 "masking" iteration.
+
+  // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
+  // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
+  constexpr int n_masking_steps = (!Is_causal && !Is_local)
+                                      ? 1
+                                      : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1);
+#pragma unroll
+  for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) {
+    Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});  // (MMA=4, MMA_M, MMA_N)
+    clear(acc_s);
+    flash::cp_async_wait<0>();
+    __syncthreads();
+
+    // Advance gV
+    if (masking_step > 0) {
+      tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
+      flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
+    } else {
+      // Clear the smem tiles to account for predicated off loads
+      flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
+          gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN);
+    }
+    cute::cp_async_fence();
+
+    flash::gemm(
+        acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
+        smem_thr_copy_Q, smem_thr_copy_K);
+    // if (cute::thread0()) { print(acc_s); }
+
+    // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
+    Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
+    // if (cute::thread0()) { print(scores); }
+    // We don't put the masking before the matmul S = Q K^T because we don't clear sK
+    // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
+    // can produce Inf / NaN.
+    if (!Is_causal && !Is_local) {
+      if (!Is_even_MN) {
+        flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN);
+      }
+    } else {
+      flash::apply_mask_local(scores, n_block * kBlockN, binfo.actual_seqlen_k,
+                              m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
+                              binfo.actual_seqlen_q, kNWarps * 16,
+                              params.window_size_left, params.window_size_right);
+    }
+
+    flash::cp_async_wait<0>();
+    __syncthreads();
+    // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); }
+    // __syncthreads();
+
+    if (n_block > n_block_min) {
+      // Advance gK
+      tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
+      flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
+      // This cp_async_fence needs to be in the if block, otherwise the synchronization
+      // isn't right and we get race conditions.
+      cute::cp_async_fence();
+    }
+
+    // We have key_padding_mask so we'll need to Check_inf
+    masking_step == 0
+        ? softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local || !Is_even_MN>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2)
+        : softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local || !Is_even_MN>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
+    // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); }
+
+    // Convert scores from fp32 to fp16/bf16
+    Tensor rP = flash::convert_type<Element>(scores);
+    // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
+    // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
+    Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
+
+    flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
+    // if (cute::thread0()) { print(scores); }
+
+    // This check is at the end of the loop since we always have at least 1 iteration
+    if (n_masking_steps > 1 && n_block <= n_block_min) {
+      --n_block;
+      break;
+    }
+  }
+
+  // These are the iterations where we don't need masking on S
+  for (; n_block >= n_block_min; --n_block) {
+    Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});  // (MMA=4, MMA_M, MMA_N)
+    clear(acc_s);
+    flash::cp_async_wait<0>();
+    __syncthreads();
+    // Advance gV
+    tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
+    flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
+    cute::cp_async_fence();
+
+    flash::gemm(
+        acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
+        smem_thr_copy_Q, smem_thr_copy_K);
+
+    flash::cp_async_wait<0>();
+    __syncthreads();
+    if (n_block > n_block_min) {
+      // Advance gK
+      tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
+      flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
+      // This cp_async_fence needs to be in the if block, otherwise the synchronization
+      // isn't right and we get race conditions.
+      cute::cp_async_fence();
+    }
+
+    // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
+    Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
+    if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) {
+      flash::apply_mask_local(
+          scores, n_block * kBlockN, binfo.actual_seqlen_k,
+          m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
+          binfo.actual_seqlen_q, kNWarps * 16,
+          params.window_size_left, params.window_size_right);
+    }
+    softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
+
+    Tensor rP = flash::convert_type<Element>(scores);
+    // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
+    // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
+    Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
+
+    flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
+  }
+
+  // Epilogue
+
+  // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
+  Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
+  // if (cute::thread0()) { print(acc_o_rowcol); }
+  Tensor lse = make_fragment_like(scores_sum);
+#pragma unroll
+  for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
+    float sum = scores_sum(mi);
+    float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
+    lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : scores_max(mi) * params.scale_softmax + __logf(sum);
+    float scale = inv_sum;
+#pragma unroll
+    for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) {
+      acc_o_rowcol(mi, ni) *= scale;
+    }
+  }
+  // if (cute::thread0()) { print(lse); }
+  // if (cute::thread0()) { print(acc_o_rowcol); }
+
+  Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementO*>(smem_)), typename Kernel_traits::SmemLayoutO{});  // (SMEM_M,SMEM_N)
+  // Partition sO to match the accumulator partitioning
+  using SmemTiledCopyO = std::conditional_t<
+      !Split,
+      typename Kernel_traits::SmemCopyAtomO,
+      typename Kernel_traits::SmemCopyAtomOaccum>;
+  auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma);
+  auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx);
+  Tensor rO = flash::convert_type<ElementO>(acc_o);
+  Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO);          // ((Atom,AtomNum), MMA_M, MMA_N)
+  Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum);  // ((Atom,AtomNum),PIPE_M,PIPE_N)
+
+  // sOaccum is larger than sQ, so we need to syncthreads here
+  // TODO: allocate enough smem for sOaccum
+  if constexpr (Split) {
+    __syncthreads();
+  }
+
+  cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum);
+
+  const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
+  const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_rounded;
+  const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
+
+  Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO*>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),
+                               Shape<Int<kBlockM>, Int<kHeadDim>>{},
+                               make_stride(Split ? kHeadDim : params.o_row_stride, _1{}));
+  Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum),
+                                 Shape<Int<kBlockM>>{}, Stride<_1>{});
+  // if (tidx == 0) { printf("row_offset_o = %d, bidh = %d, gOaccum = %p\n", row_offset_o, bidh, gOaccum.data()); }
+
+  GmemTiledCopyO gmem_tiled_copy_Oaccum;
+  auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
+  Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum);  // ((Atom,AtomNum),ATOM_M,ATOM_N)
+  Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);
+
+  __syncthreads();
+
+  Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum));
+  cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum);
+
+  Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{});  // (BLK_M,BLK_K) -> (blk_m,blk_k)
+  Tensor taccOcO = thr_mma.partition_C(caccO);                                // (MMA,MMA_M,MMA_K)
+  static_assert(decltype(size<0>(taccOcO))::value == 4);
+  // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices.
+  Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0);
+  CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row));  // MMA_M
+  if (get<1>(taccOcO_row(0)) == 0) {
+#pragma unroll
+    for (int mi = 0; mi < size(lse); ++mi) {
+      const int row = get<0>(taccOcO_row(mi));
+      if (row < binfo.actual_seqlen_q - m_block * kBlockM) {
+        gLSEaccum(row) = lse(mi);
+      }
+    }
+  }
+
+  // Construct identity layout for sO
+  Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum)));  // (BLK_M,BLK_K) -> (blk_m,blk_k)
+  // Repeat the partitioning with identity layouts
+  Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO);  // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
+  Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
+  if (!Is_even_K) {
+#pragma unroll
+    for (int k = 0; k < size(tOpO); ++k) {
+      tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d;
+    }
+  }
+  // Clear_OOB_K must be false since we don't want to write zeros to gmem
+  flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
+      gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM);
+  // __syncthreads();
+  // if (cute::thread0()) { print(tOgOaccum); }
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <typename Kernel_traits, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
+inline __device__ void compute_attn(const Params& params) {
+  const int m_block = blockIdx.x;
+  // The block index for the batch.
+  const int bidb = blockIdx.y;
+  // The block index for the head.
+  const int bidh = blockIdx.z;
+
+  // We want the fwd and bwd to generate the same dropout pattern (RNG), without restricting
+  // them to have the same number of threads or have to traverse the attention matrix
+  // in the same order.
+  // In the Philox RNG, we use the offset to store the batch, head, and the lane id
+  // (within a warp). We use the subsequence to store the location of the 16 x 32 blocks within
+  // the attention matrix. This way, as long as we have the batch, head, and the location of
+  // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
+
+  flash::compute_attn_1rowblock<Kernel_traits, Is_causal, Is_local, Is_even_MN, Is_even_K, Return_softmax>(params, bidb, bidh, m_block);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <typename Kernel_traits, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV, typename Params>
+inline __device__ void compute_attn_splitkv(const Params& params) {
+  const int m_block = blockIdx.x;
+  // The block index for the batch.
+  const int bidb = Split ? blockIdx.z / params.h : blockIdx.y;
+  // The block index for the head.
+  const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z;
+  const int n_split_idx = Split ? blockIdx.y : 0;
+  const int num_n_splits = Split ? gridDim.y : 1;
+  flash::compute_attn_1rowblock_splitkv<Kernel_traits, Is_causal, Is_local, Is_even_MN, Is_even_K, Split, Append_KV>(params, bidb, bidh, m_block, n_split_idx, num_n_splits);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <typename Kernel_traits, int kBlockM, int Log_max_splits, bool Is_even_K, typename Params>
+inline __device__ void combine_attn_seqk_parallel(const Params& params) {
+  using Element = typename Kernel_traits::Element;
+  using ElementAccum = typename Kernel_traits::ElementAccum;
+  using index_t = typename Kernel_traits::index_t;
+  constexpr int kMaxSplits = 1 << Log_max_splits;
+  constexpr int kHeadDim = Kernel_traits::kHeadDim;
+  constexpr int kNThreads = Kernel_traits::kNThreads;
+
+  static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128");
+  static_assert(kBlockM == 4 || kBlockM == 8 || kBlockM == 16 || kBlockM == 32, "kBlockM must be 4, 8, 16 or 32");
+  static_assert(kNThreads == 128, "We assume that each block has 128 threads");
+
+  // Shared memory.
+  // kBlockM + 1 instead of kBlockM to reduce bank conflicts.
+  __shared__ ElementAccum sLSE[kMaxSplits][kBlockM + 1];
+
+  // The thread and block index.
+  const int tidx = threadIdx.x;
+  const int bidx = blockIdx.x;
+
+  const index_t row_offset_lse = bidx * kBlockM;
+  Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.softmax_lseaccum_ptr) + row_offset_lse),
+                                 Shape<Int<kMaxSplits>, Int<kBlockM>>{},
+                                 make_stride(params.b * params.h * params.seqlen_q, _1{}));
+  Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr) + row_offset_lse),
+                            Shape<Int<kBlockM>>{}, Stride<_1>{});
+  constexpr int kNLsePerThread = (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads;
+
+  // Read the LSE values from gmem and store them in shared memory, then tranpose them.
+  constexpr int kRowsPerLoadLSE = kNThreads / kBlockM;
+#pragma unroll
+  for (int l = 0; l < kNLsePerThread; ++l) {
+    const int row = l * kRowsPerLoadLSE + tidx / kBlockM;
+    const int col = tidx % kBlockM;
+    ElementAccum lse = (row < params.num_splits && col < params.b * params.h * params.seqlen_q - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY;
+    if (row < kMaxSplits) {
+      sLSE[row][col] = lse;
+    }
+    // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); }
+  }
+  // if (bidx == 1 && tidx < 32) { printf("tidx = %d, row_offset_lse = %d, lse = %f\n", tidx, row_offset_lse, lse_accum(0)); }
+  __syncthreads();
+  Tensor lse_accum = make_tensor<ElementAccum>(Shape<Int<kNLsePerThread>>{});
+  constexpr int kRowsPerLoadTranspose = std::min(kRowsPerLoadLSE, kMaxSplits);
+  // To make sure that kMaxSplits is within 1 warp: we decide how many elements within kMaxSplits
+  // each thread should hold. If kMaxSplits = 16, then each thread holds 2 elements (128 threads,
+  // 16 rows, so each time we load we can load 8 rows).
+  // constexpr int kThreadsPerSplit = kMaxSplits / kRowsPerLoadTranspose;
+  // static_assert(kThreadsPerSplit <= 32);
+  static_assert(kRowsPerLoadTranspose <= 32);
+  static_assert(kNLsePerThread * kRowsPerLoadTranspose <= kMaxSplits);
+#pragma unroll
+  for (int l = 0; l < kNLsePerThread; ++l) {
+    const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose;
+    const int col = tidx / kRowsPerLoadTranspose;
+    lse_accum(l) = (row < kMaxSplits && col < kBlockM) ? sLSE[row][col] : -INFINITY;
+    // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); }
+  }
+
+  // Compute the logsumexp of the LSE along the split dimension.
+  ElementAccum lse_max = lse_accum(0);
+#pragma unroll
+  for (int l = 1; l < kNLsePerThread; ++l) {
+    lse_max = max(lse_max, lse_accum(l));
+  }
+  MaxOp<float> max_op;
+  lse_max = Allreduce<kRowsPerLoadTranspose>::run(lse_max, max_op);
+  lse_max = lse_max == -INFINITY ? 0.0f : lse_max;  // In case all local LSEs are -inf
+  float lse_sum = expf(lse_accum(0) - lse_max);
+#pragma unroll
+  for (int l = 1; l < kNLsePerThread; ++l) {
+    lse_sum += expf(lse_accum(l) - lse_max);
+  }
+  SumOp<float> sum_op;
+  lse_sum = Allreduce<kRowsPerLoadTranspose>::run(lse_sum, sum_op);
+  // For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise
+  // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum.
+  ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max;
+  // if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); }
+  if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) {
+    gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum;
+  }
+// Store the scales exp(lse - lse_logsum) in shared memory.
+#pragma unroll
+  for (int l = 0; l < kNLsePerThread; ++l) {
+    const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose;
+    const int col = tidx / kRowsPerLoadTranspose;
+    if (row < params.num_splits && col < kBlockM) {
+      sLSE[row][col] = expf(lse_accum(l) - lse_logsum);
+    }
+  }
+  __syncthreads();
+
+  const index_t row_offset_oaccum = bidx * kBlockM * params.d_rounded;
+  Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.oaccum_ptr) + row_offset_oaccum),
+                               Shape<Int<kBlockM>, Int<kHeadDim>>{},
+                               Stride<Int<kHeadDim>, _1>{});
+  constexpr int kBlockN = kNThreads / kBlockM;
+  using GmemLayoutAtomOaccum = Layout<Shape<Int<kBlockM>, Int<kBlockN>>, Stride<Int<kBlockN>, _1>>;
+  using GmemTiledCopyOaccum = decltype(make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
+                                                       GmemLayoutAtomOaccum{},
+                                                       Layout<Shape<_1, _4>>{}));  // Val layout, 4 vals per store
+  GmemTiledCopyOaccum gmem_tiled_copy_Oaccum;
+  auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
+  Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum);
+  Tensor tOrO = make_tensor<ElementAccum>(shape(tOgOaccum));
+  Tensor tOrOaccum = make_tensor<ElementAccum>(shape(tOgOaccum));
+  clear(tOrO);
+
+  // Predicates
+  Tensor cOaccum = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{});
+  // Repeat the partitioning with identity layouts
+  Tensor tOcOaccum = gmem_thr_copy_Oaccum.partition_S(cOaccum);
+  Tensor tOpOaccum = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
+  if (!Is_even_K) {
+#pragma unroll
+    for (int k = 0; k < size(tOpOaccum); ++k) {
+      tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d;
+    }
+  }
+  // Load Oaccum in then scale and accumulate to O
+  for (int split = 0; split < params.num_splits; ++split) {
+    flash::copy</*Is_even_MN=*/false, Is_even_K>(
+        gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * params.seqlen_q - bidx * kBlockM);
+#pragma unroll
+    for (int m = 0; m < size<1>(tOrOaccum); ++m) {
+      int row = get<0>(tOcOaccum(0, m, 0));
+      ElementAccum lse_scale = sLSE[split][row];
+#pragma unroll
+      for (int k = 0; k < size<2>(tOrOaccum); ++k) {
+#pragma unroll
+        for (int i = 0; i < size<0>(tOrOaccum); ++i) {
+          tOrO(i, m, k) += lse_scale * tOrOaccum(i, m, k);
+        }
+      }
+      // if (cute::thread0()) { printf("lse_scale = %f, %f\n", sLSE[split][0], sLSE[split][1]); print(tOrOaccum); print(tOrO); }
+    }
+    tOgOaccum.data() = tOgOaccum.data() + params.b * params.h * params.seqlen_q * params.d_rounded;
+  }
+  // if (cute::thread0()) { print(tOrO); }
+
+  Tensor rO = flash::convert_type<Element>(tOrO);
+// Write to gO
+#pragma unroll
+  for (int m = 0; m < size<1>(rO); ++m) {
+    const int idx = bidx * kBlockM + get<0>(tOcOaccum(0, m, 0));
+    if (idx < params.b * params.h * params.seqlen_q) {
+      const int batch_idx = idx / (params.h * params.seqlen_q);
+      const int head_idx = (idx - batch_idx * (params.h * params.seqlen_q)) / params.seqlen_q;
+      // The index to the rows of Q
+      const int row = idx - batch_idx * (params.h * params.seqlen_q) - head_idx * params.seqlen_q;
+      auto o_ptr = reinterpret_cast<Element*>(params.o_ptr) + batch_idx * params.o_batch_stride + head_idx * params.o_head_stride + row * params.o_row_stride;
+#pragma unroll
+      for (int k = 0; k < size<2>(rO); ++k) {
+        if (Is_even_K || tOpOaccum(k)) {
+          const int col = get<1>(tOcOaccum(0, m, k));
+          Tensor gO = make_tensor(make_gmem_ptr(o_ptr + col),
+                                  Shape<Int<decltype(size<0>(rO))::value>>{}, Stride<_1>{});
+          // TODO: Should check if this is using vectorized store, but it seems pretty fast
+          copy(rO(_, m, k), gO);
+          // if (bidx == 0 && tidx == 0) { printf("tidx = %d, idx = %d, batch_idx = %d, head_idx = %d, row = %d, col = %d\n", tidx, idx, batch_idx, head_idx, row, col); print(rO(_, m, k)); print(gO); }
+          // reinterpret_cast<uint64_t *>(o_ptr)[col / 4] = recast<uint64_t>(rO)(0, m, k);
+        }
+      }
+    }
+  }
+}
+
+}  // namespace flash
+
+#if defined(__GNUC__)
+#pragma GCC diagnostic pop
+#endif
diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_launch_template.h b/operators/cuda/attention_lib/flash_attention/flash_fwd_launch_template.h
new file mode 100644
index 000000000..e2f2505a7
--- /dev/null
+++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_launch_template.h
@@ -0,0 +1,294 @@
+/******************************************************************************
+ * Copyright (c) 2023, Tri Dao.
+ ******************************************************************************/
+#pragma once
+
+#include "static_switch.h"
+#include "flash.h"
+#include "flash_fwd_kernel.h"
+
+namespace flash {
+
+template <typename Kernel_traits, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Return_softmax>
+__global__ void flash_fwd_kernel(Flash_fwd_params params) {
+  static_assert(!(Is_causal && Is_local));  // If Is_local is true, Is_causal should be false
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+  flash::compute_attn<Kernel_traits, Is_causal, Is_local, Is_even_MN, Is_even_K, Return_softmax>(params);
+#else
+  (void)params;
+#endif
+}
+
+template <typename Kernel_traits, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV>
+__global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+  flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_local, Is_even_MN, Is_even_K, Split, Append_KV>(params);
+#else
+  (void)params;
+#endif
+}
+
+template <typename Kernel_traits, int kBlockM, int Log_max_splits, bool Is_even_K>
+__global__ void flash_fwd_splitkv_combine_kernel(Flash_fwd_params params) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+  static_assert(Log_max_splits >= 1);
+  flash::combine_attn_seqk_parallel<Kernel_traits, kBlockM, Log_max_splits, Is_even_K>(params);
+#else
+  (void)params;
+#endif
+}
+
+template <typename Kernel_traits, bool Is_causal>
+void run_flash_fwd(Flash_fwd_params& params, cudaStream_t stream) {
+  constexpr size_t smem_size = Kernel_traits::kSmemSize;
+
+  // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
+  // https://github.com/kokkos/kokkos-kernels/issues/349
+  // https://github.com/HazyResearch/flash-attention/issues/21
+
+  const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
+  dim3 grid(num_m_block, params.b, params.h);
+  const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
+  const bool is_even_K = params.d == Kernel_traits::kHeadDim;
+  BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
+    BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
+      BOOL_SWITCH(params.window_size_left >= 0 || params.window_size_right >= 0, Is_local, [&] {
+        // Will only return softmax if dropout, to reduce compilation time.
+        // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
+        // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
+        // If Is_local, set Is_causal to false
+        auto kernel = &flash_fwd_kernel < Kernel_traits, Is_causal && !Is_local, Is_local, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, false > ;
+        // auto kernel = &flash_fwd_kernel<Kernel_traits, Is_causal, IsEvenMNConst, true, ReturnSoftmaxConst>;
+        if (smem_size >= 48 * 1024) {
+          cudaFuncSetAttribute(
+              kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
+          // ORT_ENFORCE(cudaFuncSetAttribute(
+          //     kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
+        }
+        // int ctas_per_sm;
+        // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
+        //     &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
+        //  printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
+        kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
+      });
+    });
+  });
+}
+
+template <typename Kernel_traits>
+void run_flash_splitkv_fwd(Flash_fwd_params& params, cudaStream_t stream) {
+  static_assert(!Kernel_traits::Is_Q_in_regs, "SplitKV implementation does not support Is_Q_in_regs");
+  static_assert(!Kernel_traits::Share_Q_K_smem, "SplitKV implementation does not support Share_Q_K_smem");
+  constexpr size_t smem_size = Kernel_traits::kSmemSize;
+  const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
+  dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.b, params.num_splits > 1 ? params.b * params.h : params.h);
+  const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
+  const bool is_even_K = params.d == Kernel_traits::kHeadDim;
+  BOOL_SWITCH(params.is_causal, Is_causal, [&] {
+    BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
+      BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
+        BOOL_SWITCH(params.window_size_left >= 0 || params.window_size_right >= 0, Is_local, [&] {
+          BOOL_SWITCH(params.num_splits > 1, Split, [&] {
+            BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
+              // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
+              // printf("About to launch, Split = %d, Append_KV = %d, knew_ptr = %p\n", Split, Append_KV, params.knew_ptr);
+              auto kernel = &flash_fwd_splitkv_kernel < Kernel_traits, Is_causal && !Is_local, Is_local, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Split, Append_KV > ;
+              // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
+              // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
+              if (smem_size >= 48 * 1024) {
+                cudaFuncSetAttribute(
+                    kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
+              }
+              kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
+            });
+          });
+        });
+      });
+    });
+  });
+  if (params.num_splits > 1) {
+    // We want kBlockM to be as small as possible for more parallelism.
+    // With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4.
+    // If headdim is divisible by 64, then we set kBlockM = 8, etc.
+    constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16);
+    dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM);
+    BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
+      if (params.num_splits <= 2) {
+        flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 1, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
+      } else if (params.num_splits <= 4) {
+        flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 2, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
+      } else if (params.num_splits <= 8) {
+        flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 3, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
+      } else if (params.num_splits <= 16) {
+        flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 4, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
+      } else if (params.num_splits <= 32) {
+        flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 5, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
+      } else if (params.num_splits <= 64) {
+        flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 6, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
+      } else if (params.num_splits <= 128) {
+        flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 7, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
+      }
+    });
+  }
+}
+
+template <typename T, int Headdim>
+void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream) {
+  constexpr int kBlockM = 64;  // Fixed for all head dimensions
+  constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64);
+  run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>>(params, stream);
+}
+
+template <typename T>
+void run_mha_fwd_hdim32(Flash_fwd_params& params, cudaStream_t stream) {
+  constexpr static int Headdim = 32;
+  BOOL_SWITCH(params.is_causal, Is_causal, [&] {
+    run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_causal>(params, stream);
+  });
+}
+
+template <typename T>
+void run_mha_fwd_hdim64(Flash_fwd_params& params, cudaStream_t stream) {
+  constexpr static int Headdim = 64;
+  BOOL_SWITCH(params.is_causal, Is_causal, [&] {
+    // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
+    // Using block size (64 x 256) is 27% slower for seqlen=2k
+    // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling
+    run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_causal>(params, stream);
+    // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_causal>(params, stream);
+    // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_causal>(params, stream);
+  });
+}
+
+template <typename T>
+void run_mha_fwd_hdim96(Flash_fwd_params& params, cudaStream_t stream) {
+  constexpr int Headdim = 96;
+  const bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0;
+  BOOL_SWITCH(params.is_causal, Is_causal, [&] {
+    // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
+    if (is_sm8x) {
+      if constexpr (!Is_causal) {
+        run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_causal>(params, stream);
+      } else {
+        run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_causal>(params, stream);
+      }
+    } else {
+      run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_causal>(params, stream);
+    }
+    // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_causal>(params, stream);
+    // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_causal>(params, stream);
+    // These two are always slower
+    // run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, T>>(params, stream);
+    // run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, T>>(params, stream);
+  });
+}
+
+template <typename T>
+void run_mha_fwd_hdim128(Flash_fwd_params& params, cudaStream_t stream) {
+  constexpr static int Headdim = 128;
+  bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0;
+  BOOL_SWITCH(params.is_causal, Is_causal, [&] {
+    // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
+    // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM.
+    if (is_sm8x) {
+      if constexpr (!Is_causal) {
+        run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_causal>(params, stream);
+      } else {
+        run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_causal>(params, stream);
+      }
+    } else {
+      run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_causal>(params, stream);
+    }
+    // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_causal>(params, stream);
+    // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_causal>(params, stream);
+    // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, false, T>, Is_causal>(params, stream);
+    // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k
+    // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, false, T>, Is_causal>(params, stream);
+    // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_causal>(params, stream);
+    // 1st ones are good for H100, A100
+    // 2nd one is good for A6000 bc we get slightly better occupancy
+  });
+}
+
+template <typename T>
+void run_mha_fwd_hdim160(Flash_fwd_params& params, cudaStream_t stream) {
+  constexpr static int Headdim = 160;
+  bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0;
+  BOOL_SWITCH(params.is_causal, Is_causal, [&] {
+    // For A100, H100, 128 x 32 is the fastest.
+    // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
+    // and 128 x 64 with 8 warps is the fastest for non-causal.
+    if (is_sm8x) {
+      if constexpr (!Is_causal) {
+        run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_causal>(params, stream);
+      } else {
+        run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_causal>(params, stream);
+      }
+    } else {
+      run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_causal>(params, stream);
+    }
+    // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, true, T>, Is_causal>(params, stream);
+    // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_causal>(params, stream);
+    // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
+    // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
+    // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream);
+    // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);
+    // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
+  });
+}
+
+template <typename T>
+void run_mha_fwd_hdim192(Flash_fwd_params& params, cudaStream_t stream) {
+  constexpr int Headdim = 192;
+  BOOL_SWITCH(params.is_causal, Is_causal, [&] {
+    run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_causal>(params, stream);
+    // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_causal>(params, stream);
+    // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_causal>(params, stream);
+    // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
+    // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
+    // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
+  });
+}
+
+template <typename T>
+void run_mha_fwd_hdim224(Flash_fwd_params& params, cudaStream_t stream) {
+  constexpr static int Headdim = 224;
+  int max_smem_per_block = params.dprops->sharedMemPerBlockOptin;
+  //  printf("max_smem_per_block = %d\n", max_smem_per_block);
+  BOOL_SWITCH(params.is_causal, Is_causal, [&] {
+    if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) {  // 112 KB
+      run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_causal>(params, stream);
+    } else {
+      run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_causal>(params, stream);
+    }
+    // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_causal>(params, stream);
+    // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_causal>(params, stream);
+    // We can't do 128 x 32 with 8 warps because with headdim 224, kBlockKSmem = 32.
+    // If we have N = 32, there are only 1024 elements to load at once, where each load
+    // is 8 elements. This means we can only use 128 threads and not 256 threads.
+    // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_causal>(params, stream);
+  });
+}
+
+template <typename T>
+void run_mha_fwd_hdim256(Flash_fwd_params& params, cudaStream_t stream) {
+  constexpr static int Headdim = 256;
+  size_t max_smem_per_sm = params.dprops->sharedMemPerMultiprocessor;
+  size_t max_smem_per_block = params.dprops->sharedMemPerBlockOptin;
+  //  printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block);
+  BOOL_SWITCH(params.is_causal, Is_causal, [&] {
+    // For A100, we want to run with 128 x 64 (128KB smem).
+    // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM.
+    if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) {
+      run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_causal>(params, stream);
+    } else {
+      run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_causal>(params, stream);
+    }
+    // 64 KB
+    // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_causal>(params, stream);
+    // 96 KB
+    // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_causal>(params, stream);
+  });
+}
+
+}  // namespace flash
diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim128_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim128_bf16_sm80.cu
new file mode 100644
index 000000000..f818e584d
--- /dev/null
+++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim128_bf16_sm80.cu
@@ -0,0 +1,13 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#if OCOS_USE_FLASH_ATTENTION
+
+#include "flash_fwd_launch_template.h"
+
+namespace flash {
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 128>(Flash_fwd_params& params, cudaStream_t stream);
+
+}  // namespace flash
+#endif
diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim128_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim128_fp16_sm80.cu
new file mode 100644
index 000000000..d7db5839d
--- /dev/null
+++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim128_fp16_sm80.cu
@@ -0,0 +1,13 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#if OCOS_USE_FLASH_ATTENTION
+
+#include "flash_fwd_launch_template.h"
+
+namespace flash {
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128>(Flash_fwd_params& params, cudaStream_t stream);
+
+}  // namespace flash
+#endif
diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim160_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim160_bf16_sm80.cu
new file mode 100644
index 000000000..855863403
--- /dev/null
+++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim160_bf16_sm80.cu
@@ -0,0 +1,13 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#if OCOS_USE_FLASH_ATTENTION
+
+#include "flash_fwd_launch_template.h"
+
+namespace flash {
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 160>(Flash_fwd_params& params, cudaStream_t stream);
+
+}  // namespace flash
+#endif
diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim160_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim160_fp16_sm80.cu
new file mode 100644
index 000000000..86057f9a7
--- /dev/null
+++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim160_fp16_sm80.cu
@@ -0,0 +1,13 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#if OCOS_USE_FLASH_ATTENTION
+
+#include "flash_fwd_launch_template.h"
+
+namespace flash {
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 160>(Flash_fwd_params& params, cudaStream_t stream);
+
+}  // namespace flash
+#endif
diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim192_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim192_bf16_sm80.cu
new file mode 100644
index 000000000..b6aec378c
--- /dev/null
+++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim192_bf16_sm80.cu
@@ -0,0 +1,13 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#if OCOS_USE_FLASH_ATTENTION
+
+#include "flash_fwd_launch_template.h"
+
+namespace flash {
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 192>(Flash_fwd_params& params, cudaStream_t stream);
+
+}  // namespace flash
+#endif
diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim192_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim192_fp16_sm80.cu
new file mode 100644
index 000000000..0ba212ef7
--- /dev/null
+++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim192_fp16_sm80.cu
@@ -0,0 +1,13 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#if OCOS_USE_FLASH_ATTENTION
+
+#include "flash_fwd_launch_template.h"
+
+namespace flash {
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 192>(Flash_fwd_params& params, cudaStream_t stream);
+
+}  // namespace flash
+#endif
diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim224_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim224_bf16_sm80.cu
new file mode 100644
index 000000000..1fd816e2c
--- /dev/null
+++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim224_bf16_sm80.cu
@@ -0,0 +1,13 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#if OCOS_USE_FLASH_ATTENTION
+
+#include "flash_fwd_launch_template.h"
+
+namespace flash {
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 224>(Flash_fwd_params& params, cudaStream_t stream);
+
+}  // namespace flash
+#endif
diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim224_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim224_fp16_sm80.cu
new file mode 100644
index 000000000..98740bded
--- /dev/null
+++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim224_fp16_sm80.cu
@@ -0,0 +1,13 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#if OCOS_USE_FLASH_ATTENTION
+
+#include "flash_fwd_launch_template.h"
+
+namespace flash {
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 224>(Flash_fwd_params& params, cudaStream_t stream);
+
+}  // namespace flash
+#endif
diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim256_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim256_bf16_sm80.cu
new file mode 100644
index 000000000..8982c23ab
--- /dev/null
+++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim256_bf16_sm80.cu
@@ -0,0 +1,13 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#if OCOS_USE_FLASH_ATTENTION
+
+#include "flash_fwd_launch_template.h"
+
+namespace flash {
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 256>(Flash_fwd_params& params, cudaStream_t stream);
+
+}  // namespace flash
+#endif
diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim256_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim256_fp16_sm80.cu
new file mode 100644
index 000000000..ab1eec23d
--- /dev/null
+++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim256_fp16_sm80.cu
@@ -0,0 +1,13 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#if OCOS_USE_FLASH_ATTENTION
+
+#include "flash_fwd_launch_template.h"
+
+namespace flash {
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 256>(Flash_fwd_params& params, cudaStream_t stream);
+
+}  // namespace flash
+#endif
diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim32_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim32_bf16_sm80.cu
new file mode 100644
index 000000000..2cb3a4d3d
--- /dev/null
+++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim32_bf16_sm80.cu
@@ -0,0 +1,13 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#if OCOS_USE_FLASH_ATTENTION
+
+#include "flash_fwd_launch_template.h"
+
+namespace flash {
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 32>(Flash_fwd_params& params, cudaStream_t stream);
+
+}  // namespace flash
+#endif
diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim32_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim32_fp16_sm80.cu
new file mode 100644
index 000000000..e02735b11
--- /dev/null
+++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim32_fp16_sm80.cu
@@ -0,0 +1,13 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#if OCOS_USE_FLASH_ATTENTION
+
+#include "flash_fwd_launch_template.h"
+
+namespace flash {
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 32>(Flash_fwd_params& params, cudaStream_t stream);
+
+}  // namespace flash
+#endif
diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim64_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim64_bf16_sm80.cu
new file mode 100644
index 000000000..721fdb9f7
--- /dev/null
+++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim64_bf16_sm80.cu
@@ -0,0 +1,13 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#if OCOS_USE_FLASH_ATTENTION
+
+#include "flash_fwd_launch_template.h"
+
+namespace flash {
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 64>(Flash_fwd_params& params, cudaStream_t stream);
+
+}  // namespace flash
+#endif
diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim64_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim64_fp16_sm80.cu
new file mode 100644
index 000000000..37830f2f5
--- /dev/null
+++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim64_fp16_sm80.cu
@@ -0,0 +1,13 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#if OCOS_USE_FLASH_ATTENTION
+
+#include "flash_fwd_launch_template.h"
+
+namespace flash {
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 64>(Flash_fwd_params& params, cudaStream_t stream);
+
+}  // namespace flash
+#endif
diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim96_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim96_bf16_sm80.cu
new file mode 100644
index 000000000..e90f2540e
--- /dev/null
+++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim96_bf16_sm80.cu
@@ -0,0 +1,13 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#if OCOS_USE_FLASH_ATTENTION
+
+#include "flash_fwd_launch_template.h"
+
+namespace flash {
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 96>(Flash_fwd_params& params, cudaStream_t stream);
+
+}  // namespace flash
+#endif
diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim96_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim96_fp16_sm80.cu
new file mode 100644
index 000000000..394cca7c8
--- /dev/null
+++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim96_fp16_sm80.cu
@@ -0,0 +1,13 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#if OCOS_USE_FLASH_ATTENTION
+
+#include "flash_fwd_launch_template.h"
+
+namespace flash {
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 96>(Flash_fwd_params& params, cudaStream_t stream);
+
+}  // namespace flash
+#endif
diff --git a/operators/cuda/attention_lib/flash_attention/kernel_traits.h b/operators/cuda/attention_lib/flash_attention/kernel_traits.h
new file mode 100644
index 000000000..48e899c2a
--- /dev/null
+++ b/operators/cuda/attention_lib/flash_attention/kernel_traits.h
@@ -0,0 +1,367 @@
+/******************************************************************************
+ * Copyright (c) 2023, Tri Dao.
+ ******************************************************************************/
+#pragma once
+
+#include <cute/algorithm/copy.hpp>
+
+#include <cutlass/cutlass.h>
+#include <cutlass/layout/layout.h>
+#include <cutlass/numeric_types.h>
+
+using namespace cute;
+
+namespace flash {
+
+template <int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type = cutlass::half_t>
+struct Flash_kernel_traits {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+  using Element = elem_type;
+  static constexpr bool Has_cp_async = true;
+#else
+  using Element = cutlass::half_t;
+  static constexpr bool Has_cp_async = false;
+#endif
+
+  using ElementAccum = float;
+  using index_t = uint32_t;
+
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+  using MMA_Atom_Arch = std::conditional_t<
+      std::is_same_v<elem_type, cutlass::half_t>,
+      MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
+      MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>>;
+  using ValLayoutMNK = cute::Layout<cute::Shape<_1, _2, _1>>;
+#else
+  using MMA_Atom_Arch = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
+  using ValLayoutMNK = cute::Layout<cute::Shape<_1, _2, _2>>;
+#endif
+
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
+  using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, elem_type>;
+  using SmemCopyAtomTransposed = Copy_Atom<SM75_U16x8_LDSM_T, elem_type>;
+#else
+  using SmemCopyAtom = Copy_Atom<DefaultCopy, elem_type>;
+  using SmemCopyAtomTransposed = Copy_Atom<DefaultCopy, elem_type>;
+#endif
+};
+
+// If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true
+template <int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_,
+          bool Is_Q_in_regs_ = false, bool Share_Q_K_smem_ = false, typename elem_type = cutlass::half_t,
+          typename Base = Flash_kernel_traits<kHeadDim_, kBlockM_, kBlockN_, kNWarps_, elem_type>>
+struct Flash_fwd_kernel_traits : public Base {
+  using Element = typename Base::Element;
+  using ElementAccum = typename Base::ElementAccum;
+  using index_t = typename Base::index_t;
+  static constexpr bool Has_cp_async = Base::Has_cp_async;
+  using SmemCopyAtom = typename Base::SmemCopyAtom;
+  using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed;
+
+  static constexpr bool Share_Q_K_smem = Share_Q_K_smem_;
+  static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem;
+
+  // The number of threads.
+  static constexpr int kNWarps = kNWarps_;
+  static constexpr int kNThreads = kNWarps * 32;
+
+  static constexpr int kBlockM = kBlockM_;
+  static constexpr int kBlockN = kBlockN_;
+  static constexpr int kHeadDim = kHeadDim_;
+  static_assert(kHeadDim % 32 == 0);
+  static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
+  static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
+  static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
+
+  using TiledMma = TiledMMA<
+      typename Base::MMA_Atom_Arch,
+      Layout<Shape<Int<kNWarps>, _1, _1>>,  // 4x1x1 or 8x1x1 thread group
+      typename Base::ValLayoutMNK>;         // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
+
+  using SmemLayoutAtomQ = decltype(composition(Swizzle<kSwizzle, 3, 3>{},
+                                               // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128
+                                               Layout<Shape<_8, Int<kBlockKSmem>>,
+                                                      Stride<Int<kBlockKSmem>, _1>>{}));
+  using SmemLayoutQ = decltype(tile_to_shape(
+      SmemLayoutAtomQ{},
+      Shape<Int<kBlockM>, Int<kHeadDim>>{}));
+
+  using SmemLayoutKV = decltype(tile_to_shape(
+      SmemLayoutAtomQ{},
+      Shape<Int<kBlockN>, Int<kHeadDim>>{}));
+
+  // This has to be kBlockN and not 8, otherwise we get wrong results for d=128
+  using SmemLayoutAtomVtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
+                                                    Stride<_1, Int<kBlockKSmem>>>;
+  using SmemLayoutAtomVtransposed = decltype(composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomVtransposedNoSwizzle{}));
+  using SmemLayoutVtransposed = decltype(tile_to_shape(
+      SmemLayoutAtomVtransposed{},
+      Shape<Int<kHeadDim>, Int<kBlockN>>{}));
+  // Maybe the VtransposeNoSwizzle just needs to have the right shape
+  // And the strides don't matter?
+  using SmemLayoutVtransposedNoSwizzle = decltype(tile_to_shape(
+      SmemLayoutAtomVtransposedNoSwizzle{},
+      Shape<Int<kHeadDim>, Int<kBlockN>>{}));
+  // using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn());
+
+  using SmemLayoutAtomO = decltype(composition(Swizzle<kSwizzle, 3, 3>{},
+                                               Layout<Shape<Int<8>, Int<kBlockKSmem>>,
+                                                      Stride<Int<kBlockKSmem>, _1>>{}));
+  using SmemLayoutO = decltype(tile_to_shape(
+      SmemLayoutAtomO{},
+      Shape<Int<kBlockM>, Int<kHeadDim>>{}));
+  using SmemCopyAtomO = Copy_Atom<DefaultCopy, Element>;
+  using SmemCopyAtomOaccum = Copy_Atom<DefaultCopy, ElementAccum>;
+
+  static constexpr int kSmemQCount = cute::size(SmemLayoutQ{});
+  static constexpr int kSmemKVCount = cute::size(SmemLayoutKV{}) * 2;
+  static constexpr int kSmemQSize = kSmemQCount * sizeof(Element);
+  static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element);
+  static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize;
+
+  static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
+  static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
+  // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts.
+  // For example, for d=128, smem is split into 2 "pages", each page takes care of columns
+  // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem,
+  // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page,
+  // to the same banks.
+  static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
+  static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
+  using GmemLayoutAtom = cute::Layout<cute::Shape<cute::Int<kNThreads / kGmemThreadsPerRow>, cute::Int<kGmemThreadsPerRow>>,
+                                      cute::Stride<cute::Int<kGmemThreadsPerRow>, _1>>;
+
+  // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading
+  // from the same address by the same threadblock. This is slightly faster.
+  using Gmem_copy_struct = std::conditional_t<
+      Has_cp_async,
+      SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
+      DefaultCopy>;
+  using GmemTiledCopyQKV = decltype(make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{},
+                                                    GmemLayoutAtom{},
+                                                    cute::Layout<cute::Shape<_1, _8>>{}));  // Val layout, 8 vals per read
+  using GmemTiledCopyO = decltype(make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
+                                                  GmemLayoutAtom{},
+                                                  cute::Layout<cute::Shape<_1, _8>>{}));  // Val layout, 8 vals per store
+  static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad;
+  static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP");
+  using GmemLayoutAtomP = cute::Layout<cute::Shape<cute::Int<kNThreads / kGmemThreadsPerRowP>, cute::Int<kGmemThreadsPerRowP>>,
+                                       cute::Stride<cute::Int<kGmemThreadsPerRowP>, _1>>;
+
+  using GmemTiledCopyP = decltype(make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
+                                                  GmemLayoutAtomP{},
+                                                  cute::Layout<cute::Shape<_1, _8>>{}));  // Val layout, 8 vals per store
+
+  using GmemLayoutAtomOaccum = std::conditional_t<
+      kBlockKSmem == 32,
+      cute::Layout<cute::Shape<_16, _8>,  // Thread layout, 8 threads per row
+                   cute::Stride<_8, _1>>,
+      cute::Layout<cute::Shape<_8, _16>,  // Thread layout, 16 threads per row
+                   cute::Stride<_16, _1>>>;
+  using GmemTiledCopyOaccum = decltype(make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
+                                                       GmemLayoutAtomOaccum{},
+                                                       Layout<Shape<_1, _4>>{}));  // Val layout, 4 vals per store
+  using GmemLayoutAtomRotcossin = GmemLayoutAtom;
+  using GmemTiledCopyRotcossin = decltype(make_tiled_copy(Copy_Atom<UniversalCopy<uint64_t>, Element>{},
+                                                          GmemLayoutAtomRotcossin{},
+                                                          Layout<Shape<_1, _4>>{}));  // Val layout, 4 vals per load
+  using GmemTiledCopyRotcossinCont = decltype(make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
+                                                              GmemLayoutAtomRotcossin{},
+                                                              Layout<Shape<_1, _8>>{}));  // Val layout, 8 vals per load
+};
+
+// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue.
+// No_double_buffer is another option to reduce smem usage, but will slow things down.
+template <int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_,
+          int AtomLayoutMSdP_ = 1, int AtomLayoutNdKV = 2, int AtomLayoutMdQ = 2,
+          bool Is_V_in_regs_ = false, bool No_double_buffer_ = false, typename elem_type = cutlass::half_t,
+          typename Base = Flash_kernel_traits<kHeadDim_, kBlockM_, kBlockN_, kNWarps_, elem_type>>
+struct Flash_bwd_kernel_traits : public Base {
+  using Element = typename Base::Element;
+  using ElementAccum = typename Base::ElementAccum;
+  using index_t = typename Base::index_t;
+  static constexpr bool Has_cp_async = Base::Has_cp_async;
+  using SmemCopyAtom = typename Base::SmemCopyAtom;
+  using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed;
+
+  static constexpr bool Is_V_in_regs = Is_V_in_regs_;
+  static constexpr bool No_double_buffer = No_double_buffer_;
+
+  // The number of threads.
+  static constexpr int kNWarps = kNWarps_;
+  static constexpr int kNThreads = kNWarps * 32;
+
+  static constexpr int kBlockM = kBlockM_;
+  static constexpr int kBlockN = kBlockN_;
+  static constexpr int kHeadDim = kHeadDim_;
+  static_assert(kHeadDim % 32 == 0);
+  static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
+  static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
+  static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
+
+  static constexpr int AtomLayoutMSdP = AtomLayoutMSdP_;
+  static_assert(kNWarps % AtomLayoutMSdP == 0);
+  static_assert(kNWarps % AtomLayoutNdKV == 0);
+  static_assert(kNWarps % AtomLayoutMdQ == 0);
+
+  using TiledMmaSdP = TiledMMA<
+      typename Base::MMA_Atom_Arch,
+      cute::Layout<cute::Shape<cute::Int<AtomLayoutMSdP>, cute::Int<kNWarps / AtomLayoutMSdP>, _1>>,
+      typename Base::ValLayoutMNK>;  // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
+
+  using TiledMmadKV = TiledMMA<
+      typename Base::MMA_Atom_Arch,
+      cute::Layout<cute::Shape<cute::Int<AtomLayoutNdKV>, cute::Int<kNWarps / AtomLayoutNdKV>, _1>>,
+      typename Base::ValLayoutMNK>;  // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
+
+  using TiledMmadQ = TiledMMA<
+      typename Base::MMA_Atom_Arch,
+      cute::Layout<cute::Shape<cute::Int<AtomLayoutMdQ>, cute::Int<kNWarps / AtomLayoutMdQ>, _1>>,  // 2x4x1 or 4x2x1 thread group
+      typename Base::ValLayoutMNK>;                                                                 // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
+
+  using SmemLayoutAtomQdO = decltype(composition(Swizzle<kSwizzle, 3, 3>{},
+                                                 cute::Layout<cute::Shape<_8, cute::Int<kBlockKSmem>>,
+                                                              cute::Stride<cute::Int<kBlockKSmem>, _1>>{}));
+  using SmemLayoutQdO = decltype(tile_to_shape(
+      SmemLayoutAtomQdO{},
+      cute::make_shape(cute::Int<kBlockM>{}, cute::Int<kHeadDim>{})));
+
+  using SmemLayoutAtomKV = decltype(composition(Swizzle<kSwizzle, 3, 3>{},
+                                                cute::Layout<cute::Shape<cute::Int<kBlockM / kNWarps>, cute::Int<kBlockKSmem>>,
+                                                             cute::Stride<cute::Int<kBlockKSmem>, _1>>{}));
+  using SmemLayoutKV = decltype(tile_to_shape(
+      // SmemLayoutAtomQdO{},
+      SmemLayoutAtomKV{},
+      cute::make_shape(cute::Int<kBlockN>{}, cute::Int<kHeadDim>{})));
+
+  using SmemLayoutAtomKtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
+                                                    Stride<_1, Int<kBlockKSmem>>>;
+  using SmemLayoutAtomKtransposed = decltype(composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomKtransposedNoSwizzle{}));
+  using SmemLayoutKtransposed = decltype(tile_to_shape(
+      SmemLayoutAtomKtransposed{},
+      make_shape(Int<kHeadDim>{}, Int<kBlockN>{})));
+  // Maybe the KtransposeNoSwizzle just needs to have the right shape
+  // And the strides don't matter?
+  using SmemLayoutKtransposedNoSwizzle = decltype(tile_to_shape(
+      SmemLayoutAtomKtransposedNoSwizzle{},
+      make_shape(Int<kHeadDim>{}, Int<kBlockN>{})));
+  // using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn());
+
+  // TODO: generalize to other values of kBlockN
+  // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2
+  // static constexpr int kPBlockN = kBlockN;
+  static_assert(kBlockN >= 64);
+  // TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest.
+  static constexpr int kPBlockN = 64;
+  static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64);
+  // static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3);
+  static constexpr int kSwizzlePdS = 3;
+  using SmemLayoutAtomPdS = decltype(composition(Swizzle<kSwizzlePdS, 3, 3>{},
+                                                 cute::Layout<cute::Shape<cute::Int<kBlockM>, cute::Int<kPBlockN>>,
+                                                              cute::Stride<cute::Int<kPBlockN>, _1>>{}));
+  using SmemLayoutPdS = decltype(tile_to_shape(
+      SmemLayoutAtomPdS{},
+      cute::make_shape(cute::Int<kBlockM>{}, cute::Int<kBlockN>{})));
+  using SmemLayoutAtomPdStransposedNoSwizzle = Layout<Shape<Int<kPBlockN>, Int<kBlockM>>,
+                                                      Stride<_1, Int<kPBlockN>>>;
+  using SmemLayoutAtomPdStransposed = decltype(composition(Swizzle<kSwizzlePdS, 3, 3>{}, SmemLayoutAtomPdStransposedNoSwizzle{}));
+  using SmemLayoutPdStransposed = decltype(tile_to_shape(
+      SmemLayoutAtomPdStransposed{},
+      make_shape(Int<kBlockN>{}, Int<kBlockM>{})));
+  using SmemLayoutPdStransposedNoSwizzle = decltype(tile_to_shape(
+      SmemLayoutAtomPdStransposedNoSwizzle{},
+      make_shape(Int<kBlockN>{}, Int<kBlockM>{})));
+  // using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn());
+  using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>;
+
+  using SmemLayoutAtomQdOtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockM>>,
+                                                      Stride<_1, Int<kBlockKSmem>>>;
+  using SmemLayoutAtomQdOtransposed = decltype(composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomQdOtransposedNoSwizzle{}));
+  using SmemLayoutQdOtransposed = decltype(tile_to_shape(
+      SmemLayoutAtomQdOtransposed{},
+      make_shape(Int<kHeadDim>{}, Int<kBlockM>{})));
+  using SmemLayoutQdOtransposedNoSwizzle = decltype(tile_to_shape(
+      SmemLayoutAtomQdOtransposedNoSwizzle{},
+      make_shape(Int<kHeadDim>{}, Int<kBlockM>{})));
+  // using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn());
+
+  using SmemLayoutAtomdKV = decltype(composition(Swizzle<kSwizzle, 3, 3>{},
+                                                 Layout<Shape<_8, Int<kBlockKSmem>>,
+                                                        Stride<Int<kBlockKSmem>, _1>>{}));
+  using SmemLayoutdKV = decltype(tile_to_shape(
+      SmemLayoutAtomdKV{},
+      make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
+  using SmemCopyAtomdKV = Copy_Atom<DefaultCopy, elem_type>;
+
+  using SmemLayoutAtomdQ = decltype(composition(Swizzle<kSwizzle, 3, 3>{},
+                                                Layout<Shape<_8, Int<kBlockKSmem>>,
+                                                       Stride<Int<kBlockKSmem>, _1>>{}));
+  using SmemLayoutdQ = decltype(tile_to_shape(
+      SmemLayoutAtomdQ{},
+      make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
+  using SmemCopyAtomdQ = Copy_Atom<DefaultCopy, elem_type>;
+
+  static constexpr int kSmemQdOCount = cute::size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3);  // Double buffer for sQ
+  static constexpr int kSmemKVCount = cute::size(SmemLayoutKV{}) * 2;
+  static constexpr int kSmemdSCount = cute::size(SmemLayoutPdS{});
+  static constexpr int kSmemPCount = cute::size(SmemLayoutPdS{});
+  static constexpr int kSmemdQCount = cute::size(SmemLayoutdQ{});
+  //   static constexpr int kSmemdPsumCount = kBlockM;
+  static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element);
+  static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element);
+  static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element);
+  static constexpr int kSmemPSize = kSmemPCount * sizeof(Element);
+  static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element);
+  //   static constexpr int kSmemdPsumSize = kSmemdPsumCount * sizeof(ElementAccum);
+  static constexpr int kSmemSize = kSmemQdOSize + (!Is_V_in_regs
+                                                       ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)
+                                                       : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)));
+  static constexpr int kSmemSize1colblock = kSmemQdOSize + (!Is_V_in_regs
+                                                                ? kSmemKVSize + kSmemdSSize + kSmemPSize
+                                                                : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize));
+  static constexpr int kSmemSize1rowblock = kSmemQdOSize / 3 * 2 + kSmemKVSize / 2 * 3 + kSmemdSSize + kSmemPSize;
+
+  static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
+  static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
+  // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem
+  // to affect speed in practice.
+  static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
+  static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
+  using GmemLayoutAtom = cute::Layout<cute::Shape<cute::Int<kNThreads / kGmemThreadsPerRow>, cute::Int<kGmemThreadsPerRow>>,
+                                      cute::Stride<cute::Int<kGmemThreadsPerRow>, _1>>;
+
+  // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading
+  // from the same address by the same threadblock. This is slightly faster.
+  using Gmem_copy_struct = std::conditional_t<
+      Has_cp_async,
+      SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
+      DefaultCopy>;
+  using GmemTiledCopyQKV = decltype(make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{},
+                                                    GmemLayoutAtom{},
+                                                    cute::Layout<cute::Shape<_1, _8>>{}));  // Val layout, 8 vals per read
+  using GmemTiledCopydO = decltype(make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
+                                                   GmemLayoutAtom{},
+                                                   cute::Layout<cute::Shape<_1, _8>>{}));  // Val layout, 8 vals per store
+  using GmemTiledCopydKV = decltype(make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
+                                                    GmemLayoutAtom{},
+                                                    cute::Layout<cute::Shape<_1, _8>>{}));  // Val layout, 8 vals per store
+  using GmemTiledCopydQ = decltype(make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
+                                                   GmemLayoutAtom{},
+                                                   cute::Layout<cute::Shape<_1, _8>>{}));  // Val layout, 8 vals per store
+  using GmemLayoutAtomdQaccum = std::conditional_t<
+      kBlockKSmem == 32,
+      cute::Layout<cute::Shape<_32, _8>,  // Thread layout, 8 threads per row
+                   cute::Stride<_8, _1>>,
+      cute::Layout<cute::Shape<_16, _16>,  // Thread layout, 16 threads per row
+                   cute::Stride<_16, _1>>>;
+  using GmemTiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
+                                                        GmemLayoutAtomdQaccum{},
+                                                        cute::Layout<cute::Shape<_1, _4>>{}));  // Val layout, 4 vals per store
+
+  using GmemTiledCopydQaccumAtomicAdd = decltype(make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
+                                                                 cute::Layout<cute::Shape<_8, _32>,  // Thread layout, 8 threads per row
+                                                                              cute::Stride<_32, _1>>{},
+                                                                 cute::Layout<cute::Shape<_1, _1>>{}));  // Val layout, 1 val per store
+};
+
+}  // namespace flash
diff --git a/operators/cuda/attention_lib/flash_attention/softmax.h b/operators/cuda/attention_lib/flash_attention/softmax.h
new file mode 100644
index 000000000..9c31336c9
--- /dev/null
+++ b/operators/cuda/attention_lib/flash_attention/softmax.h
@@ -0,0 +1,215 @@
+/******************************************************************************
+ * Copyright (c) 2023, Tri Dao.
+ ******************************************************************************/
+#pragma once
+
+#include <cmath>
+
+#include <cute/tensor.hpp>
+
+#include <cutlass/cutlass.h>
+#include <cutlass/array.h>
+
+#include "utils.h"
+
+namespace flash {
+
+using namespace cute;
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <bool zero_init = true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
+__device__ inline void thread_reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1>& summary, Operator& op) {
+  static_assert(Layout0::rank == 2, "Only support 2D Tensor");
+  static_assert(Layout1::rank == 1, "Only support 1D Tensor");
+  CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
+#pragma unroll
+  for (int mi = 0; mi < size<0>(tensor); mi++) {
+    summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
+#pragma unroll
+    for (int ni = 1; ni < size<1>(tensor); ni++) {
+      summary(mi) = op(summary(mi), tensor(mi, ni));
+    }
+  }
+}
+
+template <typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
+__device__ inline void quad_allreduce_(Tensor<Engine0, Layout0>& dst, Tensor<Engine1, Layout1>& src, Operator& op) {
+  CUTE_STATIC_ASSERT_V(size(dst) == size(src));
+#pragma unroll
+  for (int i = 0; i < size(dst); i++) {
+    dst(i) = Allreduce<4>::run(src(i), op);
+  }
+}
+
+template <bool zero_init = true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
+__device__ inline void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1>& summary, Operator& op) {
+  thread_reduce_<zero_init>(tensor, summary, op);
+  quad_allreduce_(summary, summary, op);
+}
+
+template <bool zero_init = true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
+__device__ inline void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1>& max) {
+  MaxOp<float> max_op;
+  reduce_<zero_init>(tensor, max, max_op);
+}
+
+template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
+__device__ inline void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1>& sum) {
+  SumOp<float> sum_op;
+  reduce_(tensor, sum, sum_op);
+}
+
+// Apply the exp to all the elements.
+template <bool Scale_max = true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
+inline __device__ void scale_apply_exp2(Tensor<Engine0, Layout0>& tensor, Tensor<Engine1, Layout1> const& max, const float scale) {
+  static_assert(Layout0::rank == 2, "Only support 2D Tensor");
+  static_assert(Layout1::rank == 1, "Only support 1D Tensor");
+  CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
+#pragma unroll
+  for (int mi = 0; mi < size<0>(tensor); ++mi) {
+    // If max is -inf, then all elements must have been -inf (possibly due to masking).
+    // We don't want (-inf - (-inf)) since that would give NaN.
+    // If we don't have float around M_LOG2E the multiplication is done in fp64.
+    const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E));
+#pragma unroll
+    for (int ni = 0; ni < size<1>(tensor); ++ni) {
+      // Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
+      // max * log_2(e)) This allows the compiler to use the ffma
+      // instruction instead of fadd and fmul separately.
+      tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
+    }
+  }
+}
+
+// Apply the exp to all the elements.
+template <bool zero_init = true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
+inline __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0>& tensor, Tensor<Engine1, Layout1>& max, Tensor<Engine1, Layout1>& sum, const float scale) {
+  static_assert(Layout0::rank == 2, "Only support 2D Tensor");
+  static_assert(Layout1::rank == 1, "Only support 1D Tensor");
+  CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
+#pragma unroll
+  for (int mi = 0; mi < size<0>(tensor); ++mi) {
+    MaxOp<float> max_op;
+    max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0));
+#pragma unroll
+    for (int ni = 1; ni < size<1>(tensor); ni++) {
+      max(mi) = max_op(max(mi), tensor(mi, ni));
+    }
+    max(mi) = Allreduce<4>::run(max(mi), max_op);
+    // If max is -inf, then all elements must have been -inf (possibly due to masking).
+    // We don't want (-inf - (-inf)) since that would give NaN.
+    const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale;
+    sum(mi) = 0;
+#pragma unroll
+    for (int ni = 0; ni < size<1>(tensor); ++ni) {
+      // Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
+      // max * log_2(e)) This allows the compiler to use the ffma
+      // instruction instead of fadd and fmul separately.
+      tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
+      sum(mi) += tensor(mi, ni);
+    }
+    SumOp<float> sum_op;
+    sum(mi) = Allreduce<4>::run(sum(mi), sum_op);
+  }
+}
+
+template <typename Engine, typename Layout>
+inline __device__ void apply_mask(Tensor<Engine, Layout>& tensor, const int max_seqlen_k,
+                                  const int col_idx_offset_ = 0) {
+  // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
+  static_assert(Layout::rank == 2, "Only support 2D Tensor");
+  const int lane_id = threadIdx.x % 32;
+  const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
+#pragma unroll
+  for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
+    const int col_idx_base = col_idx_offset + nj * 8;
+#pragma unroll
+    for (int j = 0; j < size<1, 0>(tensor); ++j) {
+      const int col_idx = col_idx_base + j;
+      if (col_idx >= max_seqlen_k) {
+// Without the "make_coord" we get wrong results
+#pragma unroll
+        for (int mi = 0; mi < size<0>(tensor); ++mi) {
+          tensor(mi, make_coord(j, nj)) = -INFINITY;
+        }
+      }
+    }
+  }
+}
+
+template <bool HasWSLeft = true, typename Engine, typename Layout>
+inline __device__ void apply_mask_local(Tensor<Engine, Layout>& tensor, const int col_idx_offset_,
+                                        const int max_seqlen_k, const int row_idx_offset_,
+                                        const int max_seqlen_q, const int warp_row_stride,
+                                        const int window_size_left, const int window_size_right) {
+  // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
+  static_assert(Layout::rank == 2, "Only support 2D Tensor");
+  const int lane_id = threadIdx.x % 32;
+  // const int row_idx_offset = row_idx_offset_ + lane_id / 4;
+  const int row_idx_offset = row_idx_offset_;
+  const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
+#pragma unroll
+  for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
+    const int row_idx_base = row_idx_offset + mi * warp_row_stride;
+#pragma unroll
+    for (int i = 0; i < size<0, 0>(tensor); ++i) {
+      const int row_idx = row_idx_base + i * 8;
+      const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
+      const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
+#pragma unroll
+      for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
+        const int col_idx_base = col_idx_offset + nj * 8;
+#pragma unroll
+        for (int j = 0; j < size<1, 0>(tensor); ++j) {
+          const int col_idx = col_idx_base + j;
+          if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) {
+            tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
+          }
+        }
+      }
+      // if (cute::thread0()) {
+      //     printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k);
+      //     print(tensor(make_coord(i, mi), _));
+      //     // print(tensor(_, j + nj * size<1, 0>(tensor)));
+      // }
+    }
+  }
+}
+
+template <typename Engine, typename Layout>
+inline __device__ void apply_mask_causal(Tensor<Engine, Layout>& tensor, const int col_idx_offset_,
+                                         const int max_seqlen_k, const int row_idx_offset_,
+                                         const int max_seqlen_q, const int warp_row_stride) {
+  // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
+  apply_mask_local</*HasWSLeft=*/false>(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset_,
+                                        max_seqlen_q, warp_row_stride, -1, 0);
+}
+
+template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
+inline __device__ void apply_mask_causal_w_idx(
+    Tensor<Engine0, Layout0>& tensor, Tensor<Engine1, Layout1> const& idx_rowcol,
+    const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset_) {
+  // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
+  static_assert(Layout0::rank == 2, "Only support 2D Tensor");
+  static_assert(Layout1::rank == 2, "Only support 2D Tensor");
+  CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol));
+  CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol));
+#pragma unroll
+  for (int mi = 0; mi < size<0>(tensor); ++mi) {
+    const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset_ + get<0>(idx_rowcol(mi, 0)));
+#pragma unroll
+    for (int ni = 0; ni < size<1, 1>(tensor); ++ni) {
+      if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) {
+        tensor(mi, ni) = -INFINITY;
+      }
+    }
+    // if (cute::thread0()) {
+    //     printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k);
+    //     print(tensor(_, make_coord(j, ni)));
+    //     // print(tensor(_, j + ni * size<1, 0>(tensor)));
+    // }
+  }
+}
+
+}  // namespace flash
diff --git a/operators/cuda/attention_lib/flash_attention/static_switch.h b/operators/cuda/attention_lib/flash_attention/static_switch.h
new file mode 100644
index 000000000..5b7098894
--- /dev/null
+++ b/operators/cuda/attention_lib/flash_attention/static_switch.h
@@ -0,0 +1,64 @@
+// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
+// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
+#pragma once
+
+/// @param COND       - a boolean expression to switch by
+/// @param CONST_NAME - a name given for the constexpr bool variable.
+/// @param ...       - code to execute for true and false
+///
+/// Usage:
+/// ```
+/// BOOL_SWITCH(flag, BoolConst, [&] {
+///     some_function<BoolConst>(...);
+/// });
+/// ```
+#define BOOL_SWITCH(COND, CONST_NAME, ...)      \
+  [&] {                                         \
+    if (COND) {                                 \
+      constexpr static bool CONST_NAME = true;  \
+      return __VA_ARGS__();                     \
+    } else {                                    \
+      constexpr static bool CONST_NAME = false; \
+      return __VA_ARGS__();                     \
+    }                                           \
+  }()
+
+#define FP16_SWITCH(COND, ...)               \
+  [&] {                                      \
+    if (COND) {                              \
+      using elem_type = cutlass::half_t;     \
+      return __VA_ARGS__();                  \
+    } else {                                 \
+      using elem_type = cutlass::bfloat16_t; \
+      return __VA_ARGS__();                  \
+    }                                        \
+  }()
+
+#define FWD_HEADDIM_SWITCH(HEADDIM, ...)   \
+  [&] {                                    \
+    if (HEADDIM <= 32) {                   \
+      constexpr static int kHeadDim = 32;  \
+      return __VA_ARGS__();                \
+    } else if (HEADDIM <= 64) {            \
+      constexpr static int kHeadDim = 64;  \
+      return __VA_ARGS__();                \
+    } else if (HEADDIM <= 96) {            \
+      constexpr static int kHeadDim = 96;  \
+      return __VA_ARGS__();                \
+    } else if (HEADDIM <= 128) {           \
+      constexpr static int kHeadDim = 128; \
+      return __VA_ARGS__();                \
+    } else if (HEADDIM <= 160) {           \
+      constexpr static int kHeadDim = 160; \
+      return __VA_ARGS__();                \
+    } else if (HEADDIM <= 192) {           \
+      constexpr static int kHeadDim = 192; \
+      return __VA_ARGS__();                \
+    } else if (HEADDIM <= 224) {           \
+      constexpr static int kHeadDim = 224; \
+      return __VA_ARGS__();                \
+    } else if (HEADDIM <= 256) {           \
+      constexpr static int kHeadDim = 256; \
+      return __VA_ARGS__();                \
+    }                                      \
+  }()
diff --git a/operators/cuda/attention_lib/flash_attention/utils.h b/operators/cuda/attention_lib/flash_attention/utils.h
new file mode 100644
index 000000000..cd10bd534
--- /dev/null
+++ b/operators/cuda/attention_lib/flash_attention/utils.h
@@ -0,0 +1,499 @@
+/******************************************************************************
+ * Copyright (c) 2023, Tri Dao.
+ ******************************************************************************/
+#pragma once
+
+#include <assert.h>
+#include <stdint.h>
+#include <stdlib.h>
+
+#include <cuda_fp16.h>
+
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+#include <cuda_bf16.h>
+#endif
+
+#include <cute/algorithm/copy.hpp>
+#include <cute/algorithm/gemm.hpp>
+
+#include <cutlass/array.h>
+#include <cutlass/cutlass.h>
+#include <cutlass/numeric_conversion.h>
+#include <cutlass/numeric_types.h>
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+namespace flash {
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <typename T>
+inline __device__ uint32_t relu2(const uint32_t x);
+
+template <>
+inline __device__ uint32_t relu2<cutlass::half_t>(const uint32_t x) {
+  uint32_t res;
+  const uint32_t zero = 0u;
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+  asm volatile("max.f16x2 %0, %1, %2;\n"
+               : "=r"(res)
+               : "r"(x), "r"(zero));
+#else
+  asm volatile(
+      "{\n"
+      "\t .reg .f16x2 sela;\n"
+      "\t set.gtu.u32.f16x2 sela, %1, %2;\n"
+      "\t and.b32 %0, sela, %1;\n"
+      "}\n"
+      : "=r"(res)
+      : "r"(x), "r"(zero));
+#endif
+  return res;
+}
+
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+template <>
+inline __device__ uint32_t relu2<cutlass::bfloat16_t>(const uint32_t x) {
+  uint32_t res;
+  const uint32_t zero = 0u;
+  asm volatile("max.bf16x2 %0, %1, %2;\n"
+               : "=r"(res)
+               : "r"(x), "r"(zero));
+  return res;
+}
+#endif
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+
+template <typename T>
+inline __device__ uint32_t convert_relu2(const float2 x);
+
+template <>
+inline __device__ uint32_t convert_relu2<cutlass::half_t>(const float2 x) {
+  uint32_t res;
+  const uint32_t a = reinterpret_cast<const uint32_t&>(x.x);
+  const uint32_t b = reinterpret_cast<const uint32_t&>(x.y);
+  asm volatile("cvt.rn.relu.f16x2.f32 %0, %1, %2;\n"
+               : "=r"(res)
+               : "r"(b), "r"(a));
+  return res;
+}
+
+template <>
+inline __device__ uint32_t convert_relu2<cutlass::bfloat16_t>(const float2 x) {
+  uint32_t res;
+  const uint32_t a = reinterpret_cast<const uint32_t&>(x.x);
+  const uint32_t b = reinterpret_cast<const uint32_t&>(x.y);
+  asm volatile("cvt.rn.relu.bf16x2.f32 %0, %1, %2;\n"
+               : "=r"(res)
+               : "r"(b), "r"(a));
+  return res;
+}
+
+#endif
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <typename T>
+struct MaxOp {
+  __device__ inline T operator()(T const& x, T const& y) { return x > y ? x : y; }
+};
+
+template <>
+struct MaxOp<float> {
+  // This is slightly faster
+  __device__ inline float operator()(float const& x, float const& y) { return max(x, y); }
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <typename T>
+struct SumOp {
+  __device__ inline T operator()(T const& x, T const& y) { return x + y; }
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <int THREADS>
+struct Allreduce {
+  static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
+  template <typename T, typename Operator>
+  static __device__ inline T run(T x, Operator& op) {
+    constexpr int OFFSET = THREADS / 2;
+    x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
+    return Allreduce<OFFSET>::run(x, op);
+  }
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <>
+struct Allreduce<2> {
+  template <typename T, typename Operator>
+  static __device__ inline T run(T x, Operator& op) {
+    x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
+    return x;
+  }
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <bool A_in_regs = false, bool B_in_regs = false, typename Tensor0, typename Tensor1,
+          typename Tensor2, typename Tensor3, typename Tensor4,
+          typename TiledMma, typename TiledCopyA, typename TiledCopyB,
+          typename ThrCopyA, typename ThrCopyB>
+inline __device__ void gemm(Tensor0& acc, Tensor1& tCrA, Tensor2& tCrB, Tensor3 const& tCsA,
+                            Tensor4 const& tCsB, TiledMma tiled_mma,
+                            TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B,
+                            ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) {
+  CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc));   // MMA_M
+  CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc));   // MMA_N
+  CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB));  // MMA_K
+  Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA);
+  CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view));  // M
+  Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
+  CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view));  // N
+  if (!A_in_regs) {
+    cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{}));
+  }
+  if (!B_in_regs) {
+    cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{}));
+  }
+#pragma unroll
+  for (int i = 0; i < size<2>(tCrA); ++i) {
+    if (i < size<2>(tCrA) - 1) {
+      if (!A_in_regs) {
+        cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1));
+      }
+      if (!B_in_regs) {
+        cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1));
+      }
+    }
+    cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
+  }
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
+          typename TiledMma, typename TiledCopy, typename ThrCopy>
+inline __device__ void gemm_A_in_regs(Tensor0& acc, Tensor1& tCrA, Tensor2& tCrB, Tensor3 const& tCsB,
+                                      TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
+                                      ThrCopy smem_thr_copy_B) {
+  CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc));   // MMA_M
+  CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc));   // MMA_N
+  CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB));  // MMA_K
+  Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
+  CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view));  // N
+  cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{}));
+#pragma unroll
+  for (int i = 0; i < size<2>(tCrA); ++i) {
+    if (i < size<2>(tCrA) - 1) {
+      cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1));
+    }
+    cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
+  }
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
+template <typename Layout>
+inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
+  static_assert(decltype(size<0>(acc_layout))::value == 4);
+  static_assert(decltype(rank(acc_layout))::value == 3);
+  auto l = logical_divide(acc_layout, Shape<_2>{});  // ((2, 2), MMA_M, MMA_N)
+                                                     // TD [2023-08-13]: Idk why but get<0, 1>(l) doesn't work for Cutlass 3.2, I'm getting
+  // "int_tuple.hpp(74): error: conversion to inaccessible base class"
+  // return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
+  return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)), make_layout(get<0>(get<0>(l)), get<2>(l)));
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+// Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
+// if using m16n8k16, or to ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
+template <typename MMA_traits, typename Layout>
+inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) {
+  using X = Underscore;
+  static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2);
+  static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2);
+  constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{});
+  static_assert(mma_shape_K == 8 || mma_shape_K == 16);
+  constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2;
+  auto l = logical_divide(rowcol_layout, Shape<X, Shape<X, Int<MMA_N_divisor>>>{});  // ((2, MMA_M), (2, (2, MMA_N / 2)))
+                                                                                     // TD [2023-08-13]: Same error as above on Cutlass 3.2
+  // return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)),
+  //                    get<0, 1>(l),
+  //                    get<1, 1, 1>(l));
+  return make_layout(make_layout(get<0>(get<1>(l)), get<0>(get<0>(l)), get<0>(get<1>(get<1>(l)))),
+                     get<1>(get<0>(l)),
+                     get<1>(get<1>(get<1>(l))));
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <typename To_type, typename Engine, typename Layout>
+inline __device__ auto convert_type(Tensor<Engine, Layout> const& tensor) {
+  using From_type = typename Engine::value_type;
+  constexpr int numel = decltype(size(tensor))::value;
+  cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
+  // HACK: this requires tensor to be "contiguous"
+  auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel>*>(tensor.data()));
+  return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <typename Engine, typename Layout>
+inline __device__ void relu_(Tensor<Engine, Layout>& tensor) {
+  constexpr int numel = decltype(size(tensor))::value;
+  static_assert(numel % 2 == 0);
+  using value_t = typename Engine::value_type;
+  // HACK: this requires tensor to be "contiguous"
+  Tensor tensor_uint32 = recast<uint32_t>(tensor);
+#pragma unroll
+  for (int i = 0; i < size(tensor_uint32); ++i) {
+    tensor_uint32(i) = relu2<value_t>(tensor_uint32(i));
+  }
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+// On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction
+template <typename To_type, typename Engine, typename Layout>
+inline __device__ auto convert_type_relu(Tensor<Engine, Layout> const& tensor) {
+  using From_type = typename Engine::value_type;
+  static_assert(std::is_same_v<To_type, cutlass::half_t> || std::is_same_v<To_type, cutlass::bfloat16_t>);
+  static_assert(std::is_same_v<float, From_type>);
+  constexpr int numel = decltype(size(tensor))::value;
+  static_assert(numel % 2 == 0);
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+  // HACK: this requires tensor to be "contiguous"
+  Tensor tensor_float2 = recast<float2>(tensor);
+  Tensor out_uint32 = make_tensor<uint32_t>(tensor_float2.layout());
+#pragma unroll
+  for (int i = 0; i < size(out_uint32); ++i) {
+    out_uint32(i) = convert_relu2<To_type>(tensor_float2(i));
+  }
+  Tensor out = make_tensor(make_rmem_ptr<To_type>(out_uint32.data()), tensor.layout());
+#else
+  Tensor out = flash::convert_type<To_type>(tensor);
+  flash::relu_(out);
+#endif
+  return out;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+// Blocks until all but N previous cp.async.commit_group operations have committed.
+// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all
+// (which is equivalent to commit_group then wait_group 0).
+// Instead we just call cp.async.wait_group 0, which is slightly faster.
+// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113
+template <int N>
+CUTE_HOST_DEVICE void cp_async_wait() {
+#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
+  asm volatile("cp.async.wait_group %0;\n" ::"n"(N));
+#endif
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <bool Is_even_MN = true, bool Is_even_K = true, bool Clear_OOB_MN = false, bool Clear_OOB_K = true,
+          typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
+          typename Engine2, typename Layout2, typename Engine3, typename Layout3>
+inline __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const& S,
+                            Tensor<Engine1, Layout1>& D, Tensor<Engine2, Layout2> const& identity_MN,
+                            Tensor<Engine3, Layout3> const& predicate_K, const int max_MN = 0) {
+  CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
+  CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
+  CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D));  // MMA
+  CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D));  // MMA_M
+  CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D));  // MMA_K
+  // There's no case where !Clear_OOB_K && Clear_OOB_MN
+  static_assert(!(Clear_OOB_MN && !Clear_OOB_K));
+#pragma unroll
+  for (int m = 0; m < size<1>(S); ++m) {
+    if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
+#pragma unroll
+      for (int k = 0; k < size<2>(S); ++k) {
+        if (Is_even_K || predicate_K(k)) {
+          cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
+        } else if (Clear_OOB_K) {
+          cute::clear(D(_, m, k));
+        }
+      }
+    } else if (Clear_OOB_MN) {
+      cute::clear(D(_, m, _));
+    }
+  }
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <bool Is_even_K = true,
+          typename Engine0, typename Layout0, typename Engine1, typename Layout1,
+          typename Engine2, typename Layout2, typename Engine3, typename Layout3>
+inline __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const& S,
+                                      Tensor<Engine1, Layout1>& D, Tensor<Engine2, Layout2> const& identity_MN,
+                                      Tensor<Engine3, Layout3> const& predicate_K,
+                                      const int max_MN = 0, const int min_MN = 0) {
+  CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
+  CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
+  CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D));  // MMA
+  CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D));  // MMA_M
+  CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D));  // MMA_K
+// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); }
+#pragma unroll
+  for (int m = 0; m < size<1>(S); ++m) {
+    // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); }
+    if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
+// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); }
+#pragma unroll
+      for (int k = 0; k < size<2>(S); ++k) {
+        if (Is_even_K || predicate_K(k)) {
+          cute::copy(S(_, m, k), D(_, m, k));
+        }
+      }
+    }
+  }
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <bool Is_even_K = true, bool Clear_OOB_K = true,
+          typename Engine0, typename Layout0, typename Engine1, typename Layout1,
+          typename Engine2, typename Layout2, typename Engine3, typename Layout3>
+inline __device__ void copy_rotary_interleaved(Tensor<Engine0, Layout0> const& S,
+                                               Tensor<Engine1, Layout1>& D,
+                                               Tensor<Engine2, Layout2> const& Cos,
+                                               Tensor<Engine2, Layout2> const& Sin,
+                                               Tensor<Engine3, Layout3> const& identity_MN,
+                                               const int max_MN, const int min_MN,
+                                               const int dim, const int rotary_dim) {
+  CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
+  CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
+  CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D));      // MMA
+  CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D));      // MMA_M
+  CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D));      // MMA_K
+  CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos));    // MMA_M
+  CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos));    // MMA_K
+  CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin));    // MMA_M
+  CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin));    // MMA_K
+  CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin));  // MMA_K
+  static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2);
+  static_assert(decltype(size<0>(Cos))::value % 2 == 0);  // Since we do fast conversion from fp16/bf16 to fp32
+  Tensor rCos = make_fragment_like(Cos);
+  Tensor rSin = make_fragment_like(Sin);
+  Tensor rS = make_fragment_like(S);
+#pragma unroll
+  for (int m = 0; m < size<1>(S); ++m) {
+    if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
+#pragma unroll
+      for (int k = 0; k < size<2>(S); ++k) {
+        if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) {
+          cute::copy(S(_, m, k), rS(_, m, k));
+          if (get<1>(identity_MN(0, 0, k)) < rotary_dim) {
+            cute::copy(Cos(_, m, k), rCos(_, m, k));
+            cute::copy(Sin(_, m, k), rSin(_, m, k));
+            Tensor S_fp32 = convert_type<float>(rS(_, m, k));
+            Tensor cos_fp32 = convert_type<float>(rCos(_, m, k));
+            Tensor sin_fp32 = convert_type<float>(rSin(_, m, k));
+#pragma unroll
+            for (int i = 0; i < size<0>(rS) / 2; ++i) {
+              float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i);
+              float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i);
+              S_fp32(2 * i) = real;
+              S_fp32(2 * i + 1) = imag;
+            }
+            // Idk but I need to copy for the convert_type to work
+            Tensor S_fp32_copy = make_fragment_like(S_fp32);
+            cute::copy(S_fp32, S_fp32_copy);
+            using T = typename Engine0::value_type;
+            Tensor S_og_type = convert_type<T>(S_fp32_copy);
+            cute::copy(S_og_type, rS(_, m, k));
+          }
+          cute::copy(rS(_, m, k), D(_, m, k));
+        } else if (Clear_OOB_K) {
+          cute::clear(D(_, m, k));
+        }
+      }
+    }
+  }
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <bool Is_even_K = true, bool Clear_OOB_K = true,
+          typename Engine0, typename Layout0, typename Engine1, typename Layout1,
+          typename Engine2, typename Layout2, typename Engine3, typename Layout3>
+inline __device__ void copy_rotary_contiguous(Tensor<Engine0, Layout0> const& S,
+                                              Tensor<Engine1, Layout1>& D,
+                                              Tensor<Engine2, Layout2> const& Cos,
+                                              Tensor<Engine2, Layout2> const& Sin,
+                                              Tensor<Engine3, Layout3> const& identity_MN,
+                                              const int max_MN, const int min_MN,
+                                              const int dim, const int rotary_dim) {
+  CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
+  CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
+  CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D));    // MMA
+  CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D));    // MMA_M
+  CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D));    // MMA_K
+  CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos));  // MMA_M
+  CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos));  // MMA_K
+  CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin));  // MMA_M
+  CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin));  // MMA_K
+  CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos));  // MMA
+  CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin));
+  static_assert(decltype(size<0>(Cos))::value % 2 == 0);  // Since we do fast conversion from fp16/bf16 to fp32
+  Tensor rCos = make_fragment_like(Cos);
+  Tensor rSin = make_fragment_like(Sin);
+  Tensor rS = make_fragment_like(S);
+  Tensor rS_other = make_fragment_like(rS(_, 0, 0));
+#pragma unroll
+  for (int m = 0; m < size<1>(S); ++m) {
+    if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
+#pragma unroll
+      for (int k = 0; k < size<2>(S); ++k) {
+        if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) {
+          cute::copy(S(_, m, k), rS(_, m, k));
+          if (get<1>(identity_MN(0, 0, k)) < rotary_dim) {
+            const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2;
+            Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout());
+            cute::copy(gS_other, rS_other);
+            // if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); }
+            Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout());
+            Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout());
+            cute::copy(gCos, rCos(_, m, k));
+            cute::copy(gSin, rSin(_, m, k));
+            // if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); }
+            Tensor S_fp32 = convert_type<float>(rS(_, m, k));
+            Tensor S_other_fp32 = convert_type<float>(rS_other);
+            Tensor cos_fp32 = convert_type<float>(rCos(_, m, k));
+            Tensor sin_fp32 = convert_type<float>(rSin(_, m, k));
+#pragma unroll
+            for (int i = 0; i < size<0>(rS); ++i) {
+              S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i));
+            }
+            // Idk but I need to copy for the convert_type to work
+            Tensor S_fp32_copy = make_fragment_like(S_fp32);
+            cute::copy(S_fp32, S_fp32_copy);
+            using T = typename Engine0::value_type;
+            Tensor S_og_type = convert_type<T>(S_fp32_copy);
+            cute::copy(S_og_type, rS(_, m, k));
+            // if (cute::thread0()) { print_tensor(rS(_, m, k)); }
+          }
+          cute::copy(rS(_, m, k), D(_, m, k));
+        } else if (Clear_OOB_K) {
+          cute::clear(D(_, m, k));
+        }
+      }
+    }
+  }
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+}  // namespace flash
diff --git a/operators/contrib/contrib.cc b/operators/cuda/cuda_ops.cc
similarity index 100%
rename from operators/contrib/contrib.cc
rename to operators/cuda/cuda_ops.cc
diff --git a/operators/contrib/cuda/cuda_type.h b/operators/cuda/cuda_type.h
similarity index 100%
rename from operators/contrib/cuda/cuda_type.h
rename to operators/cuda/cuda_type.h
diff --git a/operators/contrib/cuda/device_prop.cuh b/operators/cuda/device_prop.cuh
similarity index 100%
rename from operators/contrib/cuda/device_prop.cuh
rename to operators/cuda/device_prop.cuh
diff --git a/operators/contrib/cuda/fast_gelu.h b/operators/cuda/fast_gelu.h
similarity index 100%
rename from operators/contrib/cuda/fast_gelu.h
rename to operators/cuda/fast_gelu.h
diff --git a/operators/contrib/cuda/fast_gelu_impl.cu b/operators/cuda/fast_gelu_impl.cu
similarity index 100%
rename from operators/contrib/cuda/fast_gelu_impl.cu
rename to operators/cuda/fast_gelu_impl.cu
diff --git a/operators/contrib/cuda/fast_gelu_impl.cuh b/operators/cuda/fast_gelu_impl.cuh
similarity index 100%
rename from operators/contrib/cuda/fast_gelu_impl.cuh
rename to operators/cuda/fast_gelu_impl.cuh
diff --git a/operators/contrib/cuda/scatter_nd_of_shape.cu b/operators/cuda/scatter_nd_of_shape.cu
similarity index 100%
rename from operators/contrib/cuda/scatter_nd_of_shape.cu
rename to operators/cuda/scatter_nd_of_shape.cu
diff --git a/operators/contrib/cuda/scatter_nd_of_shape.h b/operators/cuda/scatter_nd_of_shape.h
similarity index 100%
rename from operators/contrib/cuda/scatter_nd_of_shape.h
rename to operators/cuda/scatter_nd_of_shape.h
diff --git a/operators/contrib/cuda/utils.cuh b/operators/cuda/utils.cuh
similarity index 100%
rename from operators/contrib/cuda/utils.cuh
rename to operators/cuda/utils.cuh
diff --git a/test/static_test/test_cuda_eager.cc b/test/static_test/test_cuda_eager.cc
index 65de140de..3faf1e67f 100644
--- a/test/static_test/test_cuda_eager.cc
+++ b/test/static_test/test_cuda_eager.cc
@@ -9,7 +9,7 @@
 
 #ifdef USE_CUDA
 #include "math/cuda/negpos_def.h"
-#include "contrib/cuda/fast_gelu.h"
+#include "cuda/fast_gelu.h"
 #include <cuda.h>
 #include <cuda_runtime.h>