From 68237946e26752becfbe46374bbc8e31b82beeb8 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Wed, 10 Apr 2024 19:38:20 +0000 Subject: [PATCH] cache the preprocess backwards kernel output across the ring reduce --- README.md | 2 +- .../ring_flash_attention_cuda.py | 9 ++- ring_attention_pytorch/triton_flash_attn.py | 58 ++++++++++++------- setup.py | 2 +- 4 files changed, 47 insertions(+), 24 deletions(-) diff --git a/README.md b/README.md index 1495a00..88d5ad4 100644 --- a/README.md +++ b/README.md @@ -98,8 +98,8 @@ $ python assert.py --use-cuda --causal --striped-ring-attn - [x] make sure cuda striped attention works for multiple buckets, otherwise flash attention is ineffective - [x] for cuda striped attention, for backwards hack, pad the extra token once and index out when passing into Tri's cuda kernel - [x] find a machine with 8 GPUs and test with a quarter million tokens first +- [x] see for cuda version whether softmax_D can be computed once and cached over the ring reduce. go for modified triton backwards if not -- [ ] see for cuda version whether softmax_D can be computed once and cached over the ring reduce. go for modified triton backwards if not - [ ] think about how to craft a special `Dataset` that shards across sequence length (take into account labels for cross entropy loss) for ring transformer training - [ ] add ring attention to Tri's flash attention implementation. find some cuda ring reduce impl - [ ] figure out how to pytest distributed pytorch diff --git a/ring_attention_pytorch/ring_flash_attention_cuda.py b/ring_attention_pytorch/ring_flash_attention_cuda.py index 7c33f0f..314d8f7 100644 --- a/ring_attention_pytorch/ring_flash_attention_cuda.py +++ b/ring_attention_pytorch/ring_flash_attention_cuda.py @@ -94,7 +94,7 @@ def inverse_fn(y): assert exists(importlib.util.find_spec('triton')), 'latest triton must be installed. `pip install triton -U` first' triton_version = version('triton') -assert pkg_version.parse(triton_version) >= pkg_version.parse('2.1') +assert pkg_version.parse(triton_version) >= pkg_version.parse('2.1'), 'triton must be version 2.1 or above. `pip install triton -U` to upgrade' import triton import triton.language as tl @@ -320,6 +320,10 @@ def backward(ctx, do): receive_kv_and_dkv = None receive_mask = None + # caching the delta (do * o for backwards pass) across ring reduce + + delta = None + # if not causal and has key padding mask # prepare row related tensors with unpad_input @@ -371,7 +375,7 @@ def backward(ctx, do): ring_dv = torch.empty_like(v) with torch.inference_mode(): - flash_attn_backward( + delta = flash_attn_backward( do, q, k, @@ -381,6 +385,7 @@ def backward(ctx, do): ring_dq, ring_dk, ring_dv, + delta = delta, causal = block_causal, causal_mask_diagonal = causal_mask_diagonal, softmax_scale = softmax_scale diff --git a/ring_attention_pytorch/triton_flash_attn.py b/ring_attention_pytorch/triton_flash_attn.py index 112e12f..1c5dc94 100644 --- a/ring_attention_pytorch/triton_flash_attn.py +++ b/ring_attention_pytorch/triton_flash_attn.py @@ -912,7 +912,20 @@ def _bwd_kernel( ) def flash_attn_backward( - do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, causal_mask_diagonal=False, softmax_scale=None + do, + q, + k, + v, + o, + lse, + dq, + dk, + dv, + delta=None, + bias=None, + causal=False, + causal_mask_diagonal=False, + softmax_scale=None ): # Make sure that the last dimension is contiguous if do.stride(-1) != 1: @@ -929,28 +942,31 @@ def flash_attn_backward( softmax_scale = softmax_scale or 1.0 / math.sqrt(d) # dq_accum = torch.zeros_like(q, dtype=torch.float32) dq_accum = torch.empty_like(q, dtype=torch.float32) - delta = torch.empty_like(lse) + # delta = torch.zeros_like(lse) BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) - grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) - _bwd_preprocess_do_o_dot[grid]( - o, - do, - delta, - o.stride(0), - o.stride(2), - o.stride(1), - do.stride(0), - do.stride(2), - do.stride(1), - nheads, - seqlen_q, - seqlen_q_rounded, - d, - BLOCK_M=128, - BLOCK_HEADDIM=BLOCK_HEADDIM, - ) + + if not exists(delta): + delta = torch.empty_like(lse) + grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) + _bwd_preprocess_do_o_dot[grid]( + o, + do, + delta, + o.stride(0), + o.stride(2), + o.stride(1), + do.stride(0), + do.stride(2), + do.stride(1), + nheads, + seqlen_q, + seqlen_q_rounded, + d, + BLOCK_M=128, + BLOCK_HEADDIM=BLOCK_HEADDIM, + ) has_bias = bias is not None bias_type = "none" @@ -1030,3 +1046,5 @@ def flash_attn_backward( # num_stages=1, ) dq.copy_(dq_accum) + + return delta diff --git a/setup.py b/setup.py index 57c8dd7..eb53330 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'ring-attention-pytorch', packages = find_packages(exclude=[]), - version = '0.3.15', + version = '0.3.16', license='MIT', description = 'Ring Attention - Pytorch', author = 'Phil Wang',