Skip to content

Commit

Permalink
upgrade flash_attn to 2.4.3
Browse files Browse the repository at this point in the history
  • Loading branch information
ustclight-sls committed Aug 19, 2024
1 parent 65930a5 commit 9ff00e4
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 20 deletions.
36 changes: 26 additions & 10 deletions impl/torch/functions/functions_ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ diopiError_t diopiMultiHeadAttention(diopiContextHandle_t ctx, diopiTensorHandle
auto headSize = atQ.sizes()[3];
TORCH_CHECK(headSize % 8 == 0, "DIOPI now only support head sizes which are multiple of 8");

std::vector<at::Tensor> result = DIOPI_EXT_CALL_FLASH(mha_fwd, atQ, atK, atV, optOut, dropout_p, scale, is_causal, -1, -1, return_debug_mask, atGen);
c10::optional<at::Tensor> nullOpt;
std::vector<at::Tensor> result = DIOPI_EXT_CALL_FLASH(mha_fwd, atQ, atK, atV, optOut, nullOpt, dropout_p, scale, is_causal, -1, -1, return_debug_mask, atGen);

// PERF: these copy can be eliminated by modifying the flash_attn api
impl::aten::updateATen2Tensor(ctx, result[fa::mha_fwd_ret_idx::SOFTMAX_LSE], softmax_lse);
Expand All @@ -147,10 +148,11 @@ diopiError_t diopiMultiHeadAttentionBackward(diopiContextHandle_t ctx, diopiCons
DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optGradQ, grad_q);
DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optGradK, grad_k);
DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optGradV, grad_v);
c10::optional<at::Tensor> nullOpt; // Workaround: flash_attn uses non-const optional& as args (which is a really bad idea)
c10::optional<at::Tensor> nullStateOpt; // Workaround: flash_attn uses non-const optional& as args (which is a really bad idea)
c10::optional<at::Tensor> nullSlopesOpt;

std::vector<at::Tensor> result = DIOPI_EXT_CALL_FLASH(
mha_bwd, atGradOut, atQ, atK, atV, atOut, atLogsumexp, optGradQ, optGradK, optGradV, dropout_p, scale, is_causal, -1, -1, atGen, nullOpt);
mha_bwd, atGradOut, atQ, atK, atV, atOut, atLogsumexp, optGradQ, optGradK, optGradV, nullSlopesOpt, dropout_p, scale, is_causal, -1, -1, false, atGen, nullStateOpt);

return diopiSuccess;
}
Expand All @@ -172,8 +174,10 @@ diopiError_t diopiMultiHeadAttentionVarLen(diopiContextHandle_t ctx, diopiTensor
auto headSize = atQ.sizes()[3];
TORCH_CHECK(headSize % 8 == 0, "DIOPI now only support head sizes which are multiple of 8");

c10::optional<at::Tensor> nullSeqOpt;
c10::optional<at::Tensor> nullSlopesOpt;
std::vector<at::Tensor> result = DIOPI_EXT_CALL_FLASH(
mha_varlen_fwd, atQ, atK, atV, optOut, atCumSeqQ, atCumSeqK, max_q, max_k, dropout_p, scale, false, is_causal, -1, -1, return_debug_mask, atGen);
mha_varlen_fwd, atQ, atK, atV, optOut, atCumSeqQ, atCumSeqK, nullSeqOpt, nullSlopesOpt, max_q, max_k, dropout_p, scale, false, is_causal, -1, -1, return_debug_mask, atGen);

// PERF: these copy can be eliminated by modifying the flash_attn api
impl::aten::updateATen2Tensor(ctx, result[fa::mha_fwd_ret_idx::SOFTMAX_LSE], softmax_lse);
Expand Down Expand Up @@ -203,7 +207,8 @@ diopiError_t diopiMultiHeadAttentionVarLenBackward(diopiContextHandle_t ctx, dio
DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optGradQ, grad_q);
DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optGradK, grad_k);
DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optGradV, grad_v);
c10::optional<at::Tensor> nullOpt; // Workaround: flash_attn uses non-const optional& as args (which is a really bad idea)
c10::optional<at::Tensor> nullStateOpt; // Workaround: flash_attn uses non-const optional& as args (which is a really bad idea)
c10::optional<at::Tensor> nullSlopesOpt;

std::vector<at::Tensor> result = DIOPI_EXT_CALL_FLASH(mha_varlen_bwd,
atGradOut,
Expand All @@ -217,6 +222,7 @@ diopiError_t diopiMultiHeadAttentionVarLenBackward(diopiContextHandle_t ctx, dio
optGradV,
atCumSeqQ,
atCumSeqK,
nullSlopesOpt,
max_q,
max_k,
dropout_p,
Expand All @@ -225,8 +231,9 @@ diopiError_t diopiMultiHeadAttentionVarLenBackward(diopiContextHandle_t ctx, dio
is_causal,
-1,
-1,
false,
atGen,
nullOpt);
nullStateOpt);

return diopiSuccess;
}
Expand All @@ -236,15 +243,15 @@ diopiError_t diopiFlashAttention(diopiContextHandle_t ctx, diopiTensorHandle_t a
float p_dropout, float softmax_scale, bool is_causal, int32_t window_size_left, int32_t window_size_right) {
impl::aten::setCurStream(ctx);

// handle param[out]
// handle param[out]mha_bwd
DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optOut, attention_out);
DIOPI_IMPL_BUILD_ATEN_OPTIONAL(linearBiasSlopes, alibi_slopes);

// handle param[in]
auto atGen = buildGeneratorForMha(ctx, gen, p_dropout);
auto atQ = impl::aten::buildATen(q);
auto atK = impl::aten::buildATen(k);
auto atV = impl::aten::buildATen(v);
DIOPI_CHECK(alibi_slopes == nullptr, "alibi_slopes is not yet supported in DIOPI torch impl");

auto headSize = atQ.sizes()[3];
DIOPI_CHECK(headSize % 8 == 0, "DIOPI torch impl now only support head sizes which are multiple of 8");
Expand All @@ -254,6 +261,7 @@ diopiError_t diopiFlashAttention(diopiContextHandle_t ctx, diopiTensorHandle_t a
atK,
atV,
optOut,
linearBiasSlopes,
p_dropout,
softmax_scale,
is_causal,
Expand All @@ -279,14 +287,14 @@ diopiError_t diopiFlashAttentionBackward(diopiContextHandle_t ctx, diopiTensorHa
DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optGradQ, grad_q);
DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optGradK, grad_k);
DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optGradV, grad_v);
DIOPI_IMPL_BUILD_ATEN_OPTIONAL(linearBiasSlopes, alibi_slopes);

// handle param[in]
auto atGradOut = impl::aten::buildATen(grad_output);
auto atGen = buildGeneratorForMha(ctx, gen, p_dropout);
auto atQ = impl::aten::buildATen(q);
auto atK = impl::aten::buildATen(k);
auto atV = impl::aten::buildATen(v);
DIOPI_CHECK(alibi_slopes == nullptr, "alibi_slopes is not yet supported in DIOPI torch impl");
auto atOut = impl::aten::buildATen(attention_out);
auto atSoftmaxLse = impl::aten::buildATen(softmax_lse);

Expand All @@ -301,11 +309,13 @@ diopiError_t diopiFlashAttentionBackward(diopiContextHandle_t ctx, diopiTensorHa
optGradQ,
optGradK,
optGradV,
linearBiasSlopes,
p_dropout,
softmax_scale,
is_causal,
window_size_left,
window_size_right,
/*deterministic=*/false,
atGen,
nullOpt);

Expand All @@ -320,6 +330,7 @@ diopiError_t diopiFlashAttentionVarLen(diopiContextHandle_t ctx, diopiTensorHand

// handle param[out]
DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optOut, attention_out);
DIOPI_IMPL_BUILD_ATEN_OPTIONAL(linearBiasSlopes, alibi_slopes);

// handle param[in]
auto atGen = buildGeneratorForMha(ctx, gen, p_dropout);
Expand All @@ -333,13 +344,16 @@ diopiError_t diopiFlashAttentionVarLen(diopiContextHandle_t ctx, diopiTensorHand
auto headSize = atQ.sizes()[3];
DIOPI_CHECK(headSize % 8 == 0, "DIOPI torch impl now only support head sizes which are multiple of 8");

c10::optional<at::Tensor> nullOpt;
std::vector<at::Tensor> result = DIOPI_EXT_CALL_FLASH(mha_varlen_fwd,
atQ,
atK,
atV,
optOut,
atCumSeqQ,
atCumSeqKV,
nullOpt,
linearBiasSlopes,
max_seqlen_q,
max_seqlen_kv,
p_dropout,
Expand Down Expand Up @@ -370,6 +384,7 @@ diopiError_t diopiFlashAttentionVarLenBackward(diopiContextHandle_t ctx, diopiTe
DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optGradQ, grad_q);
DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optGradK, grad_k);
DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optGradV, grad_v);
DIOPI_IMPL_BUILD_ATEN_OPTIONAL(linearBiasSlopes, alibi_slopes);

// handle param[in]
auto atGradOut = impl::aten::buildATen(grad_output);
Expand All @@ -379,7 +394,6 @@ diopiError_t diopiFlashAttentionVarLenBackward(diopiContextHandle_t ctx, diopiTe
auto atV = impl::aten::buildATen(v);
auto atCumSeqQ = impl::aten::buildATen(cum_seq_q);
auto atCumSeqKV = impl::aten::buildATen(cum_seq_kv);
DIOPI_CHECK(alibi_slopes == nullptr, "alibi_slopes is not yet supported in DIOPI torch impl");
auto atOut = impl::aten::buildATen(attention_out);
auto atSoftmaxLse = impl::aten::buildATen(softmax_lse);

Expand All @@ -396,6 +410,7 @@ diopiError_t diopiFlashAttentionVarLenBackward(diopiContextHandle_t ctx, diopiTe
optGradV,
atCumSeqQ,
atCumSeqKV,
linearBiasSlopes,
max_seqlen_q,
max_seqlen_kv,
p_dropout,
Expand All @@ -404,6 +419,7 @@ diopiError_t diopiFlashAttentionVarLenBackward(diopiContextHandle_t ctx, diopiTe
is_causal,
window_size_left,
window_size_right,
/*deterministic=*/false,
atGen,
nullOpt);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,19 @@ foreach(CC ${CURRENT_GPU_LIST})
endif()
endforeach()

set(FLASH_ATTN_LIB_PATH_2_0 "/mnt/cache/share/platform/dep/DIOPI_pytorch/flash-attention2.4.3_pt2.0")
set(FLASH_ATTN_LIB_PATH_2_1 "/mnt/cache/share/platform/dep/DIOPI_pytorch/flash-attention2.4.3_pt2.1")

if(${Torch_VERSION_MAJOR} EQUAL 2 AND ${Torch_VERSION_MINOR} EQUAL 0)
set(FLASH_ATTN_LIB_PATH "${FLASH_ATTN_LIB_PATH_2_0}")
elseif(${Torch_VERSION_MAJOR} EQUAL 2 AND ${Torch_VERSION_MINOR} EQUAL 1)
set(FLASH_ATTN_LIB_PATH "${FLASH_ATTN_LIB_PATH_2_1}")
else()
message(FATAL_ERROR "No valid torch version for setting FLASH_ATTN_LIB_PATH")
endif()

message(STATUS "FLASH_ATTN_LIB_PATH: ${FLASH_ATTN_LIB_PATH}")

if(ENABLE_TORCH_EXT_FLASH_ATTN)
# Note: it's really a bad idea to hardcode name and path here.
find_library(
Expand All @@ -20,7 +33,7 @@ if(ENABLE_TORCH_EXT_FLASH_ATTN)
flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so # fallback name of library
PATHS
ENV FLASH_ATTN_LIB_DIR # alternative path to search
/mnt/cache/share/platform/dep/DIOPI_pytorch/flash-attention # fallback path
${FLASH_ATTN_LIB_PATH} # fallback path
)
endif()

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Declare functions implemented in flash-attention library (v2.3.x).
// Declare functions implemented in flash-attention library (v2.4.3).

Check notice on line 1 in impl/torch/functions/functions_ext/flash-attention/include/flash_attn/flash_api.h

View workflow job for this annotation

GitHub Actions / cpp-linter

Run clang-format on impl/torch/functions/functions_ext/flash-attention/include/flash_attn/flash_api.h

File impl/torch/functions/functions_ext/flash-attention/include/flash_attn/flash_api.h does not conform to Custom style guidelines. (lines 42, 43, 44, 45, 46, 51, 52, 53, 54, 55, 56, 57, 58, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 92, 93)
//
// WARNING: The flash-attention library exports these functions in global namespace. Bad practice. Nothing we can do about it.

Expand Down Expand Up @@ -43,18 +43,21 @@ std::vector<at::Tensor> __attribute__((weak)) mha_fwd(at::Tensor& q,
const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor>& out_, // batch_size x seqlen_q x num_heads x head_size
const float p_dropout, const float softmax_scale, bool is_causal, const int window_size_left,
c10::optional<at::Tensor>& alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout, const float softmax_scale, bool is_causal, int window_size_left,
int window_size_right, const bool return_softmax, c10::optional<at::Generator> gen_);

std::vector<at::Tensor> __attribute__((weak))
mha_varlen_fwd(const at::Tensor& q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
mha_varlen_fwd(at::Tensor& q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor& k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor& v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
c10::optional<at::Tensor>& out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor& cu_seqlens_q, // b+1
const at::Tensor& cu_seqlens_k, // b+1
const int max_seqlen_q, const int max_seqlen_k, const float p_dropout, const float softmax_scale, const bool zero_tensors, const bool is_causal,
const int window_size_left, int window_size_right, const bool return_softmax, c10::optional<at::Generator> gen_);
c10::optional<at::Tensor>& seqused_k, // b. If given, only this many elements of each batch element's keys are used.
c10::optional<at::Tensor>& alibi_slopes_, // num_heads or b x num_heads
int max_seqlen_q, const int max_seqlen_k, const float p_dropout, const float softmax_scale, const bool zero_tensors, bool is_causal,
int window_size_left, int window_size_right, const bool return_softmax, c10::optional<at::Generator> gen_);

std::vector<at::Tensor> __attribute__((weak)) mha_bwd(const at::Tensor& dout, // batch_size x seqlen_q x num_heads, x head_size_og
const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size
Expand All @@ -65,9 +68,10 @@ std::vector<at::Tensor> __attribute__((weak)) mha_bwd(const at::Tensor& dout,
c10::optional<at::Tensor>& dq_, // batch_size x seqlen_q x num_heads x head_size
c10::optional<at::Tensor>& dk_, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor>& dv_, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout, // probability to drop
const float softmax_scale, const bool is_causal, const int window_size_left, int window_size_right,
c10::optional<at::Generator> gen_, c10::optional<at::Tensor>& rng_state);
const float softmax_scale, const bool is_causal, int window_size_left, int window_size_right,
const bool deterministic, c10::optional<at::Generator> gen_, c10::optional<at::Tensor>& rng_state);

std::vector<at::Tensor> __attribute__((weak))
mha_varlen_bwd(const at::Tensor& dout, // total_q x num_heads, x head_size
Expand All @@ -81,11 +85,12 @@ mha_varlen_bwd(const at::Tensor& dout, // total_q x num_heads, x head_s
c10::optional<at::Tensor>& dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor& cu_seqlens_q, // b+1
const at::Tensor& cu_seqlens_k, // b+1
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
const int max_seqlen_q,
const int max_seqlen_k, // max sequence length to choose the kernel
const float p_dropout, // probability to drop
const float softmax_scale, const bool zero_tensors, const bool is_causal, const int window_size_left, int window_size_right,
c10::optional<at::Generator> gen_, c10::optional<at::Tensor>& rng_state);
const float softmax_scale, const bool zero_tensors, const bool is_causal, int window_size_left, int window_size_right,
const bool deterministic, c10::optional<at::Generator> gen_, c10::optional<at::Tensor>& rng_state);

std::vector<at::Tensor> __attribute__((weak))
mha_fwd_kvcache(at::Tensor& q, const at::Tensor& kcache, const at::Tensor& vcache, const at::Tensor& k, const at::Tensor& v, const at::Tensor& seqlens_k,
Expand Down

0 comments on commit 9ff00e4

Please sign in to comment.