diff --git a/python/perf-kernels/flash-attention.py b/python/perf-kernels/flash-attention.py index 988438340abe..58cf45024dcd 100644 --- a/python/perf-kernels/flash-attention.py +++ b/python/perf-kernels/flash-attention.py @@ -316,7 +316,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri num_warps=4), ], key=['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'], - use_cuda_graph=True, + #use_cuda_graph=True, ) @triton.jit def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, L, Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, @@ -632,14 +632,14 @@ def _bwd_kernel_dk_dv(dk, dv, Q, k, v, sm_scale, alibi_slope, DO, M, D, do = tl.load(DO_block_ptr) # Compute dV. ppT = pT - ppT = ppT.to(tl.float16) + ppT = ppT.to(do.dtype) dv += tl.dot(ppT, do) # D (= delta) is pre-divided by ds_scale. Di = tl.load(D + offs_m) # Compute dP and dS. dpT = tl.dot(v, tl.trans(do)) dsT = pT * (dpT - Di[None, :]) - dsT = dsT.to(tl.float16) + dsT = dsT.to(qT.dtype) dk += tl.dot(dsT, tl.trans(qT)) # Increment pointers. curr_m += step_m @@ -685,7 +685,7 @@ def _bwd_kernel_dq(dq, q, K, V, do, m, D, alibi_slope, vT = tl.load(VT_block_ptr) dp = tl.dot(do, vT).to(tl.float32) ds = p * (dp - Di[:, None]) - ds = ds.to(tl.float16) + ds = ds.to(kT.dtype) # Compute dQ.0. # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. dq += tl.dot(ds, tl.trans(kT)) @@ -695,13 +695,12 @@ def _bwd_kernel_dq(dq, q, K, V, do, m, D, alibi_slope, VT_block_ptr = tl.advance(VT_block_ptr, (0, step_n)) return dq - @triton.jit def _attn_bwd(Q, K, V, sm_scale, alibi_slopes, DO, DQ, DK, DV, M, D, # shared by Q/K/V/DO. stride_z, stride_h, stride_tok, stride_d, # H = 16, N_CTX = 1024 - H, N_CTX, BLOCK_DMODEL: tl.constexpr, BLOCK_M1: tl.constexpr, BLOCK_N1: tl.constexpr, + H, N_CTX, CAUSAL: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_M1: tl.constexpr, BLOCK_N1: tl.constexpr, BLOCK_M2: tl.constexpr, BLOCK_N2: tl.constexpr, BLK_SLICE_FACTOR: tl.constexpr, USE_ALIBI: tl.constexpr): LN2: tl.constexpr = 0.6931471824645996 # = ln(2) @@ -765,14 +764,14 @@ def _attn_bwd(Q, K, V, sm_scale, alibi_slopes, DO, DQ, DK, DV, M, D, # compute dK and dV for blocks close to the diagonal that need to be masked num_steps = BLOCK_N1 // MASK_BLOCK_M1 dk, dv = _bwd_kernel_dk_dv(dk, dv, Q, k, v, sm_scale, alibi_slope, DO, M, D, stride_tok, stride_d, H, N_CTX, - MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, start_n, start_m, num_steps, MASK=True) + MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, start_n, start_m, num_steps, MASK=CAUSAL) # compute dK and dV for blocks that don't need masking further from the diagonal start_m += num_steps * MASK_BLOCK_M1 num_steps = (N_CTX - start_m) // BLOCK_M1 dk, dv = _bwd_kernel_dk_dv(dk, dv, Q, k, v, sm_scale, alibi_slope, DO, M, D, stride_tok, stride_d, H, N_CTX, - BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, start_n, start_m, num_steps, MASK=False) + BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, start_n, start_m, num_steps, MASK=CAUSAL) DV_block_ptrs = tl.make_block_ptr(base=DV, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), offsets=(start_n, 0), block_shape=(BLOCK_N1, BLOCK_DMODEL), order=(1, 0)) @@ -943,6 +942,7 @@ def backward(ctx, do, _): BLOCK = 64 else: BLOCK = 128 + num_stages = 1 q, k, v, o, M = ctx.saved_tensors assert do.is_contiguous() assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride() @@ -999,6 +999,7 @@ def backward(ctx, do, _): q.stride(3), N_HEAD, N_CTX, + CAUSAL=ctx.causal, BLOCK_DMODEL=ctx.BLOCK_DMODEL, BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, @@ -1006,6 +1007,7 @@ def backward(ctx, do, _): BLOCK_N2=BLOCK_N2, BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, USE_ALIBI=False if ctx.alibi_slopes is None else True, + num_stages = 1, ) return dq, dk, dv, None, None @@ -1259,91 +1261,86 @@ def test_op_varlen_mqa_fwd(Z, HQ, HK, N_CTX, D_HEAD, causal, dtype=torch.float16 #(1, 16, 8192, 63), #(1, 16, 1022, 64), ]) -@pytest.mark.parametrize('qseqlen_not_equal_kseqlen', [None]) -@pytest.mark.parametrize('torch_sdpa_test', [False, True]) -@pytest.mark.parametrize('causal', [True]) +@pytest.mark.parametrize('causal', [False, True]) @pytest.mark.parametrize('use_alibi', [False, True]) -def test_op_bwd(Z, H, N_CTX, D_HEAD, qseqlen_not_equal_kseqlen, causal, torch_sdpa_test, use_alibi, - dtype=torch.float16): - pytest.skip() - torch.manual_seed(20) - if qseqlen_not_equal_kseqlen is not None: - seqlen_q = qseqlen_not_equal_kseqlen - else: - seqlen_q = N_CTX - seqlen_k = N_CTX - - if causal and ((N_CTX - 1) & N_CTX): - pytest.skip() - if causal and seqlen_q != seqlen_k: - pytest.skip() - - sm_scale = D_HEAD**-0.5 - input_metadata = MetaData(sm_scale=sm_scale) - input_metadata.max_seqlens_q = seqlen_q - input_metadata.max_seqlens_k = seqlen_k - - dropout_p = 0 - q = (torch.empty((Z, H, seqlen_q, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) - k = (torch.empty((Z, H, seqlen_k, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) - v = (torch.empty((Z, H, seqlen_k, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) - o = torch.empty_like(q) - +@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize('layout', ['bhsd']) +def test_op_bwd(Z, H, N_CTX, D_HEAD, causal, use_alibi, + layout, dtype): + torch.manual_seed(20) + + N_CTX_Q = N_CTX_K = N_CTX + HQ = HK = H + + q, k, v, input_metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout) + dout = torch.randn_like(q) + if causal: input_metadata.need_causal() - if use_alibi and not torch_sdpa_test: + if use_alibi: # for n heads the set of slopes is the geometric sequence that starts 2^(-8/n) - alibi_slopes = torch.tensor([2**(-8 / H * i) for i in range(1, H + 1)], dtype=torch.float32, + alibi_slopes = torch.tensor([2**(-8 / HQ * i) for i in range(1, HQ + 1)], dtype=torch.float32, device="cuda").repeat(Z, 1) - input_metadata.need_alibi(alibi_slopes, Z, H) - dout = torch.randn_like(q) - # reference implementation - if torch_sdpa_test: - ref_out, ref_softmax = torch.ops.aten._scaled_dot_product_attention_math(q, k, v, dropout_p=dropout_p, - is_causal=causal, scale=sm_scale, - dropout_mask=None) - ref_out.backward(dout.to(device=ref_out.device, dtype=ref_out.dtype)) - ref_dv, v.grad = v.grad.clone(), None - ref_dk, k.grad = k.grad.clone(), None - ref_dq, q.grad = q.grad.clone(), None + input_metadata.need_alibi(alibi_slopes, Z, HQ) else: - M = torch.tril(torch.ones((seqlen_q, seqlen_k), device="cuda")) - p = torch.matmul(q, k.transpose(2, 3)) * sm_scale - if use_alibi: - p += compute_alibi_tensor(alibi_slopes, N_CTX, N_CTX) - if causal: - p[:, :, M == 0] = float("-inf") + alibi_slopes = None - p = torch.softmax(p.float(), dim=-1).type(dtype=p.dtype) - ref_out = torch.matmul(p, v) - ref_out.backward(dout) - ref_dv, v.grad = v.grad.clone(), None - ref_dk, k.grad = k.grad.clone(), None - ref_dq, q.grad = q.grad.clone(), None + o = torch.empty_like(q) - # # triton implementation + # triton implementation tri_out, _ = attention(q, k, v, o, input_metadata) tri_out.backward(dout) tri_dv, v.grad = v.grad.clone(), None tri_dk, k.grad = k.grad.clone(), None tri_dq, q.grad = q.grad.clone(), None - # test - #print("reference") - #print(ref_dv) - #print("tri") - #print(tri_dv) + + # Transpose here if layout is bshd so we have same reference code for all layouts + if layout == 'bshd': + q = q.transpose(1, 2).clone() + k = k.transpose(1, 2).clone() + v = v.transpose(1, 2).clone() + # Replicate K and V if using MQA/GQA + if HQ != HK: + k = k.view(k.shape[0], k.shape[1], -1, k.shape[2], + k.shape[3]).expand(-1, -1, HQ // HK, -1, -1).reshape(k.shape[0], -1, k.shape[2], k.shape[3]) + v = v.view(v.shape[0], v.shape[1], -1, v.shape[2], + v.shape[3]).expand(-1, -1, HQ // HK, -1, -1).reshape(v.shape[0], -1, v.shape[2], v.shape[3]) + + scores = torch.einsum('bhqd,bhkd->bhqk', q, k).float() * input_metadata.sm_scale + if causal: + mask = torch.tril(torch.ones(N_CTX_Q, N_CTX_K, device="cuda"), diagonal=N_CTX_K - N_CTX_Q) + scores[:, :, mask == 0] = float("-inf") + if use_alibi: + scores += compute_alibi_tensor(alibi_slopes, N_CTX_Q, N_CTX_K) + + p = torch.softmax(scores, dim=-1) + if causal: + # If N_CTX_Q > N_CTX_K, there is at least one row of all -infs going into + # the softmax. This produces a row of NaNs as -inf - -inf == NaN. So we fix + # this by converting the NaNs to 0s, which is what they should be out of the softmax. + nan_mask = torch.isnan(p) + p = torch.where(nan_mask == 1,0,p) + ref_out = torch.einsum('bhqk,bhkd->bhqd', p.to(dtype), v) # compare - torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=0) + if layout == 'bshd': + ref_out = ref_out.transpose(1, 2).clone() + ref_out.backward(dout) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + + torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) + # The current block size for MI200 series is 64x64. This results in # larger differences in float results due to rounding. if dtype == torch.bfloat16: - ATOL = 1e-1 * max(1.0, (seqlen_q + D_HEAD) / 64.0) + ATOL = 1e-1 * max(1.0, (N_CTX_Q + D_HEAD) / 64.0) if dtype == torch.float32: - ATOL = 1e-3 * max(1.0, (seqlen_q + D_HEAD) / 64.0) + ATOL = 1e-3 * max(1.0, (N_CTX_Q + D_HEAD) / 64.0) else: - ATOL = 1e-1 * max(1.0, (seqlen_q + D_HEAD) / 64.0) + ATOL = 1e-1 * max(1.0, (N_CTX_Q + D_HEAD) / 64.0) RTOL = 0 @@ -1351,7 +1348,6 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, qseqlen_not_equal_kseqlen, causal, torch_sd torch.testing.assert_close(ref_dk, tri_dk, atol=ATOL, rtol=RTOL) torch.testing.assert_close(ref_dq, tri_dq, atol=ATOL, rtol=RTOL) - def nonvarlen_benchmark_configs(): configs = [ (16, 16, 16, 1024, 1024), @@ -1396,6 +1392,15 @@ def varlen_benchmark_configs(): ] return configs +def nonvarlen_backward_benchmark_configs(): + configs=[(16, 16, 16, 1024, 1024), + (8, 16, 16, 2048, 2048), + (4, 16, 16, 4096, 4096), + (2, 16, 16, 8192, 8192), + (1, 16, 16, 16384, 16384), + (2, 48, 48, 1024, 1024), + ] + return configs def run_benchmark(custom, args): @@ -1403,7 +1408,7 @@ def run_benchmark(custom, args): hk = args.hq if not args.hk else args.hk sk = args.sq if not args.sk else args.sk head_size = 128 if not args.d else args.d - mode = 'fwd' + mode = args.direction x_names = ['BATCH', 'HQ', 'HK', 'N_CTX_Q', 'N_CTX_K'] causal = args.causal varlen = args.layout == 'thd' @@ -1413,6 +1418,8 @@ def run_benchmark(custom, args): else: if varlen: x_vals_list = varlen_benchmark_configs() + elif mode == 'bwd': + x_vals_list = nonvarlen_backward_benchmark_configs() else: x_vals_list = nonvarlen_benchmark_configs() print_time = args.return_time @@ -1436,10 +1443,6 @@ def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal # bias = None # bias = None - # Bwd pass only supports causal=True right now - if mode == 'bwd': - causal = True - flops_per_matmul = 0 if varlen: q, k, v, input_metadata = varlen_input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, @@ -1502,6 +1505,7 @@ def parse_args(): parser.add_argument("-dtype", default='fp16') parser.add_argument("-return_time", action='store_true', default=False) parser.add_argument("-layout", type=str, default='bhsd', help=supported_layouts()) + parser.add_argument("-direction", default='fwd') return parser.parse_args()