Skip to content

Commit

Permalink
fix issue with seq len < world size in tree attn decoding after new u…
Browse files Browse the repository at this point in the history
…pdate, also start using tensor typing
  • Loading branch information
lucidrains committed Aug 15, 2024
1 parent f3ed323 commit d91a37f
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 18 deletions.
26 changes: 26 additions & 0 deletions ring_attention_pytorch/tensor_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from torch import Tensor

from jaxtyping import (
Float,
Int,
Bool
)

# jaxtyping is a misnomer, works for pytorch

class TorchTyping:
def __init__(self, abstract_dtype):
self.abstract_dtype = abstract_dtype

def __getitem__(self, shapes: str):
return self.abstract_dtype[Tensor, shapes]

Float = TorchTyping(Float)
Int = TorchTyping(Int)
Bool = TorchTyping(Bool)

__all__ = [
Float,
Int,
Bool
]
26 changes: 13 additions & 13 deletions ring_attention_pytorch/tree_attn_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from ring_attention_pytorch.distributed import get_rank, get_world_size

from ring_attention_pytorch.tensor_typing import Float

# functions

def exists(v):
Expand All @@ -18,21 +20,17 @@ def default(v, d):

@torch.no_grad()
def tree_attn_decode(
q: Tensor,
k: Tensor | None = None,
v: Tensor | None = None,
q: Float['b h 1 d'],
k: Float['b h n d'] | None = None,
v: Float['b h n dv'] | None = None,
eps = 1e-8,
shard_kv_seq = False,
shard_kv_seq = True,
use_triton = None
):
dtype = q.dtype
) -> Float['b h 1 dv']:

assert not (exists(k) ^ exists(v)), 'keys and values are either both None, or both present'
q_prec_dims, dtype = q.shape[:-1], q.dtype

if exists(k):
assert k.shape[:-1] == v.shape[:-1]
assert q.shape[-2:] == (1, k.shape[-1])
assert q.shape[:-2] == k.shape[:-2]
assert not (exists(k) ^ exists(v)), 'keys and values are either both None, or both present'

"""
Algorithm 3 proposed in Tree Attention
Expand All @@ -43,6 +41,7 @@ def tree_attn_decode(

if shard_kv_seq:
assert exists(k), 'keys and values must be passed if not already sharded across sequence'
dim_v = v.shape[-1]

rank, world_size = get_rank(), get_world_size()
k = k.chunk(world_size, dim = -2)
Expand All @@ -68,6 +67,7 @@ def tree_attn_decode(
remove_padding = True
)

lse = rearrange(lse, '... -> ... 1')
else:
scale = q.shape[-1] ** -0.5
sim = einsum('... i d, ... j d -> ... i j', q, k) * scale
Expand All @@ -79,8 +79,8 @@ def tree_attn_decode(
else:
# handle edge case where seq length < world size

local_out = q.new_zeros((*q.shape[:-1], v.shape[-1]), dtype = torch.float32)
lse = torch.full_like(den, -torch.finfo(torch.float32).max)
local_out = q.new_zeros((*q_prec_dims, dim_v), dtype = torch.float32)
lse = torch.full((*q_prec_dims, 1), -torch.finfo(torch.float32).max, device = q.device, dtype = torch.float32)

# first get max(lse) through an all reduce

Expand Down
6 changes: 3 additions & 3 deletions ring_attention_pytorch/triton_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,10 +322,10 @@ def flash_attn_forward(
q, k, v = [x if is_contiguous(x) else x.contiguous() for x in (q, k, v)]

if head_first_dim:
q, k, v = tuple(rearrange(t, 'b n h d -> b h n d') for t in (q, k, v))
q, k, v = tuple(rearrange(t, 'b h n d -> b n h d') for t in (q, k, v))

if exists(o):
o = rearrange(o, 'b n h d -> b h n d')
o = rearrange(o, 'b h n d -> b n h d')

batch, seqlen_q, nheads, d = q.shape
_, seqlen_k, _, _ = k.shape
Expand Down Expand Up @@ -421,7 +421,7 @@ def flash_attn_forward(
)

if head_first_dim:
o = rearrange(o, 'b h n d -> b n h d')
o = rearrange(o, 'b n h d -> b h n d')

if remove_padding:
m = m[..., :seqlen_q]
Expand Down
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'ring-attention-pytorch',
packages = find_packages(exclude=[]),
version = '0.5.12',
version = '0.5.16',
license='MIT',
description = 'Ring Attention - Pytorch',
author = 'Phil Wang',
Expand All @@ -18,7 +18,8 @@
install_requires=[
'beartype',
'einops>=0.8.0',
'torch>=2.0'
'jaxtyping',
'torch>=2.0',
],
classifiers=[
'Development Status :: 4 - Beta',
Expand Down

0 comments on commit d91a37f

Please sign in to comment.