Skip to content

Commit

Permalink
updated algorithm 3 in tree attn decoding paper is more concise
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 14, 2024
1 parent d49499f commit 103379f
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 20 deletions.
29 changes: 10 additions & 19 deletions ring_attention_pytorch/tree_attn_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@ def tree_attn_decode(
k, v = (k[rank], v[rank]) if rank < len(k) else (None, None)

if exists(k):
# calculate local output and derive numerator and denominator
# calculate local output and lse

use_triton = default(use_triton, q.is_cuda)
assert not (use_triton and not q.is_cuda), 'input needs to be on cuda if forcing the use of triton'

if use_triton and q.is_cuda:
from ring_attention_pytorch.triton_flash_attn import flash_attn_forward

local_out, local_max, lse = flash_attn_forward(
local_out, _, lse = flash_attn_forward(
q, k, v,
causal = False,
return_normalized_output = True,
Expand All @@ -72,34 +72,25 @@ def tree_attn_decode(
scale = q.shape[-1] ** -0.5
sim = einsum('... i d, ... j d -> ... i j', q, k) * scale

local_max = sim.amax(dim = -1, keepdim = True)
sim -= local_max
lse = sim.logsumexp(dim = -1, keepdim = True)

attn = sim.softmax(dim = -1)
local_out = einsum('... i j, ... j d -> ... i d', attn, v)

den = lse.exp()
num = local_out.float() * den

else:
# handle edge case where seq length < world size

num = q.new_zeros((*q.shape[:-1], v.shape[-1]), dtype = torch.float32)
den = q.new_zeros((*q.shape[:-1], 1), dtype = torch.float32)
local_max = torch.zeros_like(den)

# first get global max through an all reduce (max)
local_out = q.new_zeros((*q.shape[:-1], v.shape[-1]), dtype = torch.float32)
lse = torch.full_like(den, -torch.finfo(torch.float32).max)

global_max = local_max.clone()
dist.all_reduce(global_max, dist.ReduceOp.MAX)
# first get max(lse) through an all reduce

# renormalize the numerator and denominators
max_lse = lse.clone()
dist.all_reduce(max_lse, dist.ReduceOp.MAX)

renorm_factor = (local_max - global_max).exp()
# derive numerator and denominator

den *= renorm_factor
num *= renorm_factor
den = (lse - max_lse).exp()
num = local_out * den

# second and third all reduce (sum)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'ring-attention-pytorch',
packages = find_packages(exclude=[]),
version = '0.5.10',
version = '0.5.12',
license='MIT',
description = 'Ring Attention - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 103379f

Please sign in to comment.