diff --git a/impl/torch/functions/functions_ext.cpp b/impl/torch/functions/functions_ext.cpp index bb246f149..3a4736fe9 100644 --- a/impl/torch/functions/functions_ext.cpp +++ b/impl/torch/functions/functions_ext.cpp @@ -120,7 +120,9 @@ diopiError_t diopiMultiHeadAttention(diopiContextHandle_t ctx, diopiTensorHandle auto headSize = atQ.sizes()[3]; TORCH_CHECK(headSize % 8 == 0, "DIOPI now only support head sizes which are multiple of 8"); - std::vector result = DIOPI_EXT_CALL_FLASH(mha_fwd, atQ, atK, atV, optOut, dropout_p, scale, is_causal, -1, -1, return_debug_mask, atGen); + c10::optional nullOpt; + std::vector result = + DIOPI_EXT_CALL_FLASH(mha_fwd, atQ, atK, atV, optOut, nullOpt, dropout_p, scale, is_causal, -1, -1, return_debug_mask, atGen); // PERF: these copy can be eliminated by modifying the flash_attn api impl::aten::updateATen2Tensor(ctx, result[fa::mha_fwd_ret_idx::SOFTMAX_LSE], softmax_lse); @@ -147,10 +149,28 @@ diopiError_t diopiMultiHeadAttentionBackward(diopiContextHandle_t ctx, diopiCons DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optGradQ, grad_q); DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optGradK, grad_k); DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optGradV, grad_v); - c10::optional nullOpt; // Workaround: flash_attn uses non-const optional& as args (which is a really bad idea) + c10::optional nullStateOpt; // Workaround: flash_attn uses non-const optional& as args (which is a really bad idea) + c10::optional nullSlopesOpt; - std::vector result = DIOPI_EXT_CALL_FLASH( - mha_bwd, atGradOut, atQ, atK, atV, atOut, atLogsumexp, optGradQ, optGradK, optGradV, dropout_p, scale, is_causal, -1, -1, atGen, nullOpt); + std::vector result = DIOPI_EXT_CALL_FLASH(mha_bwd, + atGradOut, + atQ, + atK, + atV, + atOut, + atLogsumexp, + optGradQ, + optGradK, + optGradV, + nullSlopesOpt, + dropout_p, + scale, + is_causal, + -1, + -1, + false, + atGen, + nullStateOpt); return diopiSuccess; } @@ -172,8 +192,27 @@ diopiError_t diopiMultiHeadAttentionVarLen(diopiContextHandle_t ctx, diopiTensor auto headSize = atQ.sizes()[3]; TORCH_CHECK(headSize % 8 == 0, "DIOPI now only support head sizes which are multiple of 8"); - std::vector result = DIOPI_EXT_CALL_FLASH( - mha_varlen_fwd, atQ, atK, atV, optOut, atCumSeqQ, atCumSeqK, max_q, max_k, dropout_p, scale, false, is_causal, -1, -1, return_debug_mask, atGen); + c10::optional nullSeqOpt; + c10::optional nullSlopesOpt; + std::vector result = DIOPI_EXT_CALL_FLASH(mha_varlen_fwd, + atQ, + atK, + atV, + optOut, + atCumSeqQ, + atCumSeqK, + nullSeqOpt, + nullSlopesOpt, + max_q, + max_k, + dropout_p, + scale, + false, + is_causal, + -1, + -1, + return_debug_mask, + atGen); // PERF: these copy can be eliminated by modifying the flash_attn api impl::aten::updateATen2Tensor(ctx, result[fa::mha_fwd_ret_idx::SOFTMAX_LSE], softmax_lse); @@ -203,7 +242,8 @@ diopiError_t diopiMultiHeadAttentionVarLenBackward(diopiContextHandle_t ctx, dio DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optGradQ, grad_q); DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optGradK, grad_k); DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optGradV, grad_v); - c10::optional nullOpt; // Workaround: flash_attn uses non-const optional& as args (which is a really bad idea) + c10::optional nullStateOpt; // Workaround: flash_attn uses non-const optional& as args (which is a really bad idea) + c10::optional nullSlopesOpt; std::vector result = DIOPI_EXT_CALL_FLASH(mha_varlen_bwd, atGradOut, @@ -217,6 +257,7 @@ diopiError_t diopiMultiHeadAttentionVarLenBackward(diopiContextHandle_t ctx, dio optGradV, atCumSeqQ, atCumSeqK, + nullSlopesOpt, max_q, max_k, dropout_p, @@ -225,8 +266,9 @@ diopiError_t diopiMultiHeadAttentionVarLenBackward(diopiContextHandle_t ctx, dio is_causal, -1, -1, + false, atGen, - nullOpt); + nullStateOpt); return diopiSuccess; } @@ -244,7 +286,7 @@ diopiError_t diopiFlashAttention(diopiContextHandle_t ctx, diopiTensorHandle_t a auto atQ = impl::aten::buildATen(q); auto atK = impl::aten::buildATen(k); auto atV = impl::aten::buildATen(v); - DIOPI_CHECK(alibi_slopes == nullptr, "alibi_slopes is not yet supported in DIOPI torch impl"); + DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optLinearBiasSlopes, alibi_slopes); auto headSize = atQ.sizes()[3]; DIOPI_CHECK(headSize % 8 == 0, "DIOPI torch impl now only support head sizes which are multiple of 8"); @@ -254,6 +296,7 @@ diopiError_t diopiFlashAttention(diopiContextHandle_t ctx, diopiTensorHandle_t a atK, atV, optOut, + optLinearBiasSlopes, p_dropout, softmax_scale, is_causal, @@ -286,9 +329,9 @@ diopiError_t diopiFlashAttentionBackward(diopiContextHandle_t ctx, diopiTensorHa auto atQ = impl::aten::buildATen(q); auto atK = impl::aten::buildATen(k); auto atV = impl::aten::buildATen(v); - DIOPI_CHECK(alibi_slopes == nullptr, "alibi_slopes is not yet supported in DIOPI torch impl"); auto atOut = impl::aten::buildATen(attention_out); auto atSoftmaxLse = impl::aten::buildATen(softmax_lse); + DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optLinearBiasSlopes, alibi_slopes); c10::optional nullOpt; // Workaround: flash_attn uses non-const optional& as args (which is a really bad idea) std::vector result = DIOPI_EXT_CALL_FLASH(mha_bwd, @@ -301,11 +344,13 @@ diopiError_t diopiFlashAttentionBackward(diopiContextHandle_t ctx, diopiTensorHa optGradQ, optGradK, optGradV, + optLinearBiasSlopes, p_dropout, softmax_scale, is_causal, window_size_left, window_size_right, + /*deterministic=*/false, atGen, nullOpt); @@ -328,11 +373,12 @@ diopiError_t diopiFlashAttentionVarLen(diopiContextHandle_t ctx, diopiTensorHand auto atV = impl::aten::buildATen(v); auto atCumSeqQ = impl::aten::buildATen(cum_seq_q); auto atCumSeqKV = impl::aten::buildATen(cum_seq_kv); - DIOPI_CHECK(alibi_slopes == nullptr, "alibi_slopes is not yet supported in DIOPI torch impl"); + DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optLinearBiasSlopes, alibi_slopes); auto headSize = atQ.sizes()[3]; DIOPI_CHECK(headSize % 8 == 0, "DIOPI torch impl now only support head sizes which are multiple of 8"); + c10::optional nullOpt; std::vector result = DIOPI_EXT_CALL_FLASH(mha_varlen_fwd, atQ, atK, @@ -340,6 +386,8 @@ diopiError_t diopiFlashAttentionVarLen(diopiContextHandle_t ctx, diopiTensorHand optOut, atCumSeqQ, atCumSeqKV, + nullOpt, + optLinearBiasSlopes, max_seqlen_q, max_seqlen_kv, p_dropout, @@ -379,9 +427,9 @@ diopiError_t diopiFlashAttentionVarLenBackward(diopiContextHandle_t ctx, diopiTe auto atV = impl::aten::buildATen(v); auto atCumSeqQ = impl::aten::buildATen(cum_seq_q); auto atCumSeqKV = impl::aten::buildATen(cum_seq_kv); - DIOPI_CHECK(alibi_slopes == nullptr, "alibi_slopes is not yet supported in DIOPI torch impl"); auto atOut = impl::aten::buildATen(attention_out); auto atSoftmaxLse = impl::aten::buildATen(softmax_lse); + DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optLinearBiasSlopes, alibi_slopes); c10::optional nullOpt; // Workaround: flash_attn uses non-const optional& as args (which is a really bad idea) std::vector result = DIOPI_EXT_CALL_FLASH(mha_varlen_bwd, @@ -396,6 +444,7 @@ diopiError_t diopiFlashAttentionVarLenBackward(diopiContextHandle_t ctx, diopiTe optGradV, atCumSeqQ, atCumSeqKV, + optLinearBiasSlopes, max_seqlen_q, max_seqlen_kv, p_dropout, @@ -404,6 +453,7 @@ diopiError_t diopiFlashAttentionVarLenBackward(diopiContextHandle_t ctx, diopiTe is_causal, window_size_left, window_size_right, + /*deterministic=*/false, atGen, nullOpt); diff --git a/impl/torch/functions/functions_ext/flash-attention/CMakeLists.txt b/impl/torch/functions/functions_ext/flash-attention/CMakeLists.txt index bbbfd5086..83108553f 100644 --- a/impl/torch/functions/functions_ext/flash-attention/CMakeLists.txt +++ b/impl/torch/functions/functions_ext/flash-attention/CMakeLists.txt @@ -11,6 +11,19 @@ foreach(CC ${CURRENT_GPU_LIST}) endif() endforeach() +set(FLASH_ATTN_LIB_PATH_2_0 "/mnt/cache/share/platform/dep/DIOPI_pytorch/flash-attention2.4.3_pt2.0") +set(FLASH_ATTN_LIB_PATH_2_1 "/mnt/cache/share/platform/dep/DIOPI_pytorch/flash-attention2.4.3_pt2.1") + +if(${Torch_VERSION_MAJOR} EQUAL 2 AND ${Torch_VERSION_MINOR} EQUAL 0) + set(FLASH_ATTN_LIB_PATH "${FLASH_ATTN_LIB_PATH_2_0}") +elseif(${Torch_VERSION_MAJOR} EQUAL 2 AND ${Torch_VERSION_MINOR} EQUAL 1) + set(FLASH_ATTN_LIB_PATH "${FLASH_ATTN_LIB_PATH_2_1}") +else() + message(FATAL_ERROR "No valid torch version for setting FLASH_ATTN_LIB_PATH") +endif() + +message(STATUS "FLASH_ATTN_LIB_PATH: ${FLASH_ATTN_LIB_PATH}") + if(ENABLE_TORCH_EXT_FLASH_ATTN) # Note: it's really a bad idea to hardcode name and path here. find_library( @@ -20,7 +33,7 @@ if(ENABLE_TORCH_EXT_FLASH_ATTN) flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so # fallback name of library PATHS ENV FLASH_ATTN_LIB_DIR # alternative path to search - /mnt/cache/share/platform/dep/DIOPI_pytorch/flash-attention # fallback path + ${FLASH_ATTN_LIB_PATH} # fallback path ) endif() diff --git a/impl/torch/functions/functions_ext/flash-attention/include/flash_attn/flash_api.h b/impl/torch/functions/functions_ext/flash-attention/include/flash_attn/flash_api.h index cb5a4a312..83d318601 100644 --- a/impl/torch/functions/functions_ext/flash-attention/include/flash_attn/flash_api.h +++ b/impl/torch/functions/functions_ext/flash-attention/include/flash_attn/flash_api.h @@ -1,4 +1,4 @@ -// Declare functions implemented in flash-attention library (v2.3.x). +// Declare functions implemented in flash-attention library (v2.4.3). // // WARNING: The flash-attention library exports these functions in global namespace. Bad practice. Nothing we can do about it. @@ -39,52 +39,57 @@ enum { } // namespace mha_fwd_ret_idx } // namespace impl::cuda::fa -std::vector __attribute__((weak)) mha_fwd(at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size - const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size - const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size - c10::optional& out_, // batch_size x seqlen_q x num_heads x head_size - const float p_dropout, const float softmax_scale, bool is_causal, const int window_size_left, +std::vector __attribute__((weak)) mha_fwd(at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size + c10::optional& out_, // batch_size x seqlen_q x num_heads x head_size + c10::optional& alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, const float softmax_scale, bool is_causal, int window_size_left, int window_size_right, const bool return_softmax, c10::optional gen_); std::vector __attribute__((weak)) -mha_varlen_fwd(const at::Tensor& q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - const at::Tensor& k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor& v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - c10::optional& out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor& cu_seqlens_q, // b+1 - const at::Tensor& cu_seqlens_k, // b+1 - const int max_seqlen_q, const int max_seqlen_k, const float p_dropout, const float softmax_scale, const bool zero_tensors, const bool is_causal, - const int window_size_left, int window_size_right, const bool return_softmax, c10::optional gen_); +mha_varlen_fwd(at::Tensor& q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor& k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + c10::optional& out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& cu_seqlens_q, // b+1 + const at::Tensor& cu_seqlens_k, // b+1 + c10::optional& seqused_k, // b. If given, only this many elements of each batch element's keys are used. + c10::optional& alibi_slopes_, // num_heads or b x num_heads + int max_seqlen_q, const int max_seqlen_k, const float p_dropout, const float softmax_scale, const bool zero_tensors, bool is_causal, + int window_size_left, int window_size_right, const bool return_softmax, c10::optional gen_); -std::vector __attribute__((weak)) mha_bwd(const at::Tensor& dout, // batch_size x seqlen_q x num_heads, x head_size_og - const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size - const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size - const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size - const at::Tensor& out, // batch_size x seqlen_q x num_heads x head_size - const at::Tensor& softmax_lse, // b x h x seqlen_q - c10::optional& dq_, // batch_size x seqlen_q x num_heads x head_size - c10::optional& dk_, // batch_size x seqlen_k x num_heads_k x head_size - c10::optional& dv_, // batch_size x seqlen_k x num_heads_k x head_size - const float p_dropout, // probability to drop - const float softmax_scale, const bool is_causal, const int window_size_left, int window_size_right, - c10::optional gen_, c10::optional& rng_state); +std::vector __attribute__((weak)) mha_bwd(const at::Tensor& dout, // batch_size x seqlen_q x num_heads, x head_size_og + const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& out, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& softmax_lse, // b x h x seqlen_q + c10::optional& dq_, // batch_size x seqlen_q x num_heads x head_size + c10::optional& dk_, // batch_size x seqlen_k x num_heads_k x head_size + c10::optional& dv_, // batch_size x seqlen_k x num_heads_k x head_size + c10::optional& alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, // probability to drop + const float softmax_scale, const bool is_causal, int window_size_left, int window_size_right, + const bool deterministic, c10::optional gen_, c10::optional& rng_state); std::vector __attribute__((weak)) -mha_varlen_bwd(const at::Tensor& dout, // total_q x num_heads, x head_size - const at::Tensor& q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - const at::Tensor& k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor& v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor& out, // total_q x num_heads x head_size - const at::Tensor& softmax_lse, // b x h x s softmax logsumexp - c10::optional& dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - c10::optional& dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - c10::optional& dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor& cu_seqlens_q, // b+1 - const at::Tensor& cu_seqlens_k, // b+1 +mha_varlen_bwd(const at::Tensor& dout, // total_q x num_heads, x head_size + const at::Tensor& q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor& k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& out, // total_q x num_heads x head_size + const at::Tensor& softmax_lse, // b x h x s softmax logsumexp + c10::optional& dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + c10::optional& dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + c10::optional& dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& cu_seqlens_q, // b+1 + const at::Tensor& cu_seqlens_k, // b+1 + c10::optional& alibi_slopes_, // num_heads or b x num_heads const int max_seqlen_q, const int max_seqlen_k, // max sequence length to choose the kernel const float p_dropout, // probability to drop - const float softmax_scale, const bool zero_tensors, const bool is_causal, const int window_size_left, int window_size_right, + const float softmax_scale, const bool zero_tensors, const bool is_causal, int window_size_left, int window_size_right, const bool deterministic, c10::optional gen_, c10::optional& rng_state); std::vector __attribute__((weak))