Skip to content

Commit

Permalink
retest triton backwards on runpod
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 10, 2024
1 parent de47575 commit 7e81a42
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 10 deletions.
10 changes: 2 additions & 8 deletions ring_attention_pytorch/ring_flash_attention_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,6 @@ def pad_at_dim(t, pad: Tuple[int, int], *, dim = -1, value = 0.):
def is_empty(t: Tensor):
return t.numel() == 0

def is_contiguous(x: Tensor):
return x.stride(-1) == 1

def padded_false_on_right_side(t: Tensor):
if t.shape[-1] <= 1:
return True
Expand All @@ -65,8 +62,7 @@ def padded_false_on_right_side(t: Tensor):
assert pkg_version.parse(flash_attn_version) >= pkg_version.parse('2.5.1')

from flash_attn.flash_attn_interface import (
_flash_attn_varlen_backward,
_flash_attn_backward
_flash_attn_varlen_backward
)

from flash_attn.bert_padding import (
Expand Down Expand Up @@ -369,15 +365,13 @@ def backward(ctx, do):

# use flash attention backwards kernel to calculate dq, dk, dv and accumulate

from ring_attention_pytorch.triton_flash_attn import _flash_attn_backward

if need_accum:
ring_dq = torch.empty_like(q)
ring_dk = torch.empty_like(k)
ring_dv = torch.empty_like(v)

with torch.inference_mode():
_flash_attn_backward(
flash_attn_backward(
do,
q,
k,
Expand Down
12 changes: 11 additions & 1 deletion ring_attention_pytorch/triton_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,19 @@
import math

import torch
from torch import Tensor
import triton
import triton.language as tl

def exists(v):
return v is not None

def default(val, d):
return val if exists(val) else d

def is_contiguous(x: Tensor):
return x.stride(-1) == 1

@triton.heuristics(
{
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
Expand Down Expand Up @@ -901,7 +911,7 @@ def _bwd_kernel(
BLOCK_N=BLOCK_N,
)

def _flash_attn_backward(
def flash_attn_backward(
do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, causal_mask_diagonal=False, softmax_scale=None
):
# Make sure that the last dimension is contiguous
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'ring-attention-pytorch',
packages = find_packages(exclude=[]),
version = '0.3.14',
version = '0.3.15',
license='MIT',
description = 'Ring Attention - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 7e81a42

Please sign in to comment.