Skip to content

Commit

Permalink
fix bw compatibility issues
Browse files Browse the repository at this point in the history
Signed-off-by: Austin Liu <[email protected]>
  • Loading branch information
austin362667 committed Jan 28, 2025
1 parent c350396 commit 7d342a6
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions src/liger_kernel/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@

logger = logging.getLogger(__name__)
SUPPORTED_TRANSFORMER_VERSION = "4.46.1"
FLEXATTENTION_SUPPORTED_TRANSFORMER_VERSION = "4.48.0"
TRANSFORMER_DEPRECATION_WARNING = "Support for transformers versions < 4.46.1 will soon be discontinued due to issues with incorrect gradient accumulation. \n Please consider upgrading to avoid potential issues. See details: https://github.com/huggingface/transformers/pull/34191"
FLEX_ATTENTION_NOT_SUPPORT_WARNING = "Not support flex attention for this model yet"
FLEX_ATTENTION_NOT_SUPPORT_WARNING = "Flex attention is not supported."


def _bind_method_to_module(module, method_name: str, new_method: Callable):
Expand Down Expand Up @@ -120,9 +121,12 @@ def apply_liger_kernel_to_llama(

if flex_attn:
# Patching HuggingFace default attn_impl from `toch.sdpa` to liger's `llama_flex_attention_forward``
modeling_llama.ALL_ATTENTION_FUNCTIONS.update(
{"sdpa": llama_flex_attention_forward, "flex_attention": llama_flex_attention_forward}
)
if transformer_version >= version.parse(FLEXATTENTION_SUPPORTED_TRANSFORMER_VERSION):
modeling_llama.ALL_ATTENTION_FUNCTIONS.update(
{"sdpa": llama_flex_attention_forward, "flex_attention": llama_flex_attention_forward}
)
else:
logger.warning(FLEX_ATTENTION_NOT_SUPPORT_WARNING)

if model is not None:
# The model instance already exists, so we need to additionally patch the
Expand Down

0 comments on commit 7d342a6

Please sign in to comment.