diff --git a/ring_attention_pytorch/ring_flash_attention_cuda.py b/ring_attention_pytorch/ring_flash_attention_cuda.py index e0ffc37..da101d2 100644 --- a/ring_attention_pytorch/ring_flash_attention_cuda.py +++ b/ring_attention_pytorch/ring_flash_attention_cuda.py @@ -457,7 +457,7 @@ def forward( max_lookback_seq_len: Optional[int], ring_size: Optional[int] ): - assert not exists(mask), 'key padding mask is not supported yet for ring flash attn cuda' + assert not (not causal and exists(mask)), 'key padding mask is not supported yet for ring flash attn cuda' assert all([t.is_cuda for t in (q, k, v)]), 'inputs must be all on cuda' @@ -561,11 +561,15 @@ def forward( bias = bias, softmax_scale = softmax_scale, causal_mask_diagonal = causal_mask_diagonal, - return_normalized_output = is_last, + return_normalized_output = False, load_accumulated = not is_first ) lse = lse[..., :q_seq_len] + m = m[..., :q_seq_len] + + o_scale = torch.exp(m - lse) + o.mul_(rearrange('b h n -> b n h 1', o_scale)) ctx.args = ( causal, @@ -643,6 +647,16 @@ def backward(ctx, do): receive_kv_and_dkv = None receive_mask = None + # hack for special causal mask for striped ring attention without having to modify cuda + + if causal and striped_ring_attn: + # this is a hack that should also mask out the diagonal + # https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#21-change-behavior-of-causal-flag + q = pad_at_dim(q, (0, 1), dim = 1) + o = pad_at_dim(o, (0, 1), dim = 1) + do = pad_at_dim(do, (0, 1), dim = 1) + lse = pad_at_dim(lse, (0, 1), dim = -1) + for (ring_rank, _), ((kv_and_dkv, mask), (receive_kv_and_dkv, receive_mask)) in ring_pass_fn(kv_and_dkv, mask, receive_buffers = (receive_kv_and_dkv, receive_mask), max_iters = max_ring_passes, ring_size = ring_size): kv, dk, dv = kv_and_dkv @@ -655,6 +669,8 @@ def backward(ctx, do): # determine whether to do causal mask or not # depends on whether it is striped attention, as well as current machine rank vs ring rank + n = row_length + if causal or not exists(mask): block_causal = False @@ -665,12 +681,7 @@ def backward(ctx, do): block_causal = True if get_rank() < ring_rank: - # this is a hack that should also mask out the diagonal - # https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#21-change-behavior-of-causal-flag - q = pad_at_dim(q, (0, 1), dim = 1) - o = pad_at_dim(o, (0, 1), dim = 1) - do = pad_at_dim(do, (0, 1), dim = 1) - lse = pad_at_dim(lse, (0, 1), dim = -1) + n += 1 else: block_causal = get_rank() == ring_rank @@ -681,15 +692,15 @@ def backward(ctx, do): if need_accum: ring_dq, ring_dk, ring_dv, *_ = _flash_attn_backward( - dout = do, - q = q, + dout = do[:, :n], + q = q[:, :n], k = k, v = v, - out = o, - softmax_lse = lse, - dq = torch.zeros_like(q), - dk = torch.zeros_like(k), - dv = torch.zeros_like(v), + out = o[:, :n], + softmax_lse = lse[..., :n], + dq = torch.empty_like(q[:, :n]), + dk = torch.empty_like(k), + dv = torch.empty_like(v), dropout_p = 0., softmax_scale = softmax_scale, causal = block_causal, diff --git a/setup.py b/setup.py index 8addce9..4925e60 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'ring-attention-pytorch', packages = find_packages(exclude=[]), - version = '0.2.25', + version = '0.2.27', license='MIT', description = 'Ring Attention - Pytorch', author = 'Phil Wang',