From 1a73f34089a3da0976a06d46060a720880a51114 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Sat, 17 Aug 2024 15:30:29 +0000 Subject: [PATCH] Change allocation of grouped mode lse from [H, M] to [1, H, M] to match the xformers scripts --- .../hip_fmha/attention_backward_generic_ck_tiled.cpp | 8 ++++---- .../hip_fmha/attention_forward_generic_ck_tiled.cpp | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/xformers/csrc/attention/hip_fmha/attention_backward_generic_ck_tiled.cpp b/xformers/csrc/attention/hip_fmha/attention_backward_generic_ck_tiled.cpp index 700adeba5..a1c542177 100644 --- a/xformers/csrc/attention/hip_fmha/attention_backward_generic_ck_tiled.cpp +++ b/xformers/csrc/attention/hip_fmha/attention_backward_generic_ck_tiled.cpp @@ -354,8 +354,8 @@ efficient_attention_backward_ck( p.max_seqlen_k = *max_seqlen_k_; // unpadded lse layout required - TORCH_CHECK(p.Hq == logsumexp.size(0)); - TORCH_CHECK(p.M == logsumexp.size(1)); + TORCH_CHECK(p.Hq == logsumexp.size(1)); + TORCH_CHECK(p.M == logsumexp.size(2)); if (scale.has_value()) p.scale = float(*scale); @@ -384,8 +384,8 @@ efficient_attention_backward_ck( static_cast(grad_out.stride(3))}; p.lsed_strides = { - static_cast(logsumexp.stride(0)), - static_cast(logsumexp.stride(1))}; + static_cast(logsumexp.stride(1)), + static_cast(logsumexp.stride(2))}; if (use_grad_q_f32) { p.grad_q_f32_strides = { diff --git a/xformers/csrc/attention/hip_fmha/attention_forward_generic_ck_tiled.cpp b/xformers/csrc/attention/hip_fmha/attention_forward_generic_ck_tiled.cpp index fa6e0127a..4bbfe71ad 100644 --- a/xformers/csrc/attention/hip_fmha/attention_forward_generic_ck_tiled.cpp +++ b/xformers/csrc/attention/hip_fmha/attention_forward_generic_ck_tiled.cpp @@ -316,11 +316,11 @@ efficient_attention_forward_ck( p.dropout_prob = 0.0f; if (p.compute_logsumexp) { - logsumexp = at::empty({Hq, M}, opts.dtype(at::kFloat)); + logsumexp = at::empty({1, Hq, M}, opts.dtype(at::kFloat)); p.logsumexp_ptr = logsumexp.data_ptr(); p.lse_strides = { - static_cast(logsumexp.stride(0)), - static_cast(logsumexp.stride(1))}; + static_cast(logsumexp.stride(1)), + static_cast(logsumexp.stride(2))}; } else { p.logsumexp_ptr = nullptr; p.lse_strides = {0, 0};