Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Groenenboomj/fixes causal #575

Open
wants to merge 3 commits into
base: main_perf
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 82 additions & 78 deletions python/perf-kernels/flash-attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
num_warps=4),
],
key=['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'],
use_cuda_graph=True,
#use_cuda_graph=True,
)
@triton.jit
def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, L, Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz,
Expand Down Expand Up @@ -632,14 +632,14 @@ def _bwd_kernel_dk_dv(dk, dv, Q, k, v, sm_scale, alibi_slope, DO, M, D,
do = tl.load(DO_block_ptr)
# Compute dV.
ppT = pT
ppT = ppT.to(tl.float16)
ppT = ppT.to(do.dtype)
dv += tl.dot(ppT, do)
# D (= delta) is pre-divided by ds_scale.
Di = tl.load(D + offs_m)
# Compute dP and dS.
dpT = tl.dot(v, tl.trans(do))
dsT = pT * (dpT - Di[None, :])
dsT = dsT.to(tl.float16)
dsT = dsT.to(qT.dtype)
dk += tl.dot(dsT, tl.trans(qT))
# Increment pointers.
curr_m += step_m
Expand Down Expand Up @@ -685,7 +685,7 @@ def _bwd_kernel_dq(dq, q, K, V, do, m, D, alibi_slope,
vT = tl.load(VT_block_ptr)
dp = tl.dot(do, vT).to(tl.float32)
ds = p * (dp - Di[:, None])
ds = ds.to(tl.float16)
ds = ds.to(kT.dtype)
# Compute dQ.0.
# NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
dq += tl.dot(ds, tl.trans(kT))
Expand All @@ -695,13 +695,12 @@ def _bwd_kernel_dq(dq, q, K, V, do, m, D, alibi_slope,
VT_block_ptr = tl.advance(VT_block_ptr, (0, step_n))
return dq


@triton.jit
def _attn_bwd(Q, K, V, sm_scale, alibi_slopes, DO, DQ, DK, DV, M, D,
# shared by Q/K/V/DO.
stride_z, stride_h, stride_tok, stride_d,
# H = 16, N_CTX = 1024
H, N_CTX, BLOCK_DMODEL: tl.constexpr, BLOCK_M1: tl.constexpr, BLOCK_N1: tl.constexpr,
H, N_CTX, CAUSAL: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_M1: tl.constexpr, BLOCK_N1: tl.constexpr,
BLOCK_M2: tl.constexpr, BLOCK_N2: tl.constexpr, BLK_SLICE_FACTOR: tl.constexpr, USE_ALIBI: tl.constexpr):
LN2: tl.constexpr = 0.6931471824645996 # = ln(2)

Expand Down Expand Up @@ -765,14 +764,14 @@ def _attn_bwd(Q, K, V, sm_scale, alibi_slopes, DO, DQ, DK, DV, M, D,
# compute dK and dV for blocks close to the diagonal that need to be masked
num_steps = BLOCK_N1 // MASK_BLOCK_M1
dk, dv = _bwd_kernel_dk_dv(dk, dv, Q, k, v, sm_scale, alibi_slope, DO, M, D, stride_tok, stride_d, H, N_CTX,
MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, start_n, start_m, num_steps, MASK=True)
MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, start_n, start_m, num_steps, MASK=CAUSAL)

# compute dK and dV for blocks that don't need masking further from the diagonal
start_m += num_steps * MASK_BLOCK_M1
num_steps = (N_CTX - start_m) // BLOCK_M1

dk, dv = _bwd_kernel_dk_dv(dk, dv, Q, k, v, sm_scale, alibi_slope, DO, M, D, stride_tok, stride_d, H, N_CTX,
BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, start_n, start_m, num_steps, MASK=False)
BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, start_n, start_m, num_steps, MASK=CAUSAL)

DV_block_ptrs = tl.make_block_ptr(base=DV, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d),
offsets=(start_n, 0), block_shape=(BLOCK_N1, BLOCK_DMODEL), order=(1, 0))
Expand Down Expand Up @@ -943,6 +942,7 @@ def backward(ctx, do, _):
BLOCK = 64
else:
BLOCK = 128
num_stages = 1
q, k, v, o, M = ctx.saved_tensors
assert do.is_contiguous()
assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()
Expand Down Expand Up @@ -999,13 +999,15 @@ def backward(ctx, do, _):
q.stride(3),
N_HEAD,
N_CTX,
CAUSAL=ctx.causal,
BLOCK_DMODEL=ctx.BLOCK_DMODEL,
BLOCK_M1=BLOCK_M1,
BLOCK_N1=BLOCK_N1,
BLOCK_M2=BLOCK_M2,
BLOCK_N2=BLOCK_N2,
BLK_SLICE_FACTOR=BLK_SLICE_FACTOR,
USE_ALIBI=False if ctx.alibi_slopes is None else True,
num_stages = 1,
)

return dq, dk, dv, None, None
Expand Down Expand Up @@ -1259,99 +1261,93 @@ def test_op_varlen_mqa_fwd(Z, HQ, HK, N_CTX, D_HEAD, causal, dtype=torch.float16
#(1, 16, 8192, 63),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should test seqlens that are small so 1, 2, 4, 16, 32, 64, 128, 256, etc ...

#(1, 16, 1022, 64),
])
@pytest.mark.parametrize('qseqlen_not_equal_kseqlen', [None])
@pytest.mark.parametrize('torch_sdpa_test', [False, True])
@pytest.mark.parametrize('causal', [True])
@pytest.mark.parametrize('causal', [False, True])
@pytest.mark.parametrize('use_alibi', [False, True])
def test_op_bwd(Z, H, N_CTX, D_HEAD, qseqlen_not_equal_kseqlen, causal, torch_sdpa_test, use_alibi,
dtype=torch.float16):
pytest.skip()
torch.manual_seed(20)
if qseqlen_not_equal_kseqlen is not None:
seqlen_q = qseqlen_not_equal_kseqlen
else:
seqlen_q = N_CTX
seqlen_k = N_CTX

if causal and ((N_CTX - 1) & N_CTX):
pytest.skip()
if causal and seqlen_q != seqlen_k:
pytest.skip()

sm_scale = D_HEAD**-0.5
input_metadata = MetaData(sm_scale=sm_scale)
input_metadata.max_seqlens_q = seqlen_q
input_metadata.max_seqlens_k = seqlen_k

dropout_p = 0
q = (torch.empty((Z, H, seqlen_q, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
k = (torch.empty((Z, H, seqlen_k, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
v = (torch.empty((Z, H, seqlen_k, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
o = torch.empty_like(q)

@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
@pytest.mark.parametrize('layout', ['bhsd'])
def test_op_bwd(Z, H, N_CTX, D_HEAD, causal, use_alibi,
layout, dtype):
torch.manual_seed(20)

N_CTX_Q = N_CTX_K = N_CTX
HQ = HK = H

q, k, v, input_metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout)
dout = torch.randn_like(q)

if causal:
input_metadata.need_causal()

if use_alibi and not torch_sdpa_test:
if use_alibi:
# for n heads the set of slopes is the geometric sequence that starts 2^(-8/n)
alibi_slopes = torch.tensor([2**(-8 / H * i) for i in range(1, H + 1)], dtype=torch.float32,
alibi_slopes = torch.tensor([2**(-8 / HQ * i) for i in range(1, HQ + 1)], dtype=torch.float32,
device="cuda").repeat(Z, 1)
input_metadata.need_alibi(alibi_slopes, Z, H)
dout = torch.randn_like(q)
# reference implementation
if torch_sdpa_test:
ref_out, ref_softmax = torch.ops.aten._scaled_dot_product_attention_math(q, k, v, dropout_p=dropout_p,
is_causal=causal, scale=sm_scale,
dropout_mask=None)
ref_out.backward(dout.to(device=ref_out.device, dtype=ref_out.dtype))
ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None
input_metadata.need_alibi(alibi_slopes, Z, HQ)
else:
M = torch.tril(torch.ones((seqlen_q, seqlen_k), device="cuda"))
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
if use_alibi:
p += compute_alibi_tensor(alibi_slopes, N_CTX, N_CTX)
if causal:
p[:, :, M == 0] = float("-inf")
alibi_slopes = None

p = torch.softmax(p.float(), dim=-1).type(dtype=p.dtype)
ref_out = torch.matmul(p, v)
ref_out.backward(dout)
ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None
o = torch.empty_like(q)

# # triton implementation
# triton implementation
tri_out, _ = attention(q, k, v, o, input_metadata)
tri_out.backward(dout)
tri_dv, v.grad = v.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
tri_dq, q.grad = q.grad.clone(), None
# test
#print("reference")
#print(ref_dv)
#print("tri")
#print(tri_dv)

# Transpose here if layout is bshd so we have same reference code for all layouts
if layout == 'bshd':
q = q.transpose(1, 2).clone()
k = k.transpose(1, 2).clone()
v = v.transpose(1, 2).clone()
# Replicate K and V if using MQA/GQA
if HQ != HK:
k = k.view(k.shape[0], k.shape[1], -1, k.shape[2],
k.shape[3]).expand(-1, -1, HQ // HK, -1, -1).reshape(k.shape[0], -1, k.shape[2], k.shape[3])
v = v.view(v.shape[0], v.shape[1], -1, v.shape[2],
v.shape[3]).expand(-1, -1, HQ // HK, -1, -1).reshape(v.shape[0], -1, v.shape[2], v.shape[3])

scores = torch.einsum('bhqd,bhkd->bhqk', q, k).float() * input_metadata.sm_scale
if causal:
mask = torch.tril(torch.ones(N_CTX_Q, N_CTX_K, device="cuda"), diagonal=N_CTX_K - N_CTX_Q)
scores[:, :, mask == 0] = float("-inf")
if use_alibi:
scores += compute_alibi_tensor(alibi_slopes, N_CTX_Q, N_CTX_K)

p = torch.softmax(scores, dim=-1)
if causal:
# If N_CTX_Q > N_CTX_K, there is at least one row of all -infs going into
# the softmax. This produces a row of NaNs as -inf - -inf == NaN. So we fix
# this by converting the NaNs to 0s, which is what they should be out of the softmax.
nan_mask = torch.isnan(p)
p = torch.where(nan_mask == 1,0,p)
ref_out = torch.einsum('bhqk,bhkd->bhqd', p.to(dtype), v)
# compare
torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=0)
if layout == 'bshd':
ref_out = ref_out.transpose(1, 2).clone()
ref_out.backward(dout)
ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None

torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2)

# The current block size for MI200 series is 64x64. This results in
# larger differences in float results due to rounding.

if dtype == torch.bfloat16:
ATOL = 1e-1 * max(1.0, (seqlen_q + D_HEAD) / 64.0)
ATOL = 1e-1 * max(1.0, (N_CTX_Q + D_HEAD) / 64.0)
if dtype == torch.float32:
ATOL = 1e-3 * max(1.0, (seqlen_q + D_HEAD) / 64.0)
ATOL = 1e-3 * max(1.0, (N_CTX_Q + D_HEAD) / 64.0)
else:
ATOL = 1e-1 * max(1.0, (seqlen_q + D_HEAD) / 64.0)
ATOL = 1e-1 * max(1.0, (N_CTX_Q + D_HEAD) / 64.0)

RTOL = 0

torch.testing.assert_close(ref_dv, tri_dv, atol=ATOL, rtol=RTOL)
torch.testing.assert_close(ref_dk, tri_dk, atol=ATOL, rtol=RTOL)
torch.testing.assert_close(ref_dq, tri_dq, atol=ATOL, rtol=RTOL)


def nonvarlen_benchmark_configs():
configs = [
(16, 16, 16, 1024, 1024),
Expand Down Expand Up @@ -1396,14 +1392,23 @@ def varlen_benchmark_configs():
]
return configs

def nonvarlen_backward_benchmark_configs():
configs=[(16, 16, 16, 1024, 1024),
(8, 16, 16, 2048, 2048),
(4, 16, 16, 4096, 4096),
(2, 16, 16, 8192, 8192),
(1, 16, 16, 16384, 16384),
(2, 48, 48, 1024, 1024),
]
return configs

def run_benchmark(custom, args):

dtype = arg_to_torch_dtype[args.dtype]
hk = args.hq if not args.hk else args.hk
sk = args.sq if not args.sk else args.sk
head_size = 128 if not args.d else args.d
mode = 'fwd'
mode = args.direction
x_names = ['BATCH', 'HQ', 'HK', 'N_CTX_Q', 'N_CTX_K']
causal = args.causal
varlen = args.layout == 'thd'
Expand All @@ -1413,6 +1418,8 @@ def run_benchmark(custom, args):
else:
if varlen:
x_vals_list = varlen_benchmark_configs()
elif mode == 'bwd':
x_vals_list = nonvarlen_backward_benchmark_configs()
else:
x_vals_list = nonvarlen_benchmark_configs()
print_time = args.return_time
Expand All @@ -1436,10 +1443,6 @@ def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal
# bias = None
# bias = None

# Bwd pass only supports causal=True right now
if mode == 'bwd':
causal = True

flops_per_matmul = 0
if varlen:
q, k, v, input_metadata = varlen_input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype,
Expand Down Expand Up @@ -1502,6 +1505,7 @@ def parse_args():
parser.add_argument("-dtype", default='fp16')
parser.add_argument("-return_time", action='store_true', default=False)
parser.add_argument("-layout", type=str, default='bhsd', help=supported_layouts())
parser.add_argument("-direction", default='fwd')
return parser.parse_args()


Expand Down
Loading