Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Qwen models #21

Open
gsarti opened this issue Jan 23, 2025 · 6 comments
Open

Support for Qwen models #21

gsarti opened this issue Jan 23, 2025 · 6 comments

Comments

@gsarti
Copy link

gsarti commented Jan 23, 2025

Hi @rachtibat,

I was wondering if you'd consider adding support for Qwen 2 models in LXT!

Judging by the modular implementation in HF, the differences should be minimal (no bias in MLP, sliding window attention). It would be very cool to have this, as it would enable attribution on the capable Qwen 2.5 family and the R1 reasoning models derived from it!

@samlewislim
Copy link

samlewislim commented Feb 4, 2025

Hi @gsarti,

I had the same thought and had a go at it today. I mostly just adapted the existing LLama implementation. I am not sure if it is working correctly though as the heatmap from the example doesn't seem to be showing very relevant tokens (see attached pdf). @rachtibat if you get some time it would be great if you could see if this looks like how you might approach adding the Qwen 2 models :).

R1-Distill-Qwen-1.5B_heatmap.pdf

My Fork:
Model file: https://github.com/samlewislim/LRP-eXplains-Transformers/blob/main/lxt/models/qwen2.py
Example: https://github.com/samlewislim/LRP-eXplains-Transformers/blob/main/examples/qwen.py

@rachtibat
Copy link
Owner

rachtibat commented Feb 5, 2025

Hey guys,

thank you alot @samlewislim for providing the implementation! The heatmap seems off? It is not looking at the correct token '320', right?

Good news!
I'll update the repo end of month with an even better LXT version that is much faster and easier to use (just changing 4-5 lines of code to adapt any model on the face of the earth!!!).
I'd like to split LXT into the original paper version (for reproducibility) and a more efficient and easier to use version.

For this, I'd like to add your contribution @samlewislim first of course! (:

@gsarti
Copy link
Author

gsarti commented Feb 5, 2025

Great news @rachtibat, looking forward to the new implementation!

@samlewislim
Copy link

samlewislim commented Feb 5, 2025

Hey both,

That is great news, looking forward to the new implementation too!

Yeah I thought it seemed off as it wasn't looking at the '320' token and other models of similar size (e.g. LLama 3.2 1B or tiny llama) seemed to look at this token. But this may just be because of model difference??

@rachtibat would you like me to create a PR for Qwen?

@rachtibat
Copy link
Owner

@samlewislim I think, there might be an error in the Qwen implementation.
I had a quick look at your implementation and found two errors:

  1. The softmax module should be an attribute of the module not created on-the-fly during the forward pass
    https://github.com/samlewislim/LRP-eXplains-Transformers/blob/5250f7be3a5f542b6ec07ca2d8b38bb9787c1e56/lxt/models/qwen2.py#L183
  2. You need to replace all LayerNorm at e.g. https://github.com/samlewislim/LRP-eXplains-Transformers/blob/5250f7be3a5f542b6ec07ca2d8b38bb9787c1e56/lxt/models/qwen2.py#L294 with lxt.modules.RMSNormIdentity like in LLaMa https://github.com/samlewislim/LRP-eXplains-Transformers/blob/5250f7be3a5f542b6ec07ca2d8b38bb9787c1e56/lxt/models/llama.py#L66C25-L66C40 and https://github.com/samlewislim/LRP-eXplains-Transformers/blob/5250f7be3a5f542b6ec07ca2d8b38bb9787c1e56/lxt/models/llama.py#L427

When it works, you can create a pull-request.

Best greetings

@samlewislim
Copy link

@rachtibat thank you so much for taking a look! I have made those fixes and just made a PR. The heatmap is also looking correct now.

qwen_updated_heatmap.pdf

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants