Skip to content

Commit

Permalink
add ring_flash_attn_cuda to __init__.py
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 11, 2024
1 parent 24e1673 commit ea80588
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 18 deletions.
5 changes: 5 additions & 0 deletions ring_attention_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,8 @@
ring_flash_attn,
ring_flash_attn_
)

from ring_attention_pytorch.ring_flash_attention_cuda import (
ring_flash_attn_cuda,
ring_flash_attn_cuda_
)
20 changes: 4 additions & 16 deletions ring_attention_pytorch/ring_flash_attention_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,21 +44,6 @@ def pad_at_dim(t, pad: Tuple[int, int], *, dim = -1, value = 0.):
def is_empty(t: Tensor):
return t.numel() == 0

# make sure triton is installed

import importlib
from importlib.metadata import version

assert exists(importlib.util.find_spec('triton')), 'latest triton must be installed. `pip install triton -U` first'

triton_version = version('triton')
assert pkg_version.parse(triton_version) >= pkg_version.parse('2.1'), 'triton must be version 2.1 or above. `pip install triton -U` to upgrade'

from ring_attention_pytorch.triton_flash_attn import (
flash_attn_backward,
flash_attn_forward
)

# ring + (flash) attention forwards and backwards

# flash attention v1 - https://arxiv.org/abs/2205.14135
Expand All @@ -82,6 +67,8 @@ def forward(
max_lookback_seq_len: Optional[int],
ring_size: Optional[int]
):
from ring_attention_pytorch.triton_flash_attn import flash_attn_forward

assert all([t.is_cuda for t in (q, k, v)]), 'inputs must be all on cuda'

dtype = q.dtype
Expand Down Expand Up @@ -220,7 +207,8 @@ def forward(
@staticmethod
@torch.no_grad()
def backward(ctx, do):
""" Algorithm 2 in the v2 paper """

from ring_attention_pytorch.triton_flash_attn import flash_attn_backward

(
causal,
Expand Down
17 changes: 15 additions & 2 deletions ring_attention_pytorch/triton_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@

import torch
from torch import Tensor
import triton
import triton.language as tl

from einops import repeat

Expand All @@ -21,6 +19,21 @@ def default(val, d):
def is_contiguous(x: Tensor):
return x.stride(-1) == 1

# make sure triton 2.1+ is installed

import importlib
from importlib.metadata import version

assert exists(importlib.util.find_spec('triton')), 'latest triton must be installed. `pip install triton -U` first'

triton_version = version('triton')
assert pkg_version.parse(triton_version) >= pkg_version.parse('2.1'), 'triton must be version 2.1 or above. `pip install triton -U` to upgrade'

import triton
import triton.language as tl

# kernels

@triton.heuristics(
{
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
Expand Down

0 comments on commit ea80588

Please sign in to comment.