Skip to content

Commit

Permalink
cache the preprocess backwards kernel output across the ring reduce
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 10, 2024
1 parent 7e81a42 commit 6823794
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 24 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ $ python assert.py --use-cuda --causal --striped-ring-attn
- [x] make sure cuda striped attention works for multiple buckets, otherwise flash attention is ineffective
- [x] for cuda striped attention, for backwards hack, pad the extra token once and index out when passing into Tri's cuda kernel
- [x] find a machine with 8 GPUs and test with a quarter million tokens first
- [x] see for cuda version whether softmax_D can be computed once and cached over the ring reduce. go for modified triton backwards if not

- [ ] see for cuda version whether softmax_D can be computed once and cached over the ring reduce. go for modified triton backwards if not
- [ ] think about how to craft a special `Dataset` that shards across sequence length (take into account labels for cross entropy loss) for ring transformer training
- [ ] add ring attention to Tri's flash attention implementation. find some cuda ring reduce impl
- [ ] figure out how to pytest distributed pytorch
Expand Down
9 changes: 7 additions & 2 deletions ring_attention_pytorch/ring_flash_attention_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def inverse_fn(y):
assert exists(importlib.util.find_spec('triton')), 'latest triton must be installed. `pip install triton -U` first'

triton_version = version('triton')
assert pkg_version.parse(triton_version) >= pkg_version.parse('2.1')
assert pkg_version.parse(triton_version) >= pkg_version.parse('2.1'), 'triton must be version 2.1 or above. `pip install triton -U` to upgrade'

import triton
import triton.language as tl
Expand Down Expand Up @@ -320,6 +320,10 @@ def backward(ctx, do):
receive_kv_and_dkv = None
receive_mask = None

# caching the delta (do * o for backwards pass) across ring reduce

delta = None

# if not causal and has key padding mask
# prepare row related tensors with unpad_input

Expand Down Expand Up @@ -371,7 +375,7 @@ def backward(ctx, do):
ring_dv = torch.empty_like(v)

with torch.inference_mode():
flash_attn_backward(
delta = flash_attn_backward(
do,
q,
k,
Expand All @@ -381,6 +385,7 @@ def backward(ctx, do):
ring_dq,
ring_dk,
ring_dv,
delta = delta,
causal = block_causal,
causal_mask_diagonal = causal_mask_diagonal,
softmax_scale = softmax_scale
Expand Down
58 changes: 38 additions & 20 deletions ring_attention_pytorch/triton_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,7 +912,20 @@ def _bwd_kernel(
)

def flash_attn_backward(
do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, causal_mask_diagonal=False, softmax_scale=None
do,
q,
k,
v,
o,
lse,
dq,
dk,
dv,
delta=None,
bias=None,
causal=False,
causal_mask_diagonal=False,
softmax_scale=None
):
# Make sure that the last dimension is contiguous
if do.stride(-1) != 1:
Expand All @@ -929,28 +942,31 @@ def flash_attn_backward(
softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
# dq_accum = torch.zeros_like(q, dtype=torch.float32)
dq_accum = torch.empty_like(q, dtype=torch.float32)
delta = torch.empty_like(lse)

# delta = torch.zeros_like(lse)

BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
_bwd_preprocess_do_o_dot[grid](
o,
do,
delta,
o.stride(0),
o.stride(2),
o.stride(1),
do.stride(0),
do.stride(2),
do.stride(1),
nheads,
seqlen_q,
seqlen_q_rounded,
d,
BLOCK_M=128,
BLOCK_HEADDIM=BLOCK_HEADDIM,
)

if not exists(delta):
delta = torch.empty_like(lse)
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
_bwd_preprocess_do_o_dot[grid](
o,
do,
delta,
o.stride(0),
o.stride(2),
o.stride(1),
do.stride(0),
do.stride(2),
do.stride(1),
nheads,
seqlen_q,
seqlen_q_rounded,
d,
BLOCK_M=128,
BLOCK_HEADDIM=BLOCK_HEADDIM,
)

has_bias = bias is not None
bias_type = "none"
Expand Down Expand Up @@ -1030,3 +1046,5 @@ def flash_attn_backward(
# num_stages=1,
)
dq.copy_(dq_accum)

return delta
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'ring-attention-pytorch',
packages = find_packages(exclude=[]),
version = '0.3.15',
version = '0.3.16',
license='MIT',
description = 'Ring Attention - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 6823794

Please sign in to comment.