Skip to content

Commit

Permalink
micro optimization for striped ring cuda attn
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 4, 2024
1 parent 1f3574b commit ee39854
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 16 deletions.
41 changes: 26 additions & 15 deletions ring_attention_pytorch/ring_flash_attention_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def forward(
max_lookback_seq_len: Optional[int],
ring_size: Optional[int]
):
assert not exists(mask), 'key padding mask is not supported yet for ring flash attn cuda'
assert not (not causal and exists(mask)), 'key padding mask is not supported yet for ring flash attn cuda'

assert all([t.is_cuda for t in (q, k, v)]), 'inputs must be all on cuda'

Expand Down Expand Up @@ -561,11 +561,15 @@ def forward(
bias = bias,
softmax_scale = softmax_scale,
causal_mask_diagonal = causal_mask_diagonal,
return_normalized_output = is_last,
return_normalized_output = False,
load_accumulated = not is_first
)

lse = lse[..., :q_seq_len]
m = m[..., :q_seq_len]

o_scale = torch.exp(m - lse)
o.mul_(rearrange('b h n -> b n h 1', o_scale))

ctx.args = (
causal,
Expand Down Expand Up @@ -643,6 +647,16 @@ def backward(ctx, do):
receive_kv_and_dkv = None
receive_mask = None

# hack for special causal mask for striped ring attention without having to modify cuda

if causal and striped_ring_attn:
# this is a hack that should also mask out the diagonal
# https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#21-change-behavior-of-causal-flag
q = pad_at_dim(q, (0, 1), dim = 1)
o = pad_at_dim(o, (0, 1), dim = 1)
do = pad_at_dim(do, (0, 1), dim = 1)
lse = pad_at_dim(lse, (0, 1), dim = -1)

for (ring_rank, _), ((kv_and_dkv, mask), (receive_kv_and_dkv, receive_mask)) in ring_pass_fn(kv_and_dkv, mask, receive_buffers = (receive_kv_and_dkv, receive_mask), max_iters = max_ring_passes, ring_size = ring_size):

kv, dk, dv = kv_and_dkv
Expand All @@ -655,6 +669,8 @@ def backward(ctx, do):
# determine whether to do causal mask or not
# depends on whether it is striped attention, as well as current machine rank vs ring rank

n = row_length

if causal or not exists(mask):

block_causal = False
Expand All @@ -665,12 +681,7 @@ def backward(ctx, do):
block_causal = True

if get_rank() < ring_rank:
# this is a hack that should also mask out the diagonal
# https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#21-change-behavior-of-causal-flag
q = pad_at_dim(q, (0, 1), dim = 1)
o = pad_at_dim(o, (0, 1), dim = 1)
do = pad_at_dim(do, (0, 1), dim = 1)
lse = pad_at_dim(lse, (0, 1), dim = -1)
n += 1
else:
block_causal = get_rank() == ring_rank

Expand All @@ -681,15 +692,15 @@ def backward(ctx, do):

if need_accum:
ring_dq, ring_dk, ring_dv, *_ = _flash_attn_backward(
dout = do,
q = q,
dout = do[:, :n],
q = q[:, :n],
k = k,
v = v,
out = o,
softmax_lse = lse,
dq = torch.zeros_like(q),
dk = torch.zeros_like(k),
dv = torch.zeros_like(v),
out = o[:, :n],
softmax_lse = lse[..., :n],
dq = torch.empty_like(q[:, :n]),
dk = torch.empty_like(k),
dv = torch.empty_like(v),
dropout_p = 0.,
softmax_scale = softmax_scale,
causal = block_causal,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'ring-attention-pytorch',
packages = find_packages(exclude=[]),
version = '0.2.25',
version = '0.2.27',
license='MIT',
description = 'Ring Attention - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit ee39854

Please sign in to comment.