-
Notifications
You must be signed in to change notification settings - Fork 252
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
README: ensure modeling code is patched before model instantiation #170
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Before or after both work because it monkey patches the HF module directly. It is more straightforward to apply afterwards.
Unfortunately you are mistaken. import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
import liger_kernel.transformers
model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
config = AutoConfig.from_pretrained(
model_name,
attn_implementation='flash_attention_2',
use_cache=False,
trust_remote_code=True,
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
config=config,
)
liger_kernel.transformers.apply_liger_kernel_to_llama()
print(model.model.norm.__class__)
print(model.model.norm.__class__ == liger_kernel.transformers.rms_norm.LigerRMSNorm)
print("\n")
model2 = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
config=config,
)
print(model2.model.norm.__class__)
print(model2.model.norm.__class__ == liger_kernel.transformers.rms_norm.LigerRMSNorm)
Basically, if any code has already run which calls i.e. |
oh interesting! i vaguely remembered i tried this before. will take another look. apologize for the overlook |
You're right, thanks for pointing that out @tmm1 ! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wait, both should work fine. i think it is because even though the object is llama rms, but it is actually called to liger's
let me verify more |
you can check |
@tmm1 Just to follow up, the module classes won't get patched if you apply liger post model-init. However, patching model functions (e.g. Thanks again for pointing it out! |
Related, is this test incorrect? We patch mini_llama3 in one test (parameter 0) and then we run the model as if the patch has not been applied in the second test (parameter 1). So it seems like the bfloat16 tests are comparing a patched liger model against itself. |
this is an essential finding!! Thanks @tmm1 for bearing with by overlook lol |
@tyler-romero you are right! i run only bf16 and they failed. look like the tol is too tight. we need some intelligent way to unset the patching |
Thanks all! Happy to help. The incomplete patching issue is quite unintuitive and has tripped me up a number of times in the past. |
Summary
Fixes example in README to make it functional
Testing Done
make test
to ensure correctnessmake checkstyle
to ensure code stylemake test-convergence
to ensure convergence