diff --git a/assert.py b/assert.py index 0c65eda..3190fb7 100644 --- a/assert.py +++ b/assert.py @@ -38,6 +38,7 @@ def start( causal, striped_ring_attn, dim, + dim_head, use_cuda, compare_regular_attn ): @@ -51,7 +52,7 @@ def start( dim = dim, causal = causal, depth = 2, - dim_head = 64, + dim_head = dim_head, ring_attn = True, striped_ring_attn = striped_ring_attn, ring_seq_size = ring_seq_size, @@ -63,7 +64,7 @@ def start( dim = dim, causal = causal, depth = 2, - dim_head = 64, + dim_head = dim_head, ring_attn = False, ring_seq_size = ring_seq_size, bucket_size = bucket_size, @@ -142,6 +143,7 @@ def start( @click.option('--num-buckets', default = 2, help = 'number of buckets per machine (each sharded sequence is further windowed for flash attention to achieve even greater context lengths)') @click.option('--seq-len', default = 31, help = 'sequence length to test') @click.option('--model-dim', default = 8, help = 'model dimensions for testing') +@click.option('--dim-head', default = 16, help = 'attention head dimension') @click.option('--compare-regular-attn', is_flag = True, help = 'compare ring to regular attention') def test( world_size: int, @@ -154,6 +156,7 @@ def test( num_buckets: int, seq_len: int, model_dim: int, + dim_head: int, compare_regular_attn: bool ): assert not use_cuda or world_size <= torch.cuda.device_count(), f'world size {world_size} must be less than the number of cuda devices {torch.cuda.device_count()}' @@ -170,6 +173,7 @@ def test( causal, striped_ring_attn, model_dim, + dim_head, use_cuda, compare_regular_attn ), diff --git a/assert_attn.py b/assert_attn.py index 289a187..525d1a9 100644 --- a/assert_attn.py +++ b/assert_attn.py @@ -38,6 +38,7 @@ def start( causal, striped_ring_attn, dim, + dim_head, use_cuda, compare_regular_attn ): @@ -49,7 +50,7 @@ def start( ring_attention = RingAttention( dim = dim, causal = causal, - dim_head = 8, + dim_head = dim_head, ring_attn = True, striped_ring_attn = striped_ring_attn, ring_seq_size = ring_seq_size, @@ -61,7 +62,7 @@ def start( flash_attention = RingAttention( dim = dim, causal = causal, - dim_head = 8, + dim_head = dim_head, ring_attn = False, ring_seq_size = ring_seq_size, bucket_size = bucket_size, @@ -80,8 +81,8 @@ def start( if use_cuda: seq = seq.cuda(rank) - flash_attention_net.cuda(rank) - ring_attention_net.cuda(rank) + flash_attention.cuda(rank) + ring_attention.cuda(rank) # separate inputs for ring vs flash @@ -144,6 +145,7 @@ def start( @click.option('--num-buckets', default = 2, help = 'number of buckets per machine (each sharded sequence is further windowed for flash attention to achieve even greater context lengths)') @click.option('--seq-len', default = 31, help = 'sequence length to test') @click.option('--model-dim', default = 8, help = 'model dimensions for testing') +@click.option('--dim-head', default = 16, help = 'model dimensions for testing') @click.option('--compare-regular-attn', is_flag = True, help = 'compare ring to regular attention') def test( world_size: int, @@ -156,6 +158,7 @@ def test( num_buckets: int, seq_len: int, model_dim: int, + dim_head: int, compare_regular_attn: bool ): assert not use_cuda or world_size <= torch.cuda.device_count(), f'world size {world_size} must be less than the number of cuda devices {torch.cuda.device_count()}' @@ -172,6 +175,7 @@ def test( causal, striped_ring_attn, model_dim, + dim_head, use_cuda, compare_regular_attn ), diff --git a/assert_flash.py b/assert_flash.py index 44e186d..a9d569b 100644 --- a/assert_flash.py +++ b/assert_flash.py @@ -12,19 +12,23 @@ @click.command() @click.option('--causal', is_flag = True) @click.option('--seq-len', default = 62) +@click.option('--dim-head', default = 16) +@click.option('--heads', default = 2) @click.option('--bucket_size', default = 4) @click.option('--flash-cuda-kernel', is_flag = True) def test( causal: bool, seq_len: int, + dim_head: int, + heads: int, bucket_size: int, flash_cuda_kernel: bool ): # base qkv - q = torch.randn(2, seq_len, 2, 16) - k = torch.randn(2, seq_len, 2, 16) - v = torch.randn(2, seq_len, 2, 16) + q = torch.randn(2, seq_len, heads, dim_head) + k = torch.randn(2, seq_len, heads, dim_head) + v = torch.randn(2, seq_len, heads, dim_head) # flash and regular qkv's diff --git a/ring_attention_pytorch/triton_flash_attn.py b/ring_attention_pytorch/triton_flash_attn.py index 8e6ae5d..6611876 100644 --- a/ring_attention_pytorch/triton_flash_attn.py +++ b/ring_attention_pytorch/triton_flash_attn.py @@ -540,89 +540,6 @@ def _bwd_kernel( BLOCK_N=BLOCK_N, ) - -def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None): - # shape constraints - batch, seqlen_q, nheads, d = q.shape - _, seqlen_k, _, _ = k.shape - assert k.shape == (batch, seqlen_k, nheads, d) - assert v.shape == (batch, seqlen_k, nheads, d) - assert d <= 128, "FlashAttention only support head dimensions up to 128" - assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type" - assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16" - assert q.is_cuda and k.is_cuda and v.is_cuda - softmax_scale = softmax_scale or 1.0 / math.sqrt(d) - - has_bias = bias is not None - bias_type = "none" - if has_bias: - assert bias.dtype in [q.dtype, torch.float] - assert bias.is_cuda - assert bias.dim() == 4 - if bias.stride(-1) != 1: - bias = bias.contiguous() - if bias.shape[2:] == (1, seqlen_k): - bias_type = "vector" - elif bias.shape[2:] == (seqlen_q, seqlen_k): - bias_type = "matrix" - else: - raise RuntimeError( - "Last 2 dimensions of bias must be (1, seqlen_k)" " or (seqlen_q, seqlen_k)" - ) - bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) - bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) - - seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 - lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) - tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) - o = torch.empty_like(q) - - BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) - BLOCK = 128 - num_warps = 4 if d <= 64 else 8 - grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) - _fwd_kernel[grid]( - q, - k, - v, - bias, - o, - lse, - tmp, - softmax_scale, - q.stride(0), - q.stride(2), - q.stride(1), - k.stride(0), - k.stride(2), - k.stride(1), - v.stride(0), - v.stride(2), - v.stride(1), - *bias_strides, - o.stride(0), - o.stride(2), - o.stride(1), - nheads, - seqlen_q, - seqlen_k, - seqlen_q_rounded, - d, - seqlen_q // 32, - seqlen_k // 32, # key for triton cache (limit number of compilations) - # Can't use kwargs here because triton autotune expects key to be args, not kwargs - # IS_CAUSAL=causal, BLOCK_HEADDIM=d, - bias_type, - causal, - BLOCK_HEADDIM, - BLOCK_M=BLOCK, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return o, lse, softmax_scale # softmax_scale could have been updated - - def _flash_attn_backward( do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, causal_mask_diagonal=False, softmax_scale=None ):