Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

README: ensure modeling code is patched before model instantiation #170

Merged
merged 3 commits into from
Aug 30, 2024

Conversation

tmm1
Copy link
Contributor

@tmm1 tmm1 commented Aug 29, 2024

Summary

Fixes example in README to make it functional

Testing Done

  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

Copy link
Collaborator

@ByronHsu ByronHsu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before or after both work because it monkey patches the HF module directly. It is more straightforward to apply afterwards.

@ByronHsu ByronHsu closed this Aug 29, 2024
@tmm1
Copy link
Contributor Author

tmm1 commented Aug 29, 2024

Unfortunately you are mistaken.

import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
import liger_kernel.transformers

model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
config = AutoConfig.from_pretrained(
    model_name,
    attn_implementation='flash_attention_2',
    use_cache=False,
    trust_remote_code=True,
)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
    config=config,
)

liger_kernel.transformers.apply_liger_kernel_to_llama()

print(model.model.norm.__class__)
print(model.model.norm.__class__ == liger_kernel.transformers.rms_norm.LigerRMSNorm)
print("\n")


model2 = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
    config=config,
)
print(model2.model.norm.__class__)
print(model2.model.norm.__class__ == liger_kernel.transformers.rms_norm.LigerRMSNorm)
<class 'transformers.models.llama.modeling_llama.LlamaRMSNorm'>
False

<class 'liger_kernel.transformers.rms_norm.LigerRMSNorm'>
True

Basically, if any code has already run which calls i.e. self.norm = LlamaRMSNorm(...), then it doesn't matter that you swapped out modeling_llama.LlamaRMSNorm afterwards because the old object was already instantiated and its old methods will be used.

@ByronHsu
Copy link
Collaborator

oh interesting! i vaguely remembered i tried this before. will take another look. apologize for the overlook

@shimizust
Copy link
Collaborator

You're right, thanks for pointing that out @tmm1 !

Copy link
Collaborator

@ByronHsu ByronHsu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait, both should work fine. i think it is because even though the object is llama rms, but it is actually called to liger's

@ByronHsu
Copy link
Collaborator

let me verify more

@tmm1
Copy link
Contributor Author

tmm1 commented Aug 30, 2024

you can check model2.model.norm.forward and other methods, they are not the liger versions

@shimizust
Copy link
Collaborator

@tmm1 Just to follow up, the module classes won't get patched if you apply liger post model-init. However, patching model functions (e.g. modeling_mistral.apply_rotary_pos_emb) or a class method (e.g. modeling_mistral.MistralForCausalLM.forward) do get applied, so it is incomplete patching.

Thanks again for pointing it out!

@shimizust shimizust enabled auto-merge (squash) August 30, 2024 01:42
@tyler-romero
Copy link
Collaborator

Related, is this test incorrect?
https://github.com/tyler-romero/Liger-Kernel/blob/main/test/convergence/test_mini_models_no_logits.py#L324

We patch mini_llama3 in one test (parameter 0) and then we run the model as if the patch has not been applied in the second test (parameter 1). So it seems like the bfloat16 tests are comparing a patched liger model against itself.

@ByronHsu
Copy link
Collaborator

this is an essential finding!! Thanks @tmm1 for bearing with by overlook lol

@ByronHsu
Copy link
Collaborator

@tyler-romero you are right! i run only bf16 and they failed. look like the tol is too tight. we need some intelligent way to unset the patching

@shimizust shimizust merged commit 80d6c0f into linkedin:main Aug 30, 2024
2 checks passed
@tmm1
Copy link
Contributor Author

tmm1 commented Aug 30, 2024

Thanks all! Happy to help. The incomplete patching issue is quite unintuitive and has tripped me up a number of times in the past.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants