Skip to content

Commit

Permalink
Fix AutoLigerKernelForCausalLM to pass through original kwargs (#263)
Browse files Browse the repository at this point in the history
## Summary
- Fixes #250 to correctly
pass all original kwargs to .from_pretrained(). Previously we were only
passing args that were part of the model config, but there are
additional valid kwargs beyond that.
- We still need to filter out the kwargs passed into the apply_liger_*
functions, or else will result in model init errors

## Testing Done
Tested on huggingface example with some of the args in
#250

- Hardware Type: A100
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence
  • Loading branch information
shimizust authored Sep 20, 2024
1 parent ce71d59 commit 1289cc4
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 20 deletions.
24 changes: 18 additions & 6 deletions src/liger_kernel/transformers/auto_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import inspect

from transformers import AutoConfig, AutoModelForCausalLM

from liger_kernel.transformers.monkey_patch import _apply_liger_kernel
from liger_kernel.transformers.monkey_patch import (
MODEL_TYPE_TO_APPLY_LIGER_FN,
_apply_liger_kernel,
)


def _get_model_config(model_dir, **model_init_kwargs):
Expand All @@ -21,13 +26,20 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# Determine the model type and apply the Liger Kernel if applicable
# Note: _apply_liger_kernel will only pass relevant kwargs to the apply_liger_kernel_to_* function
model_type = model_config.model_type

_apply_liger_kernel(model_type, **kwargs)

# Retain only the keyword args present in the model configuration
for k in list(kwargs.keys()):
if k not in model_config.__dict__:
del kwargs[k]
# Filter out kwargs that were passed to the apply_liger_* function, which will cause
# model initialization errors otherwise
apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
apply_fn_signature = inspect.signature(apply_fn)

applicable_kwargs = {
key: value
for key, value in kwargs.items()
if key not in apply_fn_signature.parameters
}

return super().from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
pretrained_model_name_or_path, *model_args, **applicable_kwargs
)
19 changes: 5 additions & 14 deletions test/transformers/test_auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,22 @@ def test_auto_liger_kernel_for_causal_lm_from_pretrained():
pretrained_model_name_or_path = "/path/to/llama/model"
model_args = ("model_arg1", "model_arg2")

valid_kwargs = {
original_kwargs = {
"valid_arg_1": "some_value_1",
"valid_arg_2": 10,
}

# This arg should be filtered out as it is not part of the model config
invalid_kwargs = {
"invalid_arg": "another_value",
}

# These args should be passed through to apply_liger_kernel_to_llama fn
apply_liger_kernel_kwargs = {
"rope": False,
"swiglu": True,
}

kwargs = {**valid_kwargs, **invalid_kwargs, **apply_liger_kernel_kwargs}
kwargs = {**original_kwargs, **apply_liger_kernel_kwargs}

# Mock the model config instance returned from AutoConfig.from_pretrained()
mock_model_config = MagicMock()
mock_model_config.__dict__ = {
"model_type": "llama",
"valid_arg_1": "",
"valid_arg_2": 0,
}
mock_model_config.model_type = "llama"
mock_llama = mock.Mock()

with patch.dict(
Expand All @@ -59,8 +50,8 @@ def test_auto_liger_kernel_for_causal_lm_from_pretrained():

# Check that the apply_liger_kernel_to_llama mock was called with the correct kwargs
mock_llama.assert_called_once_with(rope=False, swiglu=True)
# Check that only valid kwargs are passed to super().from_pretrained
# Check that the original kwargs are passed to super().from_pretrained
mock_super_from_pretrained.assert_called_once_with(
pretrained_model_name_or_path, *model_args, **valid_kwargs
pretrained_model_name_or_path, *model_args, **original_kwargs
)
assert model == "mock_model"

0 comments on commit 1289cc4

Please sign in to comment.