Skip to content

Commit

Permalink
Update Llama tutorials for vllm and general trtllm containers (#74)
Browse files Browse the repository at this point in the history
Added vllm tutorials under Llama2 folder
Added general README under Llama2 folder for preprocessing
Updated HuggingfaceTransformers README to include Llama instructions
Added Llama2 model example for vLLM and HuggingFace
  • Loading branch information
jbkyang-nvi authored Dec 14, 2023
1 parent ce3c9c0 commit d7521fe
Show file tree
Hide file tree
Showing 8 changed files with 372 additions and 4 deletions.
40 changes: 40 additions & 0 deletions Popular_Models_Guide/Llama2/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
<!--
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-->

# Deploying Hugging Face Transformer Models in Triton

There are multiple ways to run Llama2 with Tritonserver.
1. Infer with [TensorRT-LLM Backend](trtllm_guide.md#infer-with-tensorrt-llm-backend)
2. Infer with [vLLM Backend](vllm_guide.md#infer-with-vllm-backend)
3. Infer with [Python-based Backends as a HuggingFace model](../Quick_Deploy/HuggingFaceTransformers/README.md#deploying-hugging-face-transformer-models-in-triton)

## Pre-build instructions

For the tutorials we are assuming that the Llama2 models, weights, and tokens are cloned from the Huggingface Llama2 repo [here](https://huggingface.co/meta-llama/Llama-2-7b-hf/tree/main).
To run the tutorials, you will need to get permissions for the Llama2 repository as well as access to the huggingface cli.
The cli uses [User access tokens](https://huggingface.co/docs/hub/security-tokens). The tokens can be found here: [huggingface.co/settings/tokens](https://huggingface.co/settings/tokens).
7 changes: 7 additions & 0 deletions Popular_Models_Guide/Llama2/llama2vllm/1/model.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"model":"meta-llama/Llama-2-7b-hf",
"trust_remote_code":true,
"download_dir":"/opt/tritonserver/model_repository/llama2vllm/hf-cache",
"disable_log_requests": "true",
"gpu_memory_utilization": 0.5
}
37 changes: 37 additions & 0 deletions Popular_Models_Guide/Llama2/llama2vllm/config.pbtxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

# Note: You do not need to change any fields in this configuration.

backend: "vllm"

# The usage of device is deferred to the vLLM engine
instance_group [
{
count: 1
kind: KIND_MODEL
}
]
11 changes: 10 additions & 1 deletion Popular_Models_Guide/Llama2/trtllm_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-->

TensorRT-LLM is Nvidia's recommended solution of running Large Language
Models(LLMs) on Nvidia GPUs. Read more about TensoRT-LLM [here](https://github.com/NVIDIA/TensorRT-LLM)
and Triton's TensorRTLLM Backend [here](https://github.com/triton-inference-server/tensorrtllm_backend).

*NOTE:* If some parts of this tutorial doesn't work, it is possible that there
are some version mismatches between the `tutorials` and `tensorrt_backend` repository.
Refer to [llama.md](https://github.com/triton-inference-server/tensorrtllm_backend/blob/main/docs/llama.md)
for more detailed modifications if necessary.


## Pre-build instructions

For this tutorial, we are using the Llama2-7B HuggingFace model with pre-trained weights.
Expand Down Expand Up @@ -158,4 +168,3 @@ python3 /tensorrtllm_backend/inflight_batcher_llm/client/inflight_batcher_llm_cl
2. The [generate endpoint](https://github.com/triton-inference-server/tensorrtllm_backend/tree/release/0.5.0#query-the-server-with-the-triton-generate-endpoint) if you are using the Triton TensorRT-LLM Backend container with versions greater than `r23.10`.



117 changes: 117 additions & 0 deletions Popular_Models_Guide/Llama2/vllm_guide.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
<!--
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-->

The vLLM Backend uses vLLM to do inference. Read more about vLLM [here](https://blog.vllm.ai/2023/06/20/vllm.html) and the vLLM Backend [here](https://github.com/triton-inference-server/vllm_backend).

## Pre-build instructions

For this tutorial, we are using the Llama2-7B HuggingFace model with pre-trained weights. Please follow the [README.md](README.md) for pre-build instructions and links for how to run Llama with other backends.

## Installation

The triton vLLM container can be pulled from [NGC](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver) with

```bash
docker run --rm -it --net host --shm-size=2g \
--ulimit memlock=-1 --ulimit stack=67108864 --gpus all \
-v $PWD/llama2vllm:/opt/tritonserver/model_repository/llama2vllm \
nvcr.io/nvidia/tritonserver:23.11-vllm-python-py3
```
This will create a `/opt/tritonserver/model_repository` folder that contains the `llama2vllm` model. The model itself will be pulled from the HuggingFace

Once in the container, install the `huggingface-cli` and login with your own credentials.
```bash
pip install --upgrade huggingface_hub
huggingface-cli login --token <your huggingface access token>
```


## Serving with Triton

Then you can run the tritonserver as usual
```bash
tritonserver --model-repository model_repository
```
The server has launched successfully when you see the following outputs in your console:

```
I0922 23:28:40.351809 1 grpc_server.cc:2451] Started GRPCInferenceService at 0.0.0.0:8001
I0922 23:28:40.352017 1 http_server.cc:3558] Started HTTPService at 0.0.0.0:8000
I0922 23:28:40.395611 1 http_server.cc:187] Started Metrics Service at 0.0.0.0:8002
```

## Sending requests via the `generate` endpoint

As a simple example to make sure the server works, you can use the `generate` endpoint to test. More about the generate endpoint [here](https://github.com/triton-inference-server/server/blob/main/docs/protocol/extension_generate.md).

```bash
$ curl -X POST localhost:8000/v2/models/llama2vllm/generate -d '{"text_input": "What is Triton Inference Server?", "parameters": {"stream": false, "temperature": 0}}'
# returns (formatted for better visualization)
> {
"model_name":"llama2vllm",
"model_version":"1",
"text_output":"What is Triton Inference Server?\nTriton Inference Server is a lightweight, high-performance"
}
```

## Sending requests via the Triton client

The Triton vLLM Backend repository has a [samples folder](https://github.com/triton-inference-server/vllm_backend/tree/main/samples) that has an example client.py to test the Llama2 model.

```bash
pip3 install tritonclient[all]
# Assuming Tritonserver server is running already
$ git clone https://github.com/triton-inference-server/vllm_backend.git
$ cd vllm_backend/samples
$ python3 client.py -m llama2vllm

```
The following steps should result in a `results.txt` that has the following content
```bash
Hello, my name is
I am a 20 year old student from the Netherlands. I am currently

=========

The most dangerous animal is
The most dangerous animal is the one that is not there.
The most dangerous

=========

The capital of France is
The capital of France is Paris.
The capital of France is Paris. The

=========

The future of AI is
The future of AI is in the hands of the people who use it.

=========
```
15 changes: 12 additions & 3 deletions Quick_Deploy/HuggingFaceTransformers/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@
# Deploying Hugging Face Transformer Models in Triton

The following tutorial demonstrates how to deploy an arbitrary hugging face transformer
model on the Triton Inference Server using Triton's [Python backend](https://github.com/triton-inference-server/python_backend). For the purposes of this example, two transformer
models will be deployed:
model on the Triton Inference Server using Triton's [Python backend](https://github.com/triton-inference-server/python_backend).
For the purposes of this example, the following transformer models will be deployed:
- [tiiuae/falcon-7b](https://huggingface.co/tiiuae/falcon-7b)
- [adept/persimmon-8b-base](https://huggingface.co/adept/persimmon-8b-base)
- [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b)

These models were selected because of their popularity and consistent response quality.
However, this tutorial is also generalizable for any transformer model provided
Expand All @@ -41,6 +42,10 @@ sufficient infrastructure.
*NOTE*: The tutorial is intended to be a reference example only. It may not be tuned for
optimal performance.

*NOTE*: Llama 2 models are not specifically mentioned in the steps below, but
can be run if `tiiuae/falcon-7b` is replaced with `meta-llama/Llama-2-7b-hf`,
and `falcon7b` folder is replaced by `llama7b` folder.

## Step 1: Create a Model Repository

The first step is to create a model repository containing the models we want the Triton
Expand Down Expand Up @@ -76,11 +81,15 @@ docker build -t triton_transformer_server .

Once the ```triton_transformer_server``` image is created, you can launch the Triton Inference
Server in a container with the following command:

```bash
docker run --gpus all -it --rm --net=host --shm-size=1G --ulimit memlock=-1 --ulimit stack=67108864 -v ${PWD}/model_repository:/opt/tritonserver/model_repository triton_transformer_server tritonserver --model-repository=model_repository
```

**Note**: For private models like `Llama2`, you need to [request access to the model](https://huggingface.co/meta-llama/Llama-2-7b-hf/tree/main) and add the [access token](https://huggingface.co/settings/tokens) to the docker command `-e PRIVATE_REPO_TOKEN=<hf_your_huggingface_access_token>`.
```bash
docker run --gpus all -it --rm --net=host --shm-size=1G --ulimit memlock=-1 --ulimit stack=67108864 -e PRIVATE_REPO_TOKEN=<hf_your_huggingface_access_token> -v ${PWD}/model_repository:/opt/tritonserver/model_repository triton_transformer_server tritonserver --model-repository=model_repository
```

The server has launched successfully when you see the following outputs in your console:

```
Expand Down
114 changes: 114 additions & 0 deletions Quick_Deploy/HuggingFaceTransformers/llama7b/1/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import os

os.environ["TRANSFORMERS_CACHE"] = "/opt/tritonserver/model_repository/llama7b/hf-cache"

import json

import numpy as np
import torch
import transformers
import triton_python_backend_utils as pb_utils


class TritonPythonModel:
def initialize(self, args):
self.logger = pb_utils.Logger
self.model_config = json.loads(args["model_config"])
self.model_params = self.model_config.get("parameters", {})
default_hf_model = "meta-llama/Llama-2-7b-hf"
private_repo_token = ""
default_max_gen_length = "15"
# Check for user-specified model name in model config parameters
hf_model = self.model_params.get("huggingface_model", {}).get(
"string_value", default_hf_model
)
if "PRIVATE_REPO_TOKEN" not in os.environ:
print(
"envvar PRIVATE_REPO_TOKEN should be specified if running a restricted model like Llama2"
)
private_repo_token = os.environ["PRIVATE_REPO_TOKEN"]

# Check for user-specified max length in model config parameters
self.max_output_length = int(
self.model_params.get("max_output_length", {}).get(
"string_value", default_max_gen_length
)
)

self.logger.log_info(f"Max output length: {self.max_output_length}")
self.logger.log_info(f"Loading HuggingFace model: {hf_model}...")
# Assume tokenizer available for same model
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
hf_model, token=private_repo_token
)

self.pipeline = transformers.pipeline(
"text-generation",
model=hf_model,
torch_dtype=torch.float16,
tokenizer=self.tokenizer,
device_map="auto",
token=private_repo_token,
)

def execute(self, requests):
responses = []
for request in requests:
# Assume input named "prompt", specified in autocomplete above
input_tensor = pb_utils.get_input_tensor_by_name(request, "text_input")
prompt = input_tensor.as_numpy()[0].decode("utf-8")

response = self.generate(prompt)
responses.append(response)

return responses

def generate(self, prompt):
sequences = self.pipeline(
prompt,
do_sample=True,
top_k=10,
num_return_sequences=1,
eos_token_id=self.tokenizer.eos_token_id,
max_length=self.max_output_length,
)

output_tensors = []
texts = []
for i, seq in enumerate(sequences):
text = seq["generated_text"]
self.logger.log_info(f"Sequence {i+1}: {text}")
texts.append(text)

tensor = pb_utils.Tensor("text_output", np.array(texts, dtype=np.object_))
output_tensors.append(tensor)
response = pb_utils.InferenceResponse(output_tensors=output_tensors)
return response

def finalize(self):
print("Cleaning up...")
Loading

0 comments on commit d7521fe

Please sign in to comment.