Skip to content

Commit

Permalink
re-org
Browse files Browse the repository at this point in the history
Signed-off-by: Austin Liu <[email protected]>
  • Loading branch information
austin362667 committed Jan 27, 2025
1 parent 9f5462d commit 725b256
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 76 deletions.
65 changes: 0 additions & 65 deletions src/liger_kernel/transformers/llama_flex_attention.py

This file was deleted.

63 changes: 63 additions & 0 deletions src/liger_kernel/transformers/model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import torch.nn.functional as F

from torch.nn import CrossEntropyLoss
from torch.nn.attention.flex_attention import create_block_mask
from torch.nn.attention.flex_attention import flex_attention
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.llama.modeling_llama import _CONFIG_FOR_DOC
from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING
Expand All @@ -19,6 +21,8 @@
if TYPE_CHECKING:
from transformers.cache_utils import Cache

flex_attention = torch.compile(flex_attention)


@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
Expand Down Expand Up @@ -249,3 +253,62 @@ def lce_forward(
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/flex_attention.py#L12


def flex_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: Optional[float] = None,
softcap: Optional[float] = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
causal_mask = attention_mask
if causal_mask is not None:
causal_mask = causal_mask[:, :, :, : key.shape[-2]]

def causal_mod(score, b, h, q_idx, kv_idx):
if softcap is not None:
score = softcap * torch.tanh(score / softcap)
if causal_mask is not None:
score = score + causal_mask[b][0][q_idx][kv_idx]
return score

# We only got `attention_mask` tensors, so we recreate `causal_mask` function as specific llama causal attention
# TODO: Consider other customized `attention_mask` in the future, e.g., shared prefix
def causal_mask_fn(b, h, q_idx, kv_idx):
return q_idx >= kv_idx

# To construct block attention mask that leverages sparsity.
sparse_causal_mask = create_block_mask(causal_mask_fn, None, None, query.shape[-2], query.shape[-2], device="cuda")

attn_output, attention_weights = flex_attention(
query,
key,
value,
score_mod=causal_mod,
block_mask=sparse_causal_mask,
enable_gqa=True,
scale=scaling,
# Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
# For simplification, we thus always return it as no additional computations are introduced.
return_lse=True,
kernel_options={ # different harware might need different configs
"BLOCK_M": 32,
"BLOCK_N": 32,
"BLOCK_M1": 16,
"BLOCK_N1": 32,
"BLOCK_M2": 32,
"BLOCK_N2": 16,
},
)
# lse is returned in float32
attention_weights = attention_weights.to(value.dtype)
attn_output = attn_output.transpose(1, 2).contiguous()

return attn_output, attention_weights
37 changes: 27 additions & 10 deletions src/liger_kernel/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.geglu import LigerGEGLUMLP
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.llama_flex_attention import flex_attention_forward as llama_flex_attention_forward
from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
from liger_kernel.transformers.model.gemma import lce_forward_deprecated as gemma_lce_forward_deprecated
from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward
from liger_kernel.transformers.model.gemma2 import lce_forward_deprecated as gemma2_lce_forward_deprected
from liger_kernel.transformers.model.llama import flex_attention_forward as llama_flex_attention_forward
from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
from liger_kernel.transformers.model.llama import lce_forward_deprecated as llama_lce_forward_deprecated
from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
Expand All @@ -39,6 +39,7 @@
logger = logging.getLogger(__name__)
SUPPORTED_TRANSFORMER_VERSION = "4.46.1"
TRANSFORMER_DEPRECATION_WARNING = "Support for transformers versions < 4.46.1 will soon be discontinued due to issues with incorrect gradient accumulation. \n Please consider upgrading to avoid potential issues. See details: https://github.com/huggingface/transformers/pull/34191"
FLEX_ATTENTION_NOT_SUPPORT_WARNING = "Not support flex attention for this model yet"


def _bind_method_to_module(module, method_name: str, new_method: Callable):
Expand Down Expand Up @@ -118,7 +119,7 @@ def apply_liger_kernel_to_llama(
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated

if flex_attn:
logger.warning("Patched HuggingFace default PyTorch SDPA to liger_flex_attention.")
# Patching HuggingFace default attn_impl from `toch.sdpa` to liger's `llama_flex_attention_forward``
modeling_llama.ALL_ATTENTION_FUNCTIONS.update(
{"sdpa": llama_flex_attention_forward, "flex_attention": llama_flex_attention_forward}
)
Expand Down Expand Up @@ -149,7 +150,7 @@ def apply_liger_kernel_to_mllama(
rms_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
flex_attn: bool = False, # Not support yet.
flex_attn: bool = False, # Not support by HuggingFace
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace MLlama models.
Expand Down Expand Up @@ -203,6 +204,8 @@ def apply_liger_kernel_to_mllama(
else: # if version < 4.46.1
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward_deprecated
if flex_attn:
logger.warning(FLEX_ATTENTION_NOT_SUPPORT_WARNING)

if model is not None:
# The model instance already exists, so we need to additionally patch the
Expand Down Expand Up @@ -253,7 +256,7 @@ def apply_liger_kernel_to_mistral(
rms_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
flex_attn: bool = True,
flex_attn: bool = False, # Not support by Liger yet
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Mistral models
Expand Down Expand Up @@ -288,6 +291,8 @@ def apply_liger_kernel_to_mistral(
modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
if swiglu:
modeling_mistral.MistralMLP = LigerSwiGLUMLP
if flex_attn:
logger.warning(FLEX_ATTENTION_NOT_SUPPORT_WARNING)

if model is not None:
# The model instance already exists, so we need to additionally patch the
Expand All @@ -314,7 +319,7 @@ def apply_liger_kernel_to_mixtral(
rms_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
flex_attn: bool = True,
flex_attn: bool = False, # Not support by Liger yet
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Mixtral models
Expand Down Expand Up @@ -360,6 +365,8 @@ def apply_liger_kernel_to_mixtral(
modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward_deprecated
if swiglu:
modeling_mixtral.MixtralBlockSparseTop2MLP = LigerBlockSparseTop2MLP
if flex_attn:
logger.warning(FLEX_ATTENTION_NOT_SUPPORT_WARNING)

if model is not None:
# The model instance already exists, so we need to additionally patch the
Expand Down Expand Up @@ -387,7 +394,7 @@ def apply_liger_kernel_to_gemma(
rms_norm: bool = True,
geglu: bool = True,
model: PreTrainedModel = None,
flex_attn: bool = True,
flex_attn: bool = False, # Not support by Liger yet
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Gemma
Expand Down Expand Up @@ -436,6 +443,8 @@ def apply_liger_kernel_to_gemma(
else: # if version < 4.46.1
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward_deprecated
if flex_attn:
logger.warning(FLEX_ATTENTION_NOT_SUPPORT_WARNING)

if model is not None:
# The model instance already exists, so we need to additionally patch the
Expand All @@ -462,7 +471,7 @@ def apply_liger_kernel_to_gemma2(
rms_norm: bool = True,
geglu: bool = True,
model: PreTrainedModel = None,
flex_attn: bool = True,
flex_attn: bool = False, # Not support by Liger yet
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Gemma2
Expand Down Expand Up @@ -513,6 +522,8 @@ def apply_liger_kernel_to_gemma2(
modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward_deprected
if geglu:
modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
if flex_attn:
logger.warning(FLEX_ATTENTION_NOT_SUPPORT_WARNING)

if model is not None:
# The model instance already exists, so we need to additionally patch the
Expand Down Expand Up @@ -541,7 +552,7 @@ def apply_liger_kernel_to_qwen2(
rms_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
flex_attn: bool = False, # Not support yet.
flex_attn: bool = False, # Not support by HuggingFace
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Qwen2 models
Expand Down Expand Up @@ -588,6 +599,8 @@ def apply_liger_kernel_to_qwen2(

if swiglu:
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
if flex_attn:
logger.warning(FLEX_ATTENTION_NOT_SUPPORT_WARNING)

if model is not None:
# The model instance already exists, so we need to additionally patch the
Expand Down Expand Up @@ -616,7 +629,7 @@ def apply_liger_kernel_to_qwen2_vl(
layer_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
flex_attn: bool = False, # Not support yet.
flex_attn: bool = False, # Not support by HuggingFace
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Qwen2-VL models.
Expand Down Expand Up @@ -656,6 +669,8 @@ def apply_liger_kernel_to_qwen2_vl(
modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward
if swiglu:
modeling_qwen2_vl.Qwen2MLP = LigerSwiGLUMLP
if flex_attn:
logger.warning(FLEX_ATTENTION_NOT_SUPPORT_WARNING)

if model is not None:
# The model instance already exists, so we need to additionally patch the
Expand Down Expand Up @@ -688,7 +703,7 @@ def apply_liger_kernel_to_phi3(
rms_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
flex_attn: bool = True,
flex_attn: bool = False, # Not support by Liger yet
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Phi3 models.
Expand Down Expand Up @@ -732,6 +747,8 @@ def apply_liger_kernel_to_phi3(
else: # if version < 4.46.1
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward_deprecated
if flex_attn:
logger.warning(FLEX_ATTENTION_NOT_SUPPORT_WARNING)

if model is not None:
# The model instance already exists, so we need to additionally patch the
Expand Down
5 changes: 4 additions & 1 deletion test/convergence/test_mini_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ def run_mini_model(
kwargs = {
"rope": True,
"rms_norm": True,
"flex_attn": False,
}

model_supports_layer_norm = "qwen2_vl" in model_name
Expand All @@ -428,7 +429,9 @@ def run_mini_model(
kwargs["fused_linear_cross_entropy"] = True
kwargs["cross_entropy"] = False

kwargs["flex_attn"] = True
model_supports_flex_attn = "llama3" in model_name # excluding mllama
if model_supports_flex_attn:
kwargs["flex_attn"] = True

MINI_MODEL_SETUPS[model_name].liger_kernel_patch_func(**kwargs)
else:
Expand Down

0 comments on commit 725b256

Please sign in to comment.