Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for phi4 #764

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

feat: Add support for phi4 #764

wants to merge 2 commits into from

Conversation

jlonge4
Copy link

@jlonge4 jlonge4 commented Jan 18, 2025

This PR adds support for Meta's Phi-4 model by adapting the existing LLaMA implementation.

The Phi-4 architecture follows the LLaMA architecture closely, with the main difference being in how the weights are stored (fused qkv_proj and gate_up vs separate projections).

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@dacorvo
Copy link
Collaborator

dacorvo commented Feb 4, 2025

@jlonge4 thank you very much for this pull-request: adding support for phi4 would be awesome.

We are however heavily refactoring the export mechanism to remove the dependency to transformers-neuronx and simplify the contribution of new models.

Can you take a look at that pull-request and see if it would make it easier for you to add support for phi4 based on the new HLO backend ?

@jlonge4
Copy link
Author

jlonge4 commented Feb 5, 2025

Hi there @dacorvo , just took a look at the difference and it certainly seems a lot slimmer! I think my effort would be the same in regard to the most important part for this which is the load_weights function. However it would obviously get rid of a lot of boiler plate. I am down to rework this PR and merge into add_hlo branch if you prefer.

@dacorvo
Copy link
Collaborator

dacorvo commented Feb 11, 2025

@jlonge4 the pull-request has been merged. Please let me know if you need any help rebasing your branch.

@jlonge4
Copy link
Author

jlonge4 commented Feb 14, 2025

@dacorvo Hopefully the last commit is pretty close

@dacorvo
Copy link
Collaborator

dacorvo commented Feb 24, 2025

@jlonge4 I rebased and squashed your branch into a new phi4 branch, then did a few tests. The load_weights method required some important changes, but eventually I managed to export and run microsoft/phi4.

I think the most efficient is that you reset your branch using mine locally (assuming here you have an "upstream" repo pointing to the main optimum-neuron repository:

git fetch upstream
git reset --hard upstream/phi4

Then you can resume the work on phi4 to add more tests (I only kept the export unit tests since TGI has now moved to the upstream repository).

@jlonge4
Copy link
Author

jlonge4 commented Feb 24, 2025

Just for my knowledge, what all specifically had to be done further to the load weights func? Any specific tests you'd like added from here?

@dacorvo
Copy link
Collaborator

dacorvo commented Feb 25, 2025

Just for my knowledge, what all specifically had to be done further to the load weights func? Any specific tests you'd like added from here?

Main changes:

  • unfused weights had to be transposed (like other weights),
  • q,k,v splits are not even because of MQA (q gets a bigger chunk),
  • many many other changes related to my refactoring (config parameters and attributes removed)

@dacorvo
Copy link
Collaborator

dacorvo commented Feb 25, 2025

@jlonge4 I think it would be good to add a phi3 config to the generation tests here:

DECODER_MODEL_CONFIGURATIONS = {

There are not many small models available, but you can use microsoft/Phi-3-mini-4k-instruct for instance.

@jlonge4
Copy link
Author

jlonge4 commented Feb 25, 2025

@dacorvo got it! Btw so that we can effectively handle both, do you think we should add a check for attn_heads == kv_heads since phi3/3.5 mini models don't use GQA?

@dacorvo
Copy link
Collaborator

dacorvo commented Feb 25, 2025

@dacorvo got it! Btw so that we can effectively handle both, do you think we should add a check for attn_heads == kv_heads since phi3/3.5 mini models don't use GQA?

I think it is covered by the calculation I used (stolen from transformers 😉).

@dacorvo
Copy link
Collaborator

dacorvo commented Feb 26, 2025

@jlonge4 I realize you already added a phi4 test configuration with microsoft/phi-4. This model is too big for our unit tests IMHO so I tried to use microsoft/Phi-3-mini-4k-instruct instead but the results are garbage.

As a sanity test I checked the results of microsoft/phi-4 on gsm8k on CUDA an neuron, and there is a significant drop in accuracy.

Here is the CUDA result:

image

And the neuron one:

image

So now I am wondering if we did not miss something in the modeling code that differs from llama.

@jlonge4
Copy link
Author

jlonge4 commented Feb 26, 2025

@dacorvo hmm, so strange. I was able to make phi3.5-mini-instruct and phi-4 work with my NXDI integration, and the only difference was the load_state_dict here.

NXDI logit matching results:

Expected Output:  [", that is the question:\nWhether 'tis nobler in the mind to suffer\nThe slings and arrows of outrageous fortune,\nOr to take arms against a sea of troubles\nAnd by opposing end them. To die, to sleep;\nNo"] tensor([[29892,   393,   338,   278,  1139, 29901,    13,  8809,  1979,   525,
         28898, 22182,  1358,   297,   278,  3458,   304,  8812,    13,  1576,
          2243,   886,   322,   564,  5727,   310,   714,  6617,   681, 19717,
         29892,    13,  2816,   304,  2125, 10188,  2750,   263,  7205,   310,
         18835,    13,  2855,   491,  9209,   292,  1095,   963, 29889,  1763,
           762, 29892,   304,  8709, 29936,    13,  3782]])
Expected Logits Shape:  torch.Size([57, 1, 32064])
2025-Jan-10 01:04:05.0778 8296:9764 [0] nccl_net_ofi_rdma_init:7734 CCOM WARN NET/OFI OFI fi_getinfo() call failed: No data available
2025-Jan-10 01:04:05.0829 8296:9764 [0] nccl_net_ofi_create_plugin:251 CCOM WARN NET/OFI Unable to find a protocol that worked.  Failing initialization.
2025-Jan-10 01:04:05.0880 8296:9764 [0] nccl_net_ofi_create_plugin:316 CCOM WARN NET/OFI aws-ofi-nccl initialization failed
2025-Jan-10 01:04:05.0922 8296:9764 [0] nccl_net_ofi_init:139 CCOM WARN NET/OFI Initializing plugin failed
2025-Jan-10 01:04:05.0959 8296:9764 [0] net_plugin.cc:94 CCOM WARN OFI plugin initNet() failed is EFA enabled?
Actual Output:  [", that is the question:\nWhether 'tis nobler in the mind to suffer\nThe slings and arrows of outrageous fortune,\nOr to take arms against a sea of troubles\nAnd by opposing end them. To die, to sleep;\nNo"] tensor([[29892,   393,   338,   278,  1139, 29901,    13,  8809,  1979,   525,
         28898, 22182,  1358,   297,   278,  3458,   304,  8812,    13,  1576,
          2243,   886,   322,   564,  5727,   310,   714,  6617,   681, 19717,
         29892,    13,  2816,   304,  2125, 10188,  2750,   263,  7205,   310,
         18835,    13,  2855,   491,  9209,   292,  1095,   963, 29889,  1763,
           762, 29892,   304,  8709, 29936,    13,  3782]])
Actual Logits Shape:  torch.Size([57, 1, 32064])
Passed logits validation

@dacorvo
Copy link
Collaborator

dacorvo commented Feb 27, 2025

@jlonge4 I did a lot of tests comapring with GPU versions and:

  • the outputs of microsoft/phi4 are almost identical when doing greedy (at least the 256 first tokens), so the degradation of the gsm8k results might be due to some changes in the way I tested (it was through TGI on neuron, and HF directly on GPU, so maybe I am missing a system prompt somewhere). I nevertheless have suspicions that the rope code is not entirely identical, which would explain why both results on gpu and neuron deviate eventually,
  • the outputs of microsoft/Phi-3-mini-4k-instruct are not complete garbage: they just don't make sense. This is probably related to the fact that it uses a standard MQA (that's the only difference), which means I might have a bug in the modeling code for that configuration.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants