Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Add caching for attention backend selection results #1381

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 65 additions & 57 deletions tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ def _get_attention_backends(
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_UNFUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True

alibi_slopes_shape = None
if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
Expand All @@ -156,48 +155,35 @@ def _get_attention_backends(
):
core_attention_bias_requires_grad = True

fused_attn_backends = []
available_backends = None
fused_attention_backend = None

def test():
attention_params = AttentionParams(
qkv_dtype=qkv_dtype,
qkv_layout=qkv_layout,
batch_size=config.batch_size,
num_heads=config.num_heads,
num_gqa_groups=config.num_gqa_groups,
max_seqlen_q=config.max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv,
head_dim_qk=config.head_dim_qk,
head_dim_v=config.head_dim_v,
attn_mask_type=config.attn_mask_type,
window_size=window_size,
alibi_slopes_shape=alibi_slopes_shape,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias_shape=core_attention_bias_shape,
core_attention_bias_requires_grad=core_attention_bias_requires_grad,
pad_between_seqs=pad_between_seqs,
attention_dropout=config.dropout_p,
context_parallel=context_parallel,
deterministic=deterministic,
fp8=fp8,
fp8_meta=fp8_meta,
)
_, _, fused_attention_backend, _, available_backends = get_attention_backend(
attention_params
)
return available_backends, fused_attention_backend

backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"}
with logging_context():
for i in range(3):
os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i)
_attention_backends["backend_selection_requires_update"] = True
available_backends, fused_attention_backend = test()
if fused_attention_backend == FusedAttnBackend[backends[i]]:
fused_attn_backends.append(fused_attention_backend)
return available_backends, fused_attn_backends
attention_params = AttentionParams(
qkv_dtype=qkv_dtype,
qkv_layout=qkv_layout,
batch_size=config.batch_size,
num_heads=config.num_heads,
num_gqa_groups=config.num_gqa_groups,
max_seqlen_q=config.max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv,
head_dim_qk=config.head_dim_qk,
head_dim_v=config.head_dim_v,
attn_mask_type=config.attn_mask_type,
window_size=window_size,
alibi_slopes_shape=alibi_slopes_shape,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias_shape=core_attention_bias_shape,
core_attention_bias_requires_grad=core_attention_bias_requires_grad,
pad_between_seqs=pad_between_seqs,
attention_dropout=config.dropout_p,
context_parallel=context_parallel,
deterministic=deterministic,
fp8=fp8,
fp8_meta=fp8_meta,
)
_attention_backends["update_selection"] = (
attention_params not in _attention_backends["attention_params"]
)
_attention_backends["update_env_vars_only"] = False
available_backends, _ = get_attention_backend(attention_params)
return available_backends


model_configs_base = {
Expand All @@ -222,7 +208,7 @@ def test():
@pytest.mark.parametrize("model_configs", [model_configs_base])
@pytest.mark.parametrize("model", model_configs_base.keys())
@pytest.mark.parametrize("ckpt_attn", [False])
@pytest.mark.parametrize("workspace_opt", [True, False])
@pytest.mark.parametrize("workspace_opt", [True]) # , False])
@pytest.mark.parametrize("qkv_layout", [None])
@pytest.mark.parametrize("swa", [False])
@pytest.mark.parametrize("pad_between_seqs", [False])
Expand Down Expand Up @@ -250,14 +236,16 @@ def test_dot_product_attention(
if swa:
window_size = [2, 2]
config.window_size = check_set_window_size(config.attn_mask_type, window_size)
available_backends, fused_attn_backends = _get_attention_backends(
available_backends = _get_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
window_size=config.window_size,
pad_between_seqs=pad_between_seqs,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
flash_attn_supported, fused_attn_supported, unfused_attn_supported, fused_attn_backends = (
available_backends
)
# FlashAttention does not support pad_between_seqs, but _run_dot_product_attention
# mannually pads and unpads the input and output of FlashAttention for testing purposes
if pad_between_seqs and not (
Expand Down Expand Up @@ -679,12 +667,15 @@ def _run_dot_product_attention(
reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0"
_attention_backends["backend_selection_requires_update"] = True
if backend == "UnfusedDotProductAttention":
os.environ["NVTE_UNFUSED_ATTN"] = "1"
_attention_backends["update_env_vars_only"] = True

# Create seqlens
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
Expand Down Expand Up @@ -1045,12 +1036,14 @@ def test_transformer_layer(
workspace_opt = True

# Test backend availability
available_backends, fused_attn_backends = _get_attention_backends(
available_backends = _get_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout="sbh3d" if fused_qkv_params else "sb3hd",
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
flash_attn_supported, fused_attn_supported, unfused_attn_supported, fused_attn_backends = (
available_backends
)

# Skip if only unfused backend is supported
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
Expand Down Expand Up @@ -1168,11 +1161,14 @@ def _run_transformer_layer(
reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
if backend == "UnfusedDotProductAttention":
os.environ["NVTE_UNFUSED_ATTN"] = "1"
_attention_backends["update_env_vars_only"] = True

# Create input tensor
inp = torch.randn(
Expand Down Expand Up @@ -1362,15 +1358,17 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
if _flash_attn_3_is_installed and not is_training:
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
os.environ["NVTE_UNFUSED_ATTN"] = "0"
_attention_backends["update_selection"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
flash_attn_fwd_fp8, param_names, flash_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
dtype, config, True, qkv_format, input_layernorm, RoPE, is_training
)

os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
os.environ["NVTE_UNFUSED_ATTN"] = "0"
_attention_backends["update_selection"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
dtype, config, True, qkv_format, input_layernorm, RoPE, is_training
Expand Down Expand Up @@ -1541,15 +1539,17 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
if _flash_attn_3_is_installed and not is_training:
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
os.environ["NVTE_UNFUSED_ATTN"] = "0"
_attention_backends["update_selection"] = True
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True")
flash_attn_fwd_fp8, flash_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
dtype, config, True, qkv_layout, is_training
)

os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
os.environ["NVTE_UNFUSED_ATTN"] = "0"
_attention_backends["update_selection"] = True
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True")
fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
dtype, config, True, qkv_layout, is_training
Expand Down Expand Up @@ -1753,7 +1753,9 @@ def test_custom_mha_fp8_vs_f16(dtype, model):
config = model_configs_fp8[model]

fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(dtype, config, "FusedAttention")
unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(dtype, config, "UnfusedAttention")
unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(
dtype, config, "UnfusedDotProductAttention"
)

atol = 5e-1
rtol = 5e-1
Expand Down Expand Up @@ -1784,11 +1786,14 @@ def _run_custom_mha_fp8(dtype, config, backend):
reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
if backend == "UnfusedDotProductAttention":
os.environ["NVTE_UNFUSED_ATTN"] = "1"
_attention_backends["update_selection"] = True

inp = 0.0001 * torch.randint(
-100,
Expand Down Expand Up @@ -1838,11 +1843,14 @@ def _run_ref_mha_f16(dtype, config, backend):

os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
if backend == "UnfusedDotProductAttention":
os.environ["NVTE_UNFUSED_ATTN"] = "1"
_attention_backends["update_selection"] = True

inp = torch.load("qkv.pt").to(device="cuda")
inp.requires_grad = True
Expand Down
Loading
Loading