Skip to content

Commit

Permalink
handle MHA and GQA
Browse files Browse the repository at this point in the history
  • Loading branch information
jlonge4 committed Mar 2, 2025
1 parent d1e1954 commit c794cc9
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions optimum/neuron/models/phi4/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,16 @@ def load_weights(self):
new_layer.add_pre_attention_layer_norm(layer.input_layernorm.weight.detach(), None)
# Transpose and split fused qkv_proj into separate weights
fused_attn = attn.qkv_proj.weight.clone().detach().T
# Extract the larger query weights first
q_features = attn.num_heads * attn.head_dim
q_weight = fused_attn[:, :q_features]
# Then split the remaining into key and value weights
k_weight, v_weight = torch.chunk(fused_attn[:, q_features:], 2, dim=1)
# Handle GQA
if self.config.num_kv_heads < self.config.num_attention_heads:
# Extract the larger query weights first
q_features = attn.num_heads * attn.head_dim
q_weight = fused_attn[:, :q_features]
# Then split the remaining into key and value weights
k_weight, v_weight = torch.chunk(fused_attn[:, q_features:], 2, dim=1)
# Handle MHA
else:
q_weight, k_weight, v_weight = torch.chunk(fused_attn, 3, dim=1)
new_layer.add_attention_query(q_weight, None)
new_layer.add_attention_key(k_weight, None)
new_layer.add_attention_value(v_weight, None)
Expand Down

0 comments on commit c794cc9

Please sign in to comment.