Skip to content

Commit

Permalink
Introduce flash attention and cutlass library (microsoft#708)
Browse files Browse the repository at this point in the history
* refactor cuda ops, remove contrib folder

* introduce flash attention and cutlass

* resolve comments

---------

Co-authored-by: Lei Cao <[email protected]@onnxruntime-a10.bxgbzpva45kedp3rhbsbit4phb.jx.internal.cloudapp.net>
  • Loading branch information
jslhcl and Lei Cao authored May 6, 2024
1 parent dfdf52e commit e645cda
Show file tree
Hide file tree
Showing 52 changed files with 4,363 additions and 0 deletions.
6 changes: 6 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)
include(CheckCXXCompilerFlag)
include(CheckLanguage)
include(CMakeDependentOption)

set(_ORTX_STANDALONE_PROJECT OFF)
if (CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_SOURCE_DIR)
Expand Down Expand Up @@ -299,6 +300,7 @@ endmacro()

if(OCOS_USE_CUDA)
include(ext_cuda)
include(cutlass)
endif()

#######################################################################################################################
Expand Down Expand Up @@ -581,6 +583,10 @@ target_include_directories(ocos_operators PUBLIC
${PROJECT_SOURCE_DIR}/base
${PROJECT_SOURCE_DIR}/operators)

if (OCOS_USE_CUDA)
target_include_directories(ocos_operators PUBLIC ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples)
endif()

set(ocos_libraries)
set(OCOS_COMPILE_DEFINITIONS)

Expand Down
18 changes: 18 additions & 0 deletions cmake/ext_cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@ enable_language(CUDA)

set(CMAKE_CUDA_RUNTIME_LIBRARY Shared)
set(CMAKE_CUDA_STANDARD 17)
cmake_dependent_option(OCOS_USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "NOT WIN32" OFF)
option(OCOS_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON)
if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6)
message(STATUS "Turn off flash attention and memory efficient attention since CUDA compiler version < 11.6")
set(OCOS_USE_FLASH_ATTENTION OFF)
set(OCOS_USE_MEMORY_EFFICIENT_ATTENTION OFF)
endif()

set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 11)
Expand All @@ -22,3 +29,14 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcudafe \"--diag_suppress=unsigned_co
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcudafe \"--diag_suppress=expr_has_no_effect\"")

add_compile_definitions(USE_CUDA)

set(OCOS_USE_MEMORY_EFFICIENT_ATTENTION OFF) # turn off for the build time. Turn them on when these 2 libs are really in use
set(OCOS_USE_FLASH_ATTENTION OFF)
if (OCOS_USE_FLASH_ATTENTION)
message(STATUS "Enable flash attention")
add_compile_definitions(OCOS_USE_FLASH_ATTENTION)
endif()
if (OCOS_USE_MEMORY_EFFICIENT_ATTENTION)
message(STATUS "Enable memory efficient attention")
add_compile_definitions(OCOS_USE_MEMORY_EFFICIENT_ATTENTION)
endif()
10 changes: 10 additions & 0 deletions cmake/externals/cutlass.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
FetchContent_Declare(
cutlass
GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git
GIT_TAG v3.1.0
)

FetchContent_GetProperties(cutlass)
if(NOT cutlass_POPULATED)
FetchContent_Populate(cutlass)
endif()
276 changes: 276 additions & 0 deletions operators/cuda/attention_lib/cutlass_fmha/fmha_launch_template.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#if OCOS_USE_MEMORY_EFFICIENT_ATTENTION

#if defined(__GNUC__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif

#include "memory_efficient_attention.h"
#include "41_fused_multi_head_attention/kernel_forward.h"

namespace ort_extensions {
namespace cuda {

template <typename AttentionKernel, int kQueriesPerBlock>
struct RightPaddingBatchHook {
using scalar_t = typename AttentionKernel::scalar_t;
using accum_t = typename AttentionKernel::accum_t;
using lse_scalar_t = typename AttentionKernel::lse_scalar_t;
using output_t = typename AttentionKernel::output_t;
using output_accum_t = typename AttentionKernel::output_accum_t;

static constexpr bool kSupportsDropout = AttentionKernel::kSupportsDropout;
static constexpr bool kSupportsBias = AttentionKernel::kSupportsBias;
static constexpr int kKeysPerBlock = AttentionKernel::kKeysPerBlock;
static constexpr bool kIsAligned = AttentionKernel::kIsAligned;
static constexpr bool kSingleValueIteration = AttentionKernel::kSingleValueIteration;
static constexpr int32_t kAlignLSE = AttentionKernel::kAlignLSE; // block size of backward
static constexpr bool kPreloadV = AttentionKernel::kPreloadV;
static constexpr bool kKeepOutputInRF = AttentionKernel::kKeepOutputInRF;
static constexpr bool kNeedsOutputAccumulatorBuffer = AttentionKernel::kNeedsOutputAccumulatorBuffer;

template <typename Params>
static CUTLASS_DEVICE bool AdvanceToBlockForGQA(Params& p) {
auto batch_id = blockIdx.z;
auto head_id = blockIdx.y;
auto query_start = blockIdx.x * kQueriesPerBlock;

auto lse_dim = ceil_div((int32_t)(p.num_queries), kAlignLSE) * kAlignLSE;

// Advance to current batch - in case of different sequence lengths
if (p.seqlen_k_ptr) {
p.num_keys = p.seqlen_k_ptr[batch_id];
}

if (query_start >= p.num_queries) {
return false;
}

// Advance to the current batch / head / query_start
p.query_ptr += batch_id * p.q_strideB + query_start * p.q_strideM + head_id * p.q_strideH;
p.key_ptr += batch_id * p.k_strideB + head_id * p.k_strideH;
p.value_ptr += batch_id * p.v_strideB + head_id * p.v_strideH;
p.output_ptr += int64_t(batch_id * p.num_queries) * p.o_strideM + int64_t(query_start) * p.o_strideM + head_id * p.head_dim_value;

if (kSupportsBias && p.attn_bias_ptr != nullptr) {
p.attn_bias_ptr += (batch_id * p.bias_strideB) + (head_id * p.bias_strideH);
}
if (p.output_accum_ptr != nullptr) {
p.output_accum_ptr += int64_t(batch_id * p.num_queries) * (p.head_dim_value * p.num_heads) +
int64_t(query_start) * (p.head_dim_value * p.num_heads) +
head_id * p.head_dim_value;
} else {
// Accumulate directly in the destination buffer (eg for f32)
p.output_accum_ptr = (accum_t*)(p.output_ptr);
}

if (p.logsumexp_ptr != nullptr) {
// lse[batch_id, head_id, query_start]
p.logsumexp_ptr +=
batch_id * lse_dim * p.num_heads + head_id * lse_dim + query_start;
}

// Custom masking
if (p.causal_diagonal_ptr) {
p.causal_diagonal_offset = p.causal_diagonal_ptr[batch_id];
}
if (p.custom_mask_type == AttentionKernel::CausalFromBottomRight) {
p.causal_diagonal_offset += p.num_keys - p.num_queries;
}
if (p.custom_mask_type == AttentionKernel::CausalFromTopLeft ||
p.custom_mask_type == AttentionKernel::CausalFromBottomRight) {
// the bottom row of the current block is query_start + kQueriesPerBlock
// the last active key is then query_start + causal_diagonal_offset +
// kQueriesPerBlock so num_keys is the min between actual num_keys and
// this to avoid extra computations
p.num_keys = cutlass::fast_min(
int32_t(query_start + p.causal_diagonal_offset + kQueriesPerBlock),
p.num_keys);
}

p.num_queries -= query_start;
p.num_batches = 0; // no longer used after

// If num_queries == 1, and there is only one key head we're wasting
// 15/16th of tensor core compute In that case :
// - we only launch kernels for head_id % kQueriesPerBlock == 0
// - we iterate over heads instead of queries (strideM = strideH)
if (p.num_queries == 1 && p.k_strideH == 0 && p.v_strideH == 0) {
if (head_id % kQueriesPerBlock != 0)
return false;
p.q_strideM = p.q_strideH;
p.num_queries = p.num_heads;
p.num_heads = 1; // unused but here for intent
// remove causal since n_query = 1
// otherwise, offset would change with head !
p.custom_mask_type = AttentionKernel::NoCustomMask;
p.o_strideM = p.head_dim_value;
}

// Make sure the compiler knows these variables are the same on all
// the threads of the warp.
p.query_ptr = warp_uniform(p.query_ptr);
p.key_ptr = warp_uniform(p.key_ptr);
p.value_ptr = warp_uniform(p.value_ptr);
if (kSupportsBias) {
p.attn_bias_ptr = warp_uniform(p.attn_bias_ptr);
}
p.output_ptr = warp_uniform(p.output_ptr);
p.output_accum_ptr = warp_uniform(p.output_accum_ptr);
p.logsumexp_ptr = warp_uniform(p.logsumexp_ptr);
p.num_queries = warp_uniform(p.num_queries);
p.num_keys = warp_uniform(p.num_keys);
p.num_heads = warp_uniform(p.num_heads);
p.head_dim = warp_uniform(p.head_dim);
p.head_dim_value = warp_uniform(p.head_dim_value);
p.o_strideM = warp_uniform(p.o_strideM);
p.custom_mask_type = warp_uniform(p.custom_mask_type);
return true;
}
};

template <typename AK, int kQueriesPerBlock>
__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
attention_kernel_batched_impl_right_padding(typename AK::Params p) {
if (!RightPaddingBatchHook<AK, kQueriesPerBlock>::AdvanceToBlockForGQA(p)) {
return;
}
AK::attention_kernel(p);
}

template <typename T, typename ArchTag, bool is_aligned, int queries_per_block, int keys_per_block, bool single_value_iteration>
void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) {
using Attention = AttentionKernel<T, ArchTag, is_aligned, queries_per_block, keys_per_block, single_value_iteration>;
typename Attention::Params p;
{ // set parameters
p.query_ptr = const_cast<T*>(reinterpret_cast<const T*>(params.query));
p.key_ptr = const_cast<T*>(reinterpret_cast<const T*>(params.key));
p.value_ptr = const_cast<T*>(reinterpret_cast<const T*>(params.value));
p.attn_bias_ptr = const_cast<T*>(reinterpret_cast<const T*>(params.attn_bias));
p.seqstart_q_ptr = params.seqstart_q_ptr;
p.seqstart_k_ptr = params.seqstart_k_ptr;
p.seqlen_k_ptr = params.seqlen_k_ptr;

p.logsumexp_ptr = nullptr; // [num_heads, num_queries] for backward or nullptr for forward
p.output_ptr = reinterpret_cast<T*>(params.output);
if (Attention::kNeedsOutputAccumulatorBuffer) {
using Acc = typename Attention::accum_t;
// workspace size: batch_size * sequence_length * num_heads * v_head_size * sizeof(float)
// TODO: ORT_ENFORCE(params.workspace != nullptr, "Need output accumulator buffer but no workspace provided");
p.output_accum_ptr = reinterpret_cast<Acc*>(params.workspace);
} else {
p.output_accum_ptr = nullptr;
}
p.num_heads = params.num_heads;
p.num_batches = params.batch_size;
p.head_dim = params.qk_head_size;
p.head_dim_value = params.v_head_size;

p.scale = params.scale;

// When params.cu_seqlens_q is provided, num_queries is max_seq_q and num_keys will be set inside the kernel
p.num_queries = params.sequence_length;
p.num_keys = params.kv_sequence_length;

if (params.causal) {
p.custom_mask_type = Attention::CausalFromBottomRight;
}

// We use max_sequence_length to calculate KV stride
if (params.is_kv_bsnh) {
// Input Q, K, V format is BxSxNxH, output is BxSxNxH
p.q_strideH = params.qk_head_size;
p.k_strideH = params.qk_head_size;
p.v_strideH = params.v_head_size;
p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys;

p.q_strideM = params.num_heads * params.qk_head_size;
p.k_strideM = params.num_heads * params.qk_head_size;
p.v_strideM = params.num_heads * params.v_head_size;
p.o_strideM = params.num_heads * params.v_head_size;
p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys;

p.q_strideB = static_cast<int64_t>(p.q_strideM) * params.sequence_length;
p.k_strideB = static_cast<int64_t>(p.k_strideM) * params.max_sequence_length;
p.v_strideB = static_cast<int64_t>(p.v_strideM) * params.max_sequence_length;
p.bias_strideB = params.is_attn_bias_batched ? static_cast<int64_t>(p.bias_strideH) * params.num_heads : 0;
} else {
// Input K, V format is BxNxSxH, Input Q is BxSxNxH, output is BxSxNxH
p.q_strideH = params.qk_head_size;
p.k_strideH = params.max_sequence_length * params.qk_head_size;
p.v_strideH = params.max_sequence_length * params.v_head_size;
p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys;

p.q_strideM = params.num_heads * params.qk_head_size;
p.k_strideM = params.qk_head_size;
p.v_strideM = params.v_head_size;
p.o_strideM = params.num_heads * params.v_head_size;
p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys;

p.q_strideB = params.num_heads * params.qk_head_size * params.sequence_length;
p.k_strideB = params.num_heads * params.qk_head_size * params.max_sequence_length;
p.v_strideB = params.num_heads * params.v_head_size * params.max_sequence_length;
p.bias_strideB = params.is_attn_bias_batched ? static_cast<int64_t>(p.bias_strideH) * params.num_heads : 0;
}
}

auto kernel_fn = attention_kernel_batched_impl<Attention>;
if (params.has_custom_right_padding) {
kernel_fn = attention_kernel_batched_impl_right_padding<Attention, queries_per_block>;
}

int smem_bytes = sizeof(typename Attention::SharedStorage);
if (smem_bytes > 0xc000) {
// TODO: ORT_ENFORCE(params.sm >= 70, "This kernel requires too much shared memory on this machine!");
static bool once = [&]() {
cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
return true;
}();
}

// TODO: ORT_ENFORCE(Attention::check_supported(p));
kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes, params.stream>>>(p);
}

template <typename T, typename ArchTag, int queries_per_block, int keys_per_block, bool single_value_iteration>
void DispatchIsAligned(const MemoryEfficientAttentionParams& params) {
using AlignedAK = AttentionKernel<T, ArchTag, true, queries_per_block, keys_per_block, single_value_iteration>;
#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(push)
#pragma warning(disable : 6287)
#endif
// Run a more efficient kernel with `isAligned=True` when memory is correctly aligned.
bool is_aligned = params.qk_head_size % AlignedAK::kAlignmentQ == 0 &&
params.qk_head_size % AlignedAK::kAlignmentK == 0 &&
params.v_head_size % AlignedAK::kAlignmentV == 0;
#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(pop)
#endif
DISPATCH_BOOL(is_aligned, kIsAligned, ([&]() {
LaunchCutlassFmha<T, ArchTag, kIsAligned, queries_per_block, keys_per_block, single_value_iteration>(params);
}));
}

template <typename T, typename ArchTag>
void DispatchBlockSize(const MemoryEfficientAttentionParams& params) {
if (params.v_head_size <= 64) {
DispatchIsAligned<T, ArchTag, 64, 64, true>(params);
} else if (params.v_head_size <= 128) {
DispatchIsAligned<T, ArchTag, 32, 128, true>(params);
} else {
DispatchIsAligned<T, ArchTag, 32, 128, false>(params);
}
}

} // namespace cuda
} // namespace ort_extensions

#if defined(__GNUC__)
#pragma GCC diagnostic pop
#endif

#endif // OCOS_USE_MEMORY_EFFICIENT_ATTENTION
22 changes: 22 additions & 0 deletions operators/cuda/attention_lib/cutlass_fmha/fmha_sm50.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#if OCOS_USE_MEMORY_EFFICIENT_ATTENTION

#include "fmha_launch_template.h"

namespace ort_extensions {
namespace cuda {

void run_memory_efficient_attention_sm50(const MemoryEfficientAttentionParams& params) {
if (params.is_half) {
DispatchBlockSize<cutlass::half_t, cutlass::arch::Sm50>(params);
} else {
DispatchBlockSize<float, cutlass::arch::Sm50>(params);
}
}

} // namespace cuda
} // namespace ort_extensions

#endif // OCOS_USE_MEMORY_EFFICIENT_ATTENTION
22 changes: 22 additions & 0 deletions operators/cuda/attention_lib/cutlass_fmha/fmha_sm70.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#if OCOS_USE_MEMORY_EFFICIENT_ATTENTION

#include "fmha_launch_template.h"

namespace ort_extensions {
namespace cuda {

void run_memory_efficient_attention_sm70(const MemoryEfficientAttentionParams& params) {
if (params.is_half) {
DispatchBlockSize<cutlass::half_t, cutlass::arch::Sm70>(params);
} else {
DispatchBlockSize<float, cutlass::arch::Sm70>(params);
}
}

} // namespace cuda
} // namespace ort_extensions

#endif // OCOS_USE_MEMORY_EFFICIENT_ATTENTION
Loading

0 comments on commit e645cda

Please sign in to comment.