diff --git a/ring_attention_pytorch/ring.py b/ring_attention_pytorch/ring.py index cea719b..ffdc007 100644 --- a/ring_attention_pytorch/ring.py +++ b/ring_attention_pytorch/ring.py @@ -94,11 +94,12 @@ def null_ring_pass(*tensors, max_iters = None): def all_ring_pass(*tensors, max_iters = None): world_size = get_world_size() max_iters = default(max_iters, world_size) - total_iters = min(world_size, max_iters) - curr_ring_pos = get_rank() + # make sure iteration is between 1 and world size + + total_iters = max(1, min(world_size, max_iters)) - assert total_iters > 0 + curr_ring_pos = get_rank() for ind in range(total_iters): is_last = ind == (total_iters - 1) diff --git a/ring_attention_pytorch/ring_flash_attention.py b/ring_attention_pytorch/ring_flash_attention.py index bec219c..81bdb3a 100644 --- a/ring_attention_pytorch/ring_flash_attention.py +++ b/ring_attention_pytorch/ring_flash_attention.py @@ -16,6 +16,11 @@ get_rank ) +from ring_attention_pytorch.rotary import ( + RotaryEmbedding, + apply_rotary_pos_emb +) + # constants EPSILON = 1e-10 diff --git a/ring_attention_pytorch/rotary.py b/ring_attention_pytorch/rotary.py new file mode 100644 index 0000000..0b5bc21 --- /dev/null +++ b/ring_attention_pytorch/rotary.py @@ -0,0 +1,32 @@ +import torch +from torch import nn, einsum +from torch.nn import Module +from torch.cuda.amp import autocast + +class RotaryEmbedding(Module): + def __init__( + self, + dim, + theta = 10000 + ): + super().__init__() + inv_freq = theta ** -(torch.arange(0, dim, 2).float() / dim) + self.register_buffer('inv_freq', inv_freq) + + @autocast(enabled = False) + def forward( + self, + seq_len, + offset = 0 + ): + t = torch.arange(seq_len + offset, device = self.inv_freq.device).type_as(self.inv_freq) + freqs = torch.einsum('i , j -> i j', t, self.inv_freq) + return torch.cat((freqs, freqs), dim = -1) + +def rotate_half(x): + x1, x2 = x.chunk(2, dim = -1) + return torch.cat((-x2, x1), dim=-1) + +@autocast(enabled = False) +def apply_rotary_pos_emb(pos, t): + return t * pos.cos() + rotate_half(t) * pos.sin() diff --git a/setup.py b/setup.py index b8464dd..b468254 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'ring-attention-pytorch', packages = find_packages(exclude=[]), - version = '0.1.8', + version = '0.1.9', license='MIT', description = 'Ring Attention - Pytorch', author = 'Phil Wang',