diff --git a/impl/torch/functions/functions_ext.cpp b/impl/torch/functions/functions_ext.cpp index a102b3ad9..b8a288930 100644 --- a/impl/torch/functions/functions_ext.cpp +++ b/impl/torch/functions/functions_ext.cpp @@ -278,7 +278,7 @@ diopiError_t diopiFlashAttention(diopiContextHandle_t ctx, diopiTensorHandle_t a float p_dropout, float softmax_scale, bool is_causal, int32_t window_size_left, int32_t window_size_right) { impl::aten::setCurStream(ctx); - // handle param[out] + // handle param[out]mha_bwd DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optOut, attention_out); // handle param[in] @@ -373,7 +373,6 @@ diopiError_t diopiFlashAttentionVarLen(diopiContextHandle_t ctx, diopiTensorHand auto atV = impl::aten::buildATen(v); auto atCumSeqQ = impl::aten::buildATen(cum_seq_q); auto atCumSeqKV = impl::aten::buildATen(cum_seq_kv); - DIOPI_CHECK(alibi_slopes == nullptr, "alibi_slopes is not yet supported in DIOPI torch impl"); DIOPI_IMPL_BUILD_ATEN_OPTIONAL(optLinearBiasSlopes, alibi_slopes); auto headSize = atQ.sizes()[3];