From b9b1d462041d8f5cac8e6d91381d3fdf24c117a9 Mon Sep 17 00:00:00 2001 From: ustclight-sls Date: Mon, 19 Aug 2024 16:21:51 +0800 Subject: [PATCH 1/2] upgrade flash_attn to 2.4.3 --- impl/torch/functions/functions_ext.cpp | 75 ++++++++++++++++--- .../flash-attention/CMakeLists.txt | 15 +++- .../include/flash_attn/flash_api.h | 23 +++--- 3 files changed, 91 insertions(+), 22 deletions(-) diff --git a/impl/torch/functions/functions_ext.cpp b/impl/torch/functions/functions_ext.cpp index bb246f1497..786b1248a1 100644 --- a/impl/torch/functions/functions_ext.cpp +++ b/impl/torch/functions/functions_ext.cpp @@ -120,7 +120,9 @@ diopiError_t diopiMultiHeadAttention(diopiContextHandle_t ctx, diopiTensorHandle auto headSize = atQ.sizes()[3]; TORCH_CHECK(headSize % 8 == 0, "DIOPI now only support head sizes which are multiple of 8"); - std::vector result = DIOPI_EXT_CALL_FLASH(mha_fwd, atQ, atK, atV, optOut, dropout_p, scale, is_causal, -1, -1, return_debug_mask, atGen); + c10::optional nullOpt; + std::vector result = + DIOPI_EXT_CALL_FLASH(mha_fwd, atQ, atK, atV, optOut, nullOpt, dropout_p, scale, is_causal, -1, -1, return_debug_mask, atGen); // PERF: these copy can be eliminated by modifying the flash_attn api impl::aten::updateATen2Tensor(ctx, result[fa::mha_fwd_ret_idx::SOFTMAX_LSE], softmax_lse); @@ -147,10 +149,28 @@ diopiError_t diopiMultiHeadAttentionBackward(diopiContextHandle_t ctx, diopiCons DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optGradQ, grad_q); DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optGradK, grad_k); DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optGradV, grad_v); - c10::optional nullOpt; // Workaround: flash_attn uses non-const optional& as args (which is a really bad idea) + c10::optional nullStateOpt; // Workaround: flash_attn uses non-const optional& as args (which is a really bad idea) + c10::optional nullSlopesOpt; - std::vector result = DIOPI_EXT_CALL_FLASH( - mha_bwd, atGradOut, atQ, atK, atV, atOut, atLogsumexp, optGradQ, optGradK, optGradV, dropout_p, scale, is_causal, -1, -1, atGen, nullOpt); + std::vector result = DIOPI_EXT_CALL_FLASH(mha_bwd, + atGradOut, + atQ, + atK, + atV, + atOut, + atLogsumexp, + optGradQ, + optGradK, + optGradV, + nullSlopesOpt, + dropout_p, + scale, + is_causal, + -1, + -1, + false, + atGen, + nullStateOpt); return diopiSuccess; } @@ -172,8 +192,27 @@ diopiError_t diopiMultiHeadAttentionVarLen(diopiContextHandle_t ctx, diopiTensor auto headSize = atQ.sizes()[3]; TORCH_CHECK(headSize % 8 == 0, "DIOPI now only support head sizes which are multiple of 8"); - std::vector result = DIOPI_EXT_CALL_FLASH( - mha_varlen_fwd, atQ, atK, atV, optOut, atCumSeqQ, atCumSeqK, max_q, max_k, dropout_p, scale, false, is_causal, -1, -1, return_debug_mask, atGen); + c10::optional nullSeqOpt; + c10::optional nullSlopesOpt; + std::vector result = DIOPI_EXT_CALL_FLASH(mha_varlen_fwd, + atQ, + atK, + atV, + optOut, + atCumSeqQ, + atCumSeqK, + nullSeqOpt, + nullSlopesOpt, + max_q, + max_k, + dropout_p, + scale, + false, + is_causal, + -1, + -1, + return_debug_mask, + atGen); // PERF: these copy can be eliminated by modifying the flash_attn api impl::aten::updateATen2Tensor(ctx, result[fa::mha_fwd_ret_idx::SOFTMAX_LSE], softmax_lse); @@ -203,7 +242,8 @@ diopiError_t diopiMultiHeadAttentionVarLenBackward(diopiContextHandle_t ctx, dio DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optGradQ, grad_q); DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optGradK, grad_k); DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optGradV, grad_v); - c10::optional nullOpt; // Workaround: flash_attn uses non-const optional& as args (which is a really bad idea) + c10::optional nullStateOpt; // Workaround: flash_attn uses non-const optional& as args (which is a really bad idea) + c10::optional nullSlopesOpt; std::vector result = DIOPI_EXT_CALL_FLASH(mha_varlen_bwd, atGradOut, @@ -217,6 +257,7 @@ diopiError_t diopiMultiHeadAttentionVarLenBackward(diopiContextHandle_t ctx, dio optGradV, atCumSeqQ, atCumSeqK, + nullSlopesOpt, max_q, max_k, dropout_p, @@ -225,8 +266,9 @@ diopiError_t diopiMultiHeadAttentionVarLenBackward(diopiContextHandle_t ctx, dio is_causal, -1, -1, + false, atGen, - nullOpt); + nullStateOpt); return diopiSuccess; } @@ -236,15 +278,15 @@ diopiError_t diopiFlashAttention(diopiContextHandle_t ctx, diopiTensorHandle_t a float p_dropout, float softmax_scale, bool is_causal, int32_t window_size_left, int32_t window_size_right) { impl::aten::setCurStream(ctx); - // handle param[out] + // handle param[out]mha_bwd DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optOut, attention_out); + DIOPI_IMPL_BUILD_ATEN_OPTIONAL(linearBiasSlopes, alibi_slopes); // handle param[in] auto atGen = buildGeneratorForMha(ctx, gen, p_dropout); auto atQ = impl::aten::buildATen(q); auto atK = impl::aten::buildATen(k); auto atV = impl::aten::buildATen(v); - DIOPI_CHECK(alibi_slopes == nullptr, "alibi_slopes is not yet supported in DIOPI torch impl"); auto headSize = atQ.sizes()[3]; DIOPI_CHECK(headSize % 8 == 0, "DIOPI torch impl now only support head sizes which are multiple of 8"); @@ -254,6 +296,7 @@ diopiError_t diopiFlashAttention(diopiContextHandle_t ctx, diopiTensorHandle_t a atK, atV, optOut, + linearBiasSlopes, p_dropout, softmax_scale, is_causal, @@ -279,6 +322,7 @@ diopiError_t diopiFlashAttentionBackward(diopiContextHandle_t ctx, diopiTensorHa DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optGradQ, grad_q); DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optGradK, grad_k); DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optGradV, grad_v); + DIOPI_IMPL_BUILD_ATEN_OPTIONAL(linearBiasSlopes, alibi_slopes); // handle param[in] auto atGradOut = impl::aten::buildATen(grad_output); @@ -286,7 +330,6 @@ diopiError_t diopiFlashAttentionBackward(diopiContextHandle_t ctx, diopiTensorHa auto atQ = impl::aten::buildATen(q); auto atK = impl::aten::buildATen(k); auto atV = impl::aten::buildATen(v); - DIOPI_CHECK(alibi_slopes == nullptr, "alibi_slopes is not yet supported in DIOPI torch impl"); auto atOut = impl::aten::buildATen(attention_out); auto atSoftmaxLse = impl::aten::buildATen(softmax_lse); @@ -301,11 +344,13 @@ diopiError_t diopiFlashAttentionBackward(diopiContextHandle_t ctx, diopiTensorHa optGradQ, optGradK, optGradV, + linearBiasSlopes, p_dropout, softmax_scale, is_causal, window_size_left, window_size_right, + /*deterministic=*/false, atGen, nullOpt); @@ -320,6 +365,7 @@ diopiError_t diopiFlashAttentionVarLen(diopiContextHandle_t ctx, diopiTensorHand // handle param[out] DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optOut, attention_out); + DIOPI_IMPL_BUILD_ATEN_OPTIONAL(linearBiasSlopes, alibi_slopes); // handle param[in] auto atGen = buildGeneratorForMha(ctx, gen, p_dropout); @@ -333,6 +379,7 @@ diopiError_t diopiFlashAttentionVarLen(diopiContextHandle_t ctx, diopiTensorHand auto headSize = atQ.sizes()[3]; DIOPI_CHECK(headSize % 8 == 0, "DIOPI torch impl now only support head sizes which are multiple of 8"); + c10::optional nullOpt; std::vector result = DIOPI_EXT_CALL_FLASH(mha_varlen_fwd, atQ, atK, @@ -340,6 +387,8 @@ diopiError_t diopiFlashAttentionVarLen(diopiContextHandle_t ctx, diopiTensorHand optOut, atCumSeqQ, atCumSeqKV, + nullOpt, + linearBiasSlopes, max_seqlen_q, max_seqlen_kv, p_dropout, @@ -370,6 +419,7 @@ diopiError_t diopiFlashAttentionVarLenBackward(diopiContextHandle_t ctx, diopiTe DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optGradQ, grad_q); DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optGradK, grad_k); DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optGradV, grad_v); + DIOPI_IMPL_BUILD_ATEN_OPTIONAL(linearBiasSlopes, alibi_slopes); // handle param[in] auto atGradOut = impl::aten::buildATen(grad_output); @@ -379,7 +429,6 @@ diopiError_t diopiFlashAttentionVarLenBackward(diopiContextHandle_t ctx, diopiTe auto atV = impl::aten::buildATen(v); auto atCumSeqQ = impl::aten::buildATen(cum_seq_q); auto atCumSeqKV = impl::aten::buildATen(cum_seq_kv); - DIOPI_CHECK(alibi_slopes == nullptr, "alibi_slopes is not yet supported in DIOPI torch impl"); auto atOut = impl::aten::buildATen(attention_out); auto atSoftmaxLse = impl::aten::buildATen(softmax_lse); @@ -396,6 +445,7 @@ diopiError_t diopiFlashAttentionVarLenBackward(diopiContextHandle_t ctx, diopiTe optGradV, atCumSeqQ, atCumSeqKV, + linearBiasSlopes, max_seqlen_q, max_seqlen_kv, p_dropout, @@ -404,6 +454,7 @@ diopiError_t diopiFlashAttentionVarLenBackward(diopiContextHandle_t ctx, diopiTe is_causal, window_size_left, window_size_right, + /*deterministic=*/false, atGen, nullOpt); diff --git a/impl/torch/functions/functions_ext/flash-attention/CMakeLists.txt b/impl/torch/functions/functions_ext/flash-attention/CMakeLists.txt index bbbfd50864..83108553f4 100644 --- a/impl/torch/functions/functions_ext/flash-attention/CMakeLists.txt +++ b/impl/torch/functions/functions_ext/flash-attention/CMakeLists.txt @@ -11,6 +11,19 @@ foreach(CC ${CURRENT_GPU_LIST}) endif() endforeach() +set(FLASH_ATTN_LIB_PATH_2_0 "/mnt/cache/share/platform/dep/DIOPI_pytorch/flash-attention2.4.3_pt2.0") +set(FLASH_ATTN_LIB_PATH_2_1 "/mnt/cache/share/platform/dep/DIOPI_pytorch/flash-attention2.4.3_pt2.1") + +if(${Torch_VERSION_MAJOR} EQUAL 2 AND ${Torch_VERSION_MINOR} EQUAL 0) + set(FLASH_ATTN_LIB_PATH "${FLASH_ATTN_LIB_PATH_2_0}") +elseif(${Torch_VERSION_MAJOR} EQUAL 2 AND ${Torch_VERSION_MINOR} EQUAL 1) + set(FLASH_ATTN_LIB_PATH "${FLASH_ATTN_LIB_PATH_2_1}") +else() + message(FATAL_ERROR "No valid torch version for setting FLASH_ATTN_LIB_PATH") +endif() + +message(STATUS "FLASH_ATTN_LIB_PATH: ${FLASH_ATTN_LIB_PATH}") + if(ENABLE_TORCH_EXT_FLASH_ATTN) # Note: it's really a bad idea to hardcode name and path here. find_library( @@ -20,7 +33,7 @@ if(ENABLE_TORCH_EXT_FLASH_ATTN) flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so # fallback name of library PATHS ENV FLASH_ATTN_LIB_DIR # alternative path to search - /mnt/cache/share/platform/dep/DIOPI_pytorch/flash-attention # fallback path + ${FLASH_ATTN_LIB_PATH} # fallback path ) endif() diff --git a/impl/torch/functions/functions_ext/flash-attention/include/flash_attn/flash_api.h b/impl/torch/functions/functions_ext/flash-attention/include/flash_attn/flash_api.h index cb5a4a3127..511b42bdd7 100644 --- a/impl/torch/functions/functions_ext/flash-attention/include/flash_attn/flash_api.h +++ b/impl/torch/functions/functions_ext/flash-attention/include/flash_attn/flash_api.h @@ -1,4 +1,4 @@ -// Declare functions implemented in flash-attention library (v2.3.x). +// Declare functions implemented in flash-attention library (v2.4.3). // // WARNING: The flash-attention library exports these functions in global namespace. Bad practice. Nothing we can do about it. @@ -43,18 +43,21 @@ std::vector __attribute__((weak)) mha_fwd(at::Tensor& q, const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size c10::optional& out_, // batch_size x seqlen_q x num_heads x head_size - const float p_dropout, const float softmax_scale, bool is_causal, const int window_size_left, + c10::optional& alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, const float softmax_scale, bool is_causal, int window_size_left, int window_size_right, const bool return_softmax, c10::optional gen_); std::vector __attribute__((weak)) -mha_varlen_fwd(const at::Tensor& q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i +mha_varlen_fwd(at::Tensor& q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const at::Tensor& k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor& v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i c10::optional& out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor& cu_seqlens_q, // b+1 const at::Tensor& cu_seqlens_k, // b+1 - const int max_seqlen_q, const int max_seqlen_k, const float p_dropout, const float softmax_scale, const bool zero_tensors, const bool is_causal, - const int window_size_left, int window_size_right, const bool return_softmax, c10::optional gen_); + c10::optional& seqused_k, // b. If given, only this many elements of each batch element's keys are used. + c10::optional& alibi_slopes_, // num_heads or b x num_heads + int max_seqlen_q, const int max_seqlen_k, const float p_dropout, const float softmax_scale, const bool zero_tensors, bool is_causal, + int window_size_left, int window_size_right, const bool return_softmax, c10::optional gen_); std::vector __attribute__((weak)) mha_bwd(const at::Tensor& dout, // batch_size x seqlen_q x num_heads, x head_size_og const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size @@ -65,9 +68,10 @@ std::vector __attribute__((weak)) mha_bwd(const at::Tensor& dout, c10::optional& dq_, // batch_size x seqlen_q x num_heads x head_size c10::optional& dk_, // batch_size x seqlen_k x num_heads_k x head_size c10::optional& dv_, // batch_size x seqlen_k x num_heads_k x head_size + c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads const float p_dropout, // probability to drop - const float softmax_scale, const bool is_causal, const int window_size_left, int window_size_right, - c10::optional gen_, c10::optional& rng_state); + const float softmax_scale, const bool is_causal, int window_size_left, int window_size_right, + const bool deterministic, c10::optional gen_, c10::optional& rng_state); std::vector __attribute__((weak)) mha_varlen_bwd(const at::Tensor& dout, // total_q x num_heads, x head_size @@ -81,11 +85,12 @@ mha_varlen_bwd(const at::Tensor& dout, // total_q x num_heads, x head_s c10::optional& dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor& cu_seqlens_q, // b+1 const at::Tensor& cu_seqlens_k, // b+1 + c10::optional &alibi_slopes_, // num_heads or b x num_heads const int max_seqlen_q, const int max_seqlen_k, // max sequence length to choose the kernel const float p_dropout, // probability to drop - const float softmax_scale, const bool zero_tensors, const bool is_causal, const int window_size_left, int window_size_right, - c10::optional gen_, c10::optional& rng_state); + const float softmax_scale, const bool zero_tensors, const bool is_causal, int window_size_left, int window_size_right, + const bool deterministic, c10::optional gen_, c10::optional& rng_state); std::vector __attribute__((weak)) mha_fwd_kvcache(at::Tensor& q, const at::Tensor& kcache, const at::Tensor& vcache, const at::Tensor& k, const at::Tensor& v, const at::Tensor& seqlens_k, From 55f15a534d294073d8c79b002556386814328dae Mon Sep 17 00:00:00 2001 From: ustclight-sls Date: Mon, 19 Aug 2024 16:21:51 +0800 Subject: [PATCH 2/2] upgrade flash_attn to 2.4.3 --- impl/torch/functions/functions_ext.cpp | 73 ++++++++++++++++--- .../flash-attention/CMakeLists.txt | 15 +++- .../include/flash_attn/flash_api.h | 23 +++--- 3 files changed, 90 insertions(+), 21 deletions(-) diff --git a/impl/torch/functions/functions_ext.cpp b/impl/torch/functions/functions_ext.cpp index bb246f1497..a102b3ad95 100644 --- a/impl/torch/functions/functions_ext.cpp +++ b/impl/torch/functions/functions_ext.cpp @@ -120,7 +120,9 @@ diopiError_t diopiMultiHeadAttention(diopiContextHandle_t ctx, diopiTensorHandle auto headSize = atQ.sizes()[3]; TORCH_CHECK(headSize % 8 == 0, "DIOPI now only support head sizes which are multiple of 8"); - std::vector result = DIOPI_EXT_CALL_FLASH(mha_fwd, atQ, atK, atV, optOut, dropout_p, scale, is_causal, -1, -1, return_debug_mask, atGen); + c10::optional nullOpt; + std::vector result = + DIOPI_EXT_CALL_FLASH(mha_fwd, atQ, atK, atV, optOut, nullOpt, dropout_p, scale, is_causal, -1, -1, return_debug_mask, atGen); // PERF: these copy can be eliminated by modifying the flash_attn api impl::aten::updateATen2Tensor(ctx, result[fa::mha_fwd_ret_idx::SOFTMAX_LSE], softmax_lse); @@ -147,10 +149,28 @@ diopiError_t diopiMultiHeadAttentionBackward(diopiContextHandle_t ctx, diopiCons DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optGradQ, grad_q); DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optGradK, grad_k); DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optGradV, grad_v); - c10::optional nullOpt; // Workaround: flash_attn uses non-const optional& as args (which is a really bad idea) + c10::optional nullStateOpt; // Workaround: flash_attn uses non-const optional& as args (which is a really bad idea) + c10::optional nullSlopesOpt; - std::vector result = DIOPI_EXT_CALL_FLASH( - mha_bwd, atGradOut, atQ, atK, atV, atOut, atLogsumexp, optGradQ, optGradK, optGradV, dropout_p, scale, is_causal, -1, -1, atGen, nullOpt); + std::vector result = DIOPI_EXT_CALL_FLASH(mha_bwd, + atGradOut, + atQ, + atK, + atV, + atOut, + atLogsumexp, + optGradQ, + optGradK, + optGradV, + nullSlopesOpt, + dropout_p, + scale, + is_causal, + -1, + -1, + false, + atGen, + nullStateOpt); return diopiSuccess; } @@ -172,8 +192,27 @@ diopiError_t diopiMultiHeadAttentionVarLen(diopiContextHandle_t ctx, diopiTensor auto headSize = atQ.sizes()[3]; TORCH_CHECK(headSize % 8 == 0, "DIOPI now only support head sizes which are multiple of 8"); - std::vector result = DIOPI_EXT_CALL_FLASH( - mha_varlen_fwd, atQ, atK, atV, optOut, atCumSeqQ, atCumSeqK, max_q, max_k, dropout_p, scale, false, is_causal, -1, -1, return_debug_mask, atGen); + c10::optional nullSeqOpt; + c10::optional nullSlopesOpt; + std::vector result = DIOPI_EXT_CALL_FLASH(mha_varlen_fwd, + atQ, + atK, + atV, + optOut, + atCumSeqQ, + atCumSeqK, + nullSeqOpt, + nullSlopesOpt, + max_q, + max_k, + dropout_p, + scale, + false, + is_causal, + -1, + -1, + return_debug_mask, + atGen); // PERF: these copy can be eliminated by modifying the flash_attn api impl::aten::updateATen2Tensor(ctx, result[fa::mha_fwd_ret_idx::SOFTMAX_LSE], softmax_lse); @@ -203,7 +242,8 @@ diopiError_t diopiMultiHeadAttentionVarLenBackward(diopiContextHandle_t ctx, dio DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optGradQ, grad_q); DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optGradK, grad_k); DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optGradV, grad_v); - c10::optional nullOpt; // Workaround: flash_attn uses non-const optional& as args (which is a really bad idea) + c10::optional nullStateOpt; // Workaround: flash_attn uses non-const optional& as args (which is a really bad idea) + c10::optional nullSlopesOpt; std::vector result = DIOPI_EXT_CALL_FLASH(mha_varlen_bwd, atGradOut, @@ -217,6 +257,7 @@ diopiError_t diopiMultiHeadAttentionVarLenBackward(diopiContextHandle_t ctx, dio optGradV, atCumSeqQ, atCumSeqK, + nullSlopesOpt, max_q, max_k, dropout_p, @@ -225,8 +266,9 @@ diopiError_t diopiMultiHeadAttentionVarLenBackward(diopiContextHandle_t ctx, dio is_causal, -1, -1, + false, atGen, - nullOpt); + nullStateOpt); return diopiSuccess; } @@ -244,7 +286,7 @@ diopiError_t diopiFlashAttention(diopiContextHandle_t ctx, diopiTensorHandle_t a auto atQ = impl::aten::buildATen(q); auto atK = impl::aten::buildATen(k); auto atV = impl::aten::buildATen(v); - DIOPI_CHECK(alibi_slopes == nullptr, "alibi_slopes is not yet supported in DIOPI torch impl"); + DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optLinearBiasSlopes, alibi_slopes); auto headSize = atQ.sizes()[3]; DIOPI_CHECK(headSize % 8 == 0, "DIOPI torch impl now only support head sizes which are multiple of 8"); @@ -254,6 +296,7 @@ diopiError_t diopiFlashAttention(diopiContextHandle_t ctx, diopiTensorHandle_t a atK, atV, optOut, + optLinearBiasSlopes, p_dropout, softmax_scale, is_causal, @@ -286,9 +329,9 @@ diopiError_t diopiFlashAttentionBackward(diopiContextHandle_t ctx, diopiTensorHa auto atQ = impl::aten::buildATen(q); auto atK = impl::aten::buildATen(k); auto atV = impl::aten::buildATen(v); - DIOPI_CHECK(alibi_slopes == nullptr, "alibi_slopes is not yet supported in DIOPI torch impl"); auto atOut = impl::aten::buildATen(attention_out); auto atSoftmaxLse = impl::aten::buildATen(softmax_lse); + DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optLinearBiasSlopes, alibi_slopes); c10::optional nullOpt; // Workaround: flash_attn uses non-const optional& as args (which is a really bad idea) std::vector result = DIOPI_EXT_CALL_FLASH(mha_bwd, @@ -301,11 +344,13 @@ diopiError_t diopiFlashAttentionBackward(diopiContextHandle_t ctx, diopiTensorHa optGradQ, optGradK, optGradV, + optLinearBiasSlopes, p_dropout, softmax_scale, is_causal, window_size_left, window_size_right, + /*deterministic=*/false, atGen, nullOpt); @@ -329,10 +374,12 @@ diopiError_t diopiFlashAttentionVarLen(diopiContextHandle_t ctx, diopiTensorHand auto atCumSeqQ = impl::aten::buildATen(cum_seq_q); auto atCumSeqKV = impl::aten::buildATen(cum_seq_kv); DIOPI_CHECK(alibi_slopes == nullptr, "alibi_slopes is not yet supported in DIOPI torch impl"); + DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optLinearBiasSlopes, alibi_slopes); auto headSize = atQ.sizes()[3]; DIOPI_CHECK(headSize % 8 == 0, "DIOPI torch impl now only support head sizes which are multiple of 8"); + c10::optional nullOpt; std::vector result = DIOPI_EXT_CALL_FLASH(mha_varlen_fwd, atQ, atK, @@ -340,6 +387,8 @@ diopiError_t diopiFlashAttentionVarLen(diopiContextHandle_t ctx, diopiTensorHand optOut, atCumSeqQ, atCumSeqKV, + nullOpt, + optLinearBiasSlopes, max_seqlen_q, max_seqlen_kv, p_dropout, @@ -379,9 +428,9 @@ diopiError_t diopiFlashAttentionVarLenBackward(diopiContextHandle_t ctx, diopiTe auto atV = impl::aten::buildATen(v); auto atCumSeqQ = impl::aten::buildATen(cum_seq_q); auto atCumSeqKV = impl::aten::buildATen(cum_seq_kv); - DIOPI_CHECK(alibi_slopes == nullptr, "alibi_slopes is not yet supported in DIOPI torch impl"); auto atOut = impl::aten::buildATen(attention_out); auto atSoftmaxLse = impl::aten::buildATen(softmax_lse); + DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optLinearBiasSlopes, alibi_slopes); c10::optional nullOpt; // Workaround: flash_attn uses non-const optional& as args (which is a really bad idea) std::vector result = DIOPI_EXT_CALL_FLASH(mha_varlen_bwd, @@ -396,6 +445,7 @@ diopiError_t diopiFlashAttentionVarLenBackward(diopiContextHandle_t ctx, diopiTe optGradV, atCumSeqQ, atCumSeqKV, + optLinearBiasSlopes, max_seqlen_q, max_seqlen_kv, p_dropout, @@ -404,6 +454,7 @@ diopiError_t diopiFlashAttentionVarLenBackward(diopiContextHandle_t ctx, diopiTe is_causal, window_size_left, window_size_right, + /*deterministic=*/false, atGen, nullOpt); diff --git a/impl/torch/functions/functions_ext/flash-attention/CMakeLists.txt b/impl/torch/functions/functions_ext/flash-attention/CMakeLists.txt index bbbfd50864..83108553f4 100644 --- a/impl/torch/functions/functions_ext/flash-attention/CMakeLists.txt +++ b/impl/torch/functions/functions_ext/flash-attention/CMakeLists.txt @@ -11,6 +11,19 @@ foreach(CC ${CURRENT_GPU_LIST}) endif() endforeach() +set(FLASH_ATTN_LIB_PATH_2_0 "/mnt/cache/share/platform/dep/DIOPI_pytorch/flash-attention2.4.3_pt2.0") +set(FLASH_ATTN_LIB_PATH_2_1 "/mnt/cache/share/platform/dep/DIOPI_pytorch/flash-attention2.4.3_pt2.1") + +if(${Torch_VERSION_MAJOR} EQUAL 2 AND ${Torch_VERSION_MINOR} EQUAL 0) + set(FLASH_ATTN_LIB_PATH "${FLASH_ATTN_LIB_PATH_2_0}") +elseif(${Torch_VERSION_MAJOR} EQUAL 2 AND ${Torch_VERSION_MINOR} EQUAL 1) + set(FLASH_ATTN_LIB_PATH "${FLASH_ATTN_LIB_PATH_2_1}") +else() + message(FATAL_ERROR "No valid torch version for setting FLASH_ATTN_LIB_PATH") +endif() + +message(STATUS "FLASH_ATTN_LIB_PATH: ${FLASH_ATTN_LIB_PATH}") + if(ENABLE_TORCH_EXT_FLASH_ATTN) # Note: it's really a bad idea to hardcode name and path here. find_library( @@ -20,7 +33,7 @@ if(ENABLE_TORCH_EXT_FLASH_ATTN) flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so # fallback name of library PATHS ENV FLASH_ATTN_LIB_DIR # alternative path to search - /mnt/cache/share/platform/dep/DIOPI_pytorch/flash-attention # fallback path + ${FLASH_ATTN_LIB_PATH} # fallback path ) endif() diff --git a/impl/torch/functions/functions_ext/flash-attention/include/flash_attn/flash_api.h b/impl/torch/functions/functions_ext/flash-attention/include/flash_attn/flash_api.h index cb5a4a3127..511b42bdd7 100644 --- a/impl/torch/functions/functions_ext/flash-attention/include/flash_attn/flash_api.h +++ b/impl/torch/functions/functions_ext/flash-attention/include/flash_attn/flash_api.h @@ -1,4 +1,4 @@ -// Declare functions implemented in flash-attention library (v2.3.x). +// Declare functions implemented in flash-attention library (v2.4.3). // // WARNING: The flash-attention library exports these functions in global namespace. Bad practice. Nothing we can do about it. @@ -43,18 +43,21 @@ std::vector __attribute__((weak)) mha_fwd(at::Tensor& q, const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size c10::optional& out_, // batch_size x seqlen_q x num_heads x head_size - const float p_dropout, const float softmax_scale, bool is_causal, const int window_size_left, + c10::optional& alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, const float softmax_scale, bool is_causal, int window_size_left, int window_size_right, const bool return_softmax, c10::optional gen_); std::vector __attribute__((weak)) -mha_varlen_fwd(const at::Tensor& q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i +mha_varlen_fwd(at::Tensor& q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const at::Tensor& k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor& v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i c10::optional& out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor& cu_seqlens_q, // b+1 const at::Tensor& cu_seqlens_k, // b+1 - const int max_seqlen_q, const int max_seqlen_k, const float p_dropout, const float softmax_scale, const bool zero_tensors, const bool is_causal, - const int window_size_left, int window_size_right, const bool return_softmax, c10::optional gen_); + c10::optional& seqused_k, // b. If given, only this many elements of each batch element's keys are used. + c10::optional& alibi_slopes_, // num_heads or b x num_heads + int max_seqlen_q, const int max_seqlen_k, const float p_dropout, const float softmax_scale, const bool zero_tensors, bool is_causal, + int window_size_left, int window_size_right, const bool return_softmax, c10::optional gen_); std::vector __attribute__((weak)) mha_bwd(const at::Tensor& dout, // batch_size x seqlen_q x num_heads, x head_size_og const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size @@ -65,9 +68,10 @@ std::vector __attribute__((weak)) mha_bwd(const at::Tensor& dout, c10::optional& dq_, // batch_size x seqlen_q x num_heads x head_size c10::optional& dk_, // batch_size x seqlen_k x num_heads_k x head_size c10::optional& dv_, // batch_size x seqlen_k x num_heads_k x head_size + c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads const float p_dropout, // probability to drop - const float softmax_scale, const bool is_causal, const int window_size_left, int window_size_right, - c10::optional gen_, c10::optional& rng_state); + const float softmax_scale, const bool is_causal, int window_size_left, int window_size_right, + const bool deterministic, c10::optional gen_, c10::optional& rng_state); std::vector __attribute__((weak)) mha_varlen_bwd(const at::Tensor& dout, // total_q x num_heads, x head_size @@ -81,11 +85,12 @@ mha_varlen_bwd(const at::Tensor& dout, // total_q x num_heads, x head_s c10::optional& dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor& cu_seqlens_q, // b+1 const at::Tensor& cu_seqlens_k, // b+1 + c10::optional &alibi_slopes_, // num_heads or b x num_heads const int max_seqlen_q, const int max_seqlen_k, // max sequence length to choose the kernel const float p_dropout, // probability to drop - const float softmax_scale, const bool zero_tensors, const bool is_causal, const int window_size_left, int window_size_right, - c10::optional gen_, c10::optional& rng_state); + const float softmax_scale, const bool zero_tensors, const bool is_causal, int window_size_left, int window_size_right, + const bool deterministic, c10::optional gen_, c10::optional& rng_state); std::vector __attribute__((weak)) mha_fwd_kvcache(at::Tensor& q, const at::Tensor& kcache, const at::Tensor& vcache, const at::Tensor& k, const at::Tensor& v, const at::Tensor& seqlens_k,