Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Aot change merge #549

Draft
wants to merge 4 commits into
base: triton-mlir
Choose a base branch
from
Draft

Aot change merge #549

wants to merge 4 commits into from

Conversation

groenenboomj
Copy link

Merge in AOTriton backwards kernel changes

Bring in changes from AOT, unedited.

@@ -544,132 +591,198 @@ def _bwd_kernel_dk_dv(
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
stride_oz, stride_oh, stride_om, stride_ok,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dkdv kernel needs strides from do, dk and dv (o's not used, contradicting to its name)

q = tl.load(Q_block_ptr, boundary_check=(0,1), padding_option="zero")
do = tl.load(DO_block_ptr, boundary_check=(0,1), padding_option="zero")
else:
q = tl.load(Q_block_ptr, boundary_check=(0,), padding_option="zero")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs one more level of branching.
However if we are aiming for the performance we should consider commenting out the boundary_check for now

@@ -680,82 +793,118 @@ def _bwd_kernel_dq(
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
seqlen_q, seqlen_k, dropout_p, philox_seed, philox_offset_base,
stride_oz, stride_oh, stride_om, stride_ok,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to dkdv, dq kernel needs strides from do and dq (again, not strides from o)

strides=(stride_qm, stride_qk),
offsets=(start_m, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
tl.store(DQ_block_ptr, (dq * sm_scale).to(DQ_block_ptr.type.element_ty))
tl.store(DQ_block_ptr, (dq * sm_scale).to(DQ_block_ptr.type.element_ty), boundary_check=(0,1))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tl.store also cause some performance penalties with boundary checks, although it shouldn't.

q, k, v, ctx.sm_scale,
o, do_scaled,
dk, dv,
L, delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more reminder, o.strides are not used by backward kernels.

@@ -893,28 +1046,41 @@ def backward(ctx, do, _):
seqlen_q = q.shape[2]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove these two assertions above, since they are not need anymore.

q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
dropout_p = 0

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dropout_p = 0?

@@ -1186,45 +1361,88 @@ def test_op_varlen_mqa_fwd(Z, HQ, HK, N_CTX, D_HEAD, causal, dtype=torch.float16
(4, 48, 2048, 64),
(4, 48, 4096, 64),
(1, 16, 8192, 64),
(1, 16, 128, 32),

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One major concern is the UT coverage here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants