Skip to content

Commit

Permalink
just go entirely with triton flash attention forwards and backwards. …
Browse files Browse the repository at this point in the history
…validate that non-causal key padding works
  • Loading branch information
lucidrains committed Apr 11, 2024
1 parent 6823794 commit 61fa967
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 151 deletions.
28 changes: 20 additions & 8 deletions assert_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,31 @@
@click.option('--seq-len', default = 62)
@click.option('--dim-head', default = 16)
@click.option('--heads', default = 2)
@click.option('--rand-key-pad-mask', is_flag = True)
@click.option('--bucket_size', default = 4)
@click.option('--flash-cuda-kernel', is_flag = True)
@click.option('--cuda-kernel', is_flag = True)
def test(
causal: bool,
seq_len: int,
dim_head: int,
heads: int,
rand_key_pad_mask: bool,
bucket_size: int,
flash_cuda_kernel: bool
cuda_kernel: bool
):
# base qkv

q = torch.randn(2, seq_len, heads, dim_head)
k = torch.randn(2, seq_len, heads, dim_head)
v = torch.randn(2, seq_len, heads, dim_head)

# key padding mask

mask = None
if rand_key_pad_mask:
assert not causal
mask = torch.randint(0, 2, (2, seq_len)).bool()

# flash and regular qkv's

fq = q.clone().requires_grad_()
Expand All @@ -40,7 +49,7 @@ def test(
rk = k.clone().requires_grad_()
rv = v.clone().requires_grad_()

if flash_cuda_kernel:
if cuda_kernel:
assert torch.cuda.is_available()

fcq = q.clone().cuda().requires_grad_()
Expand All @@ -49,15 +58,18 @@ def test(

# forward

o = default_attention(rq, rk, rv, causal = causal)
fo = ring_flash_attn(fq, fk, fv, bucket_size = bucket_size, causal = causal)
o = default_attention(rq, rk, rv, causal = causal, mask = mask)
fo = ring_flash_attn(fq, fk, fv, bucket_size = bucket_size, causal = causal, mask = mask)

assert torch.allclose(o, fo, atol = 1e-6)

if flash_cuda_kernel:
if cuda_kernel:
from ring_attention_pytorch.ring_flash_attention_cuda import ring_flash_attn_cuda

fco = ring_flash_attn_cuda(fcq, fck, fcv, None, causal)
if mask is not None:
mask = mask.cuda()

fco = ring_flash_attn_cuda(fcq, fck, fcv, mask, causal)
fco.sum().backward()

assert torch.allclose(o, fco.cpu(), atol = 1e-2)
Expand All @@ -71,7 +83,7 @@ def test(
assert torch.allclose(rk.grad, fk.grad, atol = 1e-6)
assert torch.allclose(rv.grad, fv.grad, atol = 1e-6)

if flash_cuda_kernel:
if cuda_kernel:
assert torch.allclose(rq.grad, fcq.grad.cpu(), atol = 1e-2)
assert torch.allclose(rk.grad, fck.grad.cpu(), atol = 1e-2)
assert torch.allclose(rv.grad, fcv.grad.cpu(), atol = 1e-2)
Expand Down
188 changes: 47 additions & 141 deletions ring_attention_pytorch/ring_flash_attention_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,53 +44,11 @@ def pad_at_dim(t, pad: Tuple[int, int], *, dim = -1, value = 0.):
def is_empty(t: Tensor):
return t.numel() == 0

def padded_false_on_right_side(t: Tensor):
if t.shape[-1] <= 1:
return True

false_to_true = ~t[..., :-1] & t[..., 1:]
return not false_to_true.any()

# make sure flash attention is installed for backwards
# make sure triton is installed for forwards

import importlib
from importlib.metadata import version

assert exists(importlib.util.find_spec('flash_attn')), 'flash-attn must be installed. `pip install flash-attn --no-build-isolation` first'

flash_attn_version = version('flash_attn')
assert pkg_version.parse(flash_attn_version) >= pkg_version.parse('2.5.1')

from flash_attn.flash_attn_interface import (
_flash_attn_varlen_backward
)

from flash_attn.bert_padding import (
pad_input,
unpad_input
)

@beartype
def unpad_inputs_and_return_inverse_fn(
tensors: Tuple[Tensor, ...],
mask: Tensor
):
assert len(tensors) > 0
batch, seqlen, *_ = first(tensors).shape

outs = []

for tensor in tensors:
out, indices, cu_seqlens, max_seqlen = unpad_input(tensor, mask)
outs.append(out)

def inverse_fn(y):
return pad_input(y, indices, batch, seqlen)

return tuple(outs), cu_seqlens, max_seqlen, inverse_fn

# make sure triton is installed for forwards

assert exists(importlib.util.find_spec('triton')), 'latest triton must be installed. `pip install triton -U` first'

triton_version = version('triton')
Expand Down Expand Up @@ -128,7 +86,6 @@ def forward(
ring_size: Optional[int]
):
assert all([t.is_cuda for t in (q, k, v)]), 'inputs must be all on cuda'
assert not exists(mask) or padded_false_on_right_side(mask), 'key padding mask must only contain True (attend) on the left hand side, and False (not attend) on the right'

dtype = q.dtype
softmax_scale = q.shape[-1] ** -0.5
Expand All @@ -142,7 +99,6 @@ def forward(
if v.dtype == torch.float32:
v = v.half()

max_neg_value = -torch.finfo(dtype).max
ring_size = default(ring_size, get_world_size())

cross_attn = q.shape[-3] != k.shape[-3]
Expand Down Expand Up @@ -206,7 +162,7 @@ def forward(
bias = None

if exists(mask):
bias = torch.where(mask, 0., max_neg_value)
bias = torch.where(mask, 0., float('-inf'))

# for non-striped attention
# if the kv ring rank is equal to the current rank (block diagonal), then turn on causal
Expand Down Expand Up @@ -324,22 +280,6 @@ def backward(ctx, do):

delta = None

# if not causal and has key padding mask
# prepare row related tensors with unpad_input

if not causal and exists(mask):
lse = rearrange(lse, 'b h n ... -> b n h ...')

(
(q, o, do, lse),
cu_seqlens_q,
cu_maxlen_q,
repad_q
) = unpad_inputs_and_return_inverse_fn(
(q, o, do, lse),
mask
)

for (ring_rank, _), ((kv_and_dkv, mask), (receive_kv_and_dkv, receive_mask)) in ring_pass_fn(kv_and_dkv, mask, receive_buffers = (receive_kv_and_dkv, receive_mask), max_iters = max_ring_passes, ring_size = ring_size):

kv, dk, dv = kv_and_dkv
Expand All @@ -349,90 +289,56 @@ def backward(ctx, do):
k, v = kv.chunk(2, dim = -1)
k, v = k.view(k_dtype), v.view(v_dtype)

# translate key padding mask to bias

bias = None

if exists(mask):
bias = torch.where(mask, 0., float('-inf'))
bias = rearrange(bias, 'b j -> b 1 1 j')

# determine whether to do causal mask or not
# depends on whether it is striped attention, as well as current machine rank vs ring rank

if causal or not exists(mask):

if causal and striped_ring_attn:
need_accum = True
block_causal = True
causal_mask_diagonal = get_rank() < ring_rank
elif causal:
need_accum = get_rank() >= ring_rank
block_causal = get_rank() == ring_rank
causal_mask_diagonal = False
else:
need_accum = True
block_causal = False
causal_mask_diagonal = False

# use flash attention backwards kernel to calculate dq, dk, dv and accumulate

if need_accum:
ring_dq = torch.empty_like(q)
ring_dk = torch.empty_like(k)
ring_dv = torch.empty_like(v)

with torch.inference_mode():
delta = flash_attn_backward(
do,
q,
k,
v,
o,
lse,
ring_dq,
ring_dk,
ring_dv,
delta = delta,
causal = block_causal,
causal_mask_diagonal = causal_mask_diagonal,
softmax_scale = softmax_scale
)
else:
ring_dq, ring_dk, ring_dv = 0., 0., 0.
if causal and striped_ring_attn:
need_accum = True
block_causal = True
causal_mask_diagonal = get_rank() < ring_rank
elif causal:
need_accum = get_rank() >= ring_rank
block_causal = get_rank() == ring_rank
causal_mask_diagonal = False
else:

(
(k, v),
cu_seqlens_k,
cu_maxlen_k,
repad_kv
) = unpad_inputs_and_return_inverse_fn(
(k, v),
mask
)

if not is_empty(q) and not is_empty(k):
ring_dq, ring_dk, ring_dv, *_ = _flash_attn_varlen_backward(
dout = do,
q = q,
k = k,
v = v,
out = o,
softmax_lse = lse,
dq = torch.empty_like(q),
dk = torch.empty_like(k),
dv = torch.empty_like(v),
cu_seqlens_q = cu_seqlens_q,
cu_seqlens_k = cu_seqlens_k,
max_seqlen_q = cu_maxlen_q,
max_seqlen_k = cu_maxlen_k,
dropout_p = 0.,
softmax_scale = softmax_scale,
causal = False,
window_size = (-1, -1),
alibi_slopes = None,
deterministic = False
need_accum = True
block_causal = False
causal_mask_diagonal = False

# use flash attention backwards kernel to calculate dq, dk, dv and accumulate

if need_accum:
ring_dq = torch.empty_like(q)
ring_dk = torch.empty_like(k)
ring_dv = torch.empty_like(v)

with torch.inference_mode():
delta = flash_attn_backward(
do,
q,
k,
v,
o,
lse,
ring_dq,
ring_dk,
ring_dv,
delta = delta,
bias = bias,
causal = block_causal,
causal_mask_diagonal = causal_mask_diagonal,
softmax_scale = softmax_scale
)

ring_dq = repad_q(ring_dq)
ring_dk = repad_kv(ring_dk)
ring_dv = repad_kv(ring_dv)

else:
ring_dq, ring_dk, ring_dv = 0., 0., 0.
else:
ring_dq, ring_dk, ring_dv = 0., 0., 0.

dq.add_(ring_dq)
dk.add_(ring_dk)
Expand Down
5 changes: 4 additions & 1 deletion ring_attention_pytorch/triton_flash_attn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# taken from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_triton.py
# with fixes for triton 2.3 and preparing for modifications to backwards
# with fixes for triton 2.3
# forward is modified to return unnormalized accumulation, row maxes, row lse - reduced over passed rings
# both forwards and backwards is modified to allow for masking out the diagonal for striped ring attention

import math

Expand All @@ -9,6 +10,8 @@
import triton
import triton.language as tl

from einops import repeat

def exists(v):
return v is not None

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'ring-attention-pytorch',
packages = find_packages(exclude=[]),
version = '0.3.16',
version = '0.3.17',
license='MIT',
description = 'Ring Attention - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 61fa967

Please sign in to comment.