-
Notifications
You must be signed in to change notification settings - Fork 96
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update Llama tutorials for vllm and general trtllm containers (#74)
Added vllm tutorials under Llama2 folder Added general README under Llama2 folder for preprocessing Updated HuggingfaceTransformers README to include Llama instructions Added Llama2 model example for vLLM and HuggingFace
- Loading branch information
1 parent
ce3c9c0
commit d7521fe
Showing
8 changed files
with
372 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
<!-- | ||
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# Redistribution and use in source and binary forms, with or without | ||
# modification, are permitted provided that the following conditions | ||
# are met: | ||
# * Redistributions of source code must retain the above copyright | ||
# notice, this list of conditions and the following disclaimer. | ||
# * Redistributions in binary form must reproduce the above copyright | ||
# notice, this list of conditions and the following disclaimer in the | ||
# documentation and/or other materials provided with the distribution. | ||
# * Neither the name of NVIDIA CORPORATION nor the names of its | ||
# contributors may be used to endorse or promote products derived | ||
# from this software without specific prior written permission. | ||
# | ||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY | ||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR | ||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR | ||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, | ||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, | ||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR | ||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | ||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | ||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
--> | ||
|
||
# Deploying Hugging Face Transformer Models in Triton | ||
|
||
There are multiple ways to run Llama2 with Tritonserver. | ||
1. Infer with [TensorRT-LLM Backend](trtllm_guide.md#infer-with-tensorrt-llm-backend) | ||
2. Infer with [vLLM Backend](vllm_guide.md#infer-with-vllm-backend) | ||
3. Infer with [Python-based Backends as a HuggingFace model](../Quick_Deploy/HuggingFaceTransformers/README.md#deploying-hugging-face-transformer-models-in-triton) | ||
|
||
## Pre-build instructions | ||
|
||
For the tutorials we are assuming that the Llama2 models, weights, and tokens are cloned from the Huggingface Llama2 repo [here](https://huggingface.co/meta-llama/Llama-2-7b-hf/tree/main). | ||
To run the tutorials, you will need to get permissions for the Llama2 repository as well as access to the huggingface cli. | ||
The cli uses [User access tokens](https://huggingface.co/docs/hub/security-tokens). The tokens can be found here: [huggingface.co/settings/tokens](https://huggingface.co/settings/tokens). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
{ | ||
"model":"meta-llama/Llama-2-7b-hf", | ||
"trust_remote_code":true, | ||
"download_dir":"/opt/tritonserver/model_repository/llama2vllm/hf-cache", | ||
"disable_log_requests": "true", | ||
"gpu_memory_utilization": 0.5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# Redistribution and use in source and binary forms, with or without | ||
# modification, are permitted provided that the following conditions | ||
# are met: | ||
# * Redistributions of source code must retain the above copyright | ||
# notice, this list of conditions and the following disclaimer. | ||
# * Redistributions in binary form must reproduce the above copyright | ||
# notice, this list of conditions and the following disclaimer in the | ||
# documentation and/or other materials provided with the distribution. | ||
# * Neither the name of NVIDIA CORPORATION nor the names of its | ||
# contributors may be used to endorse or promote products derived | ||
# from this software without specific prior written permission. | ||
# | ||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY | ||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR | ||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR | ||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, | ||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, | ||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR | ||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | ||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | ||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
|
||
# Note: You do not need to change any fields in this configuration. | ||
|
||
backend: "vllm" | ||
|
||
# The usage of device is deferred to the vLLM engine | ||
instance_group [ | ||
{ | ||
count: 1 | ||
kind: KIND_MODEL | ||
} | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
<!-- | ||
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# Redistribution and use in source and binary forms, with or without | ||
# modification, are permitted provided that the following conditions | ||
# are met: | ||
# * Redistributions of source code must retain the above copyright | ||
# notice, this list of conditions and the following disclaimer. | ||
# * Redistributions in binary form must reproduce the above copyright | ||
# notice, this list of conditions and the following disclaimer in the | ||
# documentation and/or other materials provided with the distribution. | ||
# * Neither the name of NVIDIA CORPORATION nor the names of its | ||
# contributors may be used to endorse or promote products derived | ||
# from this software without specific prior written permission. | ||
# | ||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY | ||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR | ||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR | ||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, | ||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, | ||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR | ||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | ||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | ||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
--> | ||
|
||
The vLLM Backend uses vLLM to do inference. Read more about vLLM [here](https://blog.vllm.ai/2023/06/20/vllm.html) and the vLLM Backend [here](https://github.com/triton-inference-server/vllm_backend). | ||
|
||
## Pre-build instructions | ||
|
||
For this tutorial, we are using the Llama2-7B HuggingFace model with pre-trained weights. Please follow the [README.md](README.md) for pre-build instructions and links for how to run Llama with other backends. | ||
|
||
## Installation | ||
|
||
The triton vLLM container can be pulled from [NGC](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver) with | ||
|
||
```bash | ||
docker run --rm -it --net host --shm-size=2g \ | ||
--ulimit memlock=-1 --ulimit stack=67108864 --gpus all \ | ||
-v $PWD/llama2vllm:/opt/tritonserver/model_repository/llama2vllm \ | ||
nvcr.io/nvidia/tritonserver:23.11-vllm-python-py3 | ||
``` | ||
This will create a `/opt/tritonserver/model_repository` folder that contains the `llama2vllm` model. The model itself will be pulled from the HuggingFace | ||
|
||
Once in the container, install the `huggingface-cli` and login with your own credentials. | ||
```bash | ||
pip install --upgrade huggingface_hub | ||
huggingface-cli login --token <your huggingface access token> | ||
``` | ||
|
||
|
||
## Serving with Triton | ||
|
||
Then you can run the tritonserver as usual | ||
```bash | ||
tritonserver --model-repository model_repository | ||
``` | ||
The server has launched successfully when you see the following outputs in your console: | ||
|
||
``` | ||
I0922 23:28:40.351809 1 grpc_server.cc:2451] Started GRPCInferenceService at 0.0.0.0:8001 | ||
I0922 23:28:40.352017 1 http_server.cc:3558] Started HTTPService at 0.0.0.0:8000 | ||
I0922 23:28:40.395611 1 http_server.cc:187] Started Metrics Service at 0.0.0.0:8002 | ||
``` | ||
|
||
## Sending requests via the `generate` endpoint | ||
|
||
As a simple example to make sure the server works, you can use the `generate` endpoint to test. More about the generate endpoint [here](https://github.com/triton-inference-server/server/blob/main/docs/protocol/extension_generate.md). | ||
|
||
```bash | ||
$ curl -X POST localhost:8000/v2/models/llama2vllm/generate -d '{"text_input": "What is Triton Inference Server?", "parameters": {"stream": false, "temperature": 0}}' | ||
# returns (formatted for better visualization) | ||
> { | ||
"model_name":"llama2vllm", | ||
"model_version":"1", | ||
"text_output":"What is Triton Inference Server?\nTriton Inference Server is a lightweight, high-performance" | ||
} | ||
``` | ||
|
||
## Sending requests via the Triton client | ||
|
||
The Triton vLLM Backend repository has a [samples folder](https://github.com/triton-inference-server/vllm_backend/tree/main/samples) that has an example client.py to test the Llama2 model. | ||
|
||
```bash | ||
pip3 install tritonclient[all] | ||
# Assuming Tritonserver server is running already | ||
$ git clone https://github.com/triton-inference-server/vllm_backend.git | ||
$ cd vllm_backend/samples | ||
$ python3 client.py -m llama2vllm | ||
|
||
``` | ||
The following steps should result in a `results.txt` that has the following content | ||
```bash | ||
Hello, my name is | ||
I am a 20 year old student from the Netherlands. I am currently | ||
|
||
========= | ||
|
||
The most dangerous animal is | ||
The most dangerous animal is the one that is not there. | ||
The most dangerous | ||
|
||
========= | ||
|
||
The capital of France is | ||
The capital of France is Paris. | ||
The capital of France is Paris. The | ||
|
||
========= | ||
|
||
The future of AI is | ||
The future of AI is in the hands of the people who use it. | ||
|
||
========= | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
114 changes: 114 additions & 0 deletions
114
Quick_Deploy/HuggingFaceTransformers/llama7b/1/model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# Redistribution and use in source and binary forms, with or without | ||
# modification, are permitted provided that the following conditions | ||
# are met: | ||
# * Redistributions of source code must retain the above copyright | ||
# notice, this list of conditions and the following disclaimer. | ||
# * Redistributions in binary form must reproduce the above copyright | ||
# notice, this list of conditions and the following disclaimer in the | ||
# documentation and/or other materials provided with the distribution. | ||
# * Neither the name of NVIDIA CORPORATION nor the names of its | ||
# contributors may be used to endorse or promote products derived | ||
# from this software without specific prior written permission. | ||
# | ||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY | ||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR | ||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR | ||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, | ||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, | ||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR | ||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | ||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | ||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
import os | ||
|
||
os.environ["TRANSFORMERS_CACHE"] = "/opt/tritonserver/model_repository/llama7b/hf-cache" | ||
|
||
import json | ||
|
||
import numpy as np | ||
import torch | ||
import transformers | ||
import triton_python_backend_utils as pb_utils | ||
|
||
|
||
class TritonPythonModel: | ||
def initialize(self, args): | ||
self.logger = pb_utils.Logger | ||
self.model_config = json.loads(args["model_config"]) | ||
self.model_params = self.model_config.get("parameters", {}) | ||
default_hf_model = "meta-llama/Llama-2-7b-hf" | ||
private_repo_token = "" | ||
default_max_gen_length = "15" | ||
# Check for user-specified model name in model config parameters | ||
hf_model = self.model_params.get("huggingface_model", {}).get( | ||
"string_value", default_hf_model | ||
) | ||
if "PRIVATE_REPO_TOKEN" not in os.environ: | ||
print( | ||
"envvar PRIVATE_REPO_TOKEN should be specified if running a restricted model like Llama2" | ||
) | ||
private_repo_token = os.environ["PRIVATE_REPO_TOKEN"] | ||
|
||
# Check for user-specified max length in model config parameters | ||
self.max_output_length = int( | ||
self.model_params.get("max_output_length", {}).get( | ||
"string_value", default_max_gen_length | ||
) | ||
) | ||
|
||
self.logger.log_info(f"Max output length: {self.max_output_length}") | ||
self.logger.log_info(f"Loading HuggingFace model: {hf_model}...") | ||
# Assume tokenizer available for same model | ||
self.tokenizer = transformers.AutoTokenizer.from_pretrained( | ||
hf_model, token=private_repo_token | ||
) | ||
|
||
self.pipeline = transformers.pipeline( | ||
"text-generation", | ||
model=hf_model, | ||
torch_dtype=torch.float16, | ||
tokenizer=self.tokenizer, | ||
device_map="auto", | ||
token=private_repo_token, | ||
) | ||
|
||
def execute(self, requests): | ||
responses = [] | ||
for request in requests: | ||
# Assume input named "prompt", specified in autocomplete above | ||
input_tensor = pb_utils.get_input_tensor_by_name(request, "text_input") | ||
prompt = input_tensor.as_numpy()[0].decode("utf-8") | ||
|
||
response = self.generate(prompt) | ||
responses.append(response) | ||
|
||
return responses | ||
|
||
def generate(self, prompt): | ||
sequences = self.pipeline( | ||
prompt, | ||
do_sample=True, | ||
top_k=10, | ||
num_return_sequences=1, | ||
eos_token_id=self.tokenizer.eos_token_id, | ||
max_length=self.max_output_length, | ||
) | ||
|
||
output_tensors = [] | ||
texts = [] | ||
for i, seq in enumerate(sequences): | ||
text = seq["generated_text"] | ||
self.logger.log_info(f"Sequence {i+1}: {text}") | ||
texts.append(text) | ||
|
||
tensor = pb_utils.Tensor("text_output", np.array(texts, dtype=np.object_)) | ||
output_tensors.append(tensor) | ||
response = pb_utils.InferenceResponse(output_tensors=output_tensors) | ||
return response | ||
|
||
def finalize(self): | ||
print("Cleaning up...") |
Oops, something went wrong.