Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DIOPI]upgrade flash_attn to 2.4.3 #1348

Merged
merged 3 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 62 additions & 12 deletions impl/torch/functions/functions_ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ diopiError_t diopiMultiHeadAttention(diopiContextHandle_t ctx, diopiTensorHandle
auto headSize = atQ.sizes()[3];
TORCH_CHECK(headSize % 8 == 0, "DIOPI now only support head sizes which are multiple of 8");

std::vector<at::Tensor> result = DIOPI_EXT_CALL_FLASH(mha_fwd, atQ, atK, atV, optOut, dropout_p, scale, is_causal, -1, -1, return_debug_mask, atGen);
c10::optional<at::Tensor> nullOpt;
std::vector<at::Tensor> result =
DIOPI_EXT_CALL_FLASH(mha_fwd, atQ, atK, atV, optOut, nullOpt, dropout_p, scale, is_causal, -1, -1, return_debug_mask, atGen);

// PERF: these copy can be eliminated by modifying the flash_attn api
impl::aten::updateATen2Tensor(ctx, result[fa::mha_fwd_ret_idx::SOFTMAX_LSE], softmax_lse);
Expand All @@ -147,10 +149,28 @@ diopiError_t diopiMultiHeadAttentionBackward(diopiContextHandle_t ctx, diopiCons
DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optGradQ, grad_q);
DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optGradK, grad_k);
DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optGradV, grad_v);
c10::optional<at::Tensor> nullOpt; // Workaround: flash_attn uses non-const optional& as args (which is a really bad idea)
c10::optional<at::Tensor> nullStateOpt; // Workaround: flash_attn uses non-const optional& as args (which is a really bad idea)
c10::optional<at::Tensor> nullSlopesOpt;

std::vector<at::Tensor> result = DIOPI_EXT_CALL_FLASH(
mha_bwd, atGradOut, atQ, atK, atV, atOut, atLogsumexp, optGradQ, optGradK, optGradV, dropout_p, scale, is_causal, -1, -1, atGen, nullOpt);
std::vector<at::Tensor> result = DIOPI_EXT_CALL_FLASH(mha_bwd,
atGradOut,
atQ,
atK,
atV,
atOut,
atLogsumexp,
optGradQ,
optGradK,
optGradV,
nullSlopesOpt,
dropout_p,
scale,
is_causal,
-1,
-1,
false,
atGen,
nullStateOpt);

return diopiSuccess;
}
Expand All @@ -172,8 +192,27 @@ diopiError_t diopiMultiHeadAttentionVarLen(diopiContextHandle_t ctx, diopiTensor
auto headSize = atQ.sizes()[3];
TORCH_CHECK(headSize % 8 == 0, "DIOPI now only support head sizes which are multiple of 8");

std::vector<at::Tensor> result = DIOPI_EXT_CALL_FLASH(
mha_varlen_fwd, atQ, atK, atV, optOut, atCumSeqQ, atCumSeqK, max_q, max_k, dropout_p, scale, false, is_causal, -1, -1, return_debug_mask, atGen);
c10::optional<at::Tensor> nullSeqOpt;
c10::optional<at::Tensor> nullSlopesOpt;
std::vector<at::Tensor> result = DIOPI_EXT_CALL_FLASH(mha_varlen_fwd,
atQ,
atK,
atV,
optOut,
atCumSeqQ,
atCumSeqK,
nullSeqOpt,
nullSlopesOpt,
max_q,
max_k,
dropout_p,
scale,
false,
is_causal,
-1,
-1,
return_debug_mask,
atGen);

// PERF: these copy can be eliminated by modifying the flash_attn api
impl::aten::updateATen2Tensor(ctx, result[fa::mha_fwd_ret_idx::SOFTMAX_LSE], softmax_lse);
Expand Down Expand Up @@ -203,7 +242,8 @@ diopiError_t diopiMultiHeadAttentionVarLenBackward(diopiContextHandle_t ctx, dio
DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optGradQ, grad_q);
DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optGradK, grad_k);
DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optGradV, grad_v);
c10::optional<at::Tensor> nullOpt; // Workaround: flash_attn uses non-const optional& as args (which is a really bad idea)
c10::optional<at::Tensor> nullStateOpt; // Workaround: flash_attn uses non-const optional& as args (which is a really bad idea)
c10::optional<at::Tensor> nullSlopesOpt;

std::vector<at::Tensor> result = DIOPI_EXT_CALL_FLASH(mha_varlen_bwd,
atGradOut,
Expand All @@ -217,6 +257,7 @@ diopiError_t diopiMultiHeadAttentionVarLenBackward(diopiContextHandle_t ctx, dio
optGradV,
atCumSeqQ,
atCumSeqK,
nullSlopesOpt,
max_q,
max_k,
dropout_p,
Expand All @@ -225,8 +266,9 @@ diopiError_t diopiMultiHeadAttentionVarLenBackward(diopiContextHandle_t ctx, dio
is_causal,
-1,
-1,
false,
atGen,
nullOpt);
nullStateOpt);

return diopiSuccess;
}
Expand All @@ -244,7 +286,7 @@ diopiError_t diopiFlashAttention(diopiContextHandle_t ctx, diopiTensorHandle_t a
auto atQ = impl::aten::buildATen(q);
auto atK = impl::aten::buildATen(k);
auto atV = impl::aten::buildATen(v);
DIOPI_CHECK(alibi_slopes == nullptr, "alibi_slopes is not yet supported in DIOPI torch impl");
DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optLinearBiasSlopes, alibi_slopes);

auto headSize = atQ.sizes()[3];
DIOPI_CHECK(headSize % 8 == 0, "DIOPI torch impl now only support head sizes which are multiple of 8");
Expand All @@ -254,6 +296,7 @@ diopiError_t diopiFlashAttention(diopiContextHandle_t ctx, diopiTensorHandle_t a
atK,
atV,
optOut,
optLinearBiasSlopes,
p_dropout,
softmax_scale,
is_causal,
Expand Down Expand Up @@ -286,9 +329,9 @@ diopiError_t diopiFlashAttentionBackward(diopiContextHandle_t ctx, diopiTensorHa
auto atQ = impl::aten::buildATen(q);
auto atK = impl::aten::buildATen(k);
auto atV = impl::aten::buildATen(v);
DIOPI_CHECK(alibi_slopes == nullptr, "alibi_slopes is not yet supported in DIOPI torch impl");
auto atOut = impl::aten::buildATen(attention_out);
auto atSoftmaxLse = impl::aten::buildATen(softmax_lse);
DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optLinearBiasSlopes, alibi_slopes);

c10::optional<at::Tensor> nullOpt; // Workaround: flash_attn uses non-const optional& as args (which is a really bad idea)
std::vector<at::Tensor> result = DIOPI_EXT_CALL_FLASH(mha_bwd,
Expand All @@ -301,11 +344,13 @@ diopiError_t diopiFlashAttentionBackward(diopiContextHandle_t ctx, diopiTensorHa
optGradQ,
optGradK,
optGradV,
optLinearBiasSlopes,
p_dropout,
softmax_scale,
is_causal,
window_size_left,
window_size_right,
/*deterministic=*/false,
atGen,
nullOpt);

Expand All @@ -328,18 +373,21 @@ diopiError_t diopiFlashAttentionVarLen(diopiContextHandle_t ctx, diopiTensorHand
auto atV = impl::aten::buildATen(v);
auto atCumSeqQ = impl::aten::buildATen(cum_seq_q);
auto atCumSeqKV = impl::aten::buildATen(cum_seq_kv);
DIOPI_CHECK(alibi_slopes == nullptr, "alibi_slopes is not yet supported in DIOPI torch impl");
DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optLinearBiasSlopes, alibi_slopes);

auto headSize = atQ.sizes()[3];
DIOPI_CHECK(headSize % 8 == 0, "DIOPI torch impl now only support head sizes which are multiple of 8");

c10::optional<at::Tensor> nullOpt;
std::vector<at::Tensor> result = DIOPI_EXT_CALL_FLASH(mha_varlen_fwd,
atQ,
atK,
atV,
optOut,
atCumSeqQ,
atCumSeqKV,
nullOpt,
optLinearBiasSlopes,
max_seqlen_q,
max_seqlen_kv,
p_dropout,
Expand Down Expand Up @@ -379,9 +427,9 @@ diopiError_t diopiFlashAttentionVarLenBackward(diopiContextHandle_t ctx, diopiTe
auto atV = impl::aten::buildATen(v);
auto atCumSeqQ = impl::aten::buildATen(cum_seq_q);
auto atCumSeqKV = impl::aten::buildATen(cum_seq_kv);
DIOPI_CHECK(alibi_slopes == nullptr, "alibi_slopes is not yet supported in DIOPI torch impl");
auto atOut = impl::aten::buildATen(attention_out);
auto atSoftmaxLse = impl::aten::buildATen(softmax_lse);
DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optLinearBiasSlopes, alibi_slopes);

c10::optional<at::Tensor> nullOpt; // Workaround: flash_attn uses non-const optional& as args (which is a really bad idea)
std::vector<at::Tensor> result = DIOPI_EXT_CALL_FLASH(mha_varlen_bwd,
Expand All @@ -396,6 +444,7 @@ diopiError_t diopiFlashAttentionVarLenBackward(diopiContextHandle_t ctx, diopiTe
optGradV,
atCumSeqQ,
atCumSeqKV,
optLinearBiasSlopes,
max_seqlen_q,
max_seqlen_kv,
p_dropout,
Expand All @@ -404,6 +453,7 @@ diopiError_t diopiFlashAttentionVarLenBackward(diopiContextHandle_t ctx, diopiTe
is_causal,
window_size_left,
window_size_right,
/*deterministic=*/false,
atGen,
nullOpt);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,19 @@ foreach(CC ${CURRENT_GPU_LIST})
endif()
endforeach()

set(FLASH_ATTN_LIB_PATH_2_0 "/mnt/cache/share/platform/dep/DIOPI_pytorch/flash-attention2.4.3_pt2.0")
set(FLASH_ATTN_LIB_PATH_2_1 "/mnt/cache/share/platform/dep/DIOPI_pytorch/flash-attention2.4.3_pt2.1")

if(${Torch_VERSION_MAJOR} EQUAL 2 AND ${Torch_VERSION_MINOR} EQUAL 0)
set(FLASH_ATTN_LIB_PATH "${FLASH_ATTN_LIB_PATH_2_0}")
elseif(${Torch_VERSION_MAJOR} EQUAL 2 AND ${Torch_VERSION_MINOR} EQUAL 1)
set(FLASH_ATTN_LIB_PATH "${FLASH_ATTN_LIB_PATH_2_1}")
else()
message(FATAL_ERROR "No valid torch version for setting FLASH_ATTN_LIB_PATH")
endif()

message(STATUS "FLASH_ATTN_LIB_PATH: ${FLASH_ATTN_LIB_PATH}")

if(ENABLE_TORCH_EXT_FLASH_ATTN)
# Note: it's really a bad idea to hardcode name and path here.
find_library(
Expand All @@ -20,7 +33,7 @@ if(ENABLE_TORCH_EXT_FLASH_ATTN)
flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so # fallback name of library
PATHS
ENV FLASH_ATTN_LIB_DIR # alternative path to search
/mnt/cache/share/platform/dep/DIOPI_pytorch/flash-attention # fallback path
${FLASH_ATTN_LIB_PATH} # fallback path
)
endif()

Expand Down
Loading
Loading