Skip to content

Commit

Permalink
may be safe to accumulate dk and dv as same type as k and v in flash …
Browse files Browse the repository at this point in the history
…cuda kernel backwards, saving some communication costs
  • Loading branch information
lucidrains committed Apr 11, 2024
1 parent d823c70 commit 24e1673
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 17 deletions.
31 changes: 15 additions & 16 deletions ring_attention_pytorch/ring_flash_attention_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,19 +254,17 @@ def backward(ctx, do):
device = q.device

dq = torch.zeros(q.shape, device = device, dtype = torch.float32)
dk = torch.zeros(k.shape, device = device, dtype = torch.float32)
dv = torch.zeros(v.shape, device = device, dtype = torch.float32)
dk = torch.zeros_like(k, device = device)
dv = torch.zeros_like(v, device = device)

# k and v will have 16 bits, while dk, dv needs to be kept at 32
# view everything as int for ring passing
# for minimizing communication
# k and v will have 16 bits, and dk and dv can also be accumulated safely with the same type, i think
# view everything as float32 for ring passing

k_dtype, v_dtype = k.dtype, v.dtype
assert k.dtype == v.dtype
kv_dtype = k.dtype

k, v = map(lambda t: t.view(torch.float32), (k, v))
kv = torch.cat((k, v), dim = -1)

kv_and_dkv = torch.stack((kv, dk, dv))
k, v, dk, dv = map(lambda t: t.view(torch.float32), (k, v, dk, dv))
kv_and_dkv = torch.stack((k, v, dk, dv))

# receive buffers, to be alternated with sent buffer

Expand All @@ -279,12 +277,11 @@ def backward(ctx, do):

for (ring_rank, _), ((kv_and_dkv, mask), (receive_kv_and_dkv, receive_mask)) in ring_pass_fn(kv_and_dkv, mask, receive_buffers = (receive_kv_and_dkv, receive_mask), max_iters = max_ring_passes, ring_size = ring_size):

kv, dk, dv = kv_and_dkv
k, v, dk, dv = kv_and_dkv

# reconstitute correct types for k, v, dk, dv
# view k, v, dk, dv as the correct type of either float16 or bfloat16

k, v = kv.chunk(2, dim = -1)
k, v = k.view(k_dtype), v.view(v_dtype)
k, v, dk, dv = map(lambda t: t.view(kv_dtype), (k, v, dk, dv))

# translate key padding mask to bias

Expand Down Expand Up @@ -313,7 +310,7 @@ def backward(ctx, do):
# use flash attention backwards kernel to calculate dq, dk, dv and accumulate

if need_accum:
ring_dq = torch.empty_like(q)
ring_dq = torch.empty(q.shape, device = device, dtype = torch.float32)
ring_dk = torch.empty_like(k)
ring_dv = torch.empty_like(v)

Expand Down Expand Up @@ -344,13 +341,15 @@ def backward(ctx, do):
if not ring_reduce_col:
continue

dkv = kv_and_dkv[1:]
dkv = kv_and_dkv[2:]

max_ring_passes = default(max_ring_passes, ring_size)
dkv = ring_pass(ring_size - max_ring_passes + 1, dkv)

dk, dv = dkv

dk, dv = map(lambda t: t.view(kv_dtype), (dk, dv))

dq, dk, dv = map(lambda t: t.to(dtype), (dq, dk, dv))

return dq, dk, dv, None, None, None, None, None, None, None
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'ring-attention-pytorch',
packages = find_packages(exclude=[]),
version = '0.3.17',
version = '0.3.18',
license='MIT',
description = 'Ring Attention - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 24e1673

Please sign in to comment.