From 0b5d5af4d8c146a4f32141bca4dd5de4ad948607 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Mon, 8 Apr 2024 03:46:28 +0000 Subject: [PATCH] backwards cuda causal mask hack now relies on padding once and indexing out --- ring_attention_pytorch/ring_flash_attention_cuda.py | 6 ------ setup.py | 2 +- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/ring_attention_pytorch/ring_flash_attention_cuda.py b/ring_attention_pytorch/ring_flash_attention_cuda.py index ce71480..8dcf876 100644 --- a/ring_attention_pytorch/ring_flash_attention_cuda.py +++ b/ring_attention_pytorch/ring_flash_attention_cuda.py @@ -767,12 +767,6 @@ def backward(ctx, do): else: ring_dq, ring_dk, ring_dv = 0., 0., 0. - - q = q[:, :row_length] - o = o[:, :row_length] - do = do[:, :row_length] - lse = lse[..., :row_length] - else: ( diff --git a/setup.py b/setup.py index c347516..94bbe45 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'ring-attention-pytorch', packages = find_packages(exclude=[]), - version = '0.3.6', + version = '0.3.7', license='MIT', description = 'Ring Attention - Pytorch', author = 'Phil Wang',