From 51d0d9201bbcd7479468958e006ff22090eec5a2 Mon Sep 17 00:00:00 2001 From: Joseph Groenenboom Date: Thu, 8 Aug 2024 15:10:22 +0000 Subject: [PATCH 1/3] Add support for causal masking as a toggle and more datatype support --- python/perf-kernels/flash-attention.py | 34 +++++++++++++++++--------- 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/python/perf-kernels/flash-attention.py b/python/perf-kernels/flash-attention.py index 988438340abe..faac1fe7d123 100644 --- a/python/perf-kernels/flash-attention.py +++ b/python/perf-kernels/flash-attention.py @@ -632,14 +632,14 @@ def _bwd_kernel_dk_dv(dk, dv, Q, k, v, sm_scale, alibi_slope, DO, M, D, do = tl.load(DO_block_ptr) # Compute dV. ppT = pT - ppT = ppT.to(tl.float16) + ppT = ppT.to(do.dtype) dv += tl.dot(ppT, do) # D (= delta) is pre-divided by ds_scale. Di = tl.load(D + offs_m) # Compute dP and dS. dpT = tl.dot(v, tl.trans(do)) dsT = pT * (dpT - Di[None, :]) - dsT = dsT.to(tl.float16) + dsT = dsT.to(qT.dype) dk += tl.dot(dsT, tl.trans(qT)) # Increment pointers. curr_m += step_m @@ -685,7 +685,7 @@ def _bwd_kernel_dq(dq, q, K, V, do, m, D, alibi_slope, vT = tl.load(VT_block_ptr) dp = tl.dot(do, vT).to(tl.float32) ds = p * (dp - Di[:, None]) - ds = ds.to(tl.float16) + ds = ds.to(kT.dtype) # Compute dQ.0. # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. dq += tl.dot(ds, tl.trans(kT)) @@ -765,14 +765,14 @@ def _attn_bwd(Q, K, V, sm_scale, alibi_slopes, DO, DQ, DK, DV, M, D, # compute dK and dV for blocks close to the diagonal that need to be masked num_steps = BLOCK_N1 // MASK_BLOCK_M1 dk, dv = _bwd_kernel_dk_dv(dk, dv, Q, k, v, sm_scale, alibi_slope, DO, M, D, stride_tok, stride_d, H, N_CTX, - MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, start_n, start_m, num_steps, MASK=True) + MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, start_n, start_m, num_steps, MASK=CAUSAL) # compute dK and dV for blocks that don't need masking further from the diagonal start_m += num_steps * MASK_BLOCK_M1 num_steps = (N_CTX - start_m) // BLOCK_M1 dk, dv = _bwd_kernel_dk_dv(dk, dv, Q, k, v, sm_scale, alibi_slope, DO, M, D, stride_tok, stride_d, H, N_CTX, - BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, start_n, start_m, num_steps, MASK=False) + BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, start_n, start_m, num_steps, MASK=CAUSAL) DV_block_ptrs = tl.make_block_ptr(base=DV, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), offsets=(start_n, 0), block_shape=(BLOCK_N1, BLOCK_DMODEL), order=(1, 0)) @@ -999,6 +999,7 @@ def backward(ctx, do, _): q.stride(3), N_HEAD, N_CTX, + CAUSAL=ctx.causal, BLOCK_DMODEL=ctx.BLOCK_DMODEL, BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, @@ -1261,10 +1262,11 @@ def test_op_varlen_mqa_fwd(Z, HQ, HK, N_CTX, D_HEAD, causal, dtype=torch.float16 ]) @pytest.mark.parametrize('qseqlen_not_equal_kseqlen', [None]) @pytest.mark.parametrize('torch_sdpa_test', [False, True]) -@pytest.mark.parametrize('causal', [True]) +@pytest.mark.parametrize('causal', [False,True]) @pytest.mark.parametrize('use_alibi', [False, True]) +@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) def test_op_bwd(Z, H, N_CTX, D_HEAD, qseqlen_not_equal_kseqlen, causal, torch_sdpa_test, use_alibi, - dtype=torch.float16): + dtype): pytest.skip() torch.manual_seed(20) if qseqlen_not_equal_kseqlen is not None: @@ -1396,6 +1398,15 @@ def varlen_benchmark_configs(): ] return configs +def nonvarlen_backward_benchmark_configs(): + configs=[(16, 16, 16, 1024, 1024), + (8, 16, 16, 2048, 2048), + (4, 16, 16, 4096, 4096), + (2, 16, 16, 8192, 8192), + (1, 16, 16, 16384, 16384), + (2, 48, 48, 1024, 1024), + ] + return configs def run_benchmark(custom, args): @@ -1403,7 +1414,7 @@ def run_benchmark(custom, args): hk = args.hq if not args.hk else args.hk sk = args.sq if not args.sk else args.sk head_size = 128 if not args.d else args.d - mode = 'fwd' + mode = args.direction x_names = ['BATCH', 'HQ', 'HK', 'N_CTX_Q', 'N_CTX_K'] causal = args.causal varlen = args.layout == 'thd' @@ -1413,6 +1424,8 @@ def run_benchmark(custom, args): else: if varlen: x_vals_list = varlen_benchmark_configs() + elif mode == 'bwd': + x_vals_list = nonvarlen_backward_benchmark_configs() else: x_vals_list = nonvarlen_benchmark_configs() print_time = args.return_time @@ -1436,10 +1449,6 @@ def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal # bias = None # bias = None - # Bwd pass only supports causal=True right now - if mode == 'bwd': - causal = True - flops_per_matmul = 0 if varlen: q, k, v, input_metadata = varlen_input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, @@ -1502,6 +1511,7 @@ def parse_args(): parser.add_argument("-dtype", default='fp16') parser.add_argument("-return_time", action='store_true', default=False) parser.add_argument("-layout", type=str, default='bhsd', help=supported_layouts()) + parser.add_argument("-direction", default='fwd') return parser.parse_args() From ae4633c4e7e12bcd17710635b0516958644c2c50 Mon Sep 17 00:00:00 2001 From: Joseph Groenenboom Date: Thu, 8 Aug 2024 12:32:57 -0500 Subject: [PATCH 2/3] Unify with new forward tests and set num_stages --- python/perf-kernels/flash-attention.py | 132 +++-- python/tutorials/06-fused-attention.py | 753 +++++++++++++++---------- 2 files changed, 525 insertions(+), 360 deletions(-) diff --git a/python/perf-kernels/flash-attention.py b/python/perf-kernels/flash-attention.py index faac1fe7d123..58cf45024dcd 100644 --- a/python/perf-kernels/flash-attention.py +++ b/python/perf-kernels/flash-attention.py @@ -316,7 +316,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri num_warps=4), ], key=['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'], - use_cuda_graph=True, + #use_cuda_graph=True, ) @triton.jit def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, L, Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, @@ -639,7 +639,7 @@ def _bwd_kernel_dk_dv(dk, dv, Q, k, v, sm_scale, alibi_slope, DO, M, D, # Compute dP and dS. dpT = tl.dot(v, tl.trans(do)) dsT = pT * (dpT - Di[None, :]) - dsT = dsT.to(qT.dype) + dsT = dsT.to(qT.dtype) dk += tl.dot(dsT, tl.trans(qT)) # Increment pointers. curr_m += step_m @@ -695,13 +695,12 @@ def _bwd_kernel_dq(dq, q, K, V, do, m, D, alibi_slope, VT_block_ptr = tl.advance(VT_block_ptr, (0, step_n)) return dq - @triton.jit def _attn_bwd(Q, K, V, sm_scale, alibi_slopes, DO, DQ, DK, DV, M, D, # shared by Q/K/V/DO. stride_z, stride_h, stride_tok, stride_d, # H = 16, N_CTX = 1024 - H, N_CTX, BLOCK_DMODEL: tl.constexpr, BLOCK_M1: tl.constexpr, BLOCK_N1: tl.constexpr, + H, N_CTX, CAUSAL: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_M1: tl.constexpr, BLOCK_N1: tl.constexpr, BLOCK_M2: tl.constexpr, BLOCK_N2: tl.constexpr, BLK_SLICE_FACTOR: tl.constexpr, USE_ALIBI: tl.constexpr): LN2: tl.constexpr = 0.6931471824645996 # = ln(2) @@ -943,6 +942,7 @@ def backward(ctx, do, _): BLOCK = 64 else: BLOCK = 128 + num_stages = 1 q, k, v, o, M = ctx.saved_tensors assert do.is_contiguous() assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride() @@ -1007,6 +1007,7 @@ def backward(ctx, do, _): BLOCK_N2=BLOCK_N2, BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, USE_ALIBI=False if ctx.alibi_slopes is None else True, + num_stages = 1, ) return dq, dk, dv, None, None @@ -1260,92 +1261,86 @@ def test_op_varlen_mqa_fwd(Z, HQ, HK, N_CTX, D_HEAD, causal, dtype=torch.float16 #(1, 16, 8192, 63), #(1, 16, 1022, 64), ]) -@pytest.mark.parametrize('qseqlen_not_equal_kseqlen', [None]) -@pytest.mark.parametrize('torch_sdpa_test', [False, True]) -@pytest.mark.parametrize('causal', [False,True]) +@pytest.mark.parametrize('causal', [False, True]) @pytest.mark.parametrize('use_alibi', [False, True]) @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) -def test_op_bwd(Z, H, N_CTX, D_HEAD, qseqlen_not_equal_kseqlen, causal, torch_sdpa_test, use_alibi, - dtype): - pytest.skip() - torch.manual_seed(20) - if qseqlen_not_equal_kseqlen is not None: - seqlen_q = qseqlen_not_equal_kseqlen - else: - seqlen_q = N_CTX - seqlen_k = N_CTX - - if causal and ((N_CTX - 1) & N_CTX): - pytest.skip() - if causal and seqlen_q != seqlen_k: - pytest.skip() - - sm_scale = D_HEAD**-0.5 - input_metadata = MetaData(sm_scale=sm_scale) - input_metadata.max_seqlens_q = seqlen_q - input_metadata.max_seqlens_k = seqlen_k - - dropout_p = 0 - q = (torch.empty((Z, H, seqlen_q, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) - k = (torch.empty((Z, H, seqlen_k, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) - v = (torch.empty((Z, H, seqlen_k, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) - o = torch.empty_like(q) - +@pytest.mark.parametrize('layout', ['bhsd']) +def test_op_bwd(Z, H, N_CTX, D_HEAD, causal, use_alibi, + layout, dtype): + torch.manual_seed(20) + + N_CTX_Q = N_CTX_K = N_CTX + HQ = HK = H + + q, k, v, input_metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout) + dout = torch.randn_like(q) + if causal: input_metadata.need_causal() - if use_alibi and not torch_sdpa_test: + if use_alibi: # for n heads the set of slopes is the geometric sequence that starts 2^(-8/n) - alibi_slopes = torch.tensor([2**(-8 / H * i) for i in range(1, H + 1)], dtype=torch.float32, + alibi_slopes = torch.tensor([2**(-8 / HQ * i) for i in range(1, HQ + 1)], dtype=torch.float32, device="cuda").repeat(Z, 1) - input_metadata.need_alibi(alibi_slopes, Z, H) - dout = torch.randn_like(q) - # reference implementation - if torch_sdpa_test: - ref_out, ref_softmax = torch.ops.aten._scaled_dot_product_attention_math(q, k, v, dropout_p=dropout_p, - is_causal=causal, scale=sm_scale, - dropout_mask=None) - ref_out.backward(dout.to(device=ref_out.device, dtype=ref_out.dtype)) - ref_dv, v.grad = v.grad.clone(), None - ref_dk, k.grad = k.grad.clone(), None - ref_dq, q.grad = q.grad.clone(), None + input_metadata.need_alibi(alibi_slopes, Z, HQ) else: - M = torch.tril(torch.ones((seqlen_q, seqlen_k), device="cuda")) - p = torch.matmul(q, k.transpose(2, 3)) * sm_scale - if use_alibi: - p += compute_alibi_tensor(alibi_slopes, N_CTX, N_CTX) - if causal: - p[:, :, M == 0] = float("-inf") + alibi_slopes = None - p = torch.softmax(p.float(), dim=-1).type(dtype=p.dtype) - ref_out = torch.matmul(p, v) - ref_out.backward(dout) - ref_dv, v.grad = v.grad.clone(), None - ref_dk, k.grad = k.grad.clone(), None - ref_dq, q.grad = q.grad.clone(), None + o = torch.empty_like(q) - # # triton implementation + # triton implementation tri_out, _ = attention(q, k, v, o, input_metadata) tri_out.backward(dout) tri_dv, v.grad = v.grad.clone(), None tri_dk, k.grad = k.grad.clone(), None tri_dq, q.grad = q.grad.clone(), None - # test - #print("reference") - #print(ref_dv) - #print("tri") - #print(tri_dv) + + # Transpose here if layout is bshd so we have same reference code for all layouts + if layout == 'bshd': + q = q.transpose(1, 2).clone() + k = k.transpose(1, 2).clone() + v = v.transpose(1, 2).clone() + # Replicate K and V if using MQA/GQA + if HQ != HK: + k = k.view(k.shape[0], k.shape[1], -1, k.shape[2], + k.shape[3]).expand(-1, -1, HQ // HK, -1, -1).reshape(k.shape[0], -1, k.shape[2], k.shape[3]) + v = v.view(v.shape[0], v.shape[1], -1, v.shape[2], + v.shape[3]).expand(-1, -1, HQ // HK, -1, -1).reshape(v.shape[0], -1, v.shape[2], v.shape[3]) + + scores = torch.einsum('bhqd,bhkd->bhqk', q, k).float() * input_metadata.sm_scale + if causal: + mask = torch.tril(torch.ones(N_CTX_Q, N_CTX_K, device="cuda"), diagonal=N_CTX_K - N_CTX_Q) + scores[:, :, mask == 0] = float("-inf") + if use_alibi: + scores += compute_alibi_tensor(alibi_slopes, N_CTX_Q, N_CTX_K) + + p = torch.softmax(scores, dim=-1) + if causal: + # If N_CTX_Q > N_CTX_K, there is at least one row of all -infs going into + # the softmax. This produces a row of NaNs as -inf - -inf == NaN. So we fix + # this by converting the NaNs to 0s, which is what they should be out of the softmax. + nan_mask = torch.isnan(p) + p = torch.where(nan_mask == 1,0,p) + ref_out = torch.einsum('bhqk,bhkd->bhqd', p.to(dtype), v) # compare - torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=0) + if layout == 'bshd': + ref_out = ref_out.transpose(1, 2).clone() + ref_out.backward(dout) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + + torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) + # The current block size for MI200 series is 64x64. This results in # larger differences in float results due to rounding. if dtype == torch.bfloat16: - ATOL = 1e-1 * max(1.0, (seqlen_q + D_HEAD) / 64.0) + ATOL = 1e-1 * max(1.0, (N_CTX_Q + D_HEAD) / 64.0) if dtype == torch.float32: - ATOL = 1e-3 * max(1.0, (seqlen_q + D_HEAD) / 64.0) + ATOL = 1e-3 * max(1.0, (N_CTX_Q + D_HEAD) / 64.0) else: - ATOL = 1e-1 * max(1.0, (seqlen_q + D_HEAD) / 64.0) + ATOL = 1e-1 * max(1.0, (N_CTX_Q + D_HEAD) / 64.0) RTOL = 0 @@ -1353,7 +1348,6 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, qseqlen_not_equal_kseqlen, causal, torch_sd torch.testing.assert_close(ref_dk, tri_dk, atol=ATOL, rtol=RTOL) torch.testing.assert_close(ref_dq, tri_dq, atol=ATOL, rtol=RTOL) - def nonvarlen_benchmark_configs(): configs = [ (16, 16, 16, 1024, 1024), diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index e533576d467b..c661510f6b2f 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -17,57 +17,60 @@ import triton import triton.language as tl +# Pick the fp8 data type -def is_hip(): - return triton.runtime.driver.active.get_current_target().backend == "hip" +# AMD E4M3B8 +# Note: When picking this f8 data type, scaling is required when using f8 +# for the second gemm +#TORCH_HAS_FP8E4B8 = hasattr(torch, 'float8_e4m3fnuz') + +# AMD E5M2B16 +TORCH_HAS_FP8E5B16 = hasattr(torch, 'float8_e5m2fnuz') @triton.jit -def _attn_fwd_inner(acc, l_i, m_i, q, # - K_block_ptr, V_block_ptr, # - start_m, qk_scale, # - BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, # - STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, # - N_CTX: tl.constexpr, fp8_v: tl.constexpr): +def _attn_fwd_inner(acc, l_i, m_i, q, + K_block_ptr, V_block_ptr, + start_m, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, + N_CTX, + pre_load_v: tl.constexpr): # range of values handled by this stage if STAGE == 1: lo, hi = 0, start_m * BLOCK_M elif STAGE == 2: lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M lo = tl.multiple_of(lo, BLOCK_M) + K_block_ptr = tl.advance(K_block_ptr, (0, lo)) + V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) # causal = False else: lo, hi = 0, N_CTX - K_block_ptr = tl.advance(K_block_ptr, (0, lo)) - V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) # loop over k, v and update accumulator for start_n in range(lo, hi, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- k = tl.load(K_block_ptr) - qk = tl.dot(q, k) + if pre_load_v: + v = tl.load(V_block_ptr) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) if STAGE == 2: mask = offs_m[:, None] >= (start_n + offs_n[None, :]) - qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) - m_ij = tl.maximum(m_i, tl.max(qk, 1)) - qk -= m_ij[:, None] - else: - m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) - qk = qk * qk_scale - m_ij[:, None] + qk = tl.where(mask, qk, float("-inf")) + qk += tl.dot(q, k) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] p = tl.math.exp2(qk) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - alpha = tl.math.exp2(m_i - m_ij) - l_i = l_i * alpha + l_ij # -- update output accumulator -- + alpha = tl.math.exp2(m_i - m_ij) acc = acc * alpha[:, None] - # update acc - v = tl.load(V_block_ptr) - if fp8_v: - p = p.to(tl.float8e5) - else: - p = p.to(tl.float16) - acc = tl.dot(p, v, acc) + if not pre_load_v: + v = tl.load(V_block_ptr) + acc += tl.dot(p.to(v.dtype), v) + # -- update m_i and l_i + l_ij = tl.sum(p, 1) + l_i = l_i * alpha + l_ij # update m_i and l_i m_i = m_ij V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) @@ -75,78 +78,72 @@ def _attn_fwd_inner(acc, l_i, m_i, q, # return acc, l_i, m_i -# We don't run auto-tuning every time to keep the tutorial fast. Keeping +# We don't run auto-tuning everytime to keep the tutorial fast. Uncommenting # the code below and commenting out the equivalent parameters is convenient for # re-tuning. -configs = [ - triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w) \ - for BM in [64, 128]\ - for BN in [32, 64]\ - for s in ([1] if is_hip() else [3, 4, 7])\ - for w in [4, 8]\ -] - +@triton.autotune( + configs=[ + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': True}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': False}, num_stages=1, num_warps=4), + ], + key=['Z', 'H', 'N_CTX', 'STAGE', 'BLOCK_DMODEL'], +) -def keep(conf): - BLOCK_M = conf.kwargs["BLOCK_M"] - BLOCK_N = conf.kwargs["BLOCK_N"] - if BLOCK_M * BLOCK_N < 128 * 128 and conf.num_warps == 8: - return False - return True - -@triton.autotune(list(filter(keep, configs)), key=["N_CTX", "HEAD_DIM"]) @triton.jit -def _attn_fwd(Q, K, V, sm_scale, M, Out, # - stride_qz, stride_qh, stride_qm, stride_qk, # - stride_kz, stride_kh, stride_kn, stride_kk, # - stride_vz, stride_vh, stride_vk, stride_vn, # - stride_oz, stride_oh, stride_om, stride_on, # - Z, H, N_CTX, # - HEAD_DIM: tl.constexpr, # - BLOCK_M: tl.constexpr, # - BLOCK_N: tl.constexpr, # - STAGE: tl.constexpr # +def _attn_fwd(Q, K, V, sm_scale, M, Out, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vk, stride_vn, + stride_oz, stride_oh, stride_om, stride_on, + Z, H, + N_CTX, + BLOCK_DMODEL: tl.constexpr, + STAGE: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + pre_load_v: tl.constexpr, ): - tl.static_assert(BLOCK_N <= HEAD_DIM) start_m = tl.program_id(0) off_hz = tl.program_id(1) - off_z = off_hz // H - off_h = off_hz % H - qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh + qvk_offset = off_hz * stride_qh # block pointers Q_block_ptr = tl.make_block_ptr( base=Q + qvk_offset, - shape=(N_CTX, HEAD_DIM), + shape=(N_CTX, BLOCK_DMODEL), strides=(stride_qm, stride_qk), offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, HEAD_DIM), + block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0), ) - v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0) V_block_ptr = tl.make_block_ptr( base=V + qvk_offset, - shape=(N_CTX, HEAD_DIM), + shape=(N_CTX, BLOCK_DMODEL), strides=(stride_vk, stride_vn), offsets=(0, 0), - block_shape=(BLOCK_N, HEAD_DIM), - order=v_order, + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0), ) K_block_ptr = tl.make_block_ptr( base=K + qvk_offset, - shape=(HEAD_DIM, N_CTX), + shape=(BLOCK_DMODEL, N_CTX), strides=(stride_kk, stride_kn), offsets=(0, 0), - block_shape=(HEAD_DIM, BLOCK_N), + block_shape=(BLOCK_DMODEL, BLOCK_N), order=(0, 1), ) O_block_ptr = tl.make_block_ptr( base=Out + qvk_offset, - shape=(N_CTX, HEAD_DIM), + shape=(N_CTX, BLOCK_DMODEL), strides=(stride_om, stride_on), offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, HEAD_DIM), + block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0), ) # initialize offsets @@ -155,80 +152,96 @@ def _attn_fwd(Q, K, V, sm_scale, M, Out, # # initialize pointer to m and l m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 - acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) - # load scales - qk_scale = sm_scale - qk_scale *= 1.44269504 # 1/log(2) - # load q: it will stay in SRAM throughout + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + # load q: it will stay in SRAM throughout on NV GPUs but in VGPRs on AMD GPUs q = tl.load(Q_block_ptr) + q = (q * qk_scale).to(q.dtype) # stage 1: off-band # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE if STAGE & 1: - acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, # - start_m, qk_scale, # - BLOCK_M, HEAD_DIM, BLOCK_N, # - 4 - STAGE, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 # + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, + start_m, + BLOCK_M, BLOCK_DMODEL, BLOCK_N, + 4 - STAGE, offs_m, offs_n, N_CTX, + pre_load_v, ) # stage 2: on-band if STAGE & 2: # barrier makes it easier for compielr to schedule the # two loops independently - acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, # - start_m, qk_scale, # - BLOCK_M, HEAD_DIM, BLOCK_N, # - 2, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 # + tl.debug_barrier() + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, + start_m, + BLOCK_M, BLOCK_DMODEL, BLOCK_N, + 2, offs_m, offs_n, N_CTX, + pre_load_v, ) # epilogue - m_i += tl.math.log2(l_i) + # write back m acc = acc / l_i[:, None] m_ptrs = M + off_hz * N_CTX + offs_m - tl.store(m_ptrs, m_i) + tl.store(m_ptrs, m_i + tl.math.log2(l_i)) tl.store(O_block_ptr, acc.to(Out.type.element_ty)) - @triton.jit -def _attn_bwd_preprocess(O, DO, # - Delta, # - Z, H, N_CTX, # - BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr # +def _attn_bwd_preprocess(O, DO, + Delta, + Z, H, N_CTX, + BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr ): off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) off_hz = tl.program_id(1) - off_n = tl.arange(0, HEAD_DIM) - # load - o = tl.load(O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]) - do = tl.load(DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]).to(tl.float32) + off_n = tl.arange(0, D_HEAD) + o = tl.load(O + off_hz * D_HEAD * N_CTX + off_m[:, None] * D_HEAD + off_n[None, :]) + do = tl.load(DO + off_hz * D_HEAD * N_CTX + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) delta = tl.sum(o * do, axis=1) - # write-back tl.store(Delta + off_hz * N_CTX + off_m, delta) # The main inner-loop logic for computing dK and dV. @triton.jit -def _attn_bwd_dkdv(dk, dv, # - Q, k, v, sm_scale, # - DO, # - M, D, # +def _attn_bwd_dkdv(dk, dv, + Q, k, v, sm_scale, + DO, + M, D, # shared by Q/K/V/DO. - stride_tok, stride_d, # - H, N_CTX, BLOCK_M1: tl.constexpr, # - BLOCK_N1: tl.constexpr, # - HEAD_DIM: tl.constexpr, # + stride_tok, stride_d, + H, N_CTX, BLOCK_M1: tl.constexpr, + BLOCK_N1: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, # Filled in by the wrapper. - start_n, start_m, num_steps, # + start_n, start_m, num_steps, MASK: tl.constexpr): offs_m = start_m + tl.arange(0, BLOCK_M1) offs_n = start_n + tl.arange(0, BLOCK_N1) - offs_k = tl.arange(0, HEAD_DIM) - qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d - do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d + offs_k = tl.arange(0, BLOCK_DMODEL) + QT_block_ptr = tl.make_block_ptr( + base=Q, + shape=(BLOCK_DMODEL, N_CTX), + strides=(stride_d, stride_tok), + offsets=(0, start_m), + block_shape=(BLOCK_DMODEL, BLOCK_M1), + order=(0,1) + ) + DO_block_ptr = tl.make_block_ptr( + base=DO, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_tok, stride_d), + offsets=(start_m, 0), + block_shape=(BLOCK_M1, BLOCK_DMODEL), + order=(1,0) + ) # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) curr_m = start_m step_m = BLOCK_M1 for blk_idx in range(num_steps): - qT = tl.load(qT_ptrs) + qT = tl.load(QT_block_ptr) # Load m before computing qk to reduce pipeline stall. offs_m = curr_m + tl.arange(0, BLOCK_M1) m = tl.load(M + offs_m) @@ -238,7 +251,7 @@ def _attn_bwd_dkdv(dk, dv, # if MASK: mask = (offs_m[None, :] >= offs_n[:, None]) pT = tl.where(mask, pT, 0.0) - do = tl.load(do_ptrs) + do = tl.load(DO_block_ptr) # Compute dV. ppT = pT ppT = ppT.to(tl.float16) @@ -246,35 +259,49 @@ def _attn_bwd_dkdv(dk, dv, # # D (= delta) is pre-divided by ds_scale. Di = tl.load(D + offs_m) # Compute dP and dS. - dpT = tl.dot(v, tl.trans(do)).to(tl.float32) + dpT = tl.dot(v, tl.trans(do)) dsT = pT * (dpT - Di[None, :]) dsT = dsT.to(tl.float16) dk += tl.dot(dsT, tl.trans(qT)) # Increment pointers. curr_m += step_m - qT_ptrs += step_m * stride_tok - do_ptrs += step_m * stride_tok + QT_block_ptr = tl.advance(QT_block_ptr, (0, step_m)) + DO_block_ptr = tl.advance(DO_block_ptr, (step_m, 0)) return dk, dv # the main inner-loop logic for computing dQ @triton.jit -def _attn_bwd_dq(dq, q, K, V, # +def _attn_bwd_dq(dq, q, K, V, do, m, D, # shared by Q/K/V/DO. - stride_tok, stride_d, # - H, N_CTX, # - BLOCK_M2: tl.constexpr, # - BLOCK_N2: tl.constexpr, # - HEAD_DIM: tl.constexpr, + stride_tok, stride_d, + H, N_CTX, + BLOCK_M2: tl.constexpr, + BLOCK_N2: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, # Filled in by the wrapper. - start_m, start_n, num_steps, # + start_m, start_n, num_steps, MASK: tl.constexpr): offs_m = start_m + tl.arange(0, BLOCK_M2) offs_n = start_n + tl.arange(0, BLOCK_N2) - offs_k = tl.arange(0, HEAD_DIM) - kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d - vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d + offs_k = tl.arange(0, BLOCK_DMODEL) + KT_block_ptr = tl.make_block_ptr( + base=K, + shape=(BLOCK_DMODEL, N_CTX), + strides=(stride_d, stride_tok), + offsets=(0, start_n), + block_shape=(BLOCK_DMODEL, BLOCK_N2), + order=(0, 1) + ) + VT_block_ptr = tl.make_block_ptr( + base=V, + shape=(BLOCK_DMODEL, N_CTX), + strides=(stride_d, stride_tok), + offsets=(0, start_n), + block_shape=(BLOCK_DMODEL, BLOCK_N2), + order=(0, 1) + ) # D (= delta) is pre-divided by ds_scale. Di = tl.load(D + offs_m) # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. @@ -282,8 +309,7 @@ def _attn_bwd_dq(dq, q, K, V, # curr_n = start_n step_n = BLOCK_N2 for blk_idx in range(num_steps): - kT = tl.load(kT_ptrs) - vT = tl.load(vT_ptrs) + kT = tl.load(KT_block_ptr) qk = tl.dot(q, kT) p = tl.math.exp2(qk - m) # Autoregressive masking. @@ -292,6 +318,7 @@ def _attn_bwd_dq(dq, q, K, V, # mask = (offs_m[:, None] >= offs_n[None, :]) p = tl.where(mask, p, 0.0) # Compute dP and dS. + vT = tl.load(VT_block_ptr) dp = tl.dot(do, vT).to(tl.float32) ds = p * (dp - Di[:, None]) ds = ds.to(tl.float16) @@ -300,25 +327,50 @@ def _attn_bwd_dq(dq, q, K, V, # dq += tl.dot(ds, tl.trans(kT)) # Increment pointers. curr_n += step_n - kT_ptrs += step_n * stride_tok - vT_ptrs += step_n * stride_tok + KT_block_ptr = tl.advance(KT_block_ptr, (0, step_n)) + VT_block_ptr = tl.advance(VT_block_ptr, (0, step_n)) return dq +@triton.autotune( + configs=[ + triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 1}, + num_stages=1, num_warps=4), + triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 2}, + num_stages=1, num_warps=4), + triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 1}, + num_stages=1, num_warps=4), + triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 2}, + num_stages=1, num_warps=4), + triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 1}, + num_stages=1, num_warps=4), + triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 2}, + num_stages=1, num_warps=4), + triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 1}, + num_stages=1, num_warps=4), + triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 2}, + num_stages=1, num_warps=4), + triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 2}, + num_stages=1, num_warps=8), + ], + key=['H', 'N_CTX', 'BLOCK_DMODEL'], +) + @triton.jit -def _attn_bwd(Q, K, V, sm_scale, # - DO, # - DQ, DK, DV, # +def _attn_bwd(Q, K, V, sm_scale, + DO, + DQ, DK, DV, M, D, # shared by Q/K/V/DO. - stride_z, stride_h, stride_tok, stride_d, # - H, N_CTX, # - BLOCK_M1: tl.constexpr, # - BLOCK_N1: tl.constexpr, # - BLOCK_M2: tl.constexpr, # - BLOCK_N2: tl.constexpr, # - BLK_SLICE_FACTOR: tl.constexpr, # - HEAD_DIM: tl.constexpr): + stride_z, stride_h, stride_tok, stride_d, + # H = 16, N_CTX = 1024 + H, N_CTX, + BLOCK_DMODEL: tl.constexpr, + BLOCK_M1: tl.constexpr, + BLOCK_N1: tl.constexpr, + BLOCK_M2: tl.constexpr, + BLOCK_N2: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr): LN2: tl.constexpr = 0.6931471824645996 # = ln(2) bhid = tl.program_id(2) @@ -337,58 +389,91 @@ def _attn_bwd(Q, K, V, sm_scale, # M += off_chz D += off_chz - # load scales - offs_k = tl.arange(0, HEAD_DIM) + offs_k = tl.arange(0, BLOCK_DMODEL) start_n = pid * BLOCK_N1 + # This assignment is important. It is what allows us to pick the diagonal + # blocks. Later, when we want to do the lower triangular, we update start_m + # after the first dkdv call. start_m = start_n MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR offs_n = start_n + tl.arange(0, BLOCK_N1) - dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) - dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) - # load K and V: they stay in SRAM throughout the inner loop. - k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) - v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) + K_block_ptr = tl.make_block_ptr( + base=K, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_tok, stride_d), + offsets=(start_n, 0), + block_shape=(BLOCK_N1, BLOCK_DMODEL), + order=(1, 0), + ) + V_block_ptr = tl.make_block_ptr( + base=V, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_tok, stride_d), + offsets=(start_n, 0), + block_shape=(BLOCK_N1, BLOCK_DMODEL), + order=(1, 0), + ) + + # load K and V: they stay in SRAM throughout the inner loop for dkdv. + k = tl.load(K_block_ptr) + v = tl.load(V_block_ptr) num_steps = BLOCK_N1 // MASK_BLOCK_M1 - dk, dv = _attn_bwd_dkdv(dk, dv, # - Q, k, v, sm_scale, # - DO, # - M, D, # - stride_tok, stride_d, # - H, N_CTX, # - MASK_BLOCK_M1, BLOCK_N1, HEAD_DIM, # - start_n, start_m, num_steps, # - MASK=True # + dk, dv = _attn_bwd_dkdv(dk, dv, + Q, k, v, sm_scale, + DO, + M, D, + stride_tok, stride_d, + H, N_CTX, + MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, + start_n, start_m, num_steps, + MASK=True ) start_m += num_steps * MASK_BLOCK_M1 num_steps = (N_CTX - start_m) // BLOCK_M1 # Compute dK and dV for non-masked blocks. - dk, dv = _attn_bwd_dkdv( # - dk, dv, # - Q, k, v, sm_scale, # - DO, # - M, D, # - stride_tok, stride_d, # - H, N_CTX, # - BLOCK_M1, BLOCK_N1, HEAD_DIM, # - start_n, start_m, num_steps, # - MASK=False # + dk, dv = _attn_bwd_dkdv( + dk, dv, + Q, k, v, sm_scale, + DO, + M, D, + stride_tok, stride_d, + H, N_CTX, + BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, + start_n, start_m, num_steps, + MASK=False ) - dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d - tl.store(dv_ptrs, dv) + DV_block_ptrs = tl.make_block_ptr( + base=DV, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_tok, stride_d), + offsets=(start_n, 0), + block_shape=(BLOCK_N1, BLOCK_DMODEL), + order=(1,0) + ) + tl.store(DV_block_ptrs, dv.to(tl.float16)) # Write back dK. dk *= sm_scale - dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d - tl.store(dk_ptrs, dk) + DK_block_ptrs = tl.make_block_ptr( + base=DK, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_tok, stride_d), + offsets=(start_n, 0), + block_shape=(BLOCK_N1, BLOCK_DMODEL), + order=(1,0) + ) + tl.store(DK_block_ptrs, dk.to(tl.float16)) # THIS BLOCK DOES DQ: start_m = pid * BLOCK_M2 @@ -397,9 +482,26 @@ def _attn_bwd(Q, K, V, sm_scale, # MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR offs_m = start_m + tl.arange(0, BLOCK_M2) - q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) - dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) - do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) + Q_block_ptr = tl.make_block_ptr( + base=Q, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_tok, stride_d), + offsets=(start_m, 0), + block_shape=(BLOCK_M2, BLOCK_DMODEL), + order=(1, 0) + ) + + DO_block_ptr = tl.make_block_ptr( + base=DO, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_tok, stride_d), + offsets=(start_m, 0), + block_shape=(BLOCK_M2, BLOCK_DMODEL), + order=(1, 0) + ) + q = tl.load(Q_block_ptr) + do = tl.load(DO_block_ptr) + dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32) m = tl.load(M + offs_m) m = m[:, None] @@ -410,29 +512,39 @@ def _attn_bwd(Q, K, V, sm_scale, # # not due to anything important. I just wanted to reuse the loop # structure for dK & dV above as much as possible. num_steps = BLOCK_M2 // MASK_BLOCK_N2 - dq = _attn_bwd_dq(dq, q, K, V, # - do, m, D, # - stride_tok, stride_d, # - H, N_CTX, # - BLOCK_M2, MASK_BLOCK_N2, HEAD_DIM, # - start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, # - MASK=True # + dq = _attn_bwd_dq(dq, q, K, V, + do, m, D, + stride_tok, stride_d, + H, N_CTX, + BLOCK_M2, MASK_BLOCK_N2, BLOCK_DMODEL, + start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, + MASK=True ) end_n -= num_steps * MASK_BLOCK_N2 # stage 2 num_steps = end_n // BLOCK_N2 - dq = _attn_bwd_dq(dq, q, K, V, # - do, m, D, # - stride_tok, stride_d, # - H, N_CTX, # - BLOCK_M2, BLOCK_N2, HEAD_DIM, # - start_m, end_n - num_steps * BLOCK_N2, num_steps, # - MASK=False # + dq = _attn_bwd_dq(dq, q, K, V, + do, m, D, + stride_tok, stride_d, + H, N_CTX, + BLOCK_M2, BLOCK_N2, BLOCK_DMODEL, + start_m, end_n - num_steps * BLOCK_N2, num_steps, + MASK=False ) # Write back dQ. - dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d + DQ_block_ptr = tl.make_block_ptr( + base=DQ, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_tok, stride_d), + offsets=(start_m, 0), + block_shape=(BLOCK_M2, BLOCK_DMODEL), + order=(1, 0) + ) dq *= LN2 - tl.store(dq_ptrs, dq) + tl.store(DQ_block_ptr, dq.to(tl.float16)) + + +empty = torch.empty(128, device="cuda") class _attention(torch.autograd.Function): @@ -440,42 +552,56 @@ class _attention(torch.autograd.Function): @staticmethod def forward(ctx, q, k, v, causal, sm_scale): # shape constraints - HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] - # when v is in float8_e5m2 it is transposed. - HEAD_DIM_V = v.shape[-1] - assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V - assert HEAD_DIM_K in {16, 32, 64, 128, 256} - o = torch.empty_like(q) + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + o = torch.empty_like(q, dtype=v.dtype) + if torch.version.hip is None: + BLOCK_M = 128 + BLOCK_N = 64 if Lk <= 64 else 32 + num_stages = 4 if Lk <= 64 else 3 + num_warps = 4 if Lk <= 64 else 8 + # Tuning for H100 + if torch.cuda.get_device_capability()[0] == 9: + num_warps = 8 + num_stages = 7 if Lk >= 64 else 3 stage = 3 if causal else 1 - extra_kern_args = {} - # Tuning for AMD target - if is_hip(): - waves_per_eu = 3 if HEAD_DIM_K <= 64 else 2 - extra_kern_args = {"waves_per_eu": waves_per_eu, "allow_flush_denorm": True} - - grid = lambda args: (triton.cdiv(q.shape[2], args["BLOCK_M"]), q.shape[0] * q.shape[1], 1) - M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + grid = lambda META: ( + triton.cdiv(q.shape[2], META['BLOCK_M']), + q.shape[0] * q.shape[1], + 1 + ) + M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) _attn_fwd[grid]( - q, k, v, sm_scale, M, o, # - q.stride(0), q.stride(1), q.stride(2), q.stride(3), # - k.stride(0), k.stride(1), k.stride(2), k.stride(3), # - v.stride(0), v.stride(1), v.stride(2), v.stride(3), # - o.stride(0), o.stride(1), o.stride(2), o.stride(3), # - q.shape[0], q.shape[1], # - N_CTX=q.shape[2], # - HEAD_DIM=HEAD_DIM_K, # - STAGE=stage, # - **extra_kern_args) + q, k, v, sm_scale, M, o, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + o.stride(0), o.stride(1), o.stride(2), o.stride(3), + q.shape[0], q.shape[1], + N_CTX=q.shape[2], + BLOCK_DMODEL=Lk, + STAGE=stage, + ) + + ## restore the grid for bwd kernel + #best_config = _attn_fwd.get_best_config() + block_m = 64#int(best_config.__str__().split(",")[0].split("BLOCK_M:")[1]) + grid = (triton.cdiv(q.shape[2], block_m), q.shape[0] * q.shape[1], 1) ctx.save_for_backward(q, k, v, o, M) ctx.grid = grid ctx.sm_scale = sm_scale - ctx.HEAD_DIM = HEAD_DIM_K + ctx.BLOCK_DMODEL = Lk ctx.causal = causal return o @staticmethod def backward(ctx, do): + if torch.version.hip is not None: + BLOCK = 64 + else: + BLOCK = 128 q, k, v, o, M = ctx.saved_tensors assert do.is_contiguous() assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride() @@ -484,49 +610,96 @@ def backward(ctx, do): dv = torch.empty_like(v) BATCH, N_HEAD, N_CTX = q.shape[:3] PRE_BLOCK = 128 - NUM_WARPS, NUM_STAGES = 4, 5 - BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 + NUM_WARPS, NUM_STAGES = 4, 1 + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 64, 64, 32 BLK_SLICE_FACTOR = 2 RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) arg_k = k arg_k = arg_k * (ctx.sm_scale * RCP_LN2) - PRE_BLOCK = 128 assert N_CTX % PRE_BLOCK == 0 pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD) delta = torch.empty_like(M) _attn_bwd_preprocess[pre_grid]( - o, do, # - delta, # - BATCH, N_HEAD, N_CTX, # - BLOCK_M=PRE_BLOCK, HEAD_DIM=ctx.HEAD_DIM # + o, do, + delta, + BATCH, N_HEAD, N_CTX, + BLOCK_M=PRE_BLOCK, D_HEAD=ctx.BLOCK_DMODEL + ) + grid = lambda META: ( + triton.cdiv(N_CTX, META['BLOCK_N1']), + 1, + BATCH * N_HEAD ) - grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD) _attn_bwd[grid]( - q, arg_k, v, ctx.sm_scale, do, dq, dk, dv, # - M, delta, # - q.stride(0), q.stride(1), q.stride(2), q.stride(3), # - N_HEAD, N_CTX, # - BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, # - BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, # - BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, # - HEAD_DIM=ctx.HEAD_DIM, # - num_warps=NUM_WARPS, # - num_stages=NUM_STAGES # + q, arg_k, v, ctx.sm_scale, do, dq, dk, dv, + M, delta, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + N_HEAD, N_CTX, + BLOCK_DMODEL=ctx.BLOCK_DMODEL ) return dq, dk, dv, None, None - attention = _attention.apply +name_to_torch_types = { + 'fp16': torch.float16, +} + +if TORCH_HAS_FP8E5B16: + name_to_torch_types['fp8'] = torch.float8_e5m2fnuz + +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD, dtype', +[ (*shape, dtype) + for shape in [(4, 48, 1024, 64), + (4, 48, 2048, 64), + (4, 48, 4096, 64), + (4, 48, 1024, 128), + (4, 48, 2048, 128), + (4, 48, 4096, 128)] + for dtype in ['fp16', 'fp8']]) +@pytest.mark.parametrize('causal', [False, True]) +def test_op_fwd(Z, H, N_CTX, D_HEAD, causal, dtype): + if dtype == 'fp8' and not TORCH_HAS_FP8E5B16: + pytest.skip("fp8 not supported") + torch.manual_seed(20) + q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5).requires_grad_() -@pytest.mark.parametrize("Z, H, N_CTX, HEAD_DIM", [(1, 2, 1024, 64)]) -@pytest.mark.parametrize("causal", [True]) -def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16): + q = q.to(name_to_torch_types[dtype]) + k = k.to(name_to_torch_types[dtype]) + sm_scale = 0.5 + dout = torch.randn_like(q, dtype=torch.float16) + # reference implementation + M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) + p = torch.matmul(q.half(), k.transpose(2, 3).half()) * sm_scale + if causal: + p[:, :, M == 0] = float("-inf") + p = torch.softmax(p.float(), dim=-1).half() + ref_out = torch.matmul(p, v) + # triton implementation + tri_out = attention(q, k, v, causal, sm_scale) + # compare + atol = 1.4e-1 if dtype == 'fp8' else 1e-2 + rtol = 1e-2 if dtype == 'fp8' else 0 + torch.testing.assert_close(ref_out, tri_out, atol=atol, rtol=rtol) + + +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', + [(4, 48, 1024, 64), + (4, 48, 2048, 64), + (4, 48, 4096, 64), + (1, 16, 8192, 64), + (1, 16, 1024, 64), + ]) +def test_op_bwd(Z, H, N_CTX, D_HEAD, dtype=torch.float16): torch.manual_seed(20) - q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) - k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) - v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + causal = True + q = (torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + k = (torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + v = (torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + sm_scale = 0.5 dout = torch.randn_like(q) # reference implementation @@ -535,28 +708,27 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16): if causal: p[:, :, M == 0] = float("-inf") p = torch.softmax(p.float(), dim=-1).half() - # p = torch.exp(p) ref_out = torch.matmul(p, v) ref_out.backward(dout) ref_dv, v.grad = v.grad.clone(), None ref_dk, k.grad = k.grad.clone(), None ref_dq, q.grad = q.grad.clone(), None - # triton implementation - tri_out = attention(q, k, v, causal, sm_scale).half() + # # triton implementation + tri_out = attention(q, k, v, causal, sm_scale) tri_out.backward(dout) tri_dv, v.grad = v.grad.clone(), None tri_dk, k.grad = k.grad.clone(), None tri_dq, q.grad = q.grad.clone(), None # compare - assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0) - rtol = 0.0 - # Relative tolerance workaround for known hardware limitation of MI200 GPU. - # For details see https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices - if torch.version.hip is not None and triton.runtime.driver.active.get_current_target().arch == "gfx90a": - rtol = 1e-2 - assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=rtol) - assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=rtol) - assert torch.allclose(ref_dq, tri_dq, atol=1e-2, rtol=rtol) + torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=0) + if torch.version.hip is None: + torch.testing.assert_close(ref_dv, tri_dv, atol=1e-2, rtol=0) + # The current block size for MI200 series is 64x64. This results in + # larger differences in float results due to rounding. + else: + torch.testing.assert_close(ref_dv, tri_dv, atol=5e-2, rtol=0) + torch.testing.assert_close(ref_dk, tri_dk, atol=5e-2, rtol=1e-2) + torch.testing.assert_close(ref_dq, tri_dq, atol=5e-2, rtol=1e-2) try: @@ -566,68 +738,69 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16): except BaseException: HAS_FLASH = False -TORCH_HAS_FP8 = hasattr(torch, 'float8_e5m2') -BATCH, N_HEADS, HEAD_DIM = 4, 32, 64 # vary seq length for fixed head and batch=4 configs = [] -for mode in ["fwd", "bwd"]: - for causal in [True, False]: - if mode == "bwd" and not causal: - continue - configs.append( - triton.testing.Benchmark( - x_names=["N_CTX"], - x_vals=[2**i for i in range(10, 15)], - line_arg="provider", - line_vals=["triton-fp16"] + (["triton-fp8"] if TORCH_HAS_FP8 else []) + - (["flash"] if HAS_FLASH else []), - line_names=["Triton [FP16]"] + (["Triton [FP8]"] if TORCH_HAS_FP8 else []) + - (["Flash-2"] if HAS_FLASH else []), - styles=[("red", "-"), ("blue", "-"), ("green", "-")], - ylabel="ms", - plot_name=f"fused-attention-batch{BATCH}-head{N_HEADS}-d{HEAD_DIM}-{mode}-causal={causal}", +for mode in ['fwd', 'bwd']: + for D_HEAD in [128, 64]: + for causal in [False, True]: + if mode == 'bwd' and causal == False: + continue + configs.append(triton.testing.Benchmark( + x_names=['BATCH', 'H', 'N_CTX'], + x_vals=[(4, 16, 1024), + (8, 16, 2048), + (4, 16, 4096), + (2, 16, 8192), + (1, 16, 16384), + (4, 48, 1024), + (4, 48, 2048), + (4, 48, 4096), + (4, 48, 8192), + (4, 48, 16384), + ], + line_arg='provider', + line_vals=['triton'] + (['flash'] if HAS_FLASH else []), + line_names=['Triton'] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []), + styles=[('red', '-'), ('blue', '-')], + ylabel='ms', + plot_name=f'fused-attention-{mode}-d{D_HEAD}-causal={causal}', args={ - "H": N_HEADS, - "BATCH": BATCH, - "HEAD_DIM": HEAD_DIM, - "mode": mode, - "causal": causal, + 'D_HEAD': D_HEAD, + 'dtype': torch.float16, + 'mode': mode, + 'causal': causal, }, )) @triton.testing.perf_report(configs) -def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, device="cuda"): +def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype=torch.float16, device="cuda"): assert mode in ["fwd", "bwd"] warmup = 25 - rep = 100 - dtype = torch.float16 - if "triton" in provider: - q = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) - k = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) - v = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) - if mode == "fwd" and "fp8" in provider: - q = q.to(torch.float8_e5m2) - k = k.to(torch.float8_e5m2) - v = v.permute(0, 1, 3, 2).contiguous() - v = v.permute(0, 1, 3, 2) - v = v.to(torch.float8_e5m2) - sm_scale = 1.3 + rep = 10 + # Bwd pass only supports causal=True right now + if mode == 'bwd': + causal = True + if provider == "triton": + q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + sm_scale = D_HEAD ** -0.5 fn = lambda: attention(q, k, v, causal, sm_scale) - if mode == "bwd": + if mode == 'bwd': o = fn() do = torch.randn_like(o) fn = lambda: o.backward(do, retain_graph=True) ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) if provider == "flash": - qkv = torch.randn((BATCH, N_CTX, 3, H, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) + qkv = torch.randn((BATCH, N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True) fn = lambda: flash_attn_func(qkv, causal=causal) if mode == "bwd": o = fn() do = torch.randn_like(o) fn = lambda: o.backward(do, retain_graph=True) ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) - flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * HEAD_DIM + flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD total_flops = 2 * flops_per_matmul if causal: total_flops *= 0.5 @@ -635,7 +808,5 @@ def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, dev total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute) return total_flops / ms * 1e-9 - -if __name__ == "__main__": - # only works on post-Ampere GPUs right now - bench_flash_attention.run(save_path=".", print_data=True) +# only works on post-Ampere GPUs right now +bench_flash_attention.run(save_path=".", print_data=True) From 550f3954f6570fdf00470c49c0c81529b9a72816 Mon Sep 17 00:00:00 2001 From: Joseph Groenenboom Date: Mon, 12 Aug 2024 11:45:59 -0500 Subject: [PATCH 3/3] revert changes to tutorial kernel --- python/tutorials/06-fused-attention.py | 753 ++++++++++--------------- 1 file changed, 291 insertions(+), 462 deletions(-) diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index c661510f6b2f..e533576d467b 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -17,60 +17,57 @@ import triton import triton.language as tl -# Pick the fp8 data type -# AMD E4M3B8 -# Note: When picking this f8 data type, scaling is required when using f8 -# for the second gemm -#TORCH_HAS_FP8E4B8 = hasattr(torch, 'float8_e4m3fnuz') - -# AMD E5M2B16 -TORCH_HAS_FP8E5B16 = hasattr(torch, 'float8_e5m2fnuz') +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" @triton.jit -def _attn_fwd_inner(acc, l_i, m_i, q, - K_block_ptr, V_block_ptr, - start_m, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, - STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, - N_CTX, - pre_load_v: tl.constexpr): +def _attn_fwd_inner(acc, l_i, m_i, q, # + K_block_ptr, V_block_ptr, # + start_m, qk_scale, # + BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, # + STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, # + N_CTX: tl.constexpr, fp8_v: tl.constexpr): # range of values handled by this stage if STAGE == 1: lo, hi = 0, start_m * BLOCK_M elif STAGE == 2: lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M lo = tl.multiple_of(lo, BLOCK_M) - K_block_ptr = tl.advance(K_block_ptr, (0, lo)) - V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) # causal = False else: lo, hi = 0, N_CTX + K_block_ptr = tl.advance(K_block_ptr, (0, lo)) + V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) # loop over k, v and update accumulator for start_n in range(lo, hi, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- k = tl.load(K_block_ptr) - if pre_load_v: - v = tl.load(V_block_ptr) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.dot(q, k) if STAGE == 2: mask = offs_m[:, None] >= (start_n + offs_n[None, :]) - qk = tl.where(mask, qk, float("-inf")) - qk += tl.dot(q, k) - m_ij = tl.maximum(m_i, tl.max(qk, 1)) - qk = qk - m_ij[:, None] + qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + else: + m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) + qk = qk * qk_scale - m_ij[:, None] p = tl.math.exp2(qk) - # -- update output accumulator -- - alpha = tl.math.exp2(m_i - m_ij) - acc = acc * alpha[:, None] - if not pre_load_v: - v = tl.load(V_block_ptr) - acc += tl.dot(p.to(v.dtype), v) - # -- update m_i and l_i l_ij = tl.sum(p, 1) + # -- update m_i and l_i + alpha = tl.math.exp2(m_i - m_ij) l_i = l_i * alpha + l_ij + # -- update output accumulator -- + acc = acc * alpha[:, None] + # update acc + v = tl.load(V_block_ptr) + if fp8_v: + p = p.to(tl.float8e5) + else: + p = p.to(tl.float16) + acc = tl.dot(p, v, acc) # update m_i and l_i m_i = m_ij V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) @@ -78,72 +75,78 @@ def _attn_fwd_inner(acc, l_i, m_i, q, return acc, l_i, m_i -# We don't run auto-tuning everytime to keep the tutorial fast. Uncommenting +# We don't run auto-tuning every time to keep the tutorial fast. Keeping # the code below and commenting out the equivalent parameters is convenient for # re-tuning. -@triton.autotune( - configs=[ - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=8), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=8), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=8), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': True}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': False}, num_stages=1, num_warps=4), - ], - key=['Z', 'H', 'N_CTX', 'STAGE', 'BLOCK_DMODEL'], -) +configs = [ + triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w) \ + for BM in [64, 128]\ + for BN in [32, 64]\ + for s in ([1] if is_hip() else [3, 4, 7])\ + for w in [4, 8]\ +] + +def keep(conf): + BLOCK_M = conf.kwargs["BLOCK_M"] + BLOCK_N = conf.kwargs["BLOCK_N"] + if BLOCK_M * BLOCK_N < 128 * 128 and conf.num_warps == 8: + return False + return True + +@triton.autotune(list(filter(keep, configs)), key=["N_CTX", "HEAD_DIM"]) @triton.jit -def _attn_fwd(Q, K, V, sm_scale, M, Out, - stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vk, stride_vn, - stride_oz, stride_oh, stride_om, stride_on, - Z, H, - N_CTX, - BLOCK_DMODEL: tl.constexpr, - STAGE: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - pre_load_v: tl.constexpr, +def _attn_fwd(Q, K, V, sm_scale, M, Out, # + stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vk, stride_vn, # + stride_oz, stride_oh, stride_om, stride_on, # + Z, H, N_CTX, # + HEAD_DIM: tl.constexpr, # + BLOCK_M: tl.constexpr, # + BLOCK_N: tl.constexpr, # + STAGE: tl.constexpr # ): + tl.static_assert(BLOCK_N <= HEAD_DIM) start_m = tl.program_id(0) off_hz = tl.program_id(1) - qvk_offset = off_hz * stride_qh + off_z = off_hz // H + off_h = off_hz % H + qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh # block pointers Q_block_ptr = tl.make_block_ptr( base=Q + qvk_offset, - shape=(N_CTX, BLOCK_DMODEL), + shape=(N_CTX, HEAD_DIM), strides=(stride_qm, stride_qk), offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), + block_shape=(BLOCK_M, HEAD_DIM), order=(1, 0), ) + v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0) V_block_ptr = tl.make_block_ptr( base=V + qvk_offset, - shape=(N_CTX, BLOCK_DMODEL), + shape=(N_CTX, HEAD_DIM), strides=(stride_vk, stride_vn), offsets=(0, 0), - block_shape=(BLOCK_N, BLOCK_DMODEL), - order=(1, 0), + block_shape=(BLOCK_N, HEAD_DIM), + order=v_order, ) K_block_ptr = tl.make_block_ptr( base=K + qvk_offset, - shape=(BLOCK_DMODEL, N_CTX), + shape=(HEAD_DIM, N_CTX), strides=(stride_kk, stride_kn), offsets=(0, 0), - block_shape=(BLOCK_DMODEL, BLOCK_N), + block_shape=(HEAD_DIM, BLOCK_N), order=(0, 1), ) O_block_ptr = tl.make_block_ptr( base=Out + qvk_offset, - shape=(N_CTX, BLOCK_DMODEL), + shape=(N_CTX, HEAD_DIM), strides=(stride_om, stride_on), offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), + block_shape=(BLOCK_M, HEAD_DIM), order=(1, 0), ) # initialize offsets @@ -152,96 +155,80 @@ def _attn_fwd(Q, K, V, sm_scale, M, Out, # initialize pointer to m and l m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - # scale sm_scale by log_2(e) and use - # 2^x instead of exp in the loop because CSE and LICM - # don't work as expected with `exp` in the loop - qk_scale = sm_scale * 1.44269504 - # load q: it will stay in SRAM throughout on NV GPUs but in VGPRs on AMD GPUs + acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + # load scales + qk_scale = sm_scale + qk_scale *= 1.44269504 # 1/log(2) + # load q: it will stay in SRAM throughout q = tl.load(Q_block_ptr) - q = (q * qk_scale).to(q.dtype) # stage 1: off-band # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE if STAGE & 1: - acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, - start_m, - BLOCK_M, BLOCK_DMODEL, BLOCK_N, - 4 - STAGE, offs_m, offs_n, N_CTX, - pre_load_v, + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, # + start_m, qk_scale, # + BLOCK_M, HEAD_DIM, BLOCK_N, # + 4 - STAGE, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 # ) # stage 2: on-band if STAGE & 2: # barrier makes it easier for compielr to schedule the # two loops independently - tl.debug_barrier() - acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, - start_m, - BLOCK_M, BLOCK_DMODEL, BLOCK_N, - 2, offs_m, offs_n, N_CTX, - pre_load_v, + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, # + start_m, qk_scale, # + BLOCK_M, HEAD_DIM, BLOCK_N, # + 2, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 # ) # epilogue - # write back m + m_i += tl.math.log2(l_i) acc = acc / l_i[:, None] m_ptrs = M + off_hz * N_CTX + offs_m - tl.store(m_ptrs, m_i + tl.math.log2(l_i)) + tl.store(m_ptrs, m_i) tl.store(O_block_ptr, acc.to(Out.type.element_ty)) + @triton.jit -def _attn_bwd_preprocess(O, DO, - Delta, - Z, H, N_CTX, - BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr +def _attn_bwd_preprocess(O, DO, # + Delta, # + Z, H, N_CTX, # + BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr # ): off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) off_hz = tl.program_id(1) - off_n = tl.arange(0, D_HEAD) - o = tl.load(O + off_hz * D_HEAD * N_CTX + off_m[:, None] * D_HEAD + off_n[None, :]) - do = tl.load(DO + off_hz * D_HEAD * N_CTX + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + off_n = tl.arange(0, HEAD_DIM) + # load + o = tl.load(O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]) + do = tl.load(DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]).to(tl.float32) delta = tl.sum(o * do, axis=1) + # write-back tl.store(Delta + off_hz * N_CTX + off_m, delta) # The main inner-loop logic for computing dK and dV. @triton.jit -def _attn_bwd_dkdv(dk, dv, - Q, k, v, sm_scale, - DO, - M, D, +def _attn_bwd_dkdv(dk, dv, # + Q, k, v, sm_scale, # + DO, # + M, D, # # shared by Q/K/V/DO. - stride_tok, stride_d, - H, N_CTX, BLOCK_M1: tl.constexpr, - BLOCK_N1: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, + stride_tok, stride_d, # + H, N_CTX, BLOCK_M1: tl.constexpr, # + BLOCK_N1: tl.constexpr, # + HEAD_DIM: tl.constexpr, # # Filled in by the wrapper. - start_n, start_m, num_steps, + start_n, start_m, num_steps, # MASK: tl.constexpr): offs_m = start_m + tl.arange(0, BLOCK_M1) offs_n = start_n + tl.arange(0, BLOCK_N1) - offs_k = tl.arange(0, BLOCK_DMODEL) - QT_block_ptr = tl.make_block_ptr( - base=Q, - shape=(BLOCK_DMODEL, N_CTX), - strides=(stride_d, stride_tok), - offsets=(0, start_m), - block_shape=(BLOCK_DMODEL, BLOCK_M1), - order=(0,1) - ) - DO_block_ptr = tl.make_block_ptr( - base=DO, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_tok, stride_d), - offsets=(start_m, 0), - block_shape=(BLOCK_M1, BLOCK_DMODEL), - order=(1,0) - ) + offs_k = tl.arange(0, HEAD_DIM) + qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d + do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) curr_m = start_m step_m = BLOCK_M1 for blk_idx in range(num_steps): - qT = tl.load(QT_block_ptr) + qT = tl.load(qT_ptrs) # Load m before computing qk to reduce pipeline stall. offs_m = curr_m + tl.arange(0, BLOCK_M1) m = tl.load(M + offs_m) @@ -251,7 +238,7 @@ def _attn_bwd_dkdv(dk, dv, if MASK: mask = (offs_m[None, :] >= offs_n[:, None]) pT = tl.where(mask, pT, 0.0) - do = tl.load(DO_block_ptr) + do = tl.load(do_ptrs) # Compute dV. ppT = pT ppT = ppT.to(tl.float16) @@ -259,49 +246,35 @@ def _attn_bwd_dkdv(dk, dv, # D (= delta) is pre-divided by ds_scale. Di = tl.load(D + offs_m) # Compute dP and dS. - dpT = tl.dot(v, tl.trans(do)) + dpT = tl.dot(v, tl.trans(do)).to(tl.float32) dsT = pT * (dpT - Di[None, :]) dsT = dsT.to(tl.float16) dk += tl.dot(dsT, tl.trans(qT)) # Increment pointers. curr_m += step_m - QT_block_ptr = tl.advance(QT_block_ptr, (0, step_m)) - DO_block_ptr = tl.advance(DO_block_ptr, (step_m, 0)) + qT_ptrs += step_m * stride_tok + do_ptrs += step_m * stride_tok return dk, dv # the main inner-loop logic for computing dQ @triton.jit -def _attn_bwd_dq(dq, q, K, V, +def _attn_bwd_dq(dq, q, K, V, # do, m, D, # shared by Q/K/V/DO. - stride_tok, stride_d, - H, N_CTX, - BLOCK_M2: tl.constexpr, - BLOCK_N2: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, + stride_tok, stride_d, # + H, N_CTX, # + BLOCK_M2: tl.constexpr, # + BLOCK_N2: tl.constexpr, # + HEAD_DIM: tl.constexpr, # Filled in by the wrapper. - start_m, start_n, num_steps, + start_m, start_n, num_steps, # MASK: tl.constexpr): offs_m = start_m + tl.arange(0, BLOCK_M2) offs_n = start_n + tl.arange(0, BLOCK_N2) - offs_k = tl.arange(0, BLOCK_DMODEL) - KT_block_ptr = tl.make_block_ptr( - base=K, - shape=(BLOCK_DMODEL, N_CTX), - strides=(stride_d, stride_tok), - offsets=(0, start_n), - block_shape=(BLOCK_DMODEL, BLOCK_N2), - order=(0, 1) - ) - VT_block_ptr = tl.make_block_ptr( - base=V, - shape=(BLOCK_DMODEL, N_CTX), - strides=(stride_d, stride_tok), - offsets=(0, start_n), - block_shape=(BLOCK_DMODEL, BLOCK_N2), - order=(0, 1) - ) + offs_k = tl.arange(0, HEAD_DIM) + kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d + vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d # D (= delta) is pre-divided by ds_scale. Di = tl.load(D + offs_m) # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. @@ -309,7 +282,8 @@ def _attn_bwd_dq(dq, q, K, V, curr_n = start_n step_n = BLOCK_N2 for blk_idx in range(num_steps): - kT = tl.load(KT_block_ptr) + kT = tl.load(kT_ptrs) + vT = tl.load(vT_ptrs) qk = tl.dot(q, kT) p = tl.math.exp2(qk - m) # Autoregressive masking. @@ -318,7 +292,6 @@ def _attn_bwd_dq(dq, q, K, V, mask = (offs_m[:, None] >= offs_n[None, :]) p = tl.where(mask, p, 0.0) # Compute dP and dS. - vT = tl.load(VT_block_ptr) dp = tl.dot(do, vT).to(tl.float32) ds = p * (dp - Di[:, None]) ds = ds.to(tl.float16) @@ -327,50 +300,25 @@ def _attn_bwd_dq(dq, q, K, V, dq += tl.dot(ds, tl.trans(kT)) # Increment pointers. curr_n += step_n - KT_block_ptr = tl.advance(KT_block_ptr, (0, step_n)) - VT_block_ptr = tl.advance(VT_block_ptr, (0, step_n)) + kT_ptrs += step_n * stride_tok + vT_ptrs += step_n * stride_tok return dq -@triton.autotune( - configs=[ - triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 1}, - num_stages=1, num_warps=4), - triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 2}, - num_stages=1, num_warps=4), - triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 1}, - num_stages=1, num_warps=4), - triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 2}, - num_stages=1, num_warps=4), - triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 1}, - num_stages=1, num_warps=4), - triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 2}, - num_stages=1, num_warps=4), - triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 1}, - num_stages=1, num_warps=4), - triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 2}, - num_stages=1, num_warps=4), - triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 2}, - num_stages=1, num_warps=8), - ], - key=['H', 'N_CTX', 'BLOCK_DMODEL'], -) - @triton.jit -def _attn_bwd(Q, K, V, sm_scale, - DO, - DQ, DK, DV, +def _attn_bwd(Q, K, V, sm_scale, # + DO, # + DQ, DK, DV, # M, D, # shared by Q/K/V/DO. - stride_z, stride_h, stride_tok, stride_d, - # H = 16, N_CTX = 1024 - H, N_CTX, - BLOCK_DMODEL: tl.constexpr, - BLOCK_M1: tl.constexpr, - BLOCK_N1: tl.constexpr, - BLOCK_M2: tl.constexpr, - BLOCK_N2: tl.constexpr, - BLK_SLICE_FACTOR: tl.constexpr): + stride_z, stride_h, stride_tok, stride_d, # + H, N_CTX, # + BLOCK_M1: tl.constexpr, # + BLOCK_N1: tl.constexpr, # + BLOCK_M2: tl.constexpr, # + BLOCK_N2: tl.constexpr, # + BLK_SLICE_FACTOR: tl.constexpr, # + HEAD_DIM: tl.constexpr): LN2: tl.constexpr = 0.6931471824645996 # = ln(2) bhid = tl.program_id(2) @@ -389,91 +337,58 @@ def _attn_bwd(Q, K, V, sm_scale, M += off_chz D += off_chz - offs_k = tl.arange(0, BLOCK_DMODEL) + # load scales + offs_k = tl.arange(0, HEAD_DIM) start_n = pid * BLOCK_N1 - # This assignment is important. It is what allows us to pick the diagonal - # blocks. Later, when we want to do the lower triangular, we update start_m - # after the first dkdv call. start_m = start_n MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR offs_n = start_n + tl.arange(0, BLOCK_N1) - dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) - dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) + dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) - K_block_ptr = tl.make_block_ptr( - base=K, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_tok, stride_d), - offsets=(start_n, 0), - block_shape=(BLOCK_N1, BLOCK_DMODEL), - order=(1, 0), - ) - V_block_ptr = tl.make_block_ptr( - base=V, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_tok, stride_d), - offsets=(start_n, 0), - block_shape=(BLOCK_N1, BLOCK_DMODEL), - order=(1, 0), - ) - - # load K and V: they stay in SRAM throughout the inner loop for dkdv. - k = tl.load(K_block_ptr) - v = tl.load(V_block_ptr) + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) + v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) num_steps = BLOCK_N1 // MASK_BLOCK_M1 - dk, dv = _attn_bwd_dkdv(dk, dv, - Q, k, v, sm_scale, - DO, - M, D, - stride_tok, stride_d, - H, N_CTX, - MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, - start_n, start_m, num_steps, - MASK=True + dk, dv = _attn_bwd_dkdv(dk, dv, # + Q, k, v, sm_scale, # + DO, # + M, D, # + stride_tok, stride_d, # + H, N_CTX, # + MASK_BLOCK_M1, BLOCK_N1, HEAD_DIM, # + start_n, start_m, num_steps, # + MASK=True # ) start_m += num_steps * MASK_BLOCK_M1 num_steps = (N_CTX - start_m) // BLOCK_M1 # Compute dK and dV for non-masked blocks. - dk, dv = _attn_bwd_dkdv( - dk, dv, - Q, k, v, sm_scale, - DO, - M, D, - stride_tok, stride_d, - H, N_CTX, - BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, - start_n, start_m, num_steps, - MASK=False + dk, dv = _attn_bwd_dkdv( # + dk, dv, # + Q, k, v, sm_scale, # + DO, # + M, D, # + stride_tok, stride_d, # + H, N_CTX, # + BLOCK_M1, BLOCK_N1, HEAD_DIM, # + start_n, start_m, num_steps, # + MASK=False # ) - DV_block_ptrs = tl.make_block_ptr( - base=DV, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_tok, stride_d), - offsets=(start_n, 0), - block_shape=(BLOCK_N1, BLOCK_DMODEL), - order=(1,0) - ) - tl.store(DV_block_ptrs, dv.to(tl.float16)) + dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d + tl.store(dv_ptrs, dv) # Write back dK. dk *= sm_scale - DK_block_ptrs = tl.make_block_ptr( - base=DK, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_tok, stride_d), - offsets=(start_n, 0), - block_shape=(BLOCK_N1, BLOCK_DMODEL), - order=(1,0) - ) - tl.store(DK_block_ptrs, dk.to(tl.float16)) + dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d + tl.store(dk_ptrs, dk) # THIS BLOCK DOES DQ: start_m = pid * BLOCK_M2 @@ -482,26 +397,9 @@ def _attn_bwd(Q, K, V, sm_scale, MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR offs_m = start_m + tl.arange(0, BLOCK_M2) - Q_block_ptr = tl.make_block_ptr( - base=Q, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_tok, stride_d), - offsets=(start_m, 0), - block_shape=(BLOCK_M2, BLOCK_DMODEL), - order=(1, 0) - ) - - DO_block_ptr = tl.make_block_ptr( - base=DO, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_tok, stride_d), - offsets=(start_m, 0), - block_shape=(BLOCK_M2, BLOCK_DMODEL), - order=(1, 0) - ) - q = tl.load(Q_block_ptr) - do = tl.load(DO_block_ptr) - dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32) + q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) + dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) + do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) m = tl.load(M + offs_m) m = m[:, None] @@ -512,39 +410,29 @@ def _attn_bwd(Q, K, V, sm_scale, # not due to anything important. I just wanted to reuse the loop # structure for dK & dV above as much as possible. num_steps = BLOCK_M2 // MASK_BLOCK_N2 - dq = _attn_bwd_dq(dq, q, K, V, - do, m, D, - stride_tok, stride_d, - H, N_CTX, - BLOCK_M2, MASK_BLOCK_N2, BLOCK_DMODEL, - start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, - MASK=True + dq = _attn_bwd_dq(dq, q, K, V, # + do, m, D, # + stride_tok, stride_d, # + H, N_CTX, # + BLOCK_M2, MASK_BLOCK_N2, HEAD_DIM, # + start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, # + MASK=True # ) end_n -= num_steps * MASK_BLOCK_N2 # stage 2 num_steps = end_n // BLOCK_N2 - dq = _attn_bwd_dq(dq, q, K, V, - do, m, D, - stride_tok, stride_d, - H, N_CTX, - BLOCK_M2, BLOCK_N2, BLOCK_DMODEL, - start_m, end_n - num_steps * BLOCK_N2, num_steps, - MASK=False + dq = _attn_bwd_dq(dq, q, K, V, # + do, m, D, # + stride_tok, stride_d, # + H, N_CTX, # + BLOCK_M2, BLOCK_N2, HEAD_DIM, # + start_m, end_n - num_steps * BLOCK_N2, num_steps, # + MASK=False # ) # Write back dQ. - DQ_block_ptr = tl.make_block_ptr( - base=DQ, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_tok, stride_d), - offsets=(start_m, 0), - block_shape=(BLOCK_M2, BLOCK_DMODEL), - order=(1, 0) - ) + dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d dq *= LN2 - tl.store(DQ_block_ptr, dq.to(tl.float16)) - - -empty = torch.empty(128, device="cuda") + tl.store(dq_ptrs, dq) class _attention(torch.autograd.Function): @@ -552,56 +440,42 @@ class _attention(torch.autograd.Function): @staticmethod def forward(ctx, q, k, v, causal, sm_scale): # shape constraints - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128} - o = torch.empty_like(q, dtype=v.dtype) - if torch.version.hip is None: - BLOCK_M = 128 - BLOCK_N = 64 if Lk <= 64 else 32 - num_stages = 4 if Lk <= 64 else 3 - num_warps = 4 if Lk <= 64 else 8 - # Tuning for H100 - if torch.cuda.get_device_capability()[0] == 9: - num_warps = 8 - num_stages = 7 if Lk >= 64 else 3 + HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] + # when v is in float8_e5m2 it is transposed. + HEAD_DIM_V = v.shape[-1] + assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V + assert HEAD_DIM_K in {16, 32, 64, 128, 256} + o = torch.empty_like(q) stage = 3 if causal else 1 - grid = lambda META: ( - triton.cdiv(q.shape[2], META['BLOCK_M']), - q.shape[0] * q.shape[1], - 1 - ) - M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + extra_kern_args = {} + # Tuning for AMD target + if is_hip(): + waves_per_eu = 3 if HEAD_DIM_K <= 64 else 2 + extra_kern_args = {"waves_per_eu": waves_per_eu, "allow_flush_denorm": True} + + grid = lambda args: (triton.cdiv(q.shape[2], args["BLOCK_M"]), q.shape[0] * q.shape[1], 1) + M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) _attn_fwd[grid]( - q, k, v, sm_scale, M, o, - q.stride(0), q.stride(1), q.stride(2), q.stride(3), - k.stride(0), k.stride(1), k.stride(2), k.stride(3), - v.stride(0), v.stride(1), v.stride(2), v.stride(3), - o.stride(0), o.stride(1), o.stride(2), o.stride(3), - q.shape[0], q.shape[1], - N_CTX=q.shape[2], - BLOCK_DMODEL=Lk, - STAGE=stage, - ) - - ## restore the grid for bwd kernel - #best_config = _attn_fwd.get_best_config() - block_m = 64#int(best_config.__str__().split(",")[0].split("BLOCK_M:")[1]) - grid = (triton.cdiv(q.shape[2], block_m), q.shape[0] * q.shape[1], 1) + q, k, v, sm_scale, M, o, # + q.stride(0), q.stride(1), q.stride(2), q.stride(3), # + k.stride(0), k.stride(1), k.stride(2), k.stride(3), # + v.stride(0), v.stride(1), v.stride(2), v.stride(3), # + o.stride(0), o.stride(1), o.stride(2), o.stride(3), # + q.shape[0], q.shape[1], # + N_CTX=q.shape[2], # + HEAD_DIM=HEAD_DIM_K, # + STAGE=stage, # + **extra_kern_args) ctx.save_for_backward(q, k, v, o, M) ctx.grid = grid ctx.sm_scale = sm_scale - ctx.BLOCK_DMODEL = Lk + ctx.HEAD_DIM = HEAD_DIM_K ctx.causal = causal return o @staticmethod def backward(ctx, do): - if torch.version.hip is not None: - BLOCK = 64 - else: - BLOCK = 128 q, k, v, o, M = ctx.saved_tensors assert do.is_contiguous() assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride() @@ -610,96 +484,49 @@ def backward(ctx, do): dv = torch.empty_like(v) BATCH, N_HEAD, N_CTX = q.shape[:3] PRE_BLOCK = 128 - NUM_WARPS, NUM_STAGES = 4, 1 - BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 64, 64, 32 + NUM_WARPS, NUM_STAGES = 4, 5 + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 BLK_SLICE_FACTOR = 2 RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) arg_k = k arg_k = arg_k * (ctx.sm_scale * RCP_LN2) + PRE_BLOCK = 128 assert N_CTX % PRE_BLOCK == 0 pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD) delta = torch.empty_like(M) _attn_bwd_preprocess[pre_grid]( - o, do, - delta, - BATCH, N_HEAD, N_CTX, - BLOCK_M=PRE_BLOCK, D_HEAD=ctx.BLOCK_DMODEL - ) - grid = lambda META: ( - triton.cdiv(N_CTX, META['BLOCK_N1']), - 1, - BATCH * N_HEAD + o, do, # + delta, # + BATCH, N_HEAD, N_CTX, # + BLOCK_M=PRE_BLOCK, HEAD_DIM=ctx.HEAD_DIM # ) + grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD) _attn_bwd[grid]( - q, arg_k, v, ctx.sm_scale, do, dq, dk, dv, - M, delta, - q.stride(0), q.stride(1), q.stride(2), q.stride(3), - N_HEAD, N_CTX, - BLOCK_DMODEL=ctx.BLOCK_DMODEL + q, arg_k, v, ctx.sm_scale, do, dq, dk, dv, # + M, delta, # + q.stride(0), q.stride(1), q.stride(2), q.stride(3), # + N_HEAD, N_CTX, # + BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, # + BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, # + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, # + HEAD_DIM=ctx.HEAD_DIM, # + num_warps=NUM_WARPS, # + num_stages=NUM_STAGES # ) return dq, dk, dv, None, None + attention = _attention.apply -name_to_torch_types = { - 'fp16': torch.float16, -} - -if TORCH_HAS_FP8E5B16: - name_to_torch_types['fp8'] = torch.float8_e5m2fnuz - -@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD, dtype', -[ (*shape, dtype) - for shape in [(4, 48, 1024, 64), - (4, 48, 2048, 64), - (4, 48, 4096, 64), - (4, 48, 1024, 128), - (4, 48, 2048, 128), - (4, 48, 4096, 128)] - for dtype in ['fp16', 'fp8']]) -@pytest.mark.parametrize('causal', [False, True]) -def test_op_fwd(Z, H, N_CTX, D_HEAD, causal, dtype): - if dtype == 'fp8' and not TORCH_HAS_FP8E5B16: - pytest.skip("fp8 not supported") - torch.manual_seed(20) - q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5).requires_grad_() - k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5).requires_grad_() - v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5).requires_grad_() - q = q.to(name_to_torch_types[dtype]) - k = k.to(name_to_torch_types[dtype]) - sm_scale = 0.5 - dout = torch.randn_like(q, dtype=torch.float16) - # reference implementation - M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) - p = torch.matmul(q.half(), k.transpose(2, 3).half()) * sm_scale - if causal: - p[:, :, M == 0] = float("-inf") - p = torch.softmax(p.float(), dim=-1).half() - ref_out = torch.matmul(p, v) - # triton implementation - tri_out = attention(q, k, v, causal, sm_scale) - # compare - atol = 1.4e-1 if dtype == 'fp8' else 1e-2 - rtol = 1e-2 if dtype == 'fp8' else 0 - torch.testing.assert_close(ref_out, tri_out, atol=atol, rtol=rtol) - - -@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', - [(4, 48, 1024, 64), - (4, 48, 2048, 64), - (4, 48, 4096, 64), - (1, 16, 8192, 64), - (1, 16, 1024, 64), - ]) -def test_op_bwd(Z, H, N_CTX, D_HEAD, dtype=torch.float16): +@pytest.mark.parametrize("Z, H, N_CTX, HEAD_DIM", [(1, 2, 1024, 64)]) +@pytest.mark.parametrize("causal", [True]) +def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16): torch.manual_seed(20) - causal = True - q = (torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) - k = (torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) - v = (torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) - + q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) sm_scale = 0.5 dout = torch.randn_like(q) # reference implementation @@ -708,27 +535,28 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, dtype=torch.float16): if causal: p[:, :, M == 0] = float("-inf") p = torch.softmax(p.float(), dim=-1).half() + # p = torch.exp(p) ref_out = torch.matmul(p, v) ref_out.backward(dout) ref_dv, v.grad = v.grad.clone(), None ref_dk, k.grad = k.grad.clone(), None ref_dq, q.grad = q.grad.clone(), None - # # triton implementation - tri_out = attention(q, k, v, causal, sm_scale) + # triton implementation + tri_out = attention(q, k, v, causal, sm_scale).half() tri_out.backward(dout) tri_dv, v.grad = v.grad.clone(), None tri_dk, k.grad = k.grad.clone(), None tri_dq, q.grad = q.grad.clone(), None # compare - torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=0) - if torch.version.hip is None: - torch.testing.assert_close(ref_dv, tri_dv, atol=1e-2, rtol=0) - # The current block size for MI200 series is 64x64. This results in - # larger differences in float results due to rounding. - else: - torch.testing.assert_close(ref_dv, tri_dv, atol=5e-2, rtol=0) - torch.testing.assert_close(ref_dk, tri_dk, atol=5e-2, rtol=1e-2) - torch.testing.assert_close(ref_dq, tri_dq, atol=5e-2, rtol=1e-2) + assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0) + rtol = 0.0 + # Relative tolerance workaround for known hardware limitation of MI200 GPU. + # For details see https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices + if torch.version.hip is not None and triton.runtime.driver.active.get_current_target().arch == "gfx90a": + rtol = 1e-2 + assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=rtol) + assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=rtol) + assert torch.allclose(ref_dq, tri_dq, atol=1e-2, rtol=rtol) try: @@ -738,69 +566,68 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, dtype=torch.float16): except BaseException: HAS_FLASH = False +TORCH_HAS_FP8 = hasattr(torch, 'float8_e5m2') +BATCH, N_HEADS, HEAD_DIM = 4, 32, 64 # vary seq length for fixed head and batch=4 configs = [] -for mode in ['fwd', 'bwd']: - for D_HEAD in [128, 64]: - for causal in [False, True]: - if mode == 'bwd' and causal == False: - continue - configs.append(triton.testing.Benchmark( - x_names=['BATCH', 'H', 'N_CTX'], - x_vals=[(4, 16, 1024), - (8, 16, 2048), - (4, 16, 4096), - (2, 16, 8192), - (1, 16, 16384), - (4, 48, 1024), - (4, 48, 2048), - (4, 48, 4096), - (4, 48, 8192), - (4, 48, 16384), - ], - line_arg='provider', - line_vals=['triton'] + (['flash'] if HAS_FLASH else []), - line_names=['Triton'] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []), - styles=[('red', '-'), ('blue', '-')], - ylabel='ms', - plot_name=f'fused-attention-{mode}-d{D_HEAD}-causal={causal}', +for mode in ["fwd", "bwd"]: + for causal in [True, False]: + if mode == "bwd" and not causal: + continue + configs.append( + triton.testing.Benchmark( + x_names=["N_CTX"], + x_vals=[2**i for i in range(10, 15)], + line_arg="provider", + line_vals=["triton-fp16"] + (["triton-fp8"] if TORCH_HAS_FP8 else []) + + (["flash"] if HAS_FLASH else []), + line_names=["Triton [FP16]"] + (["Triton [FP8]"] if TORCH_HAS_FP8 else []) + + (["Flash-2"] if HAS_FLASH else []), + styles=[("red", "-"), ("blue", "-"), ("green", "-")], + ylabel="ms", + plot_name=f"fused-attention-batch{BATCH}-head{N_HEADS}-d{HEAD_DIM}-{mode}-causal={causal}", args={ - 'D_HEAD': D_HEAD, - 'dtype': torch.float16, - 'mode': mode, - 'causal': causal, + "H": N_HEADS, + "BATCH": BATCH, + "HEAD_DIM": HEAD_DIM, + "mode": mode, + "causal": causal, }, )) @triton.testing.perf_report(configs) -def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype=torch.float16, device="cuda"): +def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, device="cuda"): assert mode in ["fwd", "bwd"] warmup = 25 - rep = 10 - # Bwd pass only supports causal=True right now - if mode == 'bwd': - causal = True - if provider == "triton": - q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - sm_scale = D_HEAD ** -0.5 + rep = 100 + dtype = torch.float16 + if "triton" in provider: + q = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) + k = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) + v = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) + if mode == "fwd" and "fp8" in provider: + q = q.to(torch.float8_e5m2) + k = k.to(torch.float8_e5m2) + v = v.permute(0, 1, 3, 2).contiguous() + v = v.permute(0, 1, 3, 2) + v = v.to(torch.float8_e5m2) + sm_scale = 1.3 fn = lambda: attention(q, k, v, causal, sm_scale) - if mode == 'bwd': + if mode == "bwd": o = fn() do = torch.randn_like(o) fn = lambda: o.backward(do, retain_graph=True) ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) if provider == "flash": - qkv = torch.randn((BATCH, N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True) + qkv = torch.randn((BATCH, N_CTX, 3, H, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) fn = lambda: flash_attn_func(qkv, causal=causal) if mode == "bwd": o = fn() do = torch.randn_like(o) fn = lambda: o.backward(do, retain_graph=True) ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) - flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD + flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * HEAD_DIM total_flops = 2 * flops_per_matmul if causal: total_flops *= 0.5 @@ -808,5 +635,7 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute) return total_flops / ms * 1e-9 -# only works on post-Ampere GPUs right now -bench_flash_attention.run(save_path=".", print_data=True) + +if __name__ == "__main__": + # only works on post-Ampere GPUs right now + bench_flash_attention.run(save_path=".", print_data=True)