Skip to content

Commit

Permalink
Change allocation of grouped mode lse from [H, M] to [1, H, M] to mat…
Browse files Browse the repository at this point in the history
…ch the xformers scripts
  • Loading branch information
qianfengz committed Aug 17, 2024
1 parent cbb557d commit 1a73f34
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -354,8 +354,8 @@ efficient_attention_backward_ck(
p.max_seqlen_k = *max_seqlen_k_;

// unpadded lse layout required
TORCH_CHECK(p.Hq == logsumexp.size(0));
TORCH_CHECK(p.M == logsumexp.size(1));
TORCH_CHECK(p.Hq == logsumexp.size(1));
TORCH_CHECK(p.M == logsumexp.size(2));

if (scale.has_value())
p.scale = float(*scale);
Expand Down Expand Up @@ -384,8 +384,8 @@ efficient_attention_backward_ck(
static_cast<int>(grad_out.stride(3))};

p.lsed_strides = {
static_cast<int>(logsumexp.stride(0)),
static_cast<int>(logsumexp.stride(1))};
static_cast<int>(logsumexp.stride(1)),
static_cast<int>(logsumexp.stride(2))};

if (use_grad_q_f32) {
p.grad_q_f32_strides = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -316,11 +316,11 @@ efficient_attention_forward_ck(
p.dropout_prob = 0.0f;

if (p.compute_logsumexp) {
logsumexp = at::empty({Hq, M}, opts.dtype(at::kFloat));
logsumexp = at::empty({1, Hq, M}, opts.dtype(at::kFloat));
p.logsumexp_ptr = logsumexp.data_ptr();
p.lse_strides = {
static_cast<int>(logsumexp.stride(0)),
static_cast<int>(logsumexp.stride(1))};
static_cast<int>(logsumexp.stride(1)),
static_cast<int>(logsumexp.stride(2))};
} else {
p.logsumexp_ptr = nullptr;
p.lse_strides = {0, 0};
Expand Down

0 comments on commit 1a73f34

Please sign in to comment.