Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for phi4 #764

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
version: '3.7'

services:
tgi-1:
image: neuronx-tgi:latest
ports:
- "8081:8081"
environment:
- PORT=8081
- MODEL_ID=${MODEL_ID}
- HF_AUTO_CAST_TYPE=${HF_AUTO_CAST_TYPE}
- HF_NUM_CORES=8
- MAX_BATCH_SIZE=${MAX_BATCH_SIZE}
- MAX_INPUT_TOKENS=${MAX_INPUT_TOKENS}
- MAX_TOTAL_TOKENS=${MAX_TOTAL_TOKENS}
- MAX_CONCURRENT_REQUESTS=512
- HF_TOKEN=${HF_TOKEN}
devices:
- "/dev/neuron0"
- "/dev/neuron1"
- "/dev/neuron2"
- "/dev/neuron3"

tgi-2:
image: neuronx-tgi:latest
ports:
- "8082:8082"
environment:
- PORT=8082
- MODEL_ID=${MODEL_ID}
- HF_AUTO_CAST_TYPE=${HF_AUTO_CAST_TYPE}
- HF_NUM_CORES=8
- MAX_BATCH_SIZE=${MAX_BATCH_SIZE}
- MAX_INPUT_TOKENS=${MAX_INPUT_TOKENS}
- MAX_TOTAL_TOKENS=${MAX_TOTAL_TOKENS}
- MAX_CONCURRENT_REQUESTS=512
- HF_TOKEN=${HF_TOKEN}
devices:
- "/dev/neuron4"
- "/dev/neuron5"
- "/dev/neuron6"
- "/dev/neuron7"

tgi-3:
image: neuronx-tgi:latest
ports:
- "8083:8083"
environment:
- PORT=8083
- MODEL_ID=${MODEL_ID}
- HF_AUTO_CAST_TYPE=${HF_AUTO_CAST_TYPE}
- HF_NUM_CORES=8
- MAX_BATCH_SIZE=${MAX_BATCH_SIZE}
- MAX_INPUT_TOKENS=${MAX_INPUT_TOKENS}
- MAX_TOTAL_TOKENS=${MAX_TOTAL_TOKENS}
- MAX_CONCURRENT_REQUESTS=512
- HF_TOKEN=${HF_TOKEN}
devices:
- "/dev/neuron8"
- "/dev/neuron9"
- "/dev/neuron10"
- "/dev/neuron11"

loadbalancer:
image: nginx:alpine
ports:
- "8080:80"
volumes:
- ./nginx.conf:/etc/nginx/nginx.conf:ro
depends_on:
- tgi-1
- tgi-2
- tgi-3
deploy:
placement:
constraints: [node.role == manager]
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
### Nginx TGI Load Balancer
events {}
http {
upstream tgicluster {
server tgi-1:8081;
server tgi-2:8082;
server tgi-3:8083;
}
server {
listen 80;
location / {
proxy_pass http://tgicluster;
}
}
}
9 changes: 8 additions & 1 deletion optimum/exporters/neuron/model_configs/decoder_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Neuron export configurations for decoder models."""
"""Neuron export configurations for models using transformers_neuronx."""

import importlib

Expand All @@ -25,6 +25,7 @@
from ....neuron.backends.hlo.decoder import NeuronHloDecoderModel
from ....neuron.models.granite.model import GraniteForSampling
from ....neuron.models.llama.model import LlamaHloModel
from ....neuron.models.phi4.model import Phi4ForSampling
from ....neuron.models.qwen2.model import Qwen2ForSampling
from ..base import NeuronExportConfig

Expand Down Expand Up @@ -166,3 +167,9 @@ class Qwen2NeuronConfig(NeuronDecoderExportConfig):
class GraniteNeuronConfig(NeuronDecoderExportConfig):
NEURONX_CLASS = GraniteForSampling
CONTINUOUS_BATCHING = True


@register_in_tasks_manager("phi3", "text-generation")
class Phi4NeuronConfig(NeuronDecoderExportConfig):
NEURONX_CLASS = Phi4ForSampling
CONTINUOUS_BATCHING = True
14 changes: 14 additions & 0 deletions optimum/neuron/models/phi4/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
96 changes: 96 additions & 0 deletions optimum/neuron/models/phi4/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright Amazon Web Services and its Affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import torch
from transformers import PretrainedConfig

from ...backends.hlo.config import NeuronConfig
from ...backends.hlo.dtypes import to_torch_dtype
from ..llama.model import LlamaHloModel
from .modules import Phi4ForCausalLM


class Phi4ForSampling(LlamaHloModel):
"""The Phi4 model is essentially a LLama model with fused qkv and gate_up projections.

The implementation in this class is very similar to the one used for Llama in Tnx.
The only difference is that the fused qkv and gate/up linear projection are split when
loading weights (note that they might be fused again when transferring the weights to the
neuron device if the NeuronConfig specifies it).
"""

def __init__(
self,
config: PretrainedConfig,
neuron_config: NeuronConfig,
):
dtype = to_torch_dtype(neuron_config.amp)
super().__init__(config, neuron_config, cpu_model=Phi4ForCausalLM(config, dtype))

def load_weights(self):
# Materialize the embedding to CPU
self.cpu_model.model.embed_tokens.materialize()

for layer in self.cpu_model.model.layers:
layer.materialize()
attn = layer.self_attn
mlp = layer.mlp
new_layer = self.decoder_lm_head.new_layer()
new_layer.add_pre_attention_layer_norm(layer.input_layernorm.weight.detach(), None)
# Transpose and split fused qkv_proj into separate weights
fused_attn = attn.qkv_proj.weight.clone().detach().T
# Handle GQA
if self.config.num_kv_heads < self.config.num_attention_heads:
# Extract the larger query weights first
q_features = attn.num_heads * attn.head_dim
q_weight = fused_attn[:, :q_features]
# Then split the remaining into key and value weights
k_weight, v_weight = torch.chunk(fused_attn[:, q_features:], 2, dim=1)
# Handle MHA
else:
q_weight, k_weight, v_weight = torch.chunk(fused_attn, 3, dim=1)
new_layer.add_attention_query(q_weight, None)
new_layer.add_attention_key(k_weight, None)
new_layer.add_attention_value(v_weight, None)
if self.neuron_config and self.neuron_config.attn_output_transposed:
new_layer.add_attention_output(attn.o_proj.weight.T.detach(), None, sharding=0, transposed=True)
else:
new_layer.add_attention_output(attn.o_proj.weight.detach(), None, sharding=1, transposed=False)

new_layer.add_pre_mlp_layer_norm(layer.post_attention_layernorm.weight.detach(), None)
# Tanspose and split fused mlp into separate weights
fused_gate_up = mlp.gate_up_proj.weight.clone().detach().T
gate, up = torch.chunk(fused_gate_up, 2, dim=1)
new_layer.add_parameter(gate, sharding=1, allow_transform=True)
new_layer.add_parameter(up, sharding=1, allow_transform=True)
new_layer.add_parameter(mlp.down_proj.weight, sharding=1)
new_layer.to_neuron()
layer.nullify()

ln_f = self.cpu_model.model.norm
ln_f.materialize()
self.decoder_lm_head.add_final_layer_norm(ln_f.weight.detach(), None)
ln_f.nullify()

lm_head = self.cpu_model.lm_head
lm_head.materialize()
self.decoder_lm_head.add_lm_head(lm_head.weight.detach().T)
lm_head.nullify()

self.decoder_lm_head.to_neuron()
self.decoder_lm_head.use_executor = True

self.decoder_lm_head_for_context.load_shared_weights(self.decoder_lm_head)
self.decoder_lm_head_for_context.use_executor = True
73 changes: 73 additions & 0 deletions optimum/neuron/models/phi4/modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright Amazon Web Services and its Affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from transformers.models.phi import PhiConfig

from ...backends.hlo import module


class Phi4ForCausalLM(module.PretrainedModel):
def __init__(self, config: PhiConfig, dtype):
super().__init__()
self.model = Phi4Model(config, dtype)
self.lm_head = module.LowMemoryLazyLinear(config.vocab_size, dtype=dtype, bias=False)

def get_tied_parameters(self):
return [(self.model.embed_tokens.weight, self.lm_head.weight)]

def get_base_model(self):
return self.model


class Phi4Model(module.LowMemoryModule):
def __init__(self, config: PhiConfig, dtype):
super().__init__()
self.embed_tokens = module.LowMemoryEmbedding(config.vocab_size, config.hidden_size)
self.layers = module.LowMemoryModuleList(
[Phi4DecoderLayer(config, dtype) for _ in range(config.num_hidden_layers)]
)
self.norm = Phi4RMSNorm()


class Phi4RMSNorm(module.LowMemoryModule):
def __init__(self) -> None:
super().__init__()
self.weight = module.UninitializedParameter()


class Phi4DecoderLayer(module.LowMemoryModule):
def __init__(self, config: PhiConfig, dtype):
super().__init__()
self.self_attn = Phi4Attention(config, dtype)
self.mlp = Phi4MLP(config, dtype)
self.input_layernorm = Phi4RMSNorm()
self.post_attention_layernorm = Phi4RMSNorm()


class Phi4Attention(module.LowMemoryModule):
def __init__(self, config: PhiConfig, dtype):
super().__init__()
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
op_size = config.num_attention_heads * self.head_dim + 2 * (config.num_key_value_heads * self.head_dim)
self.qkv_proj = module.LowMemoryLazyLinear(op_size, bias=False, dtype=dtype)
self.o_proj = module.LowMemoryLazyLinear(self.hidden_size, bias=False, dtype=dtype)


class Phi4MLP(module.LowMemoryModule):
def __init__(self, config, dtype):
super().__init__()
self.gate_up_proj = module.LowMemoryLazyLinear(2 * config.intermediate_size, bias=False, dtype=dtype)
self.down_proj = module.LowMemoryLazyLinear(config.hidden_size, bias=False, dtype=dtype)
4 changes: 4 additions & 0 deletions tests/decoder/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@
"model_id": "dacorvo/Mixtral-tiny",
"export_kwargs": {"batch_size": 4, "sequence_length": 1024, "num_cores": 2, "auto_cast_type": "fp16"},
},
"phi4": {
"model_id": "microsoft/phi-4",
"export_kwargs": {"batch_size": 4, "sequence_length": 4096, "num_cores": 2, "auto_cast_type": "bf16"},
},
}


Expand Down
3 changes: 2 additions & 1 deletion tests/decoder/test_decoder_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from optimum.neuron.utils.testing_utils import is_inferentia_test, requires_neuronx


DECODER_MODEL_ARCHITECTURES = ["bloom", "gpt2", "llama", "mistral", "mixtral", "opt"]
DECODER_MODEL_ARCHITECTURES = ["bloom", "gpt2", "llama", "mistral", "mixtral", "opt", "phi3"]
DECODER_MODEL_NAMES = {
"bloom": "hf-internal-testing/tiny-random-BloomForCausalLM",
"gpt2": "hf-internal-testing/tiny-random-gpt2",
Expand All @@ -32,6 +32,7 @@
"opt": "hf-internal-testing/tiny-random-OPTForCausalLM",
"qwen2": "yujiepan/qwen2.5-128k-tiny-random",
"granite": "hf-internal-testing/tiny-random-GraniteForCausalLM",
"phi3": "yujiepan/phi-4-tiny-random",
}


Expand Down