diff --git a/ring_attention_pytorch/tensor_typing.py b/ring_attention_pytorch/tensor_typing.py new file mode 100644 index 0000000..5985007 --- /dev/null +++ b/ring_attention_pytorch/tensor_typing.py @@ -0,0 +1,26 @@ +from torch import Tensor + +from jaxtyping import ( + Float, + Int, + Bool +) + +# jaxtyping is a misnomer, works for pytorch + +class TorchTyping: + def __init__(self, abstract_dtype): + self.abstract_dtype = abstract_dtype + + def __getitem__(self, shapes: str): + return self.abstract_dtype[Tensor, shapes] + +Float = TorchTyping(Float) +Int = TorchTyping(Int) +Bool = TorchTyping(Bool) + +__all__ = [ + Float, + Int, + Bool +] diff --git a/ring_attention_pytorch/tree_attn_decoding.py b/ring_attention_pytorch/tree_attn_decoding.py index 82ae34c..5c46ac6 100644 --- a/ring_attention_pytorch/tree_attn_decoding.py +++ b/ring_attention_pytorch/tree_attn_decoding.py @@ -6,6 +6,8 @@ from ring_attention_pytorch.distributed import get_rank, get_world_size +from ring_attention_pytorch.tensor_typing import Float + # functions def exists(v): @@ -18,21 +20,17 @@ def default(v, d): @torch.no_grad() def tree_attn_decode( - q: Tensor, - k: Tensor | None = None, - v: Tensor | None = None, + q: Float['b h 1 d'], + k: Float['b h n d'] | None = None, + v: Float['b h n dv'] | None = None, eps = 1e-8, - shard_kv_seq = False, + shard_kv_seq = True, use_triton = None -): - dtype = q.dtype +) -> Float['b h 1 dv']: - assert not (exists(k) ^ exists(v)), 'keys and values are either both None, or both present' + q_prec_dims, dtype = q.shape[:-1], q.dtype - if exists(k): - assert k.shape[:-1] == v.shape[:-1] - assert q.shape[-2:] == (1, k.shape[-1]) - assert q.shape[:-2] == k.shape[:-2] + assert not (exists(k) ^ exists(v)), 'keys and values are either both None, or both present' """ Algorithm 3 proposed in Tree Attention @@ -43,6 +41,7 @@ def tree_attn_decode( if shard_kv_seq: assert exists(k), 'keys and values must be passed if not already sharded across sequence' + dim_v = v.shape[-1] rank, world_size = get_rank(), get_world_size() k = k.chunk(world_size, dim = -2) @@ -68,6 +67,7 @@ def tree_attn_decode( remove_padding = True ) + lse = rearrange(lse, '... -> ... 1') else: scale = q.shape[-1] ** -0.5 sim = einsum('... i d, ... j d -> ... i j', q, k) * scale @@ -79,8 +79,8 @@ def tree_attn_decode( else: # handle edge case where seq length < world size - local_out = q.new_zeros((*q.shape[:-1], v.shape[-1]), dtype = torch.float32) - lse = torch.full_like(den, -torch.finfo(torch.float32).max) + local_out = q.new_zeros((*q_prec_dims, dim_v), dtype = torch.float32) + lse = torch.full((*q_prec_dims, 1), -torch.finfo(torch.float32).max, device = q.device, dtype = torch.float32) # first get max(lse) through an all reduce diff --git a/ring_attention_pytorch/triton_flash_attn.py b/ring_attention_pytorch/triton_flash_attn.py index 2679de3..74bf901 100644 --- a/ring_attention_pytorch/triton_flash_attn.py +++ b/ring_attention_pytorch/triton_flash_attn.py @@ -322,10 +322,10 @@ def flash_attn_forward( q, k, v = [x if is_contiguous(x) else x.contiguous() for x in (q, k, v)] if head_first_dim: - q, k, v = tuple(rearrange(t, 'b n h d -> b h n d') for t in (q, k, v)) + q, k, v = tuple(rearrange(t, 'b h n d -> b n h d') for t in (q, k, v)) if exists(o): - o = rearrange(o, 'b n h d -> b h n d') + o = rearrange(o, 'b h n d -> b n h d') batch, seqlen_q, nheads, d = q.shape _, seqlen_k, _, _ = k.shape @@ -421,7 +421,7 @@ def flash_attn_forward( ) if head_first_dim: - o = rearrange(o, 'b h n d -> b n h d') + o = rearrange(o, 'b n h d -> b h n d') if remove_padding: m = m[..., :seqlen_q] diff --git a/setup.py b/setup.py index 14314b5..7a65acc 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'ring-attention-pytorch', packages = find_packages(exclude=[]), - version = '0.5.12', + version = '0.5.16', license='MIT', description = 'Ring Attention - Pytorch', author = 'Phil Wang', @@ -18,7 +18,8 @@ install_requires=[ 'beartype', 'einops>=0.8.0', - 'torch>=2.0' + 'jaxtyping', + 'torch>=2.0', ], classifiers=[ 'Development Status :: 4 - Beta',