Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug: start_position support for the fused attention kernel #329

Open
1 of 2 tasks
ipoletaev opened this issue Jul 21, 2023 · 0 comments
Open
1 of 2 tasks

bug: start_position support for the fused attention kernel #329

ipoletaev opened this issue Jul 21, 2023 · 0 comments

Comments

@ipoletaev
Copy link

Description

Using of a start position index in a fused attention kernel does not work.

Steps to reproduce

START_IDX = 128


def attention_reference(q: torch.Tensor, k: torch.Tensor,
                        v: torch.Tensor) -> torch.Tensor:

    mask_y = torch.full((1, 1, q.size(2), q.size(2)), float("-inf"))
    mask_y = torch.triu(mask_y, diagonal=START_IDX + 1).float()
    att_y = (q @ k.transpose(-2, -1)) * scale
    att_y = att_y + mask_y.to(att_y)
    att_y = torch.nn.functional.softmax(att_y, dim=-1)
    return att_y @ v


q = torch.randn([4, 32, 4096, 128], dtype=torch.float16, device="cuda")
k = torch.randn([4, 32, 4096, 128], dtype=torch.float16, device="cuda")
v = torch.randn([4, 32, 4096, 128], dtype=torch.float16, device="cuda")
scale = 1 / math.sqrt(128)

x = triton_fa(q, k, v, scale, True, START_IDX)
y = attention_reference(q, k, v)
print(torch.max(torch.abs(x - y)))
print(torch.sum(x - y))

Expected Behavior

Almost identical prediction as with the vanilla implementation for any start position index.

Actual Behavior

Returns nan for any START_IDX != 0.

Your environment

torch==2.0.0
triton==2.0.0

Self-service

  • I would be willing to help fix this bug myself.

Code of Conduct

  • I agree to follow this project's Code of Conduct
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant