diff --git a/python/perf-kernels/flash-attention.py b/python/perf-kernels/flash-attention.py index e4e27214508e..059a0cabb439 100644 --- a/python/perf-kernels/flash-attention.py +++ b/python/perf-kernels/flash-attention.py @@ -299,7 +299,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri return acc, l_i, m_i -def get_MI_autotune_config(): +def get_MI_autotune_configs(): return [ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), @@ -316,7 +316,7 @@ def get_MI_autotune_config(): num_warps=4), ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'] -def get_NAVI_autotune_config(): +def get_NAVI_autotune_configs(): return [ ], [ @@ -331,9 +331,9 @@ def is_navi(): def get_autotune_configs(): if is_navi(): - return get_NAVI_autotune_config() + return get_NAVI_autotune_configs() else: - return get_MI_autotune_config() + return get_MI_autotune_configs() autotune_configs, autotune_keys = get_autotune_configs()