diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index dea31b5971..0929d22762 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -135,7 +135,6 @@ def _get_attention_backends( os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_UNFUSED_ATTN"] = "1" - _attention_backends["backend_selection_requires_update"] = True alibi_slopes_shape = None if config.attn_bias_type == "alibi" and config.alibi_type == "custom": @@ -156,48 +155,35 @@ def _get_attention_backends( ): core_attention_bias_requires_grad = True - fused_attn_backends = [] - available_backends = None - fused_attention_backend = None - - def test(): - attention_params = AttentionParams( - qkv_dtype=qkv_dtype, - qkv_layout=qkv_layout, - batch_size=config.batch_size, - num_heads=config.num_heads, - num_gqa_groups=config.num_gqa_groups, - max_seqlen_q=config.max_seqlen_q, - max_seqlen_kv=config.max_seqlen_kv, - head_dim_qk=config.head_dim_qk, - head_dim_v=config.head_dim_v, - attn_mask_type=config.attn_mask_type, - window_size=window_size, - alibi_slopes_shape=alibi_slopes_shape, - core_attention_bias_type=config.attn_bias_type, - core_attention_bias_shape=core_attention_bias_shape, - core_attention_bias_requires_grad=core_attention_bias_requires_grad, - pad_between_seqs=pad_between_seqs, - attention_dropout=config.dropout_p, - context_parallel=context_parallel, - deterministic=deterministic, - fp8=fp8, - fp8_meta=fp8_meta, - ) - _, _, fused_attention_backend, _, available_backends = get_attention_backend( - attention_params - ) - return available_backends, fused_attention_backend - - backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"} - with logging_context(): - for i in range(3): - os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i) - _attention_backends["backend_selection_requires_update"] = True - available_backends, fused_attention_backend = test() - if fused_attention_backend == FusedAttnBackend[backends[i]]: - fused_attn_backends.append(fused_attention_backend) - return available_backends, fused_attn_backends + attention_params = AttentionParams( + qkv_dtype=qkv_dtype, + qkv_layout=qkv_layout, + batch_size=config.batch_size, + num_heads=config.num_heads, + num_gqa_groups=config.num_gqa_groups, + max_seqlen_q=config.max_seqlen_q, + max_seqlen_kv=config.max_seqlen_kv, + head_dim_qk=config.head_dim_qk, + head_dim_v=config.head_dim_v, + attn_mask_type=config.attn_mask_type, + window_size=window_size, + alibi_slopes_shape=alibi_slopes_shape, + core_attention_bias_type=config.attn_bias_type, + core_attention_bias_shape=core_attention_bias_shape, + core_attention_bias_requires_grad=core_attention_bias_requires_grad, + pad_between_seqs=pad_between_seqs, + attention_dropout=config.dropout_p, + context_parallel=context_parallel, + deterministic=deterministic, + fp8=fp8, + fp8_meta=fp8_meta, + ) + _attention_backends["update_selection"] = ( + attention_params not in _attention_backends["attention_params"] + ) + _attention_backends["update_env_vars_only"] = False + available_backends, _ = get_attention_backend(attention_params) + return available_backends model_configs_base = { @@ -222,7 +208,7 @@ def test(): @pytest.mark.parametrize("model_configs", [model_configs_base]) @pytest.mark.parametrize("model", model_configs_base.keys()) @pytest.mark.parametrize("ckpt_attn", [False]) -@pytest.mark.parametrize("workspace_opt", [True, False]) +@pytest.mark.parametrize("workspace_opt", [True]) # , False]) @pytest.mark.parametrize("qkv_layout", [None]) @pytest.mark.parametrize("swa", [False]) @pytest.mark.parametrize("pad_between_seqs", [False]) @@ -250,14 +236,16 @@ def test_dot_product_attention( if swa: window_size = [2, 2] config.window_size = check_set_window_size(config.attn_mask_type, window_size) - available_backends, fused_attn_backends = _get_attention_backends( + available_backends = _get_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, window_size=config.window_size, pad_between_seqs=pad_between_seqs, ) - flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends + flash_attn_supported, fused_attn_supported, unfused_attn_supported, fused_attn_backends = ( + available_backends + ) # FlashAttention does not support pad_between_seqs, but _run_dot_product_attention # mannually pads and unpads the input and output of FlashAttention for testing purposes if pad_between_seqs and not ( @@ -679,12 +667,15 @@ def _run_dot_product_attention( reset_rng_states() os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_UNFUSED_ATTN"] = "0" if backend == "FlashAttention": os.environ["NVTE_FLASH_ATTN"] = "1" if backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0" - _attention_backends["backend_selection_requires_update"] = True + if backend == "UnfusedDotProductAttention": + os.environ["NVTE_UNFUSED_ATTN"] = "1" + _attention_backends["update_env_vars_only"] = True # Create seqlens qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) @@ -1045,12 +1036,14 @@ def test_transformer_layer( workspace_opt = True # Test backend availability - available_backends, fused_attn_backends = _get_attention_backends( + available_backends = _get_attention_backends( config, qkv_dtype=dtype, qkv_layout="sbh3d" if fused_qkv_params else "sb3hd", ) - flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends + flash_attn_supported, fused_attn_supported, unfused_attn_supported, fused_attn_backends = ( + available_backends + ) # Skip if only unfused backend is supported if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2: @@ -1168,11 +1161,14 @@ def _run_transformer_layer( reset_rng_states() os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_UNFUSED_ATTN"] = "0" if backend == "FlashAttention": os.environ["NVTE_FLASH_ATTN"] = "1" if backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" - _attention_backends["backend_selection_requires_update"] = True + if backend == "UnfusedDotProductAttention": + os.environ["NVTE_UNFUSED_ATTN"] = "1" + _attention_backends["update_env_vars_only"] = True # Create input tensor inp = torch.randn( @@ -1362,7 +1358,8 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, if _flash_attn_3_is_installed and not is_training: os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "0" - _attention_backends["backend_selection_requires_update"] = True + os.environ["NVTE_UNFUSED_ATTN"] = "0" + _attention_backends["update_selection"] = True logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True") flash_attn_fwd_fp8, param_names, flash_attn_bwd_fp8 = _run_mha_fp8_vs_f16( dtype, config, True, qkv_format, input_layernorm, RoPE, is_training @@ -1370,7 +1367,8 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "1" - _attention_backends["backend_selection_requires_update"] = True + os.environ["NVTE_UNFUSED_ATTN"] = "0" + _attention_backends["update_selection"] = True logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True") fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16( dtype, config, True, qkv_format, input_layernorm, RoPE, is_training @@ -1541,7 +1539,8 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): if _flash_attn_3_is_installed and not is_training: os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "0" - _attention_backends["backend_selection_requires_update"] = True + os.environ["NVTE_UNFUSED_ATTN"] = "0" + _attention_backends["update_selection"] = True logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True") flash_attn_fwd_fp8, flash_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( dtype, config, True, qkv_layout, is_training @@ -1549,7 +1548,8 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "1" - _attention_backends["backend_selection_requires_update"] = True + os.environ["NVTE_UNFUSED_ATTN"] = "0" + _attention_backends["update_selection"] = True logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True") fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( dtype, config, True, qkv_layout, is_training @@ -1753,7 +1753,9 @@ def test_custom_mha_fp8_vs_f16(dtype, model): config = model_configs_fp8[model] fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(dtype, config, "FusedAttention") - unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(dtype, config, "UnfusedAttention") + unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16( + dtype, config, "UnfusedDotProductAttention" + ) atol = 5e-1 rtol = 5e-1 @@ -1784,11 +1786,14 @@ def _run_custom_mha_fp8(dtype, config, backend): reset_rng_states() os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_UNFUSED_ATTN"] = "0" if backend == "FlashAttention": os.environ["NVTE_FLASH_ATTN"] = "1" if backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" - _attention_backends["backend_selection_requires_update"] = True + if backend == "UnfusedDotProductAttention": + os.environ["NVTE_UNFUSED_ATTN"] = "1" + _attention_backends["update_selection"] = True inp = 0.0001 * torch.randint( -100, @@ -1838,11 +1843,14 @@ def _run_ref_mha_f16(dtype, config, backend): os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_UNFUSED_ATTN"] = "0" if backend == "FlashAttention": os.environ["NVTE_FLASH_ATTN"] = "1" if backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" - _attention_backends["backend_selection_requires_update"] = True + if backend == "UnfusedDotProductAttention": + os.environ["NVTE_UNFUSED_ATTN"] = "1" + _attention_backends["update_selection"] = True inp = torch.load("qkv.pt").to(device="cuda") inp.requires_grad = True diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index be0d176520..4966961439 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -206,13 +206,20 @@ def _get_supported_versions(version_min, version_max): _flash_attn_3_0_0_beta = PkgVersion("3.0.0b") < _flash_attn_3_version < PkgVersion("3.0.0") _use_flash_attn_3 = True +# maximum number of configs allowed in _attention_backends cache +_max_configs = 10 +# attention config cache _attention_backends = { - "attention_params": None, - "use_flash_attention": None, - "use_fused_attention": None, - "fused_attention_backend": None, - "use_unfused_attention": None, - "backend_selection_requires_update": False, + # list of configs [attention_params] + "attention_params": [], + # available backends for each cached config + "available_backends": [], + # last used backend for each cached config + "selected_backend": [], + # update selected_backend if only environment variables, e.g. NVTE_FUSED_ATTN, have changed + "update_env_vars_only": False, + # update both available_backends and selected_backend + "update_selection": True, } @@ -366,7 +373,7 @@ def get_attention_backend( fp8 = attention_params.fp8 fp8_meta = attention_params.fp8_meta - # Run config + # Print config logger = logging.getLogger("DotProductAttention") logger.setLevel(_log_level) if not logger.hasHandlers(): @@ -393,22 +400,84 @@ def get_attention_backend( run_config["NVTE_FP8_DPA_BWD"] = int(os.getenv("NVTE_FP8_DPA_BWD", "1")) logger.debug("Running with config=%s", run_config) + # check if attention_params already exists in cache + global _attention_backends + if ( + attention_params in _attention_backends["attention_params"] + and not _attention_backends["update_env_vars_only"] + and not _attention_backends["update_selection"] + ): + config_id = _attention_backends["attention_params"].index(attention_params) + available_backends = _attention_backends["available_backends"][config_id] + selected_backend = _attention_backends["selected_backend"][config_id] + return available_backends, selected_backend + + # if environment variables such as NVTE_FUSED_ATTN change in the middle of the run + def update_env_vars(logger, available_backends): + # Filter: Environment variables + ( + use_flash_attention, + use_fused_attention, + use_unfused_attention, + fused_attention_backends, + ) = available_backends + _use_flash_attention = int(os.getenv("NVTE_FLASH_ATTN", "1")) + _use_fused_attention = int(os.getenv("NVTE_FUSED_ATTN", "1")) + _use_unfused_attention = int(os.getenv("NVTE_UNFUSED_ATTN", "1")) + if use_flash_attention and not _use_flash_attention and _flash_attn_is_installed: + logger.debug("Disabling FlashAttention due to NVTE_FLASH_ATTN=0") + if use_fused_attention and not _use_fused_attention: + logger.debug("Disabling FusedAttention due to NVTE_FUSED_ATTN=0") + if use_unfused_attention and not _use_unfused_attention: + logger.debug("Disabling UnfusedDotProductAttention due to NVTE_UNFUSED_ATTN=0") + use_flash_attention = use_flash_attention and _use_flash_attention + use_fused_attention = use_fused_attention and _use_fused_attention + use_unfused_attention = use_unfused_attention and _use_unfused_attention + fused_attention_backend = None + if use_fused_attention: + if len(fused_attention_backends) == 1: + fused_attention_backend = fused_attention_backends[0] + if len(fused_attention_backends) == 2: + sub_backend = int(os.getenv("NVTE_FUSED_ATTN_BACKEND", "1")) + fused_attention_backend = fused_attention_backends[sub_backend] + selected_backend = [ + use_flash_attention, + use_fused_attention, + use_unfused_attention, + fused_attention_backend, + ] + return selected_backend + + if ( + attention_params in _attention_backends["attention_params"] + and _attention_backends["update_env_vars_only"] + and not _attention_backends["update_selection"] + ): + config_id = _attention_backends["attention_params"].index(attention_params) + available_backends = _attention_backends["available_backends"][config_id] + selected_backend = update_env_vars(logger, available_backends) + _attention_backends["selected_backend"][config_id] = selected_backend + return available_backends, selected_backend + + # keep unique attention_params + if ( + attention_params in _attention_backends["attention_params"] + and _attention_backends["update_selection"] + ): + config_id = _attention_backends["attention_params"].index(attention_params) + _attention_backends["attention_params"].pop(config_id) + _attention_backends["available_backends"].pop(config_id) + _attention_backends["selected_backend"].pop(config_id) + # The following sections check if `FlashAttention` supports the provided attention params, # regardless of whether FA2 or FA3 is installed. If FA2 or FA3 is not installed but is # necessary for performance/functionality, a warning will be issued to prompt users to # install an appropriate FA version. global _flash_attn_version_required, _flash_attn_max_version, _use_flash_attn_3 - # Filter: Environment variables - use_flash_attention = int(os.getenv("NVTE_FLASH_ATTN", "1")) - use_fused_attention = int(os.getenv("NVTE_FUSED_ATTN", "1")) - use_unfused_attention = int(os.getenv("NVTE_UNFUSED_ATTN", "1")) - if not use_flash_attention and _flash_attn_is_installed: - logger.debug("Disabling FlashAttention due to NVTE_FLASH_ATTN=0") - if not use_fused_attention: - logger.debug("Disabling FusedAttention due to NVTE_FUSED_ATTN=0") - if not use_unfused_attention: - logger.debug("Disabling UnfusedDotProductAttention due to NVTE_UNFUSED_ATTN=0") + use_flash_attention = True + use_fused_attention = True + use_unfused_attention = True # Filter: ONNX mode if is_in_onnx_export_mode(): @@ -799,6 +868,7 @@ def get_attention_backend( os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1" # Filter: cuDNN support + fused_attn_avail_backends = [] fused_attention_backend = None if use_fused_attention: q_type = TE_DType[qkv_dtype] @@ -806,22 +876,37 @@ def get_attention_backend( if fp8 and fp8_meta["recipe"].fp8_dpa: q_type = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) kv_type = q_type - fused_attention_backend = tex.get_fused_attn_backend( - q_type, - kv_type, - QKVLayout[qkv_layout], - AttnBiasType[fu_core_attention_bias_type], - AttnMaskType[attn_mask_type], - attention_dropout, - num_heads, - num_gqa_groups, - max_seqlen_q, - max_seqlen_kv, - head_dim_qk, - head_dim_v, - window_size[0], - window_size[1], - ) + + def get_fused_attn_backend(): + fused_attention_backend = tex.get_fused_attn_backend( + q_type, + kv_type, + QKVLayout[qkv_layout], + AttnBiasType[fu_core_attention_bias_type], + AttnMaskType[attn_mask_type], + attention_dropout, + num_heads, + num_gqa_groups, + max_seqlen_q, + max_seqlen_kv, + head_dim_qk, + head_dim_v, + window_size[0], + window_size[1], + ) + return fused_attention_backend + + # all available cuDNN sub-backends + for k, v in FusedAttnBackend.items(): + if k != "No_Backend": + os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(int(v)) + backend = get_fused_attn_backend() + if backend == FusedAttnBackend[k]: + fused_attn_avail_backends.append(backend) + del os.environ["NVTE_FUSED_ATTN_BACKEND"] + # selected cuDNN sub-backend + fused_attention_backend = get_fused_attn_backend() + if fused_attention_backend == FusedAttnBackend["No_Backend"]: logger.debug("Disabling FusedAttention as no backend supports the provided input") use_fused_attention = False @@ -833,7 +918,7 @@ def get_attention_backend( and fused_attention_backend != FusedAttnBackend["F16_arbitrary_seqlen"] ): logger.debug( - "Disabling FusedAttention as only sub-backend %s does not support " + "Disabling FusedAttention as sub-backend %s does not support " "slidng window attention", int(fused_attention_backend), ) @@ -891,7 +976,17 @@ def get_attention_backend( use_fused_attention = False # All available backends - available_backends = [use_flash_attention, use_fused_attention, use_unfused_attention] + available_backends = [ + use_flash_attention, + use_fused_attention, + use_unfused_attention, + fused_attn_avail_backends, + ] + + # Filter: Environment variables + use_flash_attention, use_fused_attention, use_unfused_attention, fused_attention_backend = ( + update_env_vars(logger, available_backends) + ) # `FusedAttention` and `FlashAttention` are faster backends than `UnfusedDotProductAttention`. # When `FusedAttention` does not support the provided attention params, and `FlashAttention` @@ -909,16 +1004,17 @@ def get_attention_backend( use_flash_attention = False available_backends[0] = False + fused_attn_str = "" + if len(fused_attn_avail_backends) > 0: + fused_attn_str = ( + " (sub-backend " + " and ".join([str(int(x)) for x in fused_attn_avail_backends]) + ")" + ) logger.debug( "Available backends = {FlashAttention=%s, FusedAttention=%s%s," " UnfusedDotProductAttention=%s}", bool(available_backends[0]), bool(available_backends[1]), - ( - f" (sub-backend {int(fused_attention_backend)})" - if fused_attention_backend is not None - else "" - ), + fused_attn_str, bool(available_backends[2]), ) @@ -949,31 +1045,37 @@ def get_attention_backend( # Selected backend if use_flash_attention: use_fused_attention = False + fused_attention_backend = [] use_unfused_attention = False elif use_fused_attention: use_unfused_attention = False - selected_backend = "NoBackend" + backend = "NoBackend" if use_flash_attention: - selected_backend = "FlashAttention" + backend = "FlashAttention" elif use_fused_attention: - selected_backend = f"FusedAttention (sub-backend {int(fused_attention_backend)})" + backend = f"FusedAttention (sub-backend {int(fused_attention_backend)})" elif use_unfused_attention: - selected_backend = "UnfusedDotProductAttention" - logger.debug("Selected backend = %s", selected_backend) - - global _attention_backends - _attention_backends["use_flash_attention"] = use_flash_attention - _attention_backends["use_fused_attention"] = use_fused_attention - _attention_backends["fused_attention_backend"] = fused_attention_backend - _attention_backends["use_unfused_attention"] = use_unfused_attention - _attention_backends["backend_selection_requires_update"] = False - - return ( + backend = "UnfusedDotProductAttention" + logger.debug("Selected backend = %s", backend) + selected_backend = [ use_flash_attention, use_fused_attention, - fused_attention_backend, use_unfused_attention, + fused_attention_backend, + ] + + _attention_backends["attention_params"].append(attention_params) + _attention_backends["available_backends"].append(available_backends) + _attention_backends["selected_backend"].append(selected_backend) + if len(_attention_backends["attention_params"]) > _max_configs: + _attention_backends["attention_params"].pop(0) + _attention_backends["available_backends"].pop(0) + _attention_backends["selected_backend"].pop(0) + _attention_backends["update_selection"] = False + + return ( available_backends, + selected_backend, ) @@ -8193,22 +8295,19 @@ def forward( fp8=self.fp8, fp8_meta=self.fp8_meta, ) - global _attention_backends, _use_flash_attn_3 - if ( - _attention_backends["attention_params"] is None - or attention_params != _attention_backends["attention_params"] - ): - _attention_backends["attention_params"] = attention_params - _attention_backends["backend_selection_requires_update"] = True - if _attention_backends["backend_selection_requires_update"]: - _use_flash_attn_3 = _flash_attn_3_is_installed - ( - use_flash_attention, - use_fused_attention, - fused_attention_backend, - use_unfused_attention, - _, - ) = get_attention_backend(attention_params) + is_new_config = ( + attention_params not in _attention_backends["attention_params"] + or _attention_backends["update_env_vars_only"] + or _attention_backends["update_selection"] + ) + _, selected_backend = get_attention_backend(attention_params) + ( + use_flash_attention, + use_fused_attention, + use_unfused_attention, + fused_attention_backend, + ) = selected_backend + if is_new_config: if use_flash_attention: self.logger.info( "Running with FlashAttention backend (version %s)", @@ -8221,11 +8320,6 @@ def forward( ) elif use_unfused_attention: self.logger.info("Running with UnfusedDotProductAttention backend") - else: - use_flash_attention = _attention_backends["use_flash_attention"] - use_fused_attention = _attention_backends["use_fused_attention"] - fused_attention_backend = _attention_backends["fused_attention_backend"] - use_unfused_attention = _attention_backends["use_unfused_attention"] if use_flash_attention: if core_attention_bias_type == "alibi":