diff --git a/benchmark/text-generation-inference/performance/phi4-14b/docker-compose.yaml b/benchmark/text-generation-inference/performance/phi4-14b/docker-compose.yaml new file mode 100644 index 000000000..960e93ea8 --- /dev/null +++ b/benchmark/text-generation-inference/performance/phi4-14b/docker-compose.yaml @@ -0,0 +1,76 @@ +version: '3.7' + +services: + tgi-1: + image: neuronx-tgi:latest + ports: + - "8081:8081" + environment: + - PORT=8081 + - MODEL_ID=${MODEL_ID} + - HF_AUTO_CAST_TYPE=${HF_AUTO_CAST_TYPE} + - HF_NUM_CORES=8 + - MAX_BATCH_SIZE=${MAX_BATCH_SIZE} + - MAX_INPUT_TOKENS=${MAX_INPUT_TOKENS} + - MAX_TOTAL_TOKENS=${MAX_TOTAL_TOKENS} + - MAX_CONCURRENT_REQUESTS=512 + - HF_TOKEN=${HF_TOKEN} + devices: + - "/dev/neuron0" + - "/dev/neuron1" + - "/dev/neuron2" + - "/dev/neuron3" + + tgi-2: + image: neuronx-tgi:latest + ports: + - "8082:8082" + environment: + - PORT=8082 + - MODEL_ID=${MODEL_ID} + - HF_AUTO_CAST_TYPE=${HF_AUTO_CAST_TYPE} + - HF_NUM_CORES=8 + - MAX_BATCH_SIZE=${MAX_BATCH_SIZE} + - MAX_INPUT_TOKENS=${MAX_INPUT_TOKENS} + - MAX_TOTAL_TOKENS=${MAX_TOTAL_TOKENS} + - MAX_CONCURRENT_REQUESTS=512 + - HF_TOKEN=${HF_TOKEN} + devices: + - "/dev/neuron4" + - "/dev/neuron5" + - "/dev/neuron6" + - "/dev/neuron7" + + tgi-3: + image: neuronx-tgi:latest + ports: + - "8083:8083" + environment: + - PORT=8083 + - MODEL_ID=${MODEL_ID} + - HF_AUTO_CAST_TYPE=${HF_AUTO_CAST_TYPE} + - HF_NUM_CORES=8 + - MAX_BATCH_SIZE=${MAX_BATCH_SIZE} + - MAX_INPUT_TOKENS=${MAX_INPUT_TOKENS} + - MAX_TOTAL_TOKENS=${MAX_TOTAL_TOKENS} + - MAX_CONCURRENT_REQUESTS=512 + - HF_TOKEN=${HF_TOKEN} + devices: + - "/dev/neuron8" + - "/dev/neuron9" + - "/dev/neuron10" + - "/dev/neuron11" + + loadbalancer: + image: nginx:alpine + ports: + - "8080:80" + volumes: + - ./nginx.conf:/etc/nginx/nginx.conf:ro + depends_on: + - tgi-1 + - tgi-2 + - tgi-3 + deploy: + placement: + constraints: [node.role == manager] diff --git a/benchmark/text-generation-inference/performance/phi4-14b/nginx.conf b/benchmark/text-generation-inference/performance/phi4-14b/nginx.conf new file mode 100644 index 000000000..37a3b8721 --- /dev/null +++ b/benchmark/text-generation-inference/performance/phi4-14b/nginx.conf @@ -0,0 +1,15 @@ +### Nginx TGI Load Balancer +events {} +http { + upstream tgicluster { + server tgi-1:8081; + server tgi-2:8082; + server tgi-3:8083; + } + server { + listen 80; + location / { + proxy_pass http://tgicluster; + } + } +} diff --git a/optimum/exporters/neuron/model_configs/decoder_configs.py b/optimum/exporters/neuron/model_configs/decoder_configs.py index f6e4e3661..2e461b5c1 100644 --- a/optimum/exporters/neuron/model_configs/decoder_configs.py +++ b/optimum/exporters/neuron/model_configs/decoder_configs.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Neuron export configurations for decoder models.""" +"""Neuron export configurations for models using transformers_neuronx.""" import importlib @@ -25,6 +25,7 @@ from ....neuron.backends.hlo.decoder import NeuronHloDecoderModel from ....neuron.models.granite.model import GraniteForSampling from ....neuron.models.llama.model import LlamaHloModel +from ....neuron.models.phi4.model import Phi4ForSampling from ....neuron.models.qwen2.model import Qwen2ForSampling from ..base import NeuronExportConfig @@ -166,3 +167,9 @@ class Qwen2NeuronConfig(NeuronDecoderExportConfig): class GraniteNeuronConfig(NeuronDecoderExportConfig): NEURONX_CLASS = GraniteForSampling CONTINUOUS_BATCHING = True + + +@register_in_tasks_manager("phi3", "text-generation") +class Phi4NeuronConfig(NeuronDecoderExportConfig): + NEURONX_CLASS = Phi4ForSampling + CONTINUOUS_BATCHING = True diff --git a/optimum/neuron/models/phi4/__init__.py b/optimum/neuron/models/phi4/__init__.py new file mode 100644 index 000000000..fdc025786 --- /dev/null +++ b/optimum/neuron/models/phi4/__init__.py @@ -0,0 +1,14 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/optimum/neuron/models/phi4/model.py b/optimum/neuron/models/phi4/model.py new file mode 100644 index 000000000..dc56618dc --- /dev/null +++ b/optimum/neuron/models/phi4/model.py @@ -0,0 +1,96 @@ +# Copyright Amazon Web Services and its Affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import torch +from transformers import PretrainedConfig + +from ...backends.hlo.config import NeuronConfig +from ...backends.hlo.dtypes import to_torch_dtype +from ..llama.model import LlamaHloModel +from .modules import Phi4ForCausalLM + + +class Phi4ForSampling(LlamaHloModel): + """The Phi4 model is essentially a LLama model with fused qkv and gate_up projections. + + The implementation in this class is very similar to the one used for Llama in Tnx. + The only difference is that the fused qkv and gate/up linear projection are split when + loading weights (note that they might be fused again when transferring the weights to the + neuron device if the NeuronConfig specifies it). + """ + + def __init__( + self, + config: PretrainedConfig, + neuron_config: NeuronConfig, + ): + dtype = to_torch_dtype(neuron_config.amp) + super().__init__(config, neuron_config, cpu_model=Phi4ForCausalLM(config, dtype)) + + def load_weights(self): + # Materialize the embedding to CPU + self.cpu_model.model.embed_tokens.materialize() + + for layer in self.cpu_model.model.layers: + layer.materialize() + attn = layer.self_attn + mlp = layer.mlp + new_layer = self.decoder_lm_head.new_layer() + new_layer.add_pre_attention_layer_norm(layer.input_layernorm.weight.detach(), None) + # Transpose and split fused qkv_proj into separate weights + fused_attn = attn.qkv_proj.weight.clone().detach().T + # Handle GQA + if self.config.num_kv_heads < self.config.num_attention_heads: + # Extract the larger query weights first + q_features = attn.num_heads * attn.head_dim + q_weight = fused_attn[:, :q_features] + # Then split the remaining into key and value weights + k_weight, v_weight = torch.chunk(fused_attn[:, q_features:], 2, dim=1) + # Handle MHA + else: + q_weight, k_weight, v_weight = torch.chunk(fused_attn, 3, dim=1) + new_layer.add_attention_query(q_weight, None) + new_layer.add_attention_key(k_weight, None) + new_layer.add_attention_value(v_weight, None) + if self.neuron_config and self.neuron_config.attn_output_transposed: + new_layer.add_attention_output(attn.o_proj.weight.T.detach(), None, sharding=0, transposed=True) + else: + new_layer.add_attention_output(attn.o_proj.weight.detach(), None, sharding=1, transposed=False) + + new_layer.add_pre_mlp_layer_norm(layer.post_attention_layernorm.weight.detach(), None) + # Tanspose and split fused mlp into separate weights + fused_gate_up = mlp.gate_up_proj.weight.clone().detach().T + gate, up = torch.chunk(fused_gate_up, 2, dim=1) + new_layer.add_parameter(gate, sharding=1, allow_transform=True) + new_layer.add_parameter(up, sharding=1, allow_transform=True) + new_layer.add_parameter(mlp.down_proj.weight, sharding=1) + new_layer.to_neuron() + layer.nullify() + + ln_f = self.cpu_model.model.norm + ln_f.materialize() + self.decoder_lm_head.add_final_layer_norm(ln_f.weight.detach(), None) + ln_f.nullify() + + lm_head = self.cpu_model.lm_head + lm_head.materialize() + self.decoder_lm_head.add_lm_head(lm_head.weight.detach().T) + lm_head.nullify() + + self.decoder_lm_head.to_neuron() + self.decoder_lm_head.use_executor = True + + self.decoder_lm_head_for_context.load_shared_weights(self.decoder_lm_head) + self.decoder_lm_head_for_context.use_executor = True diff --git a/optimum/neuron/models/phi4/modules.py b/optimum/neuron/models/phi4/modules.py new file mode 100644 index 000000000..78d90c5ed --- /dev/null +++ b/optimum/neuron/models/phi4/modules.py @@ -0,0 +1,73 @@ +# Copyright Amazon Web Services and its Affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from transformers.models.phi import PhiConfig + +from ...backends.hlo import module + + +class Phi4ForCausalLM(module.PretrainedModel): + def __init__(self, config: PhiConfig, dtype): + super().__init__() + self.model = Phi4Model(config, dtype) + self.lm_head = module.LowMemoryLazyLinear(config.vocab_size, dtype=dtype, bias=False) + + def get_tied_parameters(self): + return [(self.model.embed_tokens.weight, self.lm_head.weight)] + + def get_base_model(self): + return self.model + + +class Phi4Model(module.LowMemoryModule): + def __init__(self, config: PhiConfig, dtype): + super().__init__() + self.embed_tokens = module.LowMemoryEmbedding(config.vocab_size, config.hidden_size) + self.layers = module.LowMemoryModuleList( + [Phi4DecoderLayer(config, dtype) for _ in range(config.num_hidden_layers)] + ) + self.norm = Phi4RMSNorm() + + +class Phi4RMSNorm(module.LowMemoryModule): + def __init__(self) -> None: + super().__init__() + self.weight = module.UninitializedParameter() + + +class Phi4DecoderLayer(module.LowMemoryModule): + def __init__(self, config: PhiConfig, dtype): + super().__init__() + self.self_attn = Phi4Attention(config, dtype) + self.mlp = Phi4MLP(config, dtype) + self.input_layernorm = Phi4RMSNorm() + self.post_attention_layernorm = Phi4RMSNorm() + + +class Phi4Attention(module.LowMemoryModule): + def __init__(self, config: PhiConfig, dtype): + super().__init__() + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + op_size = config.num_attention_heads * self.head_dim + 2 * (config.num_key_value_heads * self.head_dim) + self.qkv_proj = module.LowMemoryLazyLinear(op_size, bias=False, dtype=dtype) + self.o_proj = module.LowMemoryLazyLinear(self.hidden_size, bias=False, dtype=dtype) + + +class Phi4MLP(module.LowMemoryModule): + def __init__(self, config, dtype): + super().__init__() + self.gate_up_proj = module.LowMemoryLazyLinear(2 * config.intermediate_size, bias=False, dtype=dtype) + self.down_proj = module.LowMemoryLazyLinear(config.hidden_size, bias=False, dtype=dtype) diff --git a/tests/decoder/conftest.py b/tests/decoder/conftest.py index 861439cb9..047de247d 100644 --- a/tests/decoder/conftest.py +++ b/tests/decoder/conftest.py @@ -48,6 +48,10 @@ "model_id": "dacorvo/Mixtral-tiny", "export_kwargs": {"batch_size": 4, "sequence_length": 1024, "num_cores": 2, "auto_cast_type": "fp16"}, }, + "phi4": { + "model_id": "microsoft/phi-4", + "export_kwargs": {"batch_size": 4, "sequence_length": 4096, "num_cores": 2, "auto_cast_type": "bf16"}, + }, } diff --git a/tests/decoder/test_decoder_export.py b/tests/decoder/test_decoder_export.py index 61aa57481..92289c895 100644 --- a/tests/decoder/test_decoder_export.py +++ b/tests/decoder/test_decoder_export.py @@ -22,7 +22,7 @@ from optimum.neuron.utils.testing_utils import is_inferentia_test, requires_neuronx -DECODER_MODEL_ARCHITECTURES = ["bloom", "gpt2", "llama", "mistral", "mixtral", "opt"] +DECODER_MODEL_ARCHITECTURES = ["bloom", "gpt2", "llama", "mistral", "mixtral", "opt", "phi3"] DECODER_MODEL_NAMES = { "bloom": "hf-internal-testing/tiny-random-BloomForCausalLM", "gpt2": "hf-internal-testing/tiny-random-gpt2", @@ -32,6 +32,7 @@ "opt": "hf-internal-testing/tiny-random-OPTForCausalLM", "qwen2": "yujiepan/qwen2.5-128k-tiny-random", "granite": "hf-internal-testing/tiny-random-GraniteForCausalLM", + "phi3": "yujiepan/phi-4-tiny-random", }