Skip to content

Commit

Permalink
export recent AMD-focussed triton_splitk work
Browse files Browse the repository at this point in the history
ghstack-source-id: 8f2dbaabff2a5935c45b58bad4de3de803bf0fc1
Pull Request resolved: fairinternal/xformers#1238

__original_commit__ = fairinternal/xformers@f4cbd36
  • Loading branch information
bottler authored and xFormers Bot committed Oct 15, 2024
1 parent 3a25236 commit 916611d
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 29 deletions.
16 changes: 13 additions & 3 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,8 +548,6 @@ def test_logsumexp(opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv):
kv,
) = opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv

if op is fmha.ck.FwOp:
pytest.skip("logsumexp is not yet supported by ck-tiled fmha!")
query, key, value, attn_bias = create_tensors(
*opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv,
fmt="BMHK",
Expand Down Expand Up @@ -588,7 +586,11 @@ def test_logsumexp(opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv):
if op is fmha.cutlass.FwOp:
# CUTLASS kernel pads the last dimention of LSE to 32
lse = lse[:, :, : ref_lse.shape[2]]
assert_allclose(lse, ref_lse, atol=2e-4)
if op is fmha.ck.FwOp:
# relax numerical tolerance for CK FwOp
assert_allclose(lse, ref_lse, atol=2e-4, rtol=2e-4)
else:
assert_allclose(lse, ref_lse, atol=2e-4)


@cuda_only
Expand Down Expand Up @@ -674,8 +676,16 @@ def test_backward(

if op_bw == fmha.ck.BwOp:
op_fw = fmha.ck.FwOp
if dtype == torch.bfloat16:
pytest.skip(
"CK Fmha backward for bfloat16 currently is not very accurate for some cases!"
)
if grad_out_contiguous is False:
pytest.skip("CK Fmha does not support contiguous layout for grad_out!")
if k % 2 != 0:
pytest.skip(
"CK Fmha currently requires the headdim size of query input be an even value!"
)

qkv = None

Expand Down
55 changes: 29 additions & 26 deletions xformers/ops/fmha/triton_splitk.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,18 +579,18 @@ def _fwd_kernel_splitK(
# Maybe we can unroll the last iteration instead?
if BOUNDS_CHECKS_N:
qk = tl.where(tl.arange(0, BLOCK_N) < hi - start_n, qk, float("-inf"))
if IS_CAUSAL:
# -- apply the causal mask --
qk = tl.where(diag_idx_shifted >= start_n, qk, float("-inf"))
# -- compute scaling constant ---
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
alpha = tl.math.exp2(m_i - m_i_new)
p = tl.math.exp2(qk - m_i_new[:, None])
if HAS_ADDITIVE_BIAS:
if HAS_ADDITIVE_BIAS or IS_CAUSAL:
# NOTE: It's possible that an entire block is masked out.
# if this is the case, `m_i_new=nan` and everything becomes nan
alpha = tl.where(m_i_new == float("-inf"), 0, alpha)
p = tl.where(m_i_new[:, None] == float("-inf"), 0, p)
if IS_CAUSAL:
# -- apply the causal mask --
p = tl.where(diag_idx_shifted >= start_n, p, 0)

# -- update m_i and l_i --
l_i = l_i * alpha + tl.sum(p, 1)
Expand Down Expand Up @@ -717,6 +717,7 @@ def autotune_kernel(kernel: Callable):
kernel = triton.autotune(
configs=TRITON_CONFIGS,
key=AUTOTUNER_KEY,
use_cuda_graph=True,
)(kernel)
return kernel

Expand All @@ -734,15 +735,17 @@ def autotune_kernel(kernel: Callable):
_get_splitk_kernel(num_groups)
)

def get_autotuner_cache(num_groups: int) -> Dict[Tuple[int], triton.Config]:
def get_autotuner_cache(
num_groups: int,
) -> Dict[Tuple[Union[int, str]], triton.Config]:
"""Returns a triton.runtime.autotuner.AutoTuner.cache object, which
represents mappings from kernel autotune keys (tuples describing kernel inputs)
to triton.Config
"""
return _fwd_kernel_splitK_autotune[num_groups].cache

def set_autotuner_cache(
cache: Dict[Tuple[int], triton.Config], num_groups: int
cache: Dict[Tuple[Union[int, str]], triton.Config], num_groups: int
) -> None:
_fwd_kernel_splitK_autotune[num_groups].cache = cache

Expand Down Expand Up @@ -871,26 +874,26 @@ def _splitK_reduce(
LSE, # [B, H, M]
split_k: tl.constexpr,
splitK_pow2: tl.constexpr,
stride_osk_z,
stride_osk_g,
stride_osk_h,
stride_osk_s,
stride_osk_m,
stride_osk_k,
stride_lsek_z,
stride_lsek_g,
stride_lsek_h,
stride_lsek_s,
stride_lsek_m,
stride_oz,
stride_og,
stride_oh,
stride_om,
stride_ok,
stride_lse_z,
stride_lse_g,
stride_lse_h,
stride_lse_m,
stride_osk_z: tl.constexpr,
stride_osk_g: tl.constexpr,
stride_osk_h: tl.constexpr,
stride_osk_s: tl.constexpr,
stride_osk_m: tl.constexpr,
stride_osk_k: tl.constexpr,
stride_lsek_z: tl.constexpr,
stride_lsek_g: tl.constexpr,
stride_lsek_h: tl.constexpr,
stride_lsek_s: tl.constexpr,
stride_lsek_m: tl.constexpr,
stride_oz: tl.constexpr,
stride_og: tl.constexpr,
stride_oh: tl.constexpr,
stride_om: tl.constexpr,
stride_ok: tl.constexpr,
stride_lse_z: tl.constexpr,
stride_lse_g: tl.constexpr,
stride_lse_h: tl.constexpr,
stride_lse_m: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
H: tl.constexpr,
G: tl.constexpr,
Expand Down

0 comments on commit 916611d

Please sign in to comment.