Skip to content

Commit

Permalink
prepare rotary embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 20, 2024
1 parent 65a47a1 commit ebc5e58
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 4 deletions.
7 changes: 4 additions & 3 deletions ring_attention_pytorch/ring.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,12 @@ def null_ring_pass(*tensors, max_iters = None):
def all_ring_pass(*tensors, max_iters = None):
world_size = get_world_size()
max_iters = default(max_iters, world_size)
total_iters = min(world_size, max_iters)

curr_ring_pos = get_rank()
# make sure iteration is between 1 and world size

total_iters = max(1, min(world_size, max_iters))

assert total_iters > 0
curr_ring_pos = get_rank()

for ind in range(total_iters):
is_last = ind == (total_iters - 1)
Expand Down
5 changes: 5 additions & 0 deletions ring_attention_pytorch/ring_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
get_rank
)

from ring_attention_pytorch.rotary import (
RotaryEmbedding,
apply_rotary_pos_emb
)

# constants

EPSILON = 1e-10
Expand Down
32 changes: 32 additions & 0 deletions ring_attention_pytorch/rotary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import torch
from torch import nn, einsum
from torch.nn import Module
from torch.cuda.amp import autocast

class RotaryEmbedding(Module):
def __init__(
self,
dim,
theta = 10000
):
super().__init__()
inv_freq = theta ** -(torch.arange(0, dim, 2).float() / dim)
self.register_buffer('inv_freq', inv_freq)

@autocast(enabled = False)
def forward(
self,
seq_len,
offset = 0
):
t = torch.arange(seq_len + offset, device = self.inv_freq.device).type_as(self.inv_freq)
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
return torch.cat((freqs, freqs), dim = -1)

def rotate_half(x):
x1, x2 = x.chunk(2, dim = -1)
return torch.cat((-x2, x1), dim=-1)

@autocast(enabled = False)
def apply_rotary_pos_emb(pos, t):
return t * pos.cos() + rotate_half(t) * pos.sin()
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'ring-attention-pytorch',
packages = find_packages(exclude=[]),
version = '0.1.8',
version = '0.1.9',
license='MIT',
description = 'Ring Attention - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit ebc5e58

Please sign in to comment.