diff --git a/ring_attention_pytorch/tree_attn_decoding.py b/ring_attention_pytorch/tree_attn_decoding.py index 64ce136..82ae34c 100644 --- a/ring_attention_pytorch/tree_attn_decoding.py +++ b/ring_attention_pytorch/tree_attn_decoding.py @@ -51,7 +51,7 @@ def tree_attn_decode( k, v = (k[rank], v[rank]) if rank < len(k) else (None, None) if exists(k): - # calculate local output and derive numerator and denominator + # calculate local output and lse use_triton = default(use_triton, q.is_cuda) assert not (use_triton and not q.is_cuda), 'input needs to be on cuda if forcing the use of triton' @@ -59,7 +59,7 @@ def tree_attn_decode( if use_triton and q.is_cuda: from ring_attention_pytorch.triton_flash_attn import flash_attn_forward - local_out, local_max, lse = flash_attn_forward( + local_out, _, lse = flash_attn_forward( q, k, v, causal = False, return_normalized_output = True, @@ -72,34 +72,25 @@ def tree_attn_decode( scale = q.shape[-1] ** -0.5 sim = einsum('... i d, ... j d -> ... i j', q, k) * scale - local_max = sim.amax(dim = -1, keepdim = True) - sim -= local_max lse = sim.logsumexp(dim = -1, keepdim = True) - attn = sim.softmax(dim = -1) local_out = einsum('... i j, ... j d -> ... i d', attn, v) - den = lse.exp() - num = local_out.float() * den - else: # handle edge case where seq length < world size - num = q.new_zeros((*q.shape[:-1], v.shape[-1]), dtype = torch.float32) - den = q.new_zeros((*q.shape[:-1], 1), dtype = torch.float32) - local_max = torch.zeros_like(den) - - # first get global max through an all reduce (max) + local_out = q.new_zeros((*q.shape[:-1], v.shape[-1]), dtype = torch.float32) + lse = torch.full_like(den, -torch.finfo(torch.float32).max) - global_max = local_max.clone() - dist.all_reduce(global_max, dist.ReduceOp.MAX) + # first get max(lse) through an all reduce - # renormalize the numerator and denominators + max_lse = lse.clone() + dist.all_reduce(max_lse, dist.ReduceOp.MAX) - renorm_factor = (local_max - global_max).exp() + # derive numerator and denominator - den *= renorm_factor - num *= renorm_factor + den = (lse - max_lse).exp() + num = local_out * den # second and third all reduce (sum) diff --git a/setup.py b/setup.py index 4b5296c..14314b5 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'ring-attention-pytorch', packages = find_packages(exclude=[]), - version = '0.5.10', + version = '0.5.12', license='MIT', description = 'Ring Attention - Pytorch', author = 'Phil Wang',