diff --git a/src/liger_kernel/transformers/llama_flex_attention.py b/src/liger_kernel/transformers/llama_flex_attention.py deleted file mode 100644 index a01d81425..000000000 --- a/src/liger_kernel/transformers/llama_flex_attention.py +++ /dev/null @@ -1,65 +0,0 @@ -from typing import Optional -from typing import Tuple - -import torch - -from torch.nn.attention.flex_attention import create_block_mask -from torch.nn.attention.flex_attention import flex_attention - -flex_attention = torch.compile(flex_attention) - -# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/flex_attention.py#L12 - - -def flex_attention_forward( - module: torch.nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - scaling: Optional[float] = None, - softcap: Optional[float] = None, - **kwargs, -) -> Tuple[torch.Tensor, torch.Tensor]: - print("using liger_llama_flex_attention_forward..") - causal_mask = attention_mask - if causal_mask is not None: - causal_mask = causal_mask[:, :, :, : key.shape[-2]] - - def causal_mod(score, b, h, q_idx, kv_idx): - if softcap is not None: - score = softcap * torch.tanh(score / softcap) - if causal_mask is not None: - score = score + causal_mask[b][0][q_idx][kv_idx] - return score - - def causal_mask(b, h, q_idx, kv_idx): - return q_idx >= kv_idx - - block_mask = create_block_mask(causal_mask, None, None, query.shape[-2], query.shape[-2], device="cuda") - - attn_output, attention_weights = flex_attention( - query, - key, - value, - score_mod=causal_mod, - block_mask=block_mask, - enable_gqa=True, - scale=scaling, - # Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless. - # For simplification, we thus always return it as no additional computations are introduced. - return_lse=True, - kernel_options={ - "BLOCK_M": 32, - "BLOCK_N": 32, - "BLOCK_M1": 16, - "BLOCK_N1": 32, - "BLOCK_M2": 32, - "BLOCK_N2": 16, - }, - ) - # lse is returned in float32 - attention_weights = attention_weights.to(value.dtype) - attn_output = attn_output.transpose(1, 2).contiguous() - - return attn_output, attention_weights diff --git a/src/liger_kernel/transformers/model/llama.py b/src/liger_kernel/transformers/model/llama.py index e4dde0f55..089d10679 100644 --- a/src/liger_kernel/transformers/model/llama.py +++ b/src/liger_kernel/transformers/model/llama.py @@ -8,6 +8,8 @@ import torch.nn.functional as F from torch.nn import CrossEntropyLoss +from torch.nn.attention.flex_attention import create_block_mask +from torch.nn.attention.flex_attention import flex_attention from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.llama.modeling_llama import _CONFIG_FOR_DOC from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING @@ -19,6 +21,8 @@ if TYPE_CHECKING: from transformers.cache_utils import Cache +flex_attention = torch.compile(flex_attention) + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @@ -249,3 +253,62 @@ def lce_forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + +# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/flex_attention.py#L12 + + +def flex_attention_forward( + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + softcap: Optional[float] = None, + **kwargs, +) -> Tuple[torch.Tensor, torch.Tensor]: + causal_mask = attention_mask + if causal_mask is not None: + causal_mask = causal_mask[:, :, :, : key.shape[-2]] + + def causal_mod(score, b, h, q_idx, kv_idx): + if softcap is not None: + score = softcap * torch.tanh(score / softcap) + if causal_mask is not None: + score = score + causal_mask[b][0][q_idx][kv_idx] + return score + + # We only got `attention_mask` tensors, so we recreate `causal_mask` function as specific llama causal attention + # TODO: Consider other customized `attention_mask` in the future, e.g., shared prefix + def causal_mask_fn(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + # To construct block attention mask that leverages sparsity. + sparse_causal_mask = create_block_mask(causal_mask_fn, None, None, query.shape[-2], query.shape[-2], device="cuda") + + attn_output, attention_weights = flex_attention( + query, + key, + value, + score_mod=causal_mod, + block_mask=sparse_causal_mask, + enable_gqa=True, + scale=scaling, + # Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless. + # For simplification, we thus always return it as no additional computations are introduced. + return_lse=True, + kernel_options={ # different harware might need different configs + "BLOCK_M": 32, + "BLOCK_N": 32, + "BLOCK_M1": 16, + "BLOCK_N1": 32, + "BLOCK_M2": 32, + "BLOCK_N2": 16, + }, + ) + # lse is returned in float32 + attention_weights = attention_weights.to(value.dtype) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attention_weights diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index a89dbcd42..5a924cb25 100644 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -13,11 +13,11 @@ from liger_kernel.transformers.functional import liger_cross_entropy from liger_kernel.transformers.geglu import LigerGEGLUMLP from liger_kernel.transformers.layer_norm import LigerLayerNorm -from liger_kernel.transformers.llama_flex_attention import flex_attention_forward as llama_flex_attention_forward from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward from liger_kernel.transformers.model.gemma import lce_forward_deprecated as gemma_lce_forward_deprecated from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward from liger_kernel.transformers.model.gemma2 import lce_forward_deprecated as gemma2_lce_forward_deprected +from liger_kernel.transformers.model.llama import flex_attention_forward as llama_flex_attention_forward from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward from liger_kernel.transformers.model.llama import lce_forward_deprecated as llama_lce_forward_deprecated from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward @@ -39,6 +39,7 @@ logger = logging.getLogger(__name__) SUPPORTED_TRANSFORMER_VERSION = "4.46.1" TRANSFORMER_DEPRECATION_WARNING = "Support for transformers versions < 4.46.1 will soon be discontinued due to issues with incorrect gradient accumulation. \n Please consider upgrading to avoid potential issues. See details: https://github.com/huggingface/transformers/pull/34191" +FLEX_ATTENTION_NOT_SUPPORT_WARNING = "Not support flex attention for this model yet" def _bind_method_to_module(module, method_name: str, new_method: Callable): @@ -118,7 +119,7 @@ def apply_liger_kernel_to_llama( modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated if flex_attn: - logger.warning("Patched HuggingFace default PyTorch SDPA to liger_flex_attention.") + # Patching HuggingFace default attn_impl from `toch.sdpa` to liger's `llama_flex_attention_forward`` modeling_llama.ALL_ATTENTION_FUNCTIONS.update( {"sdpa": llama_flex_attention_forward, "flex_attention": llama_flex_attention_forward} ) @@ -149,7 +150,7 @@ def apply_liger_kernel_to_mllama( rms_norm: bool = True, swiglu: bool = True, model: PreTrainedModel = None, - flex_attn: bool = False, # Not support yet. + flex_attn: bool = False, # Not support by HuggingFace ) -> None: """ Apply Liger kernels to replace original implementation in HuggingFace MLlama models. @@ -203,6 +204,8 @@ def apply_liger_kernel_to_mllama( else: # if version < 4.46.1 logger.warning(TRANSFORMER_DEPRECATION_WARNING) modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward_deprecated + if flex_attn: + logger.warning(FLEX_ATTENTION_NOT_SUPPORT_WARNING) if model is not None: # The model instance already exists, so we need to additionally patch the @@ -253,7 +256,7 @@ def apply_liger_kernel_to_mistral( rms_norm: bool = True, swiglu: bool = True, model: PreTrainedModel = None, - flex_attn: bool = True, + flex_attn: bool = False, # Not support by Liger yet ) -> None: """ Apply Liger kernels to replace original implementation in HuggingFace Mistral models @@ -288,6 +291,8 @@ def apply_liger_kernel_to_mistral( modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward if swiglu: modeling_mistral.MistralMLP = LigerSwiGLUMLP + if flex_attn: + logger.warning(FLEX_ATTENTION_NOT_SUPPORT_WARNING) if model is not None: # The model instance already exists, so we need to additionally patch the @@ -314,7 +319,7 @@ def apply_liger_kernel_to_mixtral( rms_norm: bool = True, swiglu: bool = True, model: PreTrainedModel = None, - flex_attn: bool = True, + flex_attn: bool = False, # Not support by Liger yet ) -> None: """ Apply Liger kernels to replace original implementation in HuggingFace Mixtral models @@ -360,6 +365,8 @@ def apply_liger_kernel_to_mixtral( modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward_deprecated if swiglu: modeling_mixtral.MixtralBlockSparseTop2MLP = LigerBlockSparseTop2MLP + if flex_attn: + logger.warning(FLEX_ATTENTION_NOT_SUPPORT_WARNING) if model is not None: # The model instance already exists, so we need to additionally patch the @@ -387,7 +394,7 @@ def apply_liger_kernel_to_gemma( rms_norm: bool = True, geglu: bool = True, model: PreTrainedModel = None, - flex_attn: bool = True, + flex_attn: bool = False, # Not support by Liger yet ) -> None: """ Apply Liger kernels to replace original implementation in HuggingFace Gemma @@ -436,6 +443,8 @@ def apply_liger_kernel_to_gemma( else: # if version < 4.46.1 logger.warning(TRANSFORMER_DEPRECATION_WARNING) modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward_deprecated + if flex_attn: + logger.warning(FLEX_ATTENTION_NOT_SUPPORT_WARNING) if model is not None: # The model instance already exists, so we need to additionally patch the @@ -462,7 +471,7 @@ def apply_liger_kernel_to_gemma2( rms_norm: bool = True, geglu: bool = True, model: PreTrainedModel = None, - flex_attn: bool = True, + flex_attn: bool = False, # Not support by Liger yet ) -> None: """ Apply Liger kernels to replace original implementation in HuggingFace Gemma2 @@ -513,6 +522,8 @@ def apply_liger_kernel_to_gemma2( modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward_deprected if geglu: modeling_gemma2.Gemma2MLP = LigerGEGLUMLP + if flex_attn: + logger.warning(FLEX_ATTENTION_NOT_SUPPORT_WARNING) if model is not None: # The model instance already exists, so we need to additionally patch the @@ -541,7 +552,7 @@ def apply_liger_kernel_to_qwen2( rms_norm: bool = True, swiglu: bool = True, model: PreTrainedModel = None, - flex_attn: bool = False, # Not support yet. + flex_attn: bool = False, # Not support by HuggingFace ) -> None: """ Apply Liger kernels to replace original implementation in HuggingFace Qwen2 models @@ -588,6 +599,8 @@ def apply_liger_kernel_to_qwen2( if swiglu: modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP + if flex_attn: + logger.warning(FLEX_ATTENTION_NOT_SUPPORT_WARNING) if model is not None: # The model instance already exists, so we need to additionally patch the @@ -616,7 +629,7 @@ def apply_liger_kernel_to_qwen2_vl( layer_norm: bool = True, swiglu: bool = True, model: PreTrainedModel = None, - flex_attn: bool = False, # Not support yet. + flex_attn: bool = False, # Not support by HuggingFace ) -> None: """ Apply Liger kernels to replace original implementation in HuggingFace Qwen2-VL models. @@ -656,6 +669,8 @@ def apply_liger_kernel_to_qwen2_vl( modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward if swiglu: modeling_qwen2_vl.Qwen2MLP = LigerSwiGLUMLP + if flex_attn: + logger.warning(FLEX_ATTENTION_NOT_SUPPORT_WARNING) if model is not None: # The model instance already exists, so we need to additionally patch the @@ -688,7 +703,7 @@ def apply_liger_kernel_to_phi3( rms_norm: bool = True, swiglu: bool = True, model: PreTrainedModel = None, - flex_attn: bool = True, + flex_attn: bool = False, # Not support by Liger yet ) -> None: """ Apply Liger kernels to replace original implementation in HuggingFace Phi3 models. @@ -732,6 +747,8 @@ def apply_liger_kernel_to_phi3( else: # if version < 4.46.1 logger.warning(TRANSFORMER_DEPRECATION_WARNING) modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward_deprecated + if flex_attn: + logger.warning(FLEX_ATTENTION_NOT_SUPPORT_WARNING) if model is not None: # The model instance already exists, so we need to additionally patch the diff --git a/test/convergence/test_mini_models.py b/test/convergence/test_mini_models.py index e81fb7bfb..b7acff9d2 100644 --- a/test/convergence/test_mini_models.py +++ b/test/convergence/test_mini_models.py @@ -414,6 +414,7 @@ def run_mini_model( kwargs = { "rope": True, "rms_norm": True, + "flex_attn": False, } model_supports_layer_norm = "qwen2_vl" in model_name @@ -428,7 +429,9 @@ def run_mini_model( kwargs["fused_linear_cross_entropy"] = True kwargs["cross_entropy"] = False - kwargs["flex_attn"] = True + model_supports_flex_attn = "llama3" in model_name # excluding mllama + if model_supports_flex_attn: + kwargs["flex_attn"] = True MINI_MODEL_SETUPS[model_name].liger_kernel_patch_func(**kwargs) else: