From e645cdab8d423b63f9a8641f1b1b29a00d9b0a67 Mon Sep 17 00:00:00 2001 From: cao lei Date: Sun, 5 May 2024 22:52:28 -0700 Subject: [PATCH] Introduce flash attention and cutlass library (#708) * refactor cuda ops, remove contrib folder * introduce flash attention and cutlass * resolve comments --------- Co-authored-by: Lei Cao --- CMakeLists.txt | 6 + cmake/ext_cuda.cmake | 18 + cmake/externals/cutlass.cmake | 10 + .../cutlass_fmha/fmha_launch_template.h | 276 ++++ .../attention_lib/cutlass_fmha/fmha_sm50.cu | 22 + .../attention_lib/cutlass_fmha/fmha_sm70.cu | 22 + .../attention_lib/cutlass_fmha/fmha_sm75.cu | 22 + .../attention_lib/cutlass_fmha/fmha_sm80.cu | 22 + .../memory_efficient_attention.cu | 29 + .../cutlass_fmha/memory_efficient_attention.h | 59 + .../flash_attention/block_info.h | 44 + .../attention_lib/flash_attention/flash.h | 114 ++ .../flash_attention/flash_api.cc | 465 ++++++ .../attention_lib/flash_attention/flash_api.h | 92 ++ .../flash_fwd_hdim128_bf16_sm80.cu | 16 + .../flash_fwd_hdim128_fp16_sm80.cu | 16 + .../flash_fwd_hdim160_bf16_sm80.cu | 16 + .../flash_fwd_hdim160_fp16_sm80.cu | 16 + .../flash_fwd_hdim192_bf16_sm80.cu | 16 + .../flash_fwd_hdim192_fp16_sm80.cu | 16 + .../flash_fwd_hdim224_bf16_sm80.cu | 16 + .../flash_fwd_hdim224_fp16_sm80.cu | 16 + .../flash_fwd_hdim256_bf16_sm80.cu | 16 + .../flash_fwd_hdim256_fp16_sm80.cu | 16 + .../flash_fwd_hdim32_bf16_sm80.cu | 16 + .../flash_fwd_hdim32_fp16_sm80.cu | 16 + .../flash_fwd_hdim64_bf16_sm80.cu | 16 + .../flash_fwd_hdim64_fp16_sm80.cu | 16 + .../flash_fwd_hdim96_bf16_sm80.cu | 16 + .../flash_fwd_hdim96_fp16_sm80.cu | 16 + .../flash_attention/flash_fwd_kernel.h | 1259 +++++++++++++++++ .../flash_fwd_launch_template.h | 294 ++++ .../flash_fwd_split_hdim128_bf16_sm80.cu | 13 + .../flash_fwd_split_hdim128_fp16_sm80.cu | 13 + .../flash_fwd_split_hdim160_bf16_sm80.cu | 13 + .../flash_fwd_split_hdim160_fp16_sm80.cu | 13 + .../flash_fwd_split_hdim192_bf16_sm80.cu | 13 + .../flash_fwd_split_hdim192_fp16_sm80.cu | 13 + .../flash_fwd_split_hdim224_bf16_sm80.cu | 13 + .../flash_fwd_split_hdim224_fp16_sm80.cu | 13 + .../flash_fwd_split_hdim256_bf16_sm80.cu | 13 + .../flash_fwd_split_hdim256_fp16_sm80.cu | 13 + .../flash_fwd_split_hdim32_bf16_sm80.cu | 13 + .../flash_fwd_split_hdim32_fp16_sm80.cu | 13 + .../flash_fwd_split_hdim64_bf16_sm80.cu | 13 + .../flash_fwd_split_hdim64_fp16_sm80.cu | 13 + .../flash_fwd_split_hdim96_bf16_sm80.cu | 13 + .../flash_fwd_split_hdim96_fp16_sm80.cu | 13 + .../flash_attention/kernel_traits.h | 367 +++++ .../attention_lib/flash_attention/softmax.h | 215 +++ .../flash_attention/static_switch.h | 64 + .../attention_lib/flash_attention/utils.h | 499 +++++++ 52 files changed, 4363 insertions(+) create mode 100644 cmake/externals/cutlass.cmake create mode 100644 operators/cuda/attention_lib/cutlass_fmha/fmha_launch_template.h create mode 100644 operators/cuda/attention_lib/cutlass_fmha/fmha_sm50.cu create mode 100644 operators/cuda/attention_lib/cutlass_fmha/fmha_sm70.cu create mode 100644 operators/cuda/attention_lib/cutlass_fmha/fmha_sm75.cu create mode 100644 operators/cuda/attention_lib/cutlass_fmha/fmha_sm80.cu create mode 100644 operators/cuda/attention_lib/cutlass_fmha/memory_efficient_attention.cu create mode 100644 operators/cuda/attention_lib/cutlass_fmha/memory_efficient_attention.h create mode 100644 operators/cuda/attention_lib/flash_attention/block_info.h create mode 100644 operators/cuda/attention_lib/flash_attention/flash.h create mode 100644 operators/cuda/attention_lib/flash_attention/flash_api.cc create mode 100644 operators/cuda/attention_lib/flash_attention/flash_api.h create mode 100644 operators/cuda/attention_lib/flash_attention/flash_fwd_hdim128_bf16_sm80.cu create mode 100644 operators/cuda/attention_lib/flash_attention/flash_fwd_hdim128_fp16_sm80.cu create mode 100644 operators/cuda/attention_lib/flash_attention/flash_fwd_hdim160_bf16_sm80.cu create mode 100644 operators/cuda/attention_lib/flash_attention/flash_fwd_hdim160_fp16_sm80.cu create mode 100644 operators/cuda/attention_lib/flash_attention/flash_fwd_hdim192_bf16_sm80.cu create mode 100644 operators/cuda/attention_lib/flash_attention/flash_fwd_hdim192_fp16_sm80.cu create mode 100644 operators/cuda/attention_lib/flash_attention/flash_fwd_hdim224_bf16_sm80.cu create mode 100644 operators/cuda/attention_lib/flash_attention/flash_fwd_hdim224_fp16_sm80.cu create mode 100644 operators/cuda/attention_lib/flash_attention/flash_fwd_hdim256_bf16_sm80.cu create mode 100644 operators/cuda/attention_lib/flash_attention/flash_fwd_hdim256_fp16_sm80.cu create mode 100644 operators/cuda/attention_lib/flash_attention/flash_fwd_hdim32_bf16_sm80.cu create mode 100644 operators/cuda/attention_lib/flash_attention/flash_fwd_hdim32_fp16_sm80.cu create mode 100644 operators/cuda/attention_lib/flash_attention/flash_fwd_hdim64_bf16_sm80.cu create mode 100644 operators/cuda/attention_lib/flash_attention/flash_fwd_hdim64_fp16_sm80.cu create mode 100644 operators/cuda/attention_lib/flash_attention/flash_fwd_hdim96_bf16_sm80.cu create mode 100644 operators/cuda/attention_lib/flash_attention/flash_fwd_hdim96_fp16_sm80.cu create mode 100644 operators/cuda/attention_lib/flash_attention/flash_fwd_kernel.h create mode 100644 operators/cuda/attention_lib/flash_attention/flash_fwd_launch_template.h create mode 100644 operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim128_bf16_sm80.cu create mode 100644 operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim128_fp16_sm80.cu create mode 100644 operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim160_bf16_sm80.cu create mode 100644 operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim160_fp16_sm80.cu create mode 100644 operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim192_bf16_sm80.cu create mode 100644 operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim192_fp16_sm80.cu create mode 100644 operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim224_bf16_sm80.cu create mode 100644 operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim224_fp16_sm80.cu create mode 100644 operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim256_bf16_sm80.cu create mode 100644 operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim256_fp16_sm80.cu create mode 100644 operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim32_bf16_sm80.cu create mode 100644 operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim32_fp16_sm80.cu create mode 100644 operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim64_bf16_sm80.cu create mode 100644 operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim64_fp16_sm80.cu create mode 100644 operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim96_bf16_sm80.cu create mode 100644 operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim96_fp16_sm80.cu create mode 100644 operators/cuda/attention_lib/flash_attention/kernel_traits.h create mode 100644 operators/cuda/attention_lib/flash_attention/softmax.h create mode 100644 operators/cuda/attention_lib/flash_attention/static_switch.h create mode 100644 operators/cuda/attention_lib/flash_attention/utils.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 49843d61f..49b6eb94e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -44,6 +44,7 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) include(CheckCXXCompilerFlag) include(CheckLanguage) +include(CMakeDependentOption) set(_ORTX_STANDALONE_PROJECT OFF) if (CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_SOURCE_DIR) @@ -299,6 +300,7 @@ endmacro() if(OCOS_USE_CUDA) include(ext_cuda) + include(cutlass) endif() ####################################################################################################################### @@ -581,6 +583,10 @@ target_include_directories(ocos_operators PUBLIC ${PROJECT_SOURCE_DIR}/base ${PROJECT_SOURCE_DIR}/operators) +if (OCOS_USE_CUDA) + target_include_directories(ocos_operators PUBLIC ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples) +endif() + set(ocos_libraries) set(OCOS_COMPILE_DEFINITIONS) diff --git a/cmake/ext_cuda.cmake b/cmake/ext_cuda.cmake index 15e66ff99..aa7d3282c 100644 --- a/cmake/ext_cuda.cmake +++ b/cmake/ext_cuda.cmake @@ -6,6 +6,13 @@ enable_language(CUDA) set(CMAKE_CUDA_RUNTIME_LIBRARY Shared) set(CMAKE_CUDA_STANDARD 17) +cmake_dependent_option(OCOS_USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "NOT WIN32" OFF) +option(OCOS_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON) +if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6) + message(STATUS "Turn off flash attention and memory efficient attention since CUDA compiler version < 11.6") + set(OCOS_USE_FLASH_ATTENTION OFF) + set(OCOS_USE_MEMORY_EFFICIENT_ATTENTION OFF) +endif() set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 11) @@ -22,3 +29,14 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcudafe \"--diag_suppress=unsigned_co set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcudafe \"--diag_suppress=expr_has_no_effect\"") add_compile_definitions(USE_CUDA) + +set(OCOS_USE_MEMORY_EFFICIENT_ATTENTION OFF) # turn off for the build time. Turn them on when these 2 libs are really in use +set(OCOS_USE_FLASH_ATTENTION OFF) +if (OCOS_USE_FLASH_ATTENTION) + message(STATUS "Enable flash attention") + add_compile_definitions(OCOS_USE_FLASH_ATTENTION) +endif() +if (OCOS_USE_MEMORY_EFFICIENT_ATTENTION) + message(STATUS "Enable memory efficient attention") + add_compile_definitions(OCOS_USE_MEMORY_EFFICIENT_ATTENTION) +endif() diff --git a/cmake/externals/cutlass.cmake b/cmake/externals/cutlass.cmake new file mode 100644 index 000000000..e36d42c89 --- /dev/null +++ b/cmake/externals/cutlass.cmake @@ -0,0 +1,10 @@ +FetchContent_Declare( + cutlass + GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git + GIT_TAG v3.1.0 +) + +FetchContent_GetProperties(cutlass) +if(NOT cutlass_POPULATED) + FetchContent_Populate(cutlass) +endif() diff --git a/operators/cuda/attention_lib/cutlass_fmha/fmha_launch_template.h b/operators/cuda/attention_lib/cutlass_fmha/fmha_launch_template.h new file mode 100644 index 000000000..7a120ff16 --- /dev/null +++ b/operators/cuda/attention_lib/cutlass_fmha/fmha_launch_template.h @@ -0,0 +1,276 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if OCOS_USE_MEMORY_EFFICIENT_ATTENTION + +#if defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif + +#include "memory_efficient_attention.h" +#include "41_fused_multi_head_attention/kernel_forward.h" + +namespace ort_extensions { +namespace cuda { + +template +struct RightPaddingBatchHook { + using scalar_t = typename AttentionKernel::scalar_t; + using accum_t = typename AttentionKernel::accum_t; + using lse_scalar_t = typename AttentionKernel::lse_scalar_t; + using output_t = typename AttentionKernel::output_t; + using output_accum_t = typename AttentionKernel::output_accum_t; + + static constexpr bool kSupportsDropout = AttentionKernel::kSupportsDropout; + static constexpr bool kSupportsBias = AttentionKernel::kSupportsBias; + static constexpr int kKeysPerBlock = AttentionKernel::kKeysPerBlock; + static constexpr bool kIsAligned = AttentionKernel::kIsAligned; + static constexpr bool kSingleValueIteration = AttentionKernel::kSingleValueIteration; + static constexpr int32_t kAlignLSE = AttentionKernel::kAlignLSE; // block size of backward + static constexpr bool kPreloadV = AttentionKernel::kPreloadV; + static constexpr bool kKeepOutputInRF = AttentionKernel::kKeepOutputInRF; + static constexpr bool kNeedsOutputAccumulatorBuffer = AttentionKernel::kNeedsOutputAccumulatorBuffer; + + template + static CUTLASS_DEVICE bool AdvanceToBlockForGQA(Params& p) { + auto batch_id = blockIdx.z; + auto head_id = blockIdx.y; + auto query_start = blockIdx.x * kQueriesPerBlock; + + auto lse_dim = ceil_div((int32_t)(p.num_queries), kAlignLSE) * kAlignLSE; + + // Advance to current batch - in case of different sequence lengths + if (p.seqlen_k_ptr) { + p.num_keys = p.seqlen_k_ptr[batch_id]; + } + + if (query_start >= p.num_queries) { + return false; + } + + // Advance to the current batch / head / query_start + p.query_ptr += batch_id * p.q_strideB + query_start * p.q_strideM + head_id * p.q_strideH; + p.key_ptr += batch_id * p.k_strideB + head_id * p.k_strideH; + p.value_ptr += batch_id * p.v_strideB + head_id * p.v_strideH; + p.output_ptr += int64_t(batch_id * p.num_queries) * p.o_strideM + int64_t(query_start) * p.o_strideM + head_id * p.head_dim_value; + + if (kSupportsBias && p.attn_bias_ptr != nullptr) { + p.attn_bias_ptr += (batch_id * p.bias_strideB) + (head_id * p.bias_strideH); + } + if (p.output_accum_ptr != nullptr) { + p.output_accum_ptr += int64_t(batch_id * p.num_queries) * (p.head_dim_value * p.num_heads) + + int64_t(query_start) * (p.head_dim_value * p.num_heads) + + head_id * p.head_dim_value; + } else { + // Accumulate directly in the destination buffer (eg for f32) + p.output_accum_ptr = (accum_t*)(p.output_ptr); + } + + if (p.logsumexp_ptr != nullptr) { + // lse[batch_id, head_id, query_start] + p.logsumexp_ptr += + batch_id * lse_dim * p.num_heads + head_id * lse_dim + query_start; + } + + // Custom masking + if (p.causal_diagonal_ptr) { + p.causal_diagonal_offset = p.causal_diagonal_ptr[batch_id]; + } + if (p.custom_mask_type == AttentionKernel::CausalFromBottomRight) { + p.causal_diagonal_offset += p.num_keys - p.num_queries; + } + if (p.custom_mask_type == AttentionKernel::CausalFromTopLeft || + p.custom_mask_type == AttentionKernel::CausalFromBottomRight) { + // the bottom row of the current block is query_start + kQueriesPerBlock + // the last active key is then query_start + causal_diagonal_offset + + // kQueriesPerBlock so num_keys is the min between actual num_keys and + // this to avoid extra computations + p.num_keys = cutlass::fast_min( + int32_t(query_start + p.causal_diagonal_offset + kQueriesPerBlock), + p.num_keys); + } + + p.num_queries -= query_start; + p.num_batches = 0; // no longer used after + + // If num_queries == 1, and there is only one key head we're wasting + // 15/16th of tensor core compute In that case : + // - we only launch kernels for head_id % kQueriesPerBlock == 0 + // - we iterate over heads instead of queries (strideM = strideH) + if (p.num_queries == 1 && p.k_strideH == 0 && p.v_strideH == 0) { + if (head_id % kQueriesPerBlock != 0) + return false; + p.q_strideM = p.q_strideH; + p.num_queries = p.num_heads; + p.num_heads = 1; // unused but here for intent + // remove causal since n_query = 1 + // otherwise, offset would change with head ! + p.custom_mask_type = AttentionKernel::NoCustomMask; + p.o_strideM = p.head_dim_value; + } + + // Make sure the compiler knows these variables are the same on all + // the threads of the warp. + p.query_ptr = warp_uniform(p.query_ptr); + p.key_ptr = warp_uniform(p.key_ptr); + p.value_ptr = warp_uniform(p.value_ptr); + if (kSupportsBias) { + p.attn_bias_ptr = warp_uniform(p.attn_bias_ptr); + } + p.output_ptr = warp_uniform(p.output_ptr); + p.output_accum_ptr = warp_uniform(p.output_accum_ptr); + p.logsumexp_ptr = warp_uniform(p.logsumexp_ptr); + p.num_queries = warp_uniform(p.num_queries); + p.num_keys = warp_uniform(p.num_keys); + p.num_heads = warp_uniform(p.num_heads); + p.head_dim = warp_uniform(p.head_dim); + p.head_dim_value = warp_uniform(p.head_dim_value); + p.o_strideM = warp_uniform(p.o_strideM); + p.custom_mask_type = warp_uniform(p.custom_mask_type); + return true; + } +}; + +template +__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) + attention_kernel_batched_impl_right_padding(typename AK::Params p) { + if (!RightPaddingBatchHook::AdvanceToBlockForGQA(p)) { + return; + } + AK::attention_kernel(p); +} + +template +void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) { + using Attention = AttentionKernel; + typename Attention::Params p; + { // set parameters + p.query_ptr = const_cast(reinterpret_cast(params.query)); + p.key_ptr = const_cast(reinterpret_cast(params.key)); + p.value_ptr = const_cast(reinterpret_cast(params.value)); + p.attn_bias_ptr = const_cast(reinterpret_cast(params.attn_bias)); + p.seqstart_q_ptr = params.seqstart_q_ptr; + p.seqstart_k_ptr = params.seqstart_k_ptr; + p.seqlen_k_ptr = params.seqlen_k_ptr; + + p.logsumexp_ptr = nullptr; // [num_heads, num_queries] for backward or nullptr for forward + p.output_ptr = reinterpret_cast(params.output); + if (Attention::kNeedsOutputAccumulatorBuffer) { + using Acc = typename Attention::accum_t; + // workspace size: batch_size * sequence_length * num_heads * v_head_size * sizeof(float) + // TODO: ORT_ENFORCE(params.workspace != nullptr, "Need output accumulator buffer but no workspace provided"); + p.output_accum_ptr = reinterpret_cast(params.workspace); + } else { + p.output_accum_ptr = nullptr; + } + p.num_heads = params.num_heads; + p.num_batches = params.batch_size; + p.head_dim = params.qk_head_size; + p.head_dim_value = params.v_head_size; + + p.scale = params.scale; + + // When params.cu_seqlens_q is provided, num_queries is max_seq_q and num_keys will be set inside the kernel + p.num_queries = params.sequence_length; + p.num_keys = params.kv_sequence_length; + + if (params.causal) { + p.custom_mask_type = Attention::CausalFromBottomRight; + } + + // We use max_sequence_length to calculate KV stride + if (params.is_kv_bsnh) { + // Input Q, K, V format is BxSxNxH, output is BxSxNxH + p.q_strideH = params.qk_head_size; + p.k_strideH = params.qk_head_size; + p.v_strideH = params.v_head_size; + p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys; + + p.q_strideM = params.num_heads * params.qk_head_size; + p.k_strideM = params.num_heads * params.qk_head_size; + p.v_strideM = params.num_heads * params.v_head_size; + p.o_strideM = params.num_heads * params.v_head_size; + p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys; + + p.q_strideB = static_cast(p.q_strideM) * params.sequence_length; + p.k_strideB = static_cast(p.k_strideM) * params.max_sequence_length; + p.v_strideB = static_cast(p.v_strideM) * params.max_sequence_length; + p.bias_strideB = params.is_attn_bias_batched ? static_cast(p.bias_strideH) * params.num_heads : 0; + } else { + // Input K, V format is BxNxSxH, Input Q is BxSxNxH, output is BxSxNxH + p.q_strideH = params.qk_head_size; + p.k_strideH = params.max_sequence_length * params.qk_head_size; + p.v_strideH = params.max_sequence_length * params.v_head_size; + p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys; + + p.q_strideM = params.num_heads * params.qk_head_size; + p.k_strideM = params.qk_head_size; + p.v_strideM = params.v_head_size; + p.o_strideM = params.num_heads * params.v_head_size; + p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys; + + p.q_strideB = params.num_heads * params.qk_head_size * params.sequence_length; + p.k_strideB = params.num_heads * params.qk_head_size * params.max_sequence_length; + p.v_strideB = params.num_heads * params.v_head_size * params.max_sequence_length; + p.bias_strideB = params.is_attn_bias_batched ? static_cast(p.bias_strideH) * params.num_heads : 0; + } + } + + auto kernel_fn = attention_kernel_batched_impl; + if (params.has_custom_right_padding) { + kernel_fn = attention_kernel_batched_impl_right_padding; + } + + int smem_bytes = sizeof(typename Attention::SharedStorage); + if (smem_bytes > 0xc000) { + // TODO: ORT_ENFORCE(params.sm >= 70, "This kernel requires too much shared memory on this machine!"); + static bool once = [&]() { + cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); + return true; + }(); + } + + // TODO: ORT_ENFORCE(Attention::check_supported(p)); + kernel_fn<<>>(p); +} + +template +void DispatchIsAligned(const MemoryEfficientAttentionParams& params) { + using AlignedAK = AttentionKernel; +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(push) +#pragma warning(disable : 6287) +#endif + // Run a more efficient kernel with `isAligned=True` when memory is correctly aligned. + bool is_aligned = params.qk_head_size % AlignedAK::kAlignmentQ == 0 && + params.qk_head_size % AlignedAK::kAlignmentK == 0 && + params.v_head_size % AlignedAK::kAlignmentV == 0; +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(pop) +#endif + DISPATCH_BOOL(is_aligned, kIsAligned, ([&]() { + LaunchCutlassFmha(params); + })); +} + +template +void DispatchBlockSize(const MemoryEfficientAttentionParams& params) { + if (params.v_head_size <= 64) { + DispatchIsAligned(params); + } else if (params.v_head_size <= 128) { + DispatchIsAligned(params); + } else { + DispatchIsAligned(params); + } +} + +} // namespace cuda +} // namespace ort_extensions + +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#endif + +#endif // OCOS_USE_MEMORY_EFFICIENT_ATTENTION diff --git a/operators/cuda/attention_lib/cutlass_fmha/fmha_sm50.cu b/operators/cuda/attention_lib/cutlass_fmha/fmha_sm50.cu new file mode 100644 index 000000000..669164ae7 --- /dev/null +++ b/operators/cuda/attention_lib/cutlass_fmha/fmha_sm50.cu @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if OCOS_USE_MEMORY_EFFICIENT_ATTENTION + +#include "fmha_launch_template.h" + +namespace ort_extensions { +namespace cuda { + +void run_memory_efficient_attention_sm50(const MemoryEfficientAttentionParams& params) { + if (params.is_half) { + DispatchBlockSize(params); + } else { + DispatchBlockSize(params); + } +} + +} // namespace cuda +} // namespace ort_extensions + +#endif // OCOS_USE_MEMORY_EFFICIENT_ATTENTION diff --git a/operators/cuda/attention_lib/cutlass_fmha/fmha_sm70.cu b/operators/cuda/attention_lib/cutlass_fmha/fmha_sm70.cu new file mode 100644 index 000000000..f561b8fa6 --- /dev/null +++ b/operators/cuda/attention_lib/cutlass_fmha/fmha_sm70.cu @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if OCOS_USE_MEMORY_EFFICIENT_ATTENTION + +#include "fmha_launch_template.h" + +namespace ort_extensions { +namespace cuda { + +void run_memory_efficient_attention_sm70(const MemoryEfficientAttentionParams& params) { + if (params.is_half) { + DispatchBlockSize(params); + } else { + DispatchBlockSize(params); + } +} + +} // namespace cuda +} // namespace ort_extensions + +#endif // OCOS_USE_MEMORY_EFFICIENT_ATTENTION diff --git a/operators/cuda/attention_lib/cutlass_fmha/fmha_sm75.cu b/operators/cuda/attention_lib/cutlass_fmha/fmha_sm75.cu new file mode 100644 index 000000000..e55fe4269 --- /dev/null +++ b/operators/cuda/attention_lib/cutlass_fmha/fmha_sm75.cu @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if OCOS_USE_MEMORY_EFFICIENT_ATTENTION + +#include "fmha_launch_template.h" + +namespace ort_extensions { +namespace cuda { + +void run_memory_efficient_attention_sm75(const MemoryEfficientAttentionParams& params) { + if (params.is_half) { + DispatchBlockSize(params); + } else { + DispatchBlockSize(params); + } +} + +} // namespace cuda +} // namespace ort_extensions + +#endif // OCOS_USE_MEMORY_EFFICIENT_ATTENTION diff --git a/operators/cuda/attention_lib/cutlass_fmha/fmha_sm80.cu b/operators/cuda/attention_lib/cutlass_fmha/fmha_sm80.cu new file mode 100644 index 000000000..76d12de99 --- /dev/null +++ b/operators/cuda/attention_lib/cutlass_fmha/fmha_sm80.cu @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if OCOS_USE_MEMORY_EFFICIENT_ATTENTION + +#include "fmha_launch_template.h" + +namespace ort_extensions { +namespace cuda { + +void run_memory_efficient_attention_sm80(const MemoryEfficientAttentionParams& params) { + if (params.is_half) { + DispatchBlockSize(params); + } else { + DispatchBlockSize(params); + } +} + +} // namespace cuda +} // namespace ort_extensions + +#endif // OCOS_USE_MEMORY_EFFICIENT_ATTENTION diff --git a/operators/cuda/attention_lib/cutlass_fmha/memory_efficient_attention.cu b/operators/cuda/attention_lib/cutlass_fmha/memory_efficient_attention.cu new file mode 100644 index 000000000..c50a5b46d --- /dev/null +++ b/operators/cuda/attention_lib/cutlass_fmha/memory_efficient_attention.cu @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#if OCOS_USE_MEMORY_EFFICIENT_ATTENTION + +#include "memory_efficient_attention.h" +#include + +namespace ort_extensions { +namespace cuda { + +void run_memory_efficient_attention(const MemoryEfficientAttentionParams& params) { + const int32_t& sm = params.sm; + if (sm >= 80) { + run_memory_efficient_attention_sm80(params); + } else if (sm >= 75) { + run_memory_efficient_attention_sm75(params); + } else if (sm >= 70) { + run_memory_efficient_attention_sm70(params); + } else if (sm >= 50) { + run_memory_efficient_attention_sm50(params); + } else { + assert(false); // shall not reach here. + } +} + +} // namespace cuda +} // namespace ort_extensions + +#endif // OCOS_USE_MEMORY_EFFICIENT_ATTENTION diff --git a/operators/cuda/attention_lib/cutlass_fmha/memory_efficient_attention.h b/operators/cuda/attention_lib/cutlass_fmha/memory_efficient_attention.h new file mode 100644 index 000000000..082740ccb --- /dev/null +++ b/operators/cuda/attention_lib/cutlass_fmha/memory_efficient_attention.h @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once +#if OCOS_USE_MEMORY_EFFICIENT_ATTENTION +#include + +namespace ort_extensions { +namespace cuda { + +struct MemoryEfficientAttentionParams { + int32_t sm; + bool is_half; + bool is_kv_bsnh = true; + int32_t batch_size; + int32_t num_heads; + int32_t sequence_length; + int32_t kv_sequence_length; + int32_t max_sequence_length; + int32_t qk_head_size; + int32_t v_head_size; + bool causal; + // The default shape of attn_bias is [1, N, S, S*]. Sometimes we need to use [B, N, S, S*] in custom models. + bool is_attn_bias_batched; + + float scale; + + int32_t* seqstart_q_ptr; + int32_t* seqstart_k_ptr; + int32_t* seqlen_k_ptr; + + const void* query; // [B, S, N, H] + const void* key; // [B, L, N, H], where L is kv_sequence_length + const void* value; // [B, L, N, H_v] + const void* attn_bias; // [N, S, S*] or null + void* output; // [B, S, N, H_v] + void* workspace; // [B, S, N, H_v] when kNeedsOutputAccumulatorBuffer, nullptr otherwise + cudaStream_t stream; + + static bool need_workspace(size_t v_head_size, bool is_float) { + return (v_head_size > 128 && !is_float); + } + + bool has_custom_right_padding = false; +}; + +void run_memory_efficient_attention(const MemoryEfficientAttentionParams& params); + +inline bool has_memory_efficient_attention(int32_t sm, bool is_half) { + return sm >= (is_half ? 53 : 50); +} + +void run_memory_efficient_attention_sm80(const MemoryEfficientAttentionParams& params); +void run_memory_efficient_attention_sm75(const MemoryEfficientAttentionParams& params); +void run_memory_efficient_attention_sm70(const MemoryEfficientAttentionParams& params); +void run_memory_efficient_attention_sm50(const MemoryEfficientAttentionParams& params); + +} +} +#endif diff --git a/operators/cuda/attention_lib/flash_attention/block_info.h b/operators/cuda/attention_lib/flash_attention/block_info.h new file mode 100644 index 000000000..1ec632658 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/block_info.h @@ -0,0 +1,44 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +#pragma once + +namespace flash { +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct BlockInfo { + template + __device__ BlockInfo(const Params& params, const int bidb) + : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]), + sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb]), + actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q) + // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. + // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. + , + seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])), + actual_seqlen_k(seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) { + } + + template + inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; + } + + template + inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride; + } + + const int sum_s_q; + const int sum_s_k; + const int actual_seqlen_q; + // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0. + const int seqlen_k_cache; + const int actual_seqlen_k; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// +} // namespace flash diff --git a/operators/cuda/attention_lib/flash_attention/flash.h b/operators/cuda/attention_lib/flash_attention/flash.h new file mode 100644 index 000000000..603a6e068 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash.h @@ -0,0 +1,114 @@ +#pragma once +#include + +namespace flash { +struct Qkv_params { + using index_t = uint32_t; + // The QKV matrices. + void* __restrict__ q_ptr = nullptr; + void* __restrict__ k_ptr = nullptr; + void* __restrict__ v_ptr = nullptr; + + // The stride between rows of the Q, K and V matrices. + index_t q_batch_stride = 0; + index_t k_batch_stride = 0; + index_t v_batch_stride = 0; + index_t q_row_stride = 0; + index_t k_row_stride = 0; + index_t v_row_stride = 0; + index_t q_head_stride = 0; + index_t k_head_stride = 0; + index_t v_head_stride = 0; + + // The number of heads. + int h = 0; + int h_k = 0; + // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be + // different from nheads (query). + int h_h_k_ratio = 0; // precompute h / h_k, +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Flash_fwd_params : public Qkv_params { + // The O matrix (output). + void* __restrict__ o_ptr = nullptr; + void* __restrict__ oaccum_ptr = nullptr; + + // The stride between rows of O. + index_t o_batch_stride = 0; + index_t o_row_stride = 0; + index_t o_head_stride = 0; + + // The pointer to the P matrix. + void* __restrict__ p_ptr = nullptr; + + // The pointer to the softmax sum. + void* __restrict__ softmax_lse_ptr = nullptr; + void* __restrict__ softmax_lseaccum_ptr = nullptr; + + // The dimensions. + int b = 0; + int seqlen_q = 0; + int seqlen_k = 0; + int seqlen_knew = 0; + int d = 0; + int seqlen_q_rounded = 0; + int seqlen_k_rounded = 0; + int d_rounded = 0; + int rotary_dim = 0; + + // The scaling factors for the kernel. + float scale_softmax = 0.0; + float scale_softmax_log2 = 0.0; + + // array of length b+1 holding starting offset of each sequence. + int* __restrict__ cu_seqlens_q = nullptr; + int* __restrict__ cu_seqlens_k = nullptr; + + int* __restrict__ blockmask = nullptr; + + // The K_new and V_new matrices. + void* __restrict__ knew_ptr = nullptr; + void* __restrict__ vnew_ptr = nullptr; + + // The stride between rows of the Q, K and V matrices. + index_t knew_batch_stride = 0; + index_t vnew_batch_stride = 0; + index_t knew_row_stride = 0; + index_t vnew_row_stride = 0; + index_t knew_head_stride = 0; + index_t vnew_head_stride = 0; + + // The cos and sin matrices for rotary embedding. + void* __restrict__ rotary_cos_ptr = nullptr; + void* __restrict__ rotary_sin_ptr = nullptr; + + // The indices to index into the KV cache. + int* __restrict__ cache_batch_idx = nullptr; + + // Local window size + int window_size_left = -1; + int window_size_right = -1; + + bool is_bf16 = false; + bool is_causal = false; + + // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. + // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. + bool is_seqlens_k_cumulative = true; + + bool is_rotary_interleaved = false; + + int num_splits = 0; // For split-KV version + + const cudaDeviceProp* dprops = nullptr; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream); +template +void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); +} \ No newline at end of file diff --git a/operators/cuda/attention_lib/flash_attention/flash_api.cc b/operators/cuda/attention_lib/flash_attention/flash_api.cc new file mode 100644 index 000000000..46812b560 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_api.cc @@ -0,0 +1,465 @@ +#if OCOS_USE_FLASH_ATTENTION + +#include "flash_api.h" +#include "flash.h" +#include "static_switch.h" +#include + +namespace flash { + +void set_params_fprop(Flash_fwd_params& params, + // sizes + size_t batch_size, + size_t seqlen_q, + size_t seqlen_k, + size_t seqlen_q_rounded, + size_t seqlen_k_rounded, + size_t num_heads, + size_t num_heads_k, + size_t head_size, + size_t head_size_rounded, + // device pointers + void* q, + void* k, + void* v, + void* out, + void* cu_seqlens_q_d, + void* cu_seqlens_k_d, + void* p_d, + void* softmax_lse_d, + float softmax_scale, + bool is_causal, + bool is_bf16, + bool kv_bsnh = true, + int window_size_left = -1, + int window_size_right = -1) { + // Set the pointers and strides. + params.q_ptr = q; + params.k_ptr = k; + params.v_ptr = v; + params.o_ptr = out; + + params.is_bf16 = is_bf16; + + // All stride are in elements, not bytes. + if (kv_bsnh) { + params.q_row_stride = num_heads * head_size; + params.k_row_stride = num_heads_k * head_size; + params.v_row_stride = num_heads_k * head_size; + params.q_head_stride = head_size; + params.k_head_stride = head_size; + params.v_head_stride = head_size; + params.o_row_stride = num_heads * head_size; + params.o_head_stride = head_size; + } else { + params.q_row_stride = num_heads * head_size; + params.k_row_stride = head_size; + params.v_row_stride = head_size; + params.q_head_stride = head_size; + params.k_head_stride = seqlen_k * head_size; + params.v_head_stride = seqlen_k * head_size; + params.o_row_stride = num_heads * head_size; + params.o_head_stride = head_size; + } + + if (cu_seqlens_q_d == nullptr) { + params.q_batch_stride = seqlen_q * num_heads * head_size; // stride(0) + params.k_batch_stride = seqlen_k * num_heads_k * head_size; // stride(0) + params.v_batch_stride = seqlen_k * num_heads_k * head_size; // stride(0) + params.o_batch_stride = seqlen_q * num_heads * head_size; // stride(0) + } else { + params.q_batch_stride = 0; + params.k_batch_stride = 0; + params.v_batch_stride = 0; + params.o_batch_stride = 0; + } + + params.cu_seqlens_q = static_cast(cu_seqlens_q_d); + params.cu_seqlens_k = static_cast(cu_seqlens_k_d); + + // P = softmax(QK^T) + params.p_ptr = p_d; + + // Softmax sum + params.softmax_lse_ptr = softmax_lse_d; + + // Set the dimensions. + params.b = batch_size; + params.h = num_heads; + params.h_k = num_heads_k; + params.h_h_k_ratio = num_heads / num_heads_k; + params.seqlen_q = seqlen_q; + params.seqlen_k = seqlen_k; + params.seqlen_q_rounded = seqlen_q_rounded; + params.seqlen_k_rounded = seqlen_k_rounded; + params.d = head_size; + params.d_rounded = head_size_rounded; + + // Set the different scale values. + params.scale_softmax = softmax_scale; + params.scale_softmax_log2 = softmax_scale * M_LOG2E; + + // In our API, causal/unidirectional determines if we only look at prior tokens. However, the flash API seperates + // local and causal, meaning when we have local window size + params.is_causal = is_causal; + if (is_causal && (window_size_left >= 0 || window_size_right != 0)) { + params.is_causal = false; + } + if (window_size_left < 0 && window_size_right >= 0) { + window_size_left = seqlen_k; + } + if (window_size_left >= 0 && window_size_right < 0) { + window_size_right = seqlen_k; + } + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + + params.is_seqlens_k_cumulative = true; +} + +size_t get_softmax_lse_size(int seqlen, int batch_size, int num_heads) { + size_t bytes = sizeof(float) * batch_size * num_heads * seqlen; + return bytes; +} + +size_t get_softmax_lse_accum_size(int num_splits, int batch_size, int num_heads, int seqlen_q) { + size_t bytes = sizeof(float) * num_splits * batch_size * seqlen_q * num_heads; + return bytes; +} + +size_t get_out_accum_size(int num_splits, int batch_size, int num_heads, int seqlen_q, int head_size_rounded) { + size_t bytes = sizeof(float) * num_splits * batch_size * seqlen_q * num_heads * head_size_rounded; + return bytes; +} + +void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream, bool force_split_kernel = false) { + FP16_SWITCH(!params.is_bf16, [&] { + FWD_HEADDIM_SWITCH(params.d, [&] { + if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0 + run_mha_fwd_(params, stream); + } else { + run_mha_fwd_splitkv_dispatch(params, stream); + } + }); + }); +} + +// Find the number of splits that maximizes the occupancy. For example, if we have +// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is +// better than having 3 splits (efficiency = 0.67). However, we also don't want too many +// splits as that would incur more HBM reads/writes. +// So we find the best efficiency, then find the smallest number of splits that gets 85% +// of the best efficiency. +int num_splits_heuristic(int batch_size, int seqlen_q, int seqlen_k, int num_heads, int head_size, int num_SMs, + int max_splits) { + // This needs to match with run_mha_fwd_splitkv_dispatch + const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); + const int num_n_blocks = (seqlen_k + block_n - 1) / block_n; + // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel. + // In any case we don't expect seqlen_q to be larger than 64 for inference. + const int num_m_blocks = (seqlen_q + 64 - 1) / 64; + int batch_nheads_mblocks = batch_size * num_heads * num_m_blocks; + // If we have enough to almost fill the SMs, then just use 1 split + if (batch_nheads_mblocks >= 0.8f * num_SMs) { + return 1; + } + max_splits = std::min({max_splits, num_SMs, num_n_blocks}); + float max_efficiency = 0.f; + std::vector efficiency; + efficiency.reserve(max_splits); + auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; + // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits, + // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks + // (i.e. it's 11 splits anyway). + // So we check if the number of blocks per split is the same as the previous num_splits. + auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) { + return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1); + }; + for (int num_splits = 1; num_splits <= max_splits; num_splits++) { + if (!is_split_eligible(num_splits)) { + efficiency.push_back(0.f); + } else { + float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs; + float eff = n_waves / ceil(n_waves); + // printf("num_splits = %d, eff = %f\n", num_splits, eff); + if (eff > max_efficiency) { + max_efficiency = eff; + } + efficiency.push_back(eff); + } + } + for (int num_splits = 1; num_splits <= max_splits; num_splits++) { + if (!is_split_eligible(num_splits)) { + continue; + } + if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) { + // printf("num_splits chosen = %d\n", num_splits); + return num_splits; + } + } + return 1; +} + +// Returns (num_splits, softmax_lse_accum bytes, out_accum bytes) +std::tuple get_num_splits_and_buffer_sizes(int batch_size, int seqlen_q, int seqlen_k, int num_heads, + int head_size, int num_SMs) { + int max_splits = 128; + // split kv buffers + int num_splits = num_splits_heuristic(batch_size, seqlen_q, seqlen_k, num_heads, head_size, + num_SMs, max_splits); + if (num_splits > 1) { + // softmax_lse_accum buffer + int softmax_lse_accum_bytes = get_softmax_lse_accum_size(num_splits, batch_size, num_heads, seqlen_q); + // out_accum buffer + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, 32); + int out_accum_bytes = get_out_accum_size(num_splits, batch_size, num_heads, seqlen_q, head_size_rounded); + return {num_splits, softmax_lse_accum_bytes, out_accum_bytes}; + } else { + return {0, 0, 0}; + } +} + +OrtStatusPtr mha_fwd(const cudaDeviceProp& dprops, + cudaStream_t stream, + void* q, // batch_size x seqlen_q x num_heads x head_size + void* k, // batch_size x seqlen_k x num_heads_k x head_size + void* v, // batch_size x seqlen_k x num_heads_k x head_size + void* out, // batch_size x seqlen_q x num_heads x head_size + void* softmax_lse, // batch_size x num_heads x seqlen_q + int batch_size, + int num_heads, + int num_heads_k, + int head_size, + int seqlen_q, + int seqlen_k, + float softmax_scale, + bool is_causal, + bool is_bf16, + int num_splits, + void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads + void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded + bool kv_bsnh, + int local_window_size) { + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, 32); + const int seqlen_q_rounded = round_multiple(seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + + Flash_fwd_params params; + set_params_fprop(params, + batch_size, + seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q, k, v, out, + /*cu_seqlens_q*/ nullptr, + /*cu_seqlens_k*/ nullptr, + nullptr, + softmax_lse, + softmax_scale, + is_causal, + is_bf16, + kv_bsnh, + local_window_size, + is_causal ? 0 : -1); + params.dprops = &dprops; + params.knew_ptr = nullptr; + params.vnew_ptr = nullptr; + params.knew_batch_stride = 0; + params.vnew_batch_stride = 0; + params.knew_row_stride = 0; + params.vnew_row_stride = 0; + params.knew_head_stride = 0; + params.vnew_head_stride = 0; + + params.num_splits = num_splits; + if (params.num_splits > 1 && softmax_lse_accum != nullptr && out_accum != nullptr) { + params.softmax_lseaccum_ptr = softmax_lse_accum; + params.oaccum_ptr = out_accum; + } else { + params.softmax_lseaccum_ptr = nullptr; + params.oaccum_ptr = nullptr; + } + + run_mha_fwd(params, stream); + return nullptr; +} + +OrtStatusPtr mha_varlen_fwd(const cudaDeviceProp& dprops, + cudaStream_t stream, + void* q, // half (total_q, num_heads, head_size) + void* k, // half (total_k, num_heads, head_size) + void* v, // half (total_k, num_heads, head_size) + void* out, // half (total_q, num_heads, head_size) + int* cu_seqlens_q, // int (batch_size + 1) + int* cu_seqlens_k, // int (batch_size + 1) + void* softmax_lse, // float (batch_size, num_heads, max_seqlen_q) + int batch_size, + int num_heads, + int num_heads_k, + int head_size, + int max_seqlen_q, + int max_seqlen_k, + float softmax_scale, + bool is_causal, + bool is_bf16) { + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, 32); + const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); + + Flash_fwd_params params; + set_params_fprop(params, + batch_size, + max_seqlen_q, max_seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q, k, v, out, + cu_seqlens_q, + cu_seqlens_k, + nullptr, + softmax_lse, + softmax_scale, + is_causal, + is_bf16, + true, + -1, + is_causal ? 0 : -1); + params.dprops = &dprops; + params.num_splits = 0; + params.softmax_lseaccum_ptr = nullptr; + params.oaccum_ptr = nullptr; + params.knew_ptr = nullptr; + params.vnew_ptr = nullptr; + run_mha_fwd(params, stream); + return nullptr; +} + +bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, int num_heads_k) { + bool is_sm8x = dprops.major == 8 && dprops.minor >= 0; + bool is_sm90 = dprops.major == 9 && dprops.minor == 0; + return (is_sm8x || is_sm90) && (head_size % 8 == 0) && (head_size <= 256) && (num_heads % num_heads_k == 0); +} + +// This API is used when past key and value are present... since cached, these are assumed to have sequence length +// of max_sequence_length, so seqlen_k == max_sequence_length. The actual past sequence length is held in seqlens_k_. +OrtStatusPtr mha_fwd_kvcache(const cudaDeviceProp& dprops, + cudaStream_t stream, + void* q, // batch_size x seqlen_q x num_heads x head_size + void* kcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size + void* vcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size + void* k_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size + void* v_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size + void* out, // batch_size x seqlen_q x num_heads x head_size + void* softmax_lse, // batch_size x num_heads x seqlen_q + void* seqlens_k_, // batch_size + void* rotary_cos, // seqlen_ro x (rotary_dim / 2) + void* rotary_sin, // seqlen_ro x (rotary_dim / 2) + int batch_size, + int num_heads, + int num_heads_k, + int head_size, + int seqlen_q, + int seqlen_k, + int seqlen_k_new, + const float softmax_scale, + bool is_causal, + bool is_bf16, + bool past_bsnh, // otherwise bnsh + int num_splits, + void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads + void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded + int local_window_size, + bool is_rotary_interleaved, + bool is_packed_qkv) { + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, 32); + const int seqlen_q_rounded = round_multiple(seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + + // In kv-cache case, seqlen_k_max as kv sequence length + Flash_fwd_params params; + set_params_fprop(params, + batch_size, + seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q, kcache, vcache, out, + /*cu_seqlens_q_d=*/nullptr, + /*cu_seqlens_k_d=*/nullptr, + /*p_ptr=*/nullptr, + softmax_lse, + softmax_scale, + is_causal, + is_bf16, + past_bsnh, + local_window_size, + is_causal ? 0 : -1); + params.dprops = &dprops; + + if (k_new != nullptr && v_new != nullptr) { + params.seqlen_knew = seqlen_k_new; + params.knew_ptr = k_new; + params.vnew_ptr = v_new; + // All stride are in elements, not bytes. + if (is_packed_qkv) { + params.q_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size); + params.q_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size); + params.knew_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size); + params.vnew_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size); + params.knew_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size); + params.vnew_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size); + } else { + params.knew_batch_stride = seqlen_k_new * num_heads_k * head_size; + params.vnew_batch_stride = seqlen_k_new * num_heads_k * head_size; + params.knew_row_stride = num_heads_k * head_size; + params.vnew_row_stride = num_heads_k * head_size; + } + params.knew_head_stride = head_size; + params.vnew_head_stride = head_size; + } else { + params.seqlen_knew = 0; + params.knew_ptr = nullptr; + params.vnew_ptr = nullptr; + params.knew_batch_stride = 0; + params.vnew_batch_stride = 0; + params.knew_row_stride = 0; + params.vnew_row_stride = 0; + params.knew_head_stride = 0; + params.vnew_head_stride = 0; + } + + params.is_seqlens_k_cumulative = seqlens_k_ == nullptr; + if (seqlens_k_ != nullptr) { + params.cu_seqlens_k = static_cast(seqlens_k_); + } + + if (rotary_cos != nullptr) { + params.rotary_cos_ptr = rotary_cos; + params.rotary_sin_ptr = rotary_sin; + params.is_rotary_interleaved = is_rotary_interleaved; + params.rotary_dim = (head_size / 16) * 16; + } + + params.num_splits = num_splits; + if (params.num_splits > 1 && softmax_lse_accum != nullptr && out_accum != nullptr) { + params.softmax_lseaccum_ptr = softmax_lse_accum; + params.oaccum_ptr = out_accum; + } else { + params.softmax_lseaccum_ptr = nullptr; + params.oaccum_ptr = nullptr; + } + + // Only split kernel supports appending to KV cache + run_mha_fwd(params, stream, /*force_split_kernel=*/k_new != nullptr); + + return nullptr; +} + +} // namespace flash + +#endif // OCOS_USE_FLASH_ATTENTION diff --git a/operators/cuda/attention_lib/flash_attention/flash_api.h b/operators/cuda/attention_lib/flash_attention/flash_api.h new file mode 100644 index 000000000..4ad1b76e1 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_api.h @@ -0,0 +1,92 @@ +#pragma once + +#if OCOS_USE_FLASH_ATTENTION + +#include +#include +#include +#include "onnxruntime_c_api.h" + +namespace flash { + +OrtStatusPtr mha_fwd(const cudaDeviceProp& dprops, + cudaStream_t stream, + void* q, // batch_size x seqlen_q x num_heads x head_size + void* k, // batch_size x seqlen_k x num_heads_k x head_size + void* v, // batch_size x seqlen_k x num_heads_k x head_size + void* out, // batch_size x seqlen_q x num_heads x head_size + void* softmax_lse, // batch_size x num_heads x seqlen_q + int batch_size, + int num_heads, + int num_heads_k, + int head_size, + int seqlen_q, + int seqlen_k, + float softmax_scale, + bool is_causal, + bool is_bf16, + int num_splits = 0, + void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads + void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded + bool kv_bsnh = true, + int local_window_size = -1); + +OrtStatusPtr mha_varlen_fwd(const cudaDeviceProp& dprops, + cudaStream_t stream, + void* q, // half (total_q, num_heads, head_size) + void* k, // half (total_k, num_heads, head_size) + void* v, // half (total_k, num_heads, v_head_size) + void* out, // half (total_q, num_heads, v_head_size) + int* cu_seqlens_q, // int (batch_size + 1) + int* cu_seqlens_k, // int (batch_size + 1) + void* softmax_lse, // float (batch_size, num_heads, max_seqlen_q) + int batch_size, + int num_heads, + int num_heads_k, + int head_size, + int max_seqlen_q, + int max_seqlen_k, + float softmax_scale, + bool is_causal, + bool is_bf16); + +OrtStatusPtr mha_fwd_kvcache(const cudaDeviceProp& dprops, + cudaStream_t stream, + void* q, // batch_size x seqlen_q x num_heads x head_size + void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size + void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size + void* k, // batch_size x seqlen_k_new x num_heads_k x head_size + void* v, // batch_size x seqlen_k_new x num_heads_k x head_size + void* out, // batch_size x seqlen_q x num_heads x head_size + void* softmax_lse, // batch_size x num_heads x seqlen_q + void* seqlens_k_, // batch_size + void* rotary_sin, // seqlen_ro x (rotary_dim / 2) + void* rotary_cos, // seqlen_ro x (rotary_dim / 2) + int batch_size, + int num_heads, + int num_heads_k, + int head_size, + int seqlen_q, + int seqlen_k, + int seqlen_k_new, + const float softmax_scale, + bool is_causal, + bool is_bf16, + bool past_bsnh, // otherwise bnsh + int num_splits = 0, + void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads + void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded + int local_window_size = -1, + bool is_rotary_interleaved = false, + bool is_packed_qkv = false); + +size_t get_softmax_lse_size(int max_seqlen_q, int batch_size, int num_heads); + +std::tuple get_num_splits_and_buffer_sizes(int batch_size, int seqlen_q, int seqlen_k, int num_heads, + int head_size, int num_SMs); + +bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, int num_heads_k); + +} // namespace flash + +#endif // OCOS_USE_FLASH_ATTENTION diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim128_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim128_bf16_sm80.cu new file mode 100644 index 000000000..8a9b32ee8 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim128_bf16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if OCOS_USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template<> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim128_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim128_fp16_sm80.cu new file mode 100644 index 000000000..e643d97ec --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim128_fp16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if OCOS_USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim160_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim160_bf16_sm80.cu new file mode 100644 index 000000000..bd716155d --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim160_bf16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if OCOS_USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template<> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim160(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim160_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim160_fp16_sm80.cu new file mode 100644 index 000000000..2b61a318a --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim160_fp16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if OCOS_USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim160(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim192_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim192_bf16_sm80.cu new file mode 100644 index 000000000..a08bc39d0 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim192_bf16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if OCOS_USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template<> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim192_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim192_fp16_sm80.cu new file mode 100644 index 000000000..d9fab2d60 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim192_fp16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if OCOS_USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim224_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim224_bf16_sm80.cu new file mode 100644 index 000000000..7d6d69378 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim224_bf16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if OCOS_USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template<> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim224(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim224_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim224_fp16_sm80.cu new file mode 100644 index 000000000..83b77523b --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim224_fp16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if OCOS_USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim224(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim256_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim256_bf16_sm80.cu new file mode 100644 index 000000000..954d752fb --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim256_bf16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if OCOS_USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template<> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim256_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim256_fp16_sm80.cu new file mode 100644 index 000000000..80045c2f9 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim256_fp16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if OCOS_USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim32_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim32_bf16_sm80.cu new file mode 100644 index 000000000..e27f7907e --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim32_bf16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if OCOS_USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template<> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim32_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim32_fp16_sm80.cu new file mode 100644 index 000000000..187db09a6 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim32_fp16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if OCOS_USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim64_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim64_bf16_sm80.cu new file mode 100644 index 000000000..34728f342 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim64_bf16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if OCOS_USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template<> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim64_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim64_fp16_sm80.cu new file mode 100644 index 000000000..c62f1342d --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim64_fp16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if OCOS_USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim96_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim96_bf16_sm80.cu new file mode 100644 index 000000000..6a7fb413e --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim96_bf16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if OCOS_USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template<> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim96_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim96_fp16_sm80.cu new file mode 100644 index 000000000..68b751b87 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_hdim96_fp16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if OCOS_USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_kernel.h b/operators/cuda/attention_lib/flash_attention/flash_fwd_kernel.h new file mode 100644 index 000000000..c44a470f6 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_kernel.h @@ -0,0 +1,1259 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +#pragma once + +#if defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-variable" +#pragma GCC diagnostic ignored "-Wunused-but-set-variable" +#endif + +#include +#include +#include + +#include +#include +#include +#include + +#include "block_info.h" +#include "kernel_traits.h" +#include "utils.h" +#include "softmax.h" + +namespace flash { +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void softmax_rescale_o(Tensor0& scores, Tensor1& scores_max, Tensor1& scores_sum, + Tensor2& acc_o, float softmax_scale_log2) { + if (Is_first) { + flash::template reduce_max(scores, scores_max); + flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2); + flash::reduce_sum(scores, scores_sum); + } else { + cute::Tensor scores_max_prev = make_fragment_like(scores_max); + cute::copy(scores_max, scores_max_prev); + flash::template reduce_max(scores, scores_max); + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + cute::Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); +#pragma unroll + for (int mi = 0; mi < cute::size(scores_max); ++mi) { + float scores_max_cur = !Check_inf + ? scores_max(mi) + : (scores_max(mi) == -INFINITY ? 0.0f : scores_max(mi)); + float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); + scores_sum(mi) *= scores_scale; +#pragma unroll + for (int ni = 0; ni < cute::size<1>(acc_o_rowcol); ++ni) { + acc_o_rowcol(mi, ni) *= scores_scale; + } + } + flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2); + cute::Tensor scores_sum_cur = make_fragment_like(scores_sum); + flash::reduce_sum(scores, scores_sum_cur); +#pragma unroll + for (int mi = 0; mi < cute::size(scores_sum); ++mi) { + scores_sum(mi) += scores_sum_cur(mi); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void write_softmax_to_gmem( + cute::Tensor const& tOrP, cute::Tensor& tPgP, TiledCopy gmem_tiled_copy_P) { + // Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N) + cute::Layout l = tOrP.layout(); + cute::Tensor tPrP = make_tensor(tOrP.data(), make_layout(get<0>(l), make_layout(get<1>(l), get<2>(l)))); + CUTE_STATIC_ASSERT_V(cute::size<2>(tPgP) == _1{}); + CUTE_STATIC_ASSERT_V(cute::size<1>(tPrP) == cute::size<1>(tPgP)); +#pragma unroll + for (int mi = 0; mi < cute::size<1>(tPrP); ++mi) { + cute::copy(gmem_tiled_copy_P, tPrP(_, mi), tPgP(_, mi, 0)); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void compute_attn_1rowblock(const Params& params, const int bidb, const int bidh, const int m_block) { + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNWarps = Kernel_traits::kNWarps; + constexpr int MMA_M = kBlockM / decltype(cute::size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value; + + const BlockInfo binfo(params, bidb); + if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return; + + const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); + int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); + if (Is_causal || Is_local) { + n_block_max = std::min(n_block_max, + cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); + // We exit early and write 0 to gO and gLSE. + // Otherwise we might read OOB elements from gK and gV. + if (n_block_max <= n_block_min) { + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), + Shape, Int>{}, + make_stride(params.o_row_stride, _1{})); + Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + Shape>{}, Stride<_1>{}); + + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + Tensor tOrO = make_tensor(shape(tOgO)); + clear(tOrO); + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tOpO); ++k) { + tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; + } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM); +#pragma unroll + for (int m = 0; m < size<1>(tOgO); ++m) { + const int row = get<0>(tOcO(0, m, 0)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { + gLSE(row) = INFINITY; + } + } + return; + } + } + + // We iterate over the blocks in reverse order. This is because the last block is the only one + // that needs masking when we read K and V from global memory. Moreover, iterating in reverse + // might save us 1 register (we just need n_block instead of both n_block and n_block_max). + + const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; + // We move K and V to the last block. + const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; + const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; + cute::Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), + cute::Shape, cute::Int>{}, + make_stride(params.q_row_stride, _1{})); + cute::Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), + cute::Shape, cute::Int>{}, + make_stride(params.k_row_stride, _1{})); + cute::Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), + cute::Shape, cute::Int>{}, + make_stride(params.v_row_stride, _1{})); + cute::Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast(params.p_ptr) + row_offset_p), + cute::Shape, cute::Int>{}, + make_stride(params.seqlen_k_rounded, _1{})); + + cute::Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutQ{}); + // Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem; + cute::Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : cute::size(sQ)), + typename Kernel_traits::SmemLayoutKV{}); + cute::Tensor sV = make_tensor(sK.data() + cute::size(sK), typename Kernel_traits::SmemLayoutKV{}); + cute::Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); + cute::Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); + + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; + auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyP gmem_tiled_copy_P; + auto gmem_thr_copy_P = gmem_tiled_copy_P.get_thread_slice(tidx); + + cute::Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + cute::Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + cute::Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) + cute::Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + cute::Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) + cute::Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + cute::Tensor tPgP = gmem_thr_copy_P.partition_D(gP); + + typename Kernel_traits::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + cute::Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + cute::Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + cute::Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) + + cute::Tensor acc_o = partition_fragment_C(tiled_mma, cute::Shape, cute::Int>{}); // MMA, MMA_M, MMA_K + + // + // Copy Atom retiling + // + + auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); + cute::Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); + + auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); + cute::Tensor tSsK = smem_thr_copy_K.partition_S(sK); + + auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); + auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); + cute::Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); + + // TODO: this might need to change if we change the mma instruction in SM70 + cute::Tensor scores_max = make_tensor(cute::Shape(acc_o)>>{}); + cute::Tensor scores_sum = make_fragment_like(scores_max); + + // + // PREDICATES + // + + // Construct identity layout for sQ and sK + cute::Tensor cQ = make_identity_tensor(make_shape(cute::size<0>(sQ), cute::size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + cute::Tensor cKV = make_identity_tensor(make_shape(cute::size<0>(sK), cute::size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + + // Repeat the partitioning with identity layouts + cute::Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + cute::Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + + // Allocate predicate tensors for k + cute::Tensor tQpQ = make_tensor(make_shape(cute::size<2>(tQsQ))); + cute::Tensor tKVpKV = make_tensor(make_shape(cute::size<2>(tKsK))); + + // Set predicates for k bounds + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < cute::size(tQpQ); ++k) { + tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; + } +#pragma unroll + for (int k = 0; k < cute::size(tKVpKV); ++k) { + tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; + } + } + + // Prologue + + cute::Tensor tQrQ = make_fragment_like(tQgQ); + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - m_block * kBlockM); + if (Kernel_traits::Is_Q_in_regs) { + cute::cp_async_fence(); + } + + if (Kernel_traits::Share_Q_K_smem) { + flash::cp_async_wait<0>(); + __syncthreads(); + cute::Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + CUTE_STATIC_ASSERT_V(cute::size<1>(tSsQ) == cute::size<1>(tSrQ_copy_view)); // M + cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); + __syncthreads(); + } + + int n_block = n_block_max - 1; + // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN); + cute::cp_async_fence(); + + if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) { + flash::cp_async_wait<1>(); + __syncthreads(); + cute::Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + CUTE_STATIC_ASSERT_V(cute::size<1>(tSsQ) == cute::size<1>(tSrQ_copy_view)); // M + cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); + } + + clear(acc_o); + + // For performance reason, we separate out two kinds of iterations: + // those that need masking on S, and those that don't. + // We need masking on S for the very last block when K and V has length not multiple of kBlockN. + // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. + // We will have at least 1 "masking" iteration. + + // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to + // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. + constexpr int n_masking_steps = (!Is_causal && !Is_local) + ? 1 + : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); +#pragma unroll + for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { + cute::Tensor acc_s = partition_fragment_C(tiled_mma, cute::Shape, cute::Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + + // Advance gV + if (masking_step > 0) { + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + } else { + // Clear the smem tiles to account for predicated off loads + flash::copy( + gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN); + } + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K); + // if (cute::thread0()) { print(acc_s); } + + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + cute::Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + + // We don't put the masking before the matmul S = Q K^T because we don't clear sK + // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul + // can produce Inf / NaN. + if (!Is_causal && !Is_local) { + if (!Is_even_MN) { + flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); + } + } else { + // I can't get the stride from idx_row + flash::apply_mask_local(scores, n_block * kBlockN, binfo.actual_seqlen_k, + // m_block * kBlockM + get<0>(idx_row(0)), + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right); + } + + flash::cp_async_wait<0>(); + __syncthreads(); + if (n_block > n_block_min) { + // Advance gK + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // TODO: when we have key_padding_mask we'll need to Check_inf + masking_step == 0 + ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) + : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + + // Convert scores from fp32 to fp16/bf16 + cute::Tensor rP = flash::convert_type(scores); + // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. + cute::Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); + // if (Return_softmax) { + // cute::Tensor tOrP_copy = make_fragment_like(tOrP); + // copy(tOrP, tOrP_copy); + // flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P); + // tPgP.data() = tPgP.data() + (-kBlockN); + // } + + flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + + // This check is at the end of the loop since we always have at least 1 iteration + if (n_masking_steps > 1 && n_block <= n_block_min) { + --n_block; + break; + } + } + + // These are the iterations where we don't need masking on S + for (; n_block >= n_block_min; --n_block) { + cute::Tensor acc_s = partition_fragment_C(tiled_mma, cute::Shape, cute::Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + // Advance gV + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K); + + flash::cp_async_wait<0>(); + __syncthreads(); + if (n_block > n_block_min) { + // Advance gK + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { + flash::apply_mask_local( + scores, n_block * kBlockN, binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right); + } + softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + + cute::Tensor rP = flash::convert_type(scores); + // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. + cute::Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); + // if (Return_softmax) { + // cute::Tensor tOrP_copy = make_fragment_like(tOrP); + // copy(tOrP, tOrP_copy); + // flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P); + // tPgP.data() = tPgP.data() + (-kBlockN); + // } + + flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + } + + // Epilogue + + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + cute::Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + cute::Tensor lse = make_fragment_like(scores_sum); +#pragma unroll + for (int mi = 0; mi < cute::size<0>(acc_o_rowcol); ++mi) { + float sum = scores_sum(mi); + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; + lse(mi) = (sum == 0.f || sum != sum) ? INFINITY : scores_max(mi) * params.scale_softmax + __logf(sum); + float scale = inv_sum; +#pragma unroll + for (int ni = 0; ni < cute::size<1>(acc_o_rowcol); ++ni) { + acc_o_rowcol(mi, ni) *= scale; + } + } + + // Convert acc_o from fp32 to fp16/bf16 + cute::Tensor rO = flash::convert_type(acc_o); + cute::Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx); // auto smem_thr_copy_O = make_tiled_copy_C_warpcontiguousM(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx); + cute::Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + cute::Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // sO has the same size as sQ, so we don't need to sync here. + if (Kernel_traits::Share_Q_K_smem) { + __syncthreads(); + } + + cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); + + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + cute::Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), + cute::Shape, cute::Int>{}, + make_stride(params.o_row_stride, _1{})); + cute::Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + cute::Shape>{}, cute::Stride<_1>{}); + + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + cute::Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) + cute::Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + + __syncthreads(); + + cute::Tensor tOrO = make_tensor(cute::shape(tOgO)); + cute::copy(gmem_tiled_copy_O, tOsO, tOrO); + + cute::Tensor caccO = make_identity_tensor(cute::Shape, cute::Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + cute::Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + static_assert(decltype(cute::size<0>(taccOcO))::value == 4); + // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. + cute::Tensor taccOcO_row = logical_divide(taccOcO, cute::Shape<_2>{})(make_coord(0, _), _, 0); + CUTE_STATIC_ASSERT_V(cute::size(lse) == cute::size(taccOcO_row)); // MMA_M + if (get<1>(taccOcO_row(0)) == 0) { +#pragma unroll + for (int mi = 0; mi < cute::size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM) { + gLSE(row) = lse(mi); + } + } + } + + // Construct identity layout for sO + cute::Tensor cO = make_identity_tensor(make_shape(cute::size<0>(sO), cute::size<1>(sO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + cute::Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + cute::Tensor tOpO = make_tensor(make_shape(cute::size<2>(tOgO))); + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < cute::size(tOpO); ++k) { + tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; + } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) { + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNWarps = Kernel_traits::kNWarps; + + using GmemTiledCopyO = std::conditional_t< + !Split, + typename Kernel_traits::GmemTiledCopyOaccum, + typename Kernel_traits::GmemTiledCopyO>; + using ElementO = std::conditional_t; + + const BlockInfo binfo(params, bidb); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d, actual_seqlen_k = %d\n", Is_even_MN, params.is_seqlens_k_cumulative, binfo.seqlen_k_cache, binfo.actual_seqlen_k); } + // if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); } + if (m_block * kBlockM >= binfo.actual_seqlen_q) return; + + const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits; + const int n_block_min = !Is_local + ? n_split_idx * n_blocks_per_split + : std::max(n_split_idx * n_blocks_per_split, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); + int n_block_max = std::min(cute::ceil_div(binfo.actual_seqlen_k, kBlockN), (n_split_idx + 1) * n_blocks_per_split); + if (Is_causal || Is_local) { + n_block_max = std::min(n_block_max, + cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); + } + if (n_block_min >= n_block_max) { // This also covers the case where n_block_max <= 0 + // We exit early and write 0 to gOaccum and -inf to gLSEaccum. + // Otherwise we might read OOB elements from gK and gV, + // or get wrong results when we combine gOaccum from different blocks. + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_rounded; + const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), + Shape, Int>{}, + make_stride(Split ? kHeadDim : params.o_row_stride, _1{})); + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum), + Shape>{}, Stride<_1>{}); + + GmemTiledCopyO gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + clear(tOrOaccum); + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(gOaccum), size<1>(gOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); + Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tOpO); ++k) { + tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; + } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM); +#pragma unroll + for (int m = 0; m < size<1>(tOgOaccum); ++m) { + const int row = get<0>(tOcO(0, m, 0)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { + gLSEaccum(row) = Split ? -INFINITY : INFINITY; + } + } + return; + } + + // We iterate over the blocks in reverse order. This is because the last block is the only one + // that needs masking when we read K and V from global memory. Moreover, iterating in reverse + // might save us 1 register (we just need n_block instead of both n_block and n_block_max). + + const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; + // We move K and V to the last block. + const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb]; + const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; + const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + + Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), + Shape, Int>{}, + make_stride(params.q_row_stride, _1{})); + Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), + Shape, Int>{}, + make_stride(params.k_row_stride, _1{})); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k, gK.data()); } + Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), + Shape, Int>{}, + make_stride(params.v_row_stride, _1{})); + + Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutQ{}); + Tensor sK = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutKV{}); + Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); + Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); + Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); + + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; + auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); + + Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) + Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) + Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + + typename Kernel_traits::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) + + Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K + + // + // Copy Atom retiling + // + + auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); + Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); + + auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); + Tensor tSsK = smem_thr_copy_K.partition_S(sK); + + auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); + auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); + Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); + + // TODO: this might need to change if we change the mma instruction in SM70 + Tensor scores_max = make_tensor(Shape(acc_o)>>{}); + Tensor scores_sum = make_fragment_like(scores_max); + + // + // PREDICATES + // + + // // Allocate predicate tensors for m and n + // Tensor tQpQ = make_tensor(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{}); + // Tensor tKVpKV = make_tensor(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{}); + + // Construct identity layout for sQ and sK + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + + // Repeat the partitioning with identity layouts + Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + + // Allocate predicate tensors for k + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); + + // Set predicates for k bounds + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tQpQ); ++k) { + tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; + } +#pragma unroll + for (int k = 0; k < size(tKVpKV); ++k) { + tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; + } + } + + // Prologue + // Copy from Knew to K, optionally apply rotary embedding. + typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary; + auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont; + auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx); + if constexpr (Append_KV) { + // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to + // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe. + // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache. + const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2); + Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); + Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); + Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); + Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); + // if (cute::thread(0, 0)) { printf("rotary_cos_ptr = %p, gCos.data() = %p, tRgCos.data() = %p, rotary_dim = %d\n", params.rotary_cos_ptr, gCos.data(), tRgCos.data(), params.rotary_dim); } + // if (cute::thread(8, 0)) { print_tensor(gCos); } + // if (cute::thread(0, 0)) { print_tensor(tRgCos); } + + const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride; + const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride; + // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them, + // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64]. + // This maps to accessing the first 64 rows of knew_ptr. + Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.knew_ptr) + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride), + Shape, Int>{}, + make_stride(params.knew_row_stride, _1{})); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); } + Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.vnew_ptr) + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride), + Shape, Int>{}, + make_stride(params.vnew_row_stride, _1{})); + Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) + Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) + + const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN); + for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) { + flash::copy_w_min_idx( + tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); + if (params.rotary_dim == 0) { + flash::copy_w_min_idx( + tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); + } else { + if (params.is_rotary_interleaved) { + // Don't clear OOB_K because we're writing to global memory + flash::copy_rotary_interleaved( + tKgKnew, tKgK, tRgCos, tRgSin, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, + binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim); + tRgCos.data() = tRgCos.data() + (-int(kBlockN * params.rotary_dim / 2)); + tRgSin.data() = tRgSin.data() + (-int(kBlockN * params.rotary_dim / 2)); + } else { + // Don't clear OOB_K because we're writing to global memory + flash::copy_rotary_contiguous( + tKgKnew, tKgK, tRgCosCont, tRgSinCont, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, + binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim); + tRgCosCont.data() = tRgCosCont.data() + (-int(kBlockN * params.rotary_dim / 2)); + tRgSinCont.data() = tRgSinCont.data() + (-int(kBlockN * params.rotary_dim / 2)); + } + } + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); + } + // Need this before we can read in K again, so that we'll see the updated K values. + __syncthreads(); + if (n_block_max > n_block_copy_min) { + tKgK.data() = tKgK.data() + (n_block_max - n_block_copy_min) * kBlockN * params.k_row_stride; + tVgV.data() = tVgV.data() + (n_block_max - n_block_copy_min) * kBlockN * params.v_row_stride; + } + } + + // Read Q from gmem to smem, optionally apply rotary embedding. + Tensor tQrQ = make_fragment_like(tQgQ); + if (!Append_KV || params.rotary_dim == 0) { + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - m_block * kBlockM); + } else { + const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); + // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache. + // We do this by setting the row stride of gCos / gSin to 0. + Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); + Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); + Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); + Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); + if (params.is_rotary_interleaved) { + flash::copy_rotary_interleaved( + tQgQ, tQsQ, tRgCos, tRgSin, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, + 0, params.d, params.rotary_dim); + } else { + flash::copy_rotary_contiguous( + tQgQ, tQsQ, tRgCosCont, tRgSinCont, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, + 0, params.d, params.rotary_dim); + } + } + + int n_block = n_block_max - 1; + // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN); + cute::cp_async_fence(); + + // flash::cp_async_wait<0>(); + // __syncthreads(); + // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tKsK); } + // __syncthreads(); + + clear(acc_o); + + // For performance reason, we separate out two kinds of iterations: + // those that need masking on S, and those that don't. + // We need masking on S for the very last block when K and V has length not multiple of kBlockN. + // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. + // We will have at least 1 "masking" iteration. + + // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to + // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. + constexpr int n_masking_steps = (!Is_causal && !Is_local) + ? 1 + : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); +#pragma unroll + for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + + // Advance gV + if (masking_step > 0) { + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + } else { + // Clear the smem tiles to account for predicated off loads + flash::copy( + gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN); + } + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K); + // if (cute::thread0()) { print(acc_s); } + + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + // if (cute::thread0()) { print(scores); } + // We don't put the masking before the matmul S = Q K^T because we don't clear sK + // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul + // can produce Inf / NaN. + if (!Is_causal && !Is_local) { + if (!Is_even_MN) { + flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); + } + } else { + flash::apply_mask_local(scores, n_block * kBlockN, binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right); + } + + flash::cp_async_wait<0>(); + __syncthreads(); + // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); } + // __syncthreads(); + + if (n_block > n_block_min) { + // Advance gK + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // We have key_padding_mask so we'll need to Check_inf + masking_step == 0 + ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) + : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); } + + // Convert scores from fp32 to fp16/bf16 + Tensor rP = flash::convert_type(scores); + // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); + + flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + // if (cute::thread0()) { print(scores); } + + // This check is at the end of the loop since we always have at least 1 iteration + if (n_masking_steps > 1 && n_block <= n_block_min) { + --n_block; + break; + } + } + + // These are the iterations where we don't need masking on S + for (; n_block >= n_block_min; --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + // Advance gV + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K); + + flash::cp_async_wait<0>(); + __syncthreads(); + if (n_block > n_block_min) { + // Advance gK + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { + flash::apply_mask_local( + scores, n_block * kBlockN, binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right); + } + softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + + Tensor rP = flash::convert_type(scores); + // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); + + flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + } + + // Epilogue + + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + // if (cute::thread0()) { print(acc_o_rowcol); } + Tensor lse = make_fragment_like(scores_sum); +#pragma unroll + for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { + float sum = scores_sum(mi); + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; + lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : scores_max(mi) * params.scale_softmax + __logf(sum); + float scale = inv_sum; +#pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { + acc_o_rowcol(mi, ni) *= scale; + } + } + // if (cute::thread0()) { print(lse); } + // if (cute::thread0()) { print(acc_o_rowcol); } + + Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + using SmemTiledCopyO = std::conditional_t< + !Split, + typename Kernel_traits::SmemCopyAtomO, + typename Kernel_traits::SmemCopyAtomOaccum>; + auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma); + auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor rO = flash::convert_type(acc_o); + Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // sOaccum is larger than sQ, so we need to syncthreads here + // TODO: allocate enough smem for sOaccum + if constexpr (Split) { + __syncthreads(); + } + + cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); + + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_rounded; + const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), + Shape, Int>{}, + make_stride(Split ? kHeadDim : params.o_row_stride, _1{})); + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum), + Shape>{}, Stride<_1>{}); + // if (tidx == 0) { printf("row_offset_o = %d, bidh = %d, gOaccum = %p\n", row_offset_o, bidh, gOaccum.data()); } + + GmemTiledCopyO gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); + + __syncthreads(); + + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum); + + Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + static_assert(decltype(size<0>(taccOcO))::value == 4); + // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. + Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + if (get<1>(taccOcO_row(0)) == 0) { +#pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM) { + gLSEaccum(row) = lse(mi); + } + } + } + + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tOpO); ++k) { + tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; + } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM); + // __syncthreads(); + // if (cute::thread0()) { print(tOgOaccum); } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void compute_attn(const Params& params) { + const int m_block = blockIdx.x; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.z; + + // We want the fwd and bwd to generate the same dropout pattern (RNG), without restricting + // them to have the same number of threads or have to traverse the attention matrix + // in the same order. + // In the Philox RNG, we use the offset to store the batch, head, and the lane id + // (within a warp). We use the subsequence to store the location of the 16 x 32 blocks within + // the attention matrix. This way, as long as we have the batch, head, and the location of + // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. + + flash::compute_attn_1rowblock(params, bidb, bidh, m_block); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void compute_attn_splitkv(const Params& params) { + const int m_block = blockIdx.x; + // The block index for the batch. + const int bidb = Split ? blockIdx.z / params.h : blockIdx.y; + // The block index for the head. + const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z; + const int n_split_idx = Split ? blockIdx.y : 0; + const int num_n_splits = Split ? gridDim.y : 1; + flash::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void combine_attn_seqk_parallel(const Params& params) { + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + constexpr int kMaxSplits = 1 << Log_max_splits; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNThreads = Kernel_traits::kNThreads; + + static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128"); + static_assert(kBlockM == 4 || kBlockM == 8 || kBlockM == 16 || kBlockM == 32, "kBlockM must be 4, 8, 16 or 32"); + static_assert(kNThreads == 128, "We assume that each block has 128 threads"); + + // Shared memory. + // kBlockM + 1 instead of kBlockM to reduce bank conflicts. + __shared__ ElementAccum sLSE[kMaxSplits][kBlockM + 1]; + + // The thread and block index. + const int tidx = threadIdx.x; + const int bidx = blockIdx.x; + + const index_t row_offset_lse = bidx * kBlockM; + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + row_offset_lse), + Shape, Int>{}, + make_stride(params.b * params.h * params.seqlen_q, _1{})); + Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + Shape>{}, Stride<_1>{}); + constexpr int kNLsePerThread = (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads; + + // Read the LSE values from gmem and store them in shared memory, then tranpose them. + constexpr int kRowsPerLoadLSE = kNThreads / kBlockM; +#pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadLSE + tidx / kBlockM; + const int col = tidx % kBlockM; + ElementAccum lse = (row < params.num_splits && col < params.b * params.h * params.seqlen_q - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY; + if (row < kMaxSplits) { + sLSE[row][col] = lse; + } + // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); } + } + // if (bidx == 1 && tidx < 32) { printf("tidx = %d, row_offset_lse = %d, lse = %f\n", tidx, row_offset_lse, lse_accum(0)); } + __syncthreads(); + Tensor lse_accum = make_tensor(Shape>{}); + constexpr int kRowsPerLoadTranspose = std::min(kRowsPerLoadLSE, kMaxSplits); + // To make sure that kMaxSplits is within 1 warp: we decide how many elements within kMaxSplits + // each thread should hold. If kMaxSplits = 16, then each thread holds 2 elements (128 threads, + // 16 rows, so each time we load we can load 8 rows). + // constexpr int kThreadsPerSplit = kMaxSplits / kRowsPerLoadTranspose; + // static_assert(kThreadsPerSplit <= 32); + static_assert(kRowsPerLoadTranspose <= 32); + static_assert(kNLsePerThread * kRowsPerLoadTranspose <= kMaxSplits); +#pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; + const int col = tidx / kRowsPerLoadTranspose; + lse_accum(l) = (row < kMaxSplits && col < kBlockM) ? sLSE[row][col] : -INFINITY; + // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); } + } + + // Compute the logsumexp of the LSE along the split dimension. + ElementAccum lse_max = lse_accum(0); +#pragma unroll + for (int l = 1; l < kNLsePerThread; ++l) { + lse_max = max(lse_max, lse_accum(l)); + } + MaxOp max_op; + lse_max = Allreduce::run(lse_max, max_op); + lse_max = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf + float lse_sum = expf(lse_accum(0) - lse_max); +#pragma unroll + for (int l = 1; l < kNLsePerThread; ++l) { + lse_sum += expf(lse_accum(l) - lse_max); + } + SumOp sum_op; + lse_sum = Allreduce::run(lse_sum, sum_op); + // For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise + // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum. + ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max; + // if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); } + if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) { + gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum; + } +// Store the scales exp(lse - lse_logsum) in shared memory. +#pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; + const int col = tidx / kRowsPerLoadTranspose; + if (row < params.num_splits && col < kBlockM) { + sLSE[row][col] = expf(lse_accum(l) - lse_logsum); + } + } + __syncthreads(); + + const index_t row_offset_oaccum = bidx * kBlockM * params.d_rounded; + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum), + Shape, Int>{}, + Stride, _1>{}); + constexpr int kBlockN = kNThreads / kBlockM; + using GmemLayoutAtomOaccum = Layout, Int>, Stride, _1>>; + using GmemTiledCopyOaccum = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomOaccum{}, + Layout>{})); // Val layout, 4 vals per store + GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum); + Tensor tOrO = make_tensor(shape(tOgOaccum)); + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + clear(tOrO); + + // Predicates + Tensor cOaccum = make_identity_tensor(Shape, Int>{}); + // Repeat the partitioning with identity layouts + Tensor tOcOaccum = gmem_thr_copy_Oaccum.partition_S(cOaccum); + Tensor tOpOaccum = make_tensor(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tOpOaccum); ++k) { + tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d; + } + } + // Load Oaccum in then scale and accumulate to O + for (int split = 0; split < params.num_splits; ++split) { + flash::copy( + gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * params.seqlen_q - bidx * kBlockM); +#pragma unroll + for (int m = 0; m < size<1>(tOrOaccum); ++m) { + int row = get<0>(tOcOaccum(0, m, 0)); + ElementAccum lse_scale = sLSE[split][row]; +#pragma unroll + for (int k = 0; k < size<2>(tOrOaccum); ++k) { +#pragma unroll + for (int i = 0; i < size<0>(tOrOaccum); ++i) { + tOrO(i, m, k) += lse_scale * tOrOaccum(i, m, k); + } + } + // if (cute::thread0()) { printf("lse_scale = %f, %f\n", sLSE[split][0], sLSE[split][1]); print(tOrOaccum); print(tOrO); } + } + tOgOaccum.data() = tOgOaccum.data() + params.b * params.h * params.seqlen_q * params.d_rounded; + } + // if (cute::thread0()) { print(tOrO); } + + Tensor rO = flash::convert_type(tOrO); +// Write to gO +#pragma unroll + for (int m = 0; m < size<1>(rO); ++m) { + const int idx = bidx * kBlockM + get<0>(tOcOaccum(0, m, 0)); + if (idx < params.b * params.h * params.seqlen_q) { + const int batch_idx = idx / (params.h * params.seqlen_q); + const int head_idx = (idx - batch_idx * (params.h * params.seqlen_q)) / params.seqlen_q; + // The index to the rows of Q + const int row = idx - batch_idx * (params.h * params.seqlen_q) - head_idx * params.seqlen_q; + auto o_ptr = reinterpret_cast(params.o_ptr) + batch_idx * params.o_batch_stride + head_idx * params.o_head_stride + row * params.o_row_stride; +#pragma unroll + for (int k = 0; k < size<2>(rO); ++k) { + if (Is_even_K || tOpOaccum(k)) { + const int col = get<1>(tOcOaccum(0, m, k)); + Tensor gO = make_tensor(make_gmem_ptr(o_ptr + col), + Shape(rO))::value>>{}, Stride<_1>{}); + // TODO: Should check if this is using vectorized store, but it seems pretty fast + copy(rO(_, m, k), gO); + // if (bidx == 0 && tidx == 0) { printf("tidx = %d, idx = %d, batch_idx = %d, head_idx = %d, row = %d, col = %d\n", tidx, idx, batch_idx, head_idx, row, col); print(rO(_, m, k)); print(gO); } + // reinterpret_cast(o_ptr)[col / 4] = recast(rO)(0, m, k); + } + } + } + } +} + +} // namespace flash + +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_launch_template.h b/operators/cuda/attention_lib/flash_attention/flash_fwd_launch_template.h new file mode 100644 index 000000000..e2f2505a7 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_launch_template.h @@ -0,0 +1,294 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +#pragma once + +#include "static_switch.h" +#include "flash.h" +#include "flash_fwd_kernel.h" + +namespace flash { + +template +__global__ void flash_fwd_kernel(Flash_fwd_params params) { + static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + flash::compute_attn(params); +#else + (void)params; +#endif +} + +template +__global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + flash::compute_attn_splitkv(params); +#else + (void)params; +#endif +} + +template +__global__ void flash_fwd_splitkv_combine_kernel(Flash_fwd_params params) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + static_assert(Log_max_splits >= 1); + flash::combine_attn_seqk_parallel(params); +#else + (void)params; +#endif +} + +template +void run_flash_fwd(Flash_fwd_params& params, cudaStream_t stream) { + constexpr size_t smem_size = Kernel_traits::kSmemSize; + + // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. + // https://github.com/kokkos/kokkos-kernels/issues/349 + // https://github.com/HazyResearch/flash-attention/issues/21 + + const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; + dim3 grid(num_m_block, params.b, params.h); + const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; + const bool is_even_K = params.d == Kernel_traits::kHeadDim; + BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { + BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { + BOOL_SWITCH(params.window_size_left >= 0 || params.window_size_right >= 0, Is_local, [&] { + // Will only return softmax if dropout, to reduce compilation time. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + // If Is_local, set Is_causal to false + auto kernel = &flash_fwd_kernel < Kernel_traits, Is_causal && !Is_local, Is_local, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, false > ; + // auto kernel = &flash_fwd_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + // ORT_ENFORCE(cudaFuncSetAttribute( + // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + // int ctas_per_sm; + // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + kernel<<>>(params); + }); + }); + }); +} + +template +void run_flash_splitkv_fwd(Flash_fwd_params& params, cudaStream_t stream) { + static_assert(!Kernel_traits::Is_Q_in_regs, "SplitKV implementation does not support Is_Q_in_regs"); + static_assert(!Kernel_traits::Share_Q_K_smem, "SplitKV implementation does not support Share_Q_K_smem"); + constexpr size_t smem_size = Kernel_traits::kSmemSize; + const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; + dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.b, params.num_splits > 1 ? params.b * params.h : params.h); + const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; + const bool is_even_K = params.d == Kernel_traits::kHeadDim; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { + BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { + BOOL_SWITCH(params.window_size_left >= 0 || params.window_size_right >= 0, Is_local, [&] { + BOOL_SWITCH(params.num_splits > 1, Split, [&] { + BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { + // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. + // printf("About to launch, Split = %d, Append_KV = %d, knew_ptr = %p\n", Split, Append_KV, params.knew_ptr); + auto kernel = &flash_fwd_splitkv_kernel < Kernel_traits, Is_causal && !Is_local, Is_local, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Split, Append_KV > ; + // auto kernel = &flash_fwd_splitkv_kernel; + // auto kernel = &flash_fwd_splitkv_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + } + kernel<<>>(params); + }); + }); + }); + }); + }); + }); + if (params.num_splits > 1) { + // We want kBlockM to be as small as possible for more parallelism. + // With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4. + // If headdim is divisible by 64, then we set kBlockM = 8, etc. + constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16); + dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM); + BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { + if (params.num_splits <= 2) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 4) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 8) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 16) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 32) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 64) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 128) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } + }); + } +} + +template +void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream) { + constexpr int kBlockM = 64; // Fixed for all head dimensions + constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); + run_flash_splitkv_fwd>(params, stream); +} + +template +void run_mha_fwd_hdim32(Flash_fwd_params& params, cudaStream_t stream) { + constexpr static int Headdim = 32; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + run_flash_fwd, Is_causal>(params, stream); + }); +} + +template +void run_mha_fwd_hdim64(Flash_fwd_params& params, cudaStream_t stream) { + constexpr static int Headdim = 64; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower + // Using block size (64 x 256) is 27% slower for seqlen=2k + // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling + run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + }); +} + +template +void run_mha_fwd_hdim96(Flash_fwd_params& params, cudaStream_t stream) { + constexpr int Headdim = 96; + const bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + if (is_sm8x) { + if constexpr (!Is_causal) { + run_flash_fwd, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_causal>(params, stream); + } + } else { + run_flash_fwd, Is_causal>(params, stream); + } + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // These two are always slower + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + }); +} + +template +void run_mha_fwd_hdim128(Flash_fwd_params& params, cudaStream_t stream) { + constexpr static int Headdim = 128; + bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM. + if (is_sm8x) { + if constexpr (!Is_causal) { + run_flash_fwd, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_causal>(params, stream); + } + } else { + run_flash_fwd, Is_causal>(params, stream); + } + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // 1st ones are good for H100, A100 + // 2nd one is good for A6000 bc we get slightly better occupancy + }); +} + +template +void run_mha_fwd_hdim160(Flash_fwd_params& params, cudaStream_t stream) { + constexpr static int Headdim = 160; + bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // For A100, H100, 128 x 32 is the fastest. + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + // and 128 x 64 with 8 warps is the fastest for non-causal. + if (is_sm8x) { + if constexpr (!Is_causal) { + run_flash_fwd, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_causal>(params, stream); + } + } else { + run_flash_fwd, Is_causal>(params, stream); + } + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + }); +} + +template +void run_mha_fwd_hdim192(Flash_fwd_params& params, cudaStream_t stream) { + constexpr int Headdim = 192; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + }); +} + +template +void run_mha_fwd_hdim224(Flash_fwd_params& params, cudaStream_t stream) { + constexpr static int Headdim = 224; + int max_smem_per_block = params.dprops->sharedMemPerBlockOptin; + // printf("max_smem_per_block = %d\n", max_smem_per_block); + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB + run_flash_fwd, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_causal>(params, stream); + } + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // We can't do 128 x 32 with 8 warps because with headdim 224, kBlockKSmem = 32. + // If we have N = 32, there are only 1024 elements to load at once, where each load + // is 8 elements. This means we can only use 128 threads and not 256 threads. + // run_flash_fwd, Is_causal>(params, stream); + }); +} + +template +void run_mha_fwd_hdim256(Flash_fwd_params& params, cudaStream_t stream) { + constexpr static int Headdim = 256; + size_t max_smem_per_sm = params.dprops->sharedMemPerMultiprocessor; + size_t max_smem_per_block = params.dprops->sharedMemPerBlockOptin; + // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block); + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // For A100, we want to run with 128 x 64 (128KB smem). + // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM. + if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) { + run_flash_fwd, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_causal>(params, stream); + } + // 64 KB + // run_flash_fwd, Is_causal>(params, stream); + // 96 KB + // run_flash_fwd, Is_causal>(params, stream); + }); +} + +} // namespace flash diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim128_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim128_bf16_sm80.cu new file mode 100644 index 000000000..f818e584d --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim128_bf16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if OCOS_USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim128_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim128_fp16_sm80.cu new file mode 100644 index 000000000..d7db5839d --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim128_fp16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if OCOS_USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim160_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim160_bf16_sm80.cu new file mode 100644 index 000000000..855863403 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim160_bf16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if OCOS_USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim160_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim160_fp16_sm80.cu new file mode 100644 index 000000000..86057f9a7 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim160_fp16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if OCOS_USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim192_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim192_bf16_sm80.cu new file mode 100644 index 000000000..b6aec378c --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim192_bf16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if OCOS_USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim192_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim192_fp16_sm80.cu new file mode 100644 index 000000000..0ba212ef7 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim192_fp16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if OCOS_USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim224_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim224_bf16_sm80.cu new file mode 100644 index 000000000..1fd816e2c --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim224_bf16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if OCOS_USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim224_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim224_fp16_sm80.cu new file mode 100644 index 000000000..98740bded --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim224_fp16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if OCOS_USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim256_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim256_bf16_sm80.cu new file mode 100644 index 000000000..8982c23ab --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim256_bf16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if OCOS_USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim256_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim256_fp16_sm80.cu new file mode 100644 index 000000000..ab1eec23d --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim256_fp16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if OCOS_USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim32_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim32_bf16_sm80.cu new file mode 100644 index 000000000..2cb3a4d3d --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim32_bf16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if OCOS_USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim32_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim32_fp16_sm80.cu new file mode 100644 index 000000000..e02735b11 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim32_fp16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if OCOS_USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim64_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim64_bf16_sm80.cu new file mode 100644 index 000000000..721fdb9f7 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim64_bf16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if OCOS_USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim64_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim64_fp16_sm80.cu new file mode 100644 index 000000000..37830f2f5 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim64_fp16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if OCOS_USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim96_bf16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim96_bf16_sm80.cu new file mode 100644 index 000000000..e90f2540e --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim96_bf16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if OCOS_USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim96_fp16_sm80.cu b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim96_fp16_sm80.cu new file mode 100644 index 000000000..394cca7c8 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_split_hdim96_fp16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if OCOS_USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/cuda/attention_lib/flash_attention/kernel_traits.h b/operators/cuda/attention_lib/flash_attention/kernel_traits.h new file mode 100644 index 000000000..48e899c2a --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/kernel_traits.h @@ -0,0 +1,367 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +#pragma once + +#include + +#include +#include +#include + +using namespace cute; + +namespace flash { + +template +struct Flash_kernel_traits { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using Element = elem_type; + static constexpr bool Has_cp_async = true; +#else + using Element = cutlass::half_t; + static constexpr bool Has_cp_async = false; +#endif + + using ElementAccum = float; + using index_t = uint32_t; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using MMA_Atom_Arch = std::conditional_t< + std::is_same_v, + MMA_Atom, + MMA_Atom>; + using ValLayoutMNK = cute::Layout>; +#else + using MMA_Atom_Arch = MMA_Atom; + using ValLayoutMNK = cute::Layout>; +#endif + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#else + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#endif +}; + +// If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true +template > +struct Flash_fwd_kernel_traits : public Base { + using Element = typename Base::Element; + using ElementAccum = typename Base::ElementAccum; + using index_t = typename Base::index_t; + static constexpr bool Has_cp_async = Base::Has_cp_async; + using SmemCopyAtom = typename Base::SmemCopyAtom; + using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; + + static constexpr bool Share_Q_K_smem = Share_Q_K_smem_; + static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem; + + // The number of threads. + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + using TiledMma = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout, _1, _1>>, // 4x1x1 or 8x1x1 thread group + typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + + using SmemLayoutAtomQ = decltype(composition(Swizzle{}, + // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 + Layout>, + Stride, _1>>{})); + using SmemLayoutQ = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + using SmemLayoutKV = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + // This has to be kBlockN and not 8, otherwise we get wrong results for d=128 + using SmemLayoutAtomVtransposedNoSwizzle = Layout, Int>, + Stride<_1, Int>>; + using SmemLayoutAtomVtransposed = decltype(composition(Swizzle{}, SmemLayoutAtomVtransposedNoSwizzle{})); + using SmemLayoutVtransposed = decltype(tile_to_shape( + SmemLayoutAtomVtransposed{}, + Shape, Int>{})); + // Maybe the VtransposeNoSwizzle just needs to have the right shape + // And the strides don't matter? + using SmemLayoutVtransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomVtransposedNoSwizzle{}, + Shape, Int>{})); + // using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn()); + + using SmemLayoutAtomO = decltype(composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutO = decltype(tile_to_shape( + SmemLayoutAtomO{}, + Shape, Int>{})); + using SmemCopyAtomO = Copy_Atom; + using SmemCopyAtomOaccum = Copy_Atom; + + static constexpr int kSmemQCount = cute::size(SmemLayoutQ{}); + static constexpr int kSmemKVCount = cute::size(SmemLayoutKV{}) * 2; + static constexpr int kSmemQSize = kSmemQCount * sizeof(Element); + static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); + static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts. + // For example, for d=128, smem is split into 2 "pages", each page takes care of columns + // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem, + // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page, + // to the same banks. + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = cute::Layout, cute::Int>, + cute::Stride, _1>>; + + // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock. This is slightly faster. + using Gmem_copy_struct = std::conditional_t< + Has_cp_async, + SM80_CP_ASYNC_CACHEGLOBAL, + DefaultCopy>; + using GmemTiledCopyQKV = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + cute::Layout>{})); // Val layout, 8 vals per read + using GmemTiledCopyO = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + cute::Layout>{})); // Val layout, 8 vals per store + static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP"); + using GmemLayoutAtomP = cute::Layout, cute::Int>, + cute::Stride, _1>>; + + using GmemTiledCopyP = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomP{}, + cute::Layout>{})); // Val layout, 8 vals per store + + using GmemLayoutAtomOaccum = std::conditional_t< + kBlockKSmem == 32, + cute::Layout, // Thread layout, 8 threads per row + cute::Stride<_8, _1>>, + cute::Layout, // Thread layout, 16 threads per row + cute::Stride<_16, _1>>>; + using GmemTiledCopyOaccum = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomOaccum{}, + Layout>{})); // Val layout, 4 vals per store + using GmemLayoutAtomRotcossin = GmemLayoutAtom; + using GmemTiledCopyRotcossin = decltype(make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtomRotcossin{}, + Layout>{})); // Val layout, 4 vals per load + using GmemTiledCopyRotcossinCont = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomRotcossin{}, + Layout>{})); // Val layout, 8 vals per load +}; + +// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue. +// No_double_buffer is another option to reduce smem usage, but will slow things down. +template > +struct Flash_bwd_kernel_traits : public Base { + using Element = typename Base::Element; + using ElementAccum = typename Base::ElementAccum; + using index_t = typename Base::index_t; + static constexpr bool Has_cp_async = Base::Has_cp_async; + using SmemCopyAtom = typename Base::SmemCopyAtom; + using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; + + static constexpr bool Is_V_in_regs = Is_V_in_regs_; + static constexpr bool No_double_buffer = No_double_buffer_; + + // The number of threads. + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + static constexpr int AtomLayoutMSdP = AtomLayoutMSdP_; + static_assert(kNWarps % AtomLayoutMSdP == 0); + static_assert(kNWarps % AtomLayoutNdKV == 0); + static_assert(kNWarps % AtomLayoutMdQ == 0); + + using TiledMmaSdP = TiledMMA< + typename Base::MMA_Atom_Arch, + cute::Layout, cute::Int, _1>>, + typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + + using TiledMmadKV = TiledMMA< + typename Base::MMA_Atom_Arch, + cute::Layout, cute::Int, _1>>, + typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + + using TiledMmadQ = TiledMMA< + typename Base::MMA_Atom_Arch, + cute::Layout, cute::Int, _1>>, // 2x4x1 or 4x2x1 thread group + typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + + using SmemLayoutAtomQdO = decltype(composition(Swizzle{}, + cute::Layout>, + cute::Stride, _1>>{})); + using SmemLayoutQdO = decltype(tile_to_shape( + SmemLayoutAtomQdO{}, + cute::make_shape(cute::Int{}, cute::Int{}))); + + using SmemLayoutAtomKV = decltype(composition(Swizzle{}, + cute::Layout, cute::Int>, + cute::Stride, _1>>{})); + using SmemLayoutKV = decltype(tile_to_shape( + // SmemLayoutAtomQdO{}, + SmemLayoutAtomKV{}, + cute::make_shape(cute::Int{}, cute::Int{}))); + + using SmemLayoutAtomKtransposedNoSwizzle = Layout, Int>, + Stride<_1, Int>>; + using SmemLayoutAtomKtransposed = decltype(composition(Swizzle{}, SmemLayoutAtomKtransposedNoSwizzle{})); + using SmemLayoutKtransposed = decltype(tile_to_shape( + SmemLayoutAtomKtransposed{}, + make_shape(Int{}, Int{}))); + // Maybe the KtransposeNoSwizzle just needs to have the right shape + // And the strides don't matter? + using SmemLayoutKtransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomKtransposedNoSwizzle{}, + make_shape(Int{}, Int{}))); + // using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn()); + + // TODO: generalize to other values of kBlockN + // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 + // static constexpr int kPBlockN = kBlockN; + static_assert(kBlockN >= 64); + // TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest. + static constexpr int kPBlockN = 64; + static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64); + // static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3); + static constexpr int kSwizzlePdS = 3; + using SmemLayoutAtomPdS = decltype(composition(Swizzle{}, + cute::Layout, cute::Int>, + cute::Stride, _1>>{})); + using SmemLayoutPdS = decltype(tile_to_shape( + SmemLayoutAtomPdS{}, + cute::make_shape(cute::Int{}, cute::Int{}))); + using SmemLayoutAtomPdStransposedNoSwizzle = Layout, Int>, + Stride<_1, Int>>; + using SmemLayoutAtomPdStransposed = decltype(composition(Swizzle{}, SmemLayoutAtomPdStransposedNoSwizzle{})); + using SmemLayoutPdStransposed = decltype(tile_to_shape( + SmemLayoutAtomPdStransposed{}, + make_shape(Int{}, Int{}))); + using SmemLayoutPdStransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomPdStransposedNoSwizzle{}, + make_shape(Int{}, Int{}))); + // using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn()); + using SmemCopyAtomPdS = Copy_Atom; + + using SmemLayoutAtomQdOtransposedNoSwizzle = Layout, Int>, + Stride<_1, Int>>; + using SmemLayoutAtomQdOtransposed = decltype(composition(Swizzle{}, SmemLayoutAtomQdOtransposedNoSwizzle{})); + using SmemLayoutQdOtransposed = decltype(tile_to_shape( + SmemLayoutAtomQdOtransposed{}, + make_shape(Int{}, Int{}))); + using SmemLayoutQdOtransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomQdOtransposedNoSwizzle{}, + make_shape(Int{}, Int{}))); + // using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn()); + + using SmemLayoutAtomdKV = decltype(composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutdKV = decltype(tile_to_shape( + SmemLayoutAtomdKV{}, + make_shape(Int{}, Int{}))); + using SmemCopyAtomdKV = Copy_Atom; + + using SmemLayoutAtomdQ = decltype(composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutdQ = decltype(tile_to_shape( + SmemLayoutAtomdQ{}, + make_shape(Int{}, Int{}))); + using SmemCopyAtomdQ = Copy_Atom; + + static constexpr int kSmemQdOCount = cute::size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3); // Double buffer for sQ + static constexpr int kSmemKVCount = cute::size(SmemLayoutKV{}) * 2; + static constexpr int kSmemdSCount = cute::size(SmemLayoutPdS{}); + static constexpr int kSmemPCount = cute::size(SmemLayoutPdS{}); + static constexpr int kSmemdQCount = cute::size(SmemLayoutdQ{}); + // static constexpr int kSmemdPsumCount = kBlockM; + static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element); + static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); + static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element); + static constexpr int kSmemPSize = kSmemPCount * sizeof(Element); + static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element); + // static constexpr int kSmemdPsumSize = kSmemdPsumCount * sizeof(ElementAccum); + static constexpr int kSmemSize = kSmemQdOSize + (!Is_V_in_regs + ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) + : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize))); + static constexpr int kSmemSize1colblock = kSmemQdOSize + (!Is_V_in_regs + ? kSmemKVSize + kSmemdSSize + kSmemPSize + : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize)); + static constexpr int kSmemSize1rowblock = kSmemQdOSize / 3 * 2 + kSmemKVSize / 2 * 3 + kSmemdSSize + kSmemPSize; + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem + // to affect speed in practice. + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = cute::Layout, cute::Int>, + cute::Stride, _1>>; + + // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock. This is slightly faster. + using Gmem_copy_struct = std::conditional_t< + Has_cp_async, + SM80_CP_ASYNC_CACHEGLOBAL, + DefaultCopy>; + using GmemTiledCopyQKV = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + cute::Layout>{})); // Val layout, 8 vals per read + using GmemTiledCopydO = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + cute::Layout>{})); // Val layout, 8 vals per store + using GmemTiledCopydKV = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + cute::Layout>{})); // Val layout, 8 vals per store + using GmemTiledCopydQ = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + cute::Layout>{})); // Val layout, 8 vals per store + using GmemLayoutAtomdQaccum = std::conditional_t< + kBlockKSmem == 32, + cute::Layout, // Thread layout, 8 threads per row + cute::Stride<_8, _1>>, + cute::Layout, // Thread layout, 16 threads per row + cute::Stride<_16, _1>>>; + using GmemTiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomdQaccum{}, + cute::Layout>{})); // Val layout, 4 vals per store + + using GmemTiledCopydQaccumAtomicAdd = decltype(make_tiled_copy(Copy_Atom{}, + cute::Layout, // Thread layout, 8 threads per row + cute::Stride<_32, _1>>{}, + cute::Layout>{})); // Val layout, 1 val per store +}; + +} // namespace flash diff --git a/operators/cuda/attention_lib/flash_attention/softmax.h b/operators/cuda/attention_lib/flash_attention/softmax.h new file mode 100644 index 000000000..9c31336c9 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/softmax.h @@ -0,0 +1,215 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +#pragma once + +#include + +#include + +#include +#include + +#include "utils.h" + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ inline void thread_reduce_(Tensor const& tensor, Tensor& summary, Operator& op) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); mi++) { + summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); +#pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + summary(mi) = op(summary(mi), tensor(mi, ni)); + } + } +} + +template +__device__ inline void quad_allreduce_(Tensor& dst, Tensor& src, Operator& op) { + CUTE_STATIC_ASSERT_V(size(dst) == size(src)); +#pragma unroll + for (int i = 0; i < size(dst); i++) { + dst(i) = Allreduce<4>::run(src(i), op); + } +} + +template +__device__ inline void reduce_(Tensor const& tensor, Tensor& summary, Operator& op) { + thread_reduce_(tensor, summary, op); + quad_allreduce_(summary, summary, op); +} + +template +__device__ inline void reduce_max(Tensor const& tensor, Tensor& max) { + MaxOp max_op; + reduce_(tensor, max, max_op); +} + +template +__device__ inline void reduce_sum(Tensor const& tensor, Tensor& sum) { + SumOp sum_op; + reduce_(tensor, sum, sum_op); +} + +// Apply the exp to all the elements. +template +inline __device__ void scale_apply_exp2(Tensor& tensor, Tensor const& max, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + // If we don't have float around M_LOG2E the multiplication is done in fp64. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E)); +#pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + } + } +} + +// Apply the exp to all the elements. +template +inline __device__ void max_scale_exp2_sum(Tensor& tensor, Tensor& max, Tensor& sum, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + MaxOp max_op; + max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0)); +#pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + max(mi) = max_op(max(mi), tensor(mi, ni)); + } + max(mi) = Allreduce<4>::run(max(mi), max_op); + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale; + sum(mi) = 0; +#pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + sum(mi) += tensor(mi, ni); + } + SumOp sum_op; + sum(mi) = Allreduce<4>::run(sum(mi), sum_op); + } +} + +template +inline __device__ void apply_mask(Tensor& tensor, const int max_seqlen_k, + const int col_idx_offset_ = 0) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; +#pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; +#pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if (col_idx >= max_seqlen_k) { +// Without the "make_coord" we get wrong results +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + tensor(mi, make_coord(j, nj)) = -INFINITY; + } + } + } + } +} + +template +inline __device__ void apply_mask_local(Tensor& tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset_, + const int max_seqlen_q, const int warp_row_stride, + const int window_size_left, const int window_size_right) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const int lane_id = threadIdx.x % 32; + // const int row_idx_offset = row_idx_offset_ + lane_id / 4; + const int row_idx_offset = row_idx_offset_; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; +#pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; +#pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); + const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); +#pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; +#pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + } + // if (cute::thread0()) { + // printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k); + // print(tensor(make_coord(i, mi), _)); + // // print(tensor(_, j + nj * size<1, 0>(tensor))); + // } + } + } +} + +template +inline __device__ void apply_mask_causal(Tensor& tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset_, + const int max_seqlen_q, const int warp_row_stride) { + // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0 + apply_mask_local(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset_, + max_seqlen_q, warp_row_stride, -1, 0); +} + +template +inline __device__ void apply_mask_causal_w_idx( + Tensor& tensor, Tensor const& idx_rowcol, + const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset_) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 2, "Only support 2D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol)); + CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset_ + get<0>(idx_rowcol(mi, 0))); +#pragma unroll + for (int ni = 0; ni < size<1, 1>(tensor); ++ni) { + if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) { + tensor(mi, ni) = -INFINITY; + } + } + // if (cute::thread0()) { + // printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k); + // print(tensor(_, make_coord(j, ni))); + // // print(tensor(_, j + ni * size<1, 0>(tensor))); + // } + } +} + +} // namespace flash diff --git a/operators/cuda/attention_lib/flash_attention/static_switch.h b/operators/cuda/attention_lib/flash_attention/static_switch.h new file mode 100644 index 000000000..5b7098894 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/static_switch.h @@ -0,0 +1,64 @@ +// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() + +#define FP16_SWITCH(COND, ...) \ + [&] { \ + if (COND) { \ + using elem_type = cutlass::half_t; \ + return __VA_ARGS__(); \ + } else { \ + using elem_type = cutlass::bfloat16_t; \ + return __VA_ARGS__(); \ + } \ + }() + +#define FWD_HEADDIM_SWITCH(HEADDIM, ...) \ + [&] { \ + if (HEADDIM <= 32) { \ + constexpr static int kHeadDim = 32; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 64) { \ + constexpr static int kHeadDim = 64; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 96) { \ + constexpr static int kHeadDim = 96; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 128) { \ + constexpr static int kHeadDim = 128; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 160) { \ + constexpr static int kHeadDim = 160; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 192) { \ + constexpr static int kHeadDim = 192; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 224) { \ + constexpr static int kHeadDim = 224; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 256) { \ + constexpr static int kHeadDim = 256; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/operators/cuda/attention_lib/flash_attention/utils.h b/operators/cuda/attention_lib/flash_attention/utils.h new file mode 100644 index 000000000..cd10bd534 --- /dev/null +++ b/operators/cuda/attention_lib/flash_attention/utils.h @@ -0,0 +1,499 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +#pragma once + +#include +#include +#include + +#include + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#include +#endif + +#include +#include + +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// +namespace flash { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ uint32_t relu2(const uint32_t x); + +template <> +inline __device__ uint32_t relu2(const uint32_t x) { + uint32_t res; + const uint32_t zero = 0u; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("max.f16x2 %0, %1, %2;\n" + : "=r"(res) + : "r"(x), "r"(zero)); +#else + asm volatile( + "{\n" + "\t .reg .f16x2 sela;\n" + "\t set.gtu.u32.f16x2 sela, %1, %2;\n" + "\t and.b32 %0, sela, %1;\n" + "}\n" + : "=r"(res) + : "r"(x), "r"(zero)); +#endif + return res; +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +template <> +inline __device__ uint32_t relu2(const uint32_t x) { + uint32_t res; + const uint32_t zero = 0u; + asm volatile("max.bf16x2 %0, %1, %2;\n" + : "=r"(res) + : "r"(x), "r"(zero)); + return res; +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + +template +inline __device__ uint32_t convert_relu2(const float2 x); + +template <> +inline __device__ uint32_t convert_relu2(const float2 x) { + uint32_t res; + const uint32_t a = reinterpret_cast(x.x); + const uint32_t b = reinterpret_cast(x.y); + asm volatile("cvt.rn.relu.f16x2.f32 %0, %1, %2;\n" + : "=r"(res) + : "r"(b), "r"(a)); + return res; +} + +template <> +inline __device__ uint32_t convert_relu2(const float2 x) { + uint32_t res; + const uint32_t a = reinterpret_cast(x.x); + const uint32_t b = reinterpret_cast(x.y); + asm volatile("cvt.rn.relu.bf16x2.f32 %0, %1, %2;\n" + : "=r"(res) + : "r"(b), "r"(a)); + return res; +} + +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MaxOp { + __device__ inline T operator()(T const& x, T const& y) { return x > y ? x : y; } +}; + +template <> +struct MaxOp { + // This is slightly faster + __device__ inline float operator()(float const& x, float const& y) { return max(x, y); } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SumOp { + __device__ inline T operator()(T const& x, T const& y) { return x + y; } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ inline T run(T x, Operator& op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Allreduce<2> { + template + static __device__ inline T run(T x, Operator& op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void gemm(Tensor0& acc, Tensor1& tCrA, Tensor2& tCrB, Tensor3 const& tCsA, + Tensor4 const& tCsB, TiledMma tiled_mma, + TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, + ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) { + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + if (!A_in_regs) { + cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); + } + if (!B_in_regs) { + cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); + } +#pragma unroll + for (int i = 0; i < size<2>(tCrA); ++i) { + if (i < size<2>(tCrA) - 1) { + if (!A_in_regs) { + cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); + } + if (!B_in_regs) { + cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); + } + } + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void gemm_A_in_regs(Tensor0& acc, Tensor1& tCrA, Tensor2& tCrB, Tensor3 const& tCsB, + TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, + ThrCopy smem_thr_copy_B) { + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); +#pragma unroll + for (int i = 0; i < size<2>(tCrA); ++i) { + if (i < size<2>(tCrA) - 1) { + cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); + } + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) +template +inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) + // TD [2023-08-13]: Idk why but get<0, 1>(l) doesn't work for Cutlass 3.2, I'm getting + // "int_tuple.hpp(74): error: conversion to inaccessible base class" + // return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); + return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)), make_layout(get<0>(get<0>(l)), get<2>(l))); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) +// if using m16n8k16, or to ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. +template +inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) { + using X = Underscore; + static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2); + static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2); + constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{}); + static_assert(mma_shape_K == 8 || mma_shape_K == 16); + constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2; + auto l = logical_divide(rowcol_layout, Shape>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2))) + // TD [2023-08-13]: Same error as above on Cutlass 3.2 + // return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)), + // get<0, 1>(l), + // get<1, 1, 1>(l)); + return make_layout(make_layout(get<0>(get<1>(l)), get<0>(get<0>(l)), get<0>(get<1>(get<1>(l)))), + get<1>(get<0>(l)), + get<1>(get<1>(get<1>(l)))); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ auto convert_type(Tensor const& tensor) { + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + // HACK: this requires tensor to be "contiguous" + auto frag = convert_op(*reinterpret_cast*>(tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void relu_(Tensor& tensor) { + constexpr int numel = decltype(size(tensor))::value; + static_assert(numel % 2 == 0); + using value_t = typename Engine::value_type; + // HACK: this requires tensor to be "contiguous" + Tensor tensor_uint32 = recast(tensor); +#pragma unroll + for (int i = 0; i < size(tensor_uint32); ++i) { + tensor_uint32(i) = relu2(tensor_uint32(i)); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction +template +inline __device__ auto convert_type_relu(Tensor const& tensor) { + using From_type = typename Engine::value_type; + static_assert(std::is_same_v || std::is_same_v); + static_assert(std::is_same_v); + constexpr int numel = decltype(size(tensor))::value; + static_assert(numel % 2 == 0); +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + // HACK: this requires tensor to be "contiguous" + Tensor tensor_float2 = recast(tensor); + Tensor out_uint32 = make_tensor(tensor_float2.layout()); +#pragma unroll + for (int i = 0; i < size(out_uint32); ++i) { + out_uint32(i) = convert_relu2(tensor_float2(i)); + } + Tensor out = make_tensor(make_rmem_ptr(out_uint32.data()), tensor.layout()); +#else + Tensor out = flash::convert_type(tensor); + flash::relu_(out); +#endif + return out; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Blocks until all but N previous cp.async.commit_group operations have committed. +// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all +// (which is equivalent to commit_group then wait_group 0). +// Instead we just call cp.async.wait_group 0, which is slightly faster. +// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113 +template +CUTE_HOST_DEVICE void cp_async_wait() { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void copy(TiledCopy tiled_copy, Tensor const& S, + Tensor& D, Tensor const& identity_MN, + Tensor const& predicate_K, const int max_MN = 0) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + // There's no case where !Clear_OOB_K && Clear_OOB_MN + static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); +#pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { +#pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || predicate_K(k)) { + cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } else if (Clear_OOB_MN) { + cute::clear(D(_, m, _)); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void copy_w_min_idx(Tensor const& S, + Tensor& D, Tensor const& identity_MN, + Tensor const& predicate_K, + const int max_MN = 0, const int min_MN = 0) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K +// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); } +#pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { +// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } +#pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || predicate_K(k)) { + cute::copy(S(_, m, k), D(_, m, k)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void copy_rotary_interleaved(Tensor const& S, + Tensor& D, + Tensor const& Cos, + Tensor const& Sin, + Tensor const& identity_MN, + const int max_MN, const int min_MN, + const int dim, const int rotary_dim) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K + CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); // MMA_K + static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2); + static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + Tensor rCos = make_fragment_like(Cos); + Tensor rSin = make_fragment_like(Sin); + Tensor rS = make_fragment_like(S); +#pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { +#pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { + cute::copy(S(_, m, k), rS(_, m, k)); + if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { + cute::copy(Cos(_, m, k), rCos(_, m, k)); + cute::copy(Sin(_, m, k), rSin(_, m, k)); + Tensor S_fp32 = convert_type(rS(_, m, k)); + Tensor cos_fp32 = convert_type(rCos(_, m, k)); + Tensor sin_fp32 = convert_type(rSin(_, m, k)); +#pragma unroll + for (int i = 0; i < size<0>(rS) / 2; ++i) { + float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i); + float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i); + S_fp32(2 * i) = real; + S_fp32(2 * i + 1) = imag; + } + // Idk but I need to copy for the convert_type to work + Tensor S_fp32_copy = make_fragment_like(S_fp32); + cute::copy(S_fp32, S_fp32_copy); + using T = typename Engine0::value_type; + Tensor S_og_type = convert_type(S_fp32_copy); + cute::copy(S_og_type, rS(_, m, k)); + } + cute::copy(rS(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void copy_rotary_contiguous(Tensor const& S, + Tensor& D, + Tensor const& Cos, + Tensor const& Sin, + Tensor const& identity_MN, + const int max_MN, const int min_MN, + const int dim, const int rotary_dim) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos)); // MMA + CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); + static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + Tensor rCos = make_fragment_like(Cos); + Tensor rSin = make_fragment_like(Sin); + Tensor rS = make_fragment_like(S); + Tensor rS_other = make_fragment_like(rS(_, 0, 0)); +#pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { +#pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { + cute::copy(S(_, m, k), rS(_, m, k)); + if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { + const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2; + Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout()); + cute::copy(gS_other, rS_other); + // if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); } + Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout()); + Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout()); + cute::copy(gCos, rCos(_, m, k)); + cute::copy(gSin, rSin(_, m, k)); + // if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); } + Tensor S_fp32 = convert_type(rS(_, m, k)); + Tensor S_other_fp32 = convert_type(rS_other); + Tensor cos_fp32 = convert_type(rCos(_, m, k)); + Tensor sin_fp32 = convert_type(rSin(_, m, k)); +#pragma unroll + for (int i = 0; i < size<0>(rS); ++i) { + S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i)); + } + // Idk but I need to copy for the convert_type to work + Tensor S_fp32_copy = make_fragment_like(S_fp32); + cute::copy(S_fp32, S_fp32_copy); + using T = typename Engine0::value_type; + Tensor S_og_type = convert_type(S_fp32_copy); + cute::copy(S_og_type, rS(_, m, k)); + // if (cute::thread0()) { print_tensor(rS(_, m, k)); } + } + cute::copy(rS(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash