diff --git a/ring_attention_pytorch/ring_flash_attention_cuda.py b/ring_attention_pytorch/ring_flash_attention_cuda.py index df56247..7c33f0f 100644 --- a/ring_attention_pytorch/ring_flash_attention_cuda.py +++ b/ring_attention_pytorch/ring_flash_attention_cuda.py @@ -44,9 +44,6 @@ def pad_at_dim(t, pad: Tuple[int, int], *, dim = -1, value = 0.): def is_empty(t: Tensor): return t.numel() == 0 -def is_contiguous(x: Tensor): - return x.stride(-1) == 1 - def padded_false_on_right_side(t: Tensor): if t.shape[-1] <= 1: return True @@ -65,8 +62,7 @@ def padded_false_on_right_side(t: Tensor): assert pkg_version.parse(flash_attn_version) >= pkg_version.parse('2.5.1') from flash_attn.flash_attn_interface import ( - _flash_attn_varlen_backward, - _flash_attn_backward + _flash_attn_varlen_backward ) from flash_attn.bert_padding import ( @@ -369,15 +365,13 @@ def backward(ctx, do): # use flash attention backwards kernel to calculate dq, dk, dv and accumulate - from ring_attention_pytorch.triton_flash_attn import _flash_attn_backward - if need_accum: ring_dq = torch.empty_like(q) ring_dk = torch.empty_like(k) ring_dv = torch.empty_like(v) with torch.inference_mode(): - _flash_attn_backward( + flash_attn_backward( do, q, k, diff --git a/ring_attention_pytorch/triton_flash_attn.py b/ring_attention_pytorch/triton_flash_attn.py index 1921879..112e12f 100644 --- a/ring_attention_pytorch/triton_flash_attn.py +++ b/ring_attention_pytorch/triton_flash_attn.py @@ -5,9 +5,19 @@ import math import torch +from torch import Tensor import triton import triton.language as tl +def exists(v): + return v is not None + +def default(val, d): + return val if exists(val) else d + +def is_contiguous(x: Tensor): + return x.stride(-1) == 1 + @triton.heuristics( { "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, @@ -901,7 +911,7 @@ def _bwd_kernel( BLOCK_N=BLOCK_N, ) -def _flash_attn_backward( +def flash_attn_backward( do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, causal_mask_diagonal=False, softmax_scale=None ): # Make sure that the last dimension is contiguous diff --git a/setup.py b/setup.py index 5d98217..57c8dd7 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'ring-attention-pytorch', packages = find_packages(exclude=[]), - version = '0.3.14', + version = '0.3.15', license='MIT', description = 'Ring Attention - Pytorch', author = 'Phil Wang',