From 143a2d1979e016d9ca97477bf60d0fd439a79e14 Mon Sep 17 00:00:00 2001 From: oandreeva-nv Date: Mon, 17 Jun 2024 19:54:25 -0700 Subject: [PATCH 1/8] Llava tutorial --- .../Llava1.5/llava_trtllm_guide.md | 222 ++++++++ .../model_repository/llava-1.5/1/model.py | 206 +++++++ .../model_repository/llava-1.5/config.pbtxt | 83 +++ .../tensorrt_llm/config.pbtxt | 516 ++++++++++++++++++ .../vision_encoder/1/model.py | 104 ++++ .../vision_encoder/config.pbtxt | 22 + .../Llava1.5/multi_modal_client.py | 198 +++++++ README.md | 1 + 8 files changed, 1352 insertions(+) create mode 100644 Popular_Models_Guide/Llava1.5/llava_trtllm_guide.md create mode 100644 Popular_Models_Guide/Llava1.5/model_repository/llava-1.5/1/model.py create mode 100644 Popular_Models_Guide/Llava1.5/model_repository/llava-1.5/config.pbtxt create mode 100644 Popular_Models_Guide/Llava1.5/model_repository/tensorrt_llm/config.pbtxt create mode 100644 Popular_Models_Guide/Llava1.5/model_repository/vision_encoder/1/model.py create mode 100644 Popular_Models_Guide/Llava1.5/model_repository/vision_encoder/config.pbtxt create mode 100644 Popular_Models_Guide/Llava1.5/multi_modal_client.py diff --git a/Popular_Models_Guide/Llava1.5/llava_trtllm_guide.md b/Popular_Models_Guide/Llava1.5/llava_trtllm_guide.md new file mode 100644 index 00000000..f5aa0b75 --- /dev/null +++ b/Popular_Models_Guide/Llava1.5/llava_trtllm_guide.md @@ -0,0 +1,222 @@ + + +# Deploying Hugging Face Llava1.5-7b Model in Triton + +TensorRT-LLM is Nvidia's recommended solution of running Large Language +Models(LLMs) on Nvidia GPUs. Read more about TensoRT-LLM [here](https://github.com/NVIDIA/TensorRT-LLM) +and Triton's TensorRT-LLM Backend [here](https://github.com/triton-inference-server/tensorrtllm_backend). + +*NOTE:* If some parts of this tutorial doesn't work, it is possible that there +are some version mismatches between the `tutorials` and `tensorrtllm_backend` +repository. Refer to [llama.md](https://github.com/triton-inference-server/tensorrtllm_backend/blob/main/docs/llama.md) +for more detailed modifications if necessary. And if you are familiar with +python, you can also try using +[High-level API](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/high-level-api/README.md) +for LLM workflow. + + +## Acquiring Llava1.5-7B model + +For this tutorial, we are using the Llava1.5-7B HuggingFace model with pre-trained +weights. Clone the repo of the model with weights and tokens +[here](https://huggingface.co/llava-hf/llava-1.5-7b-hf/tree/main). + +## Deploying with Triton Inference Server + +Next steps will guide you over the process of TensorRT and TensorRT-LLM engine +building and Triton model repository set up. + +### Prerequisite: TensorRT-LLM backend + +This tutorial requires TensorRT-LLM Backend repository. Please note, +that for best user experience we recommend using the latest +[release tag](https://github.com/triton-inference-server/tensorrtllm_backend/tags) +of `tensorrtllm_backend` and +the latest [Triton Server container.](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver/tags) + +To clone TensorRT-LLM Backend repository, make sure to run the following +set of commands. +```bash +git clone https://github.com/triton-inference-server/tensorrtllm_backend.git --branch +# Update the submodules +cd tensorrtllm_backend +# Install git-lfs if needed +apt-get update && apt-get install git-lfs -y --no-install-recommends +git lfs install +git submodule update --init --recursive +``` + +### Launch Triton TensorRT-LLM container + +Launch Triton docker container with TensorRT-LLM backend. +Note that we're mounting `tensorrtllm_backend` to `/tensorrtllm_backend` +and the Llava1.5 model to `/Llava-1.5-7b-hf` in the docker container for simplicity. +Make an `engines` folder outside docker to reuse engines for future runs. +Please, make sure to replace with the version of Triton that you want +to use. + +```bash +docker run --rm -it --net host --shm-size=2g \ + --ulimit memlock=-1 --ulimit stack=67108864 --gpus all \ + -v :/tensorrtllm_backend \ + -v :/llava-1.5-7b-hf \ + -v :/engines \ + -v :/tutorials \ + nvcr.io/nvidia/tritonserver:-trtllm-python-py3 +``` + +Alternatively, you can follow instructions +[here](https://github.com/triton-inference-server/tensorrtllm_backend?tab=readme-ov-file#build-the-docker-container) +to build Triton Server with Tensorrt-LLM Backend if you want +to build a specialized container. + +Don't forget to allow gpu usage when you launch the container. + +### Create Engines for each model [skip this step if you already have engines] + +TensorRT-LLM requires each model to be compiled for the configuration +you need before running. To do so, before you run your model for the first time +on Triton Server you will need to create a TensorRT-LLM engine. + +Starting with [24.04 release](https://github.com/triton-inference-server/server/releases/tag/v2.45.0), +Triton Server TensrRT-LLM container comes with +pre-installed TensorRT-LLM package, which allows users to build engines inside +the Triton container. + +Llava1.5 requires 2 engines: a TensorRT engine for visual components, +and a TRT-LLM engine for the language components. This tutorial bases on 24.05 +release, which corresponds to `v0.9.0` version of TensorRT-LLM and +TensorRT-LLM backend and follows [this](https://github.com/NVIDIA/TensorRT-LLM/tree/v0.9.0/examples/multimodal#llava-and-vila) +TensorRT-LLM multi-modal guide. + +To generate engines, simply follow the next steps: + +```bash +HF_LLAVA_MODEL=/llava-1.5-7b-hf +UNIFIED_CKPT_PATH=/tmp/ckpt/llava/7b/ +ENGINE_DIR=/engines/llava1.5 +CONVERT_CHKPT_SCRIPT=/tensorrtllm_backend/tensorrt_llm/examples/llama/convert_checkpoint.py +python3 ${CONVERT_CHKPT_SCRIPT} --model_dir ${HF_LLAVA_MODEL} --output_dir ${UNIFIED_CKPT_PATH} --dtype float16 +trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \ + --output_dir ${ENGINE_DIR} \ + --gemm_plugin float16 \ + --use_fused_mlp \ + --max_batch_size 1 \ + --max_input_len 2048 \ + --max_output_len 512 \ + --max_multimodal_len 576 # 1 (max_batch_size) * 576 (num_visual_features) + +python /tensorrtllm_backend/tensorrt_llm/examples/multimodal/build_visual_engine.py --model_path ${HF_LLAVA_MODEL} --model_type llava --output_dir ${ENGINE_DIR}/llava1.5 +``` + + +> Optional: You can check test the output of the model with `run.py` +> located in the same llama examples folder. +> +> ```bash +> python3 /tensorrtllm_backend/tensorrt_llm/examples/multimodal/run.py --max_new_tokens 30 --hf_model_dir ${HF_LLAVA_MODEL} --visual_engine_dir ${ENGINE_DIR} --llm_engine_dir ${ENGINE_DIR} --decoder_llm --input_text "Question: which city is this? Answer:" +> ``` +> You should expect the following response: +> ``` +> [TensorRT-LLM] TensorRT-LLM version: 0.9.0 +> ... +> [06/18/2024-01:02:24] [TRT-LLM] [I] --------------------------------------------------------- +> [06/18/2024-01:02:24] [TRT-LLM] [I] +> [Q] Question: which city is this? Answer: +> [06/18/2024-01:02:24] [TRT-LLM] [I] +> [A] ['Singapore'] +> [06/18/2024-01:02:24] [TRT-LLM] [I] Generated 1 tokens +> [06/18/2024-01:02:24] [TRT-LLM] [I] --------------------------------------------------------- +> ``` + +### Serving with Triton + +The last step is to set up a Triton model repository. For this tutorial, +we provide all necessary Triton related files under `model_repository/`. +You simply need to provide TensorRT-LLM engine location in its `config.pbtxt`: + +```bash +FILL_TEMPLATE_SCRIPT=/tensorrtllm_backend/tools/fill_template.py +python3 ${FILL_TEMPLATE_SCRIPT} -i /tutorials/Popular_Models_Guide/Llava1.5/model_repository/tensorrt_llm/config.pbtxt engine_dir:${ENGINE_DIR} +``` + +3. Launch Tritonserver + +Use the [launch_triton_server.py](https://github.com/triton-inference-server/tensorrtllm_backend/blob/release/0.5.0/scripts/launch_triton_server.py) script. This launches multiple instances of `tritonserver` with MPI. +```bash +export TRT_ENGINE_LOCATION="/engines/llava1.5/visual_encoder.engine" +export HF_LOCATION=/llava-1.5-7b-hf +python3 /tensorrtllm_backend/scripts/launch_triton_server.py --world_size= --model_repo=/opt/tritonserver/inflight_batcher_llm +``` +> You should expect the following response: +> ``` +> ... +> I0503 22:01:25.210518 1175 grpc_server.cc:2463] Started GRPCInferenceService at 0.0.0.0:8001 +> I0503 22:01:25.211612 1175 http_server.cc:4692] Started HTTPService at 0.0.0.0:8000 +> I0503 22:01:25.254914 1175 http_server.cc:362] Started Metrics Service at 0.0.0.0:8002 +> ``` + +To stop Triton Server inside the container, run: +```bash +pkill tritonserver +``` + +### Send an inference request + +You can test the results of the run with: +1. The [multi_modal_client.py](tutorials/Popular_Models_Guide/Llava1.5/multi_modal_client.py) script. + +```bash +# Using the SDK container as an example +docker run --rm -it --net host --shm-size=2g \ + --ulimit memlock=-1 --ulimit stack=67108864 --gpus all \ + -v /path/to/tutorials:/tutorials + nvcr.io/nvidia/tritonserver:-py3-sdk + +python3 python multi_modal_client.py --prompt "Describe the picture." --image_url "http://images.cocodataset.org/test2017/000000155781.jpg" --max-tokens=15 +``` +> You should expect the following response: +> ``` +> Got completed request +> The image features a city bus parked on the side of a street. +> ``` + +2. The [generate endpoint](https://github.com/triton-inference-server/tensorrtllm_backend/tree/release/0.5.0#query-the-server-with-the-triton-generate-endpoint). + +```bash +curl -X POST localhost:8000/v2/models/llava-1.5/generate -d '{"prompt":"USER: \nQuestion:Describe the picture. Answer:", "image":"http://images.cocodataset.org/test2017/000000155781.jpg", "max_tokens":100}' +``` +> You should expect the following response: +> ``` +> data: {"completion_tokens":77,"finish_reason":"stop","model_name":"llava-1.5","model_version":"1","prompt_tokens":592,"text":"The image features a city bus parked on the side of a street. The bus is positioned near a railroad crossing, and there is a stop sign visible in the scene. The bus is also displaying an \"Out of Service\" sign, indicating that it is not currently in operation. The street appears to be foggy, adding a sense of atmosphere to the scene.","total_tokens":669} +> ``` + +## References + +For more examples feel free to refer to [End to end workflow to run multi-modal models.](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/multimodal/README.md) \ No newline at end of file diff --git a/Popular_Models_Guide/Llava1.5/model_repository/llava-1.5/1/model.py b/Popular_Models_Guide/Llava1.5/model_repository/llava-1.5/1/model.py new file mode 100644 index 00000000..e32607fd --- /dev/null +++ b/Popular_Models_Guide/Llava1.5/model_repository/llava-1.5/1/model.py @@ -0,0 +1,206 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import base64 +import os +from io import BytesIO + +import numpy as np +import requests as rq +import torch +import triton_python_backend_utils as pb_utils +from PIL import Image +from transformers import AutoProcessor, AutoTokenizer + + +class TritonPythonModel: + def initialize(self, args): + HF_LOCATION = os.getenv("HF_LOCATION", pb_utils.get_model_dir()) + self.image_processor = AutoProcessor.from_pretrained(HF_LOCATION) + self.logger = pb_utils.Logger + self.tokenizer = AutoTokenizer.from_pretrained(HF_LOCATION) + self.vocab_size = 32064 + self.max_input_len = 2048 + + def _tokenize(self, prompt, num_visual_tokens): + chunks = prompt.split("") + assert len(chunks) == 2, "Only support exactly one image per prompt" + + return ( + self.tokenizer.encode(chunks[0]) + + list(range(self.vocab_size, self.vocab_size + num_visual_tokens)) + + self.tokenizer.encode(chunks[1])[self.tokenizer.add_bos_token :] + ) + + def _parse_input(self, request, input_name, default=None): + input = pb_utils.get_input_tensor_by_name(request, input_name) + if input is not None: + return input.as_numpy()[0] + + return default + + def execute(self, requests): + """ + This function receives a list of requests (`pb_utils.InferenceRequest`), + performs inference on every request and appends it to responses. + """ + responses = [] + for request in requests: + # Get INPUT0 + image = ( + pb_utils.get_input_tensor_by_name(request, "image") + .as_numpy() + .flatten() + .tolist() + ) + if isinstance(image[0], bytes): + image = image[0].decode("utf-8") + pil_image = Image.open(rq.get(image, stream=True).raw).convert("RGB") + # Get INPUT1 + prompt = pb_utils.get_input_tensor_by_name(request, "prompt").as_numpy()[0] + if isinstance(prompt, bytes): + prompt = prompt.decode("utf-8") + + image = self.image_processor( + text=prompt, images=pil_image, return_tensors="np" + )["pixel_values"].astype(np.float16) + # Create inference request object + infer_request = pb_utils.InferenceRequest( + model_name="vision_encoder", + requested_output_names=["features"], + inputs=[pb_utils.Tensor("image", image)], + ) + + # Perform synchronous blocking inference request + vision_response = infer_request.exec() + response_sender = request.get_response_sender() + + image_features = pb_utils.get_output_tensor_by_name( + vision_response, "features" + ) + image_features = torch.from_dlpack(image_features.as_numpy()) + # parse input parameters + max_tokens = self._parse_input(request, "max_tokens", default=50) + temperature = self._parse_input(request, "temperature", default=0.5) + top_k = self._parse_input(request, "top_k", default=1) + frequency_penalty = self._parse_input( + request, "frequency_penalty", default=0.7 + ) + seed = self._parse_input(request, "seed", default=10) + + input_ids = self._tokenize(prompt, len(image_features[0])) + input_ids = np.array(input_ids, dtype=np.int32) + input_len = input_ids.shape[0] + if input_len > self.max_input_len: + error = pb_utils.InferenceResponse( + error=pb_utils.TritonError( + f"Input length ({input_len:d}) exceeds limit ({self.max_input_len:d})" + ) + ) + response_sender.send( + error, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL + ) + return + + # build embedding table + embedding_args = { + "prompt_vocab_size": np.array( + [[image_features[0].shape[0]]], dtype=np.uint32 + ), + "prompt_embedding_table": np.expand_dims(image_features[0], 0).astype( + np.float16 + ), + } + + llm_request_inputs = { + "input_ids": np.expand_dims(input_ids, 0), + "input_lengths": np.array([[input_len]], dtype=np.int32), + "request_output_len": np.array([[max_tokens]], dtype=np.int32), + "temperature": np.array([[temperature]], dtype=np.float32), + "runtime_top_k": np.array([[top_k]], dtype=np.int32), + "frequency_penalty": np.array([[frequency_penalty]], dtype=np.float32), + "end_id": np.array([[self.tokenizer.eos_token_id]], dtype=np.int32), + "random_seed": np.array([[seed]], dtype=np.uint64), + "streaming": np.array([[1]], dtype=np.bool_), + **embedding_args, + } + llm_request = pb_utils.InferenceRequest( + model_name="tensorrt_llm", + requested_output_names=["output_ids", "sequence_length"], + inputs=[pb_utils.Tensor(k, v) for k, v in llm_request_inputs.items()], + ) + output_ids, output_len = [], 0 + for response in llm_request.exec(decoupled=True): + if response.has_error(): + raise pb_utils.TritonModelException(response.error().message()) + + stream_output_ids = ( + pb_utils.get_output_tensor_by_name(response, "output_ids") + .as_numpy() + .flatten() + .tolist() + ) + + finish_reason = "" + if len(stream_output_ids) == 0 or ( + len(stream_output_ids) != 0 + and stream_output_ids[-1] == self.tokenizer.eos_token_id + ): + finish_reason = "stop" + output_ids += stream_output_ids + if len(output_ids) >= max_tokens: + finish_reason = "length" + output_ids = output_ids[:max_tokens] + last_response = finish_reason != "" + output_len = len(output_ids) + + if last_response: + output_text = self.tokenizer.decode(output_ids).strip() + response = pb_utils.InferenceResponse( + output_tensors=[ + pb_utils.Tensor( + "text", np.array([output_text], np.object_) + ), + pb_utils.Tensor( + "finish_reason", np.array([finish_reason], np.object_) + ), + pb_utils.Tensor( + "completion_tokens", np.array([output_len], np.int32) + ), + pb_utils.Tensor( + "prompt_tokens", np.array([input_len], np.int32) + ), + pb_utils.Tensor( + "total_tokens", + np.array([output_len + input_len], np.int32), + ), + ] + ) + response_sender.send( + response, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL + ) + + return None diff --git a/Popular_Models_Guide/Llava1.5/model_repository/llava-1.5/config.pbtxt b/Popular_Models_Guide/Llava1.5/model_repository/llava-1.5/config.pbtxt new file mode 100644 index 00000000..bf790ba6 --- /dev/null +++ b/Popular_Models_Guide/Llava1.5/model_repository/llava-1.5/config.pbtxt @@ -0,0 +1,83 @@ +model_transaction_policy { + decoupled: True +} + +input [ + { + name: "prompt" + data_type: TYPE_STRING + dims: [ 1 ] + }, + { + name: "image" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "max_tokens" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + }, + { + name: "temperature" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "top_k" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + }, + { + name: "frequency_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "seed" + data_type: TYPE_UINT64 + dims: [ 1 ] + optional: true + } +] + +output [ + { + name: "text" + data_type: TYPE_STRING + dims: [ 1 ] + }, + { + name: "finish_reason" + data_type: TYPE_STRING + dims: [ 1 ] + }, + { + name: "prompt_tokens" + data_type: TYPE_INT32 + dims: [ 1 ] + }, + { + name: "completion_tokens" + data_type: TYPE_INT32 + dims: [ 1 ] + }, + { + name: "total_tokens" + data_type: TYPE_INT32 + dims: [ 1 ] + } +] + + +instance_group [ + { + count: 1 + kind: KIND_GPU + gpus: [ 0 ] + } +] \ No newline at end of file diff --git a/Popular_Models_Guide/Llava1.5/model_repository/tensorrt_llm/config.pbtxt b/Popular_Models_Guide/Llava1.5/model_repository/tensorrt_llm/config.pbtxt new file mode 100644 index 00000000..69fbd144 --- /dev/null +++ b/Popular_Models_Guide/Llava1.5/model_repository/tensorrt_llm/config.pbtxt @@ -0,0 +1,516 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "tensorrt_llm" +backend: "tensorrtllm" +max_batch_size: 1 + +model_transaction_policy { + decoupled: True +} + +input [ + { + name: "input_ids" + data_type: TYPE_INT32 + dims: [ -1 ] + allow_ragged_batch: true + }, + { + name: "input_lengths" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + }, + { + name: "request_output_len" + data_type: TYPE_INT32 + dims: [ 1 ] + }, + { + name: "draft_input_ids" + data_type: TYPE_INT32 + dims: [ -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "draft_logits" + data_type: TYPE_FP32 + dims: [ -1, -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "draft_acceptance_threshold" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "end_id" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "pad_id" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "stop_words_list" + data_type: TYPE_INT32 + dims: [ 2, -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "bad_words_list" + data_type: TYPE_INT32 + dims: [ 2, -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "embedding_bias" + data_type: TYPE_FP32 + dims: [ -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "beam_width" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "temperature" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "runtime_top_k" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "runtime_top_p" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "runtime_top_p_min" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "runtime_top_p_decay" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "runtime_top_p_reset_ids" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "len_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "early_stopping" + data_type: TYPE_BOOL + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "repetition_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "min_length" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "beam_search_diversity_rate" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "presence_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "frequency_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "random_seed" + data_type: TYPE_UINT64 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "return_log_probs" + data_type: TYPE_BOOL + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "return_context_logits" + data_type: TYPE_BOOL + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "return_generation_logits" + data_type: TYPE_BOOL + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "stop" + data_type: TYPE_BOOL + dims: [ 1 ] + optional: true + }, + { + name: "streaming" + data_type: TYPE_BOOL + dims: [ 1 ] + optional: true + }, + { + name: "prompt_embedding_table" + data_type: TYPE_FP16 + dims: [ -1, -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "prompt_vocab_size" + data_type: TYPE_UINT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + # the unique task ID for the given LoRA. + # To perform inference with a specific LoRA for the first time `lora_task_id` `lora_weights` and `lora_config` must all be given. + # The LoRA will be cached, so that subsequent requests for the same task only require `lora_task_id`. + # If the cache is full the oldest LoRA will be evicted to make space for new ones. An error is returned if `lora_task_id` is not cached. + { + name: "lora_task_id" + data_type: TYPE_UINT64 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + # weights for a lora adapter shape [ num_lora_modules_layers, D x Hi + Ho x D ] + # where the last dimension holds the in / out adapter weights for the associated module (e.g. attn_qkv) and model layer + # each of the in / out tensors are first flattened and then concatenated together in the format above. + # D=adapter_size (R value), Hi=hidden_size_in, Ho=hidden_size_out. + { + name: "lora_weights" + data_type: TYPE_FP16 + dims: [ -1, -1 ] + optional: true + allow_ragged_batch: true + }, + # module identifier (same size a first dimension of lora_weights) + # See LoraModule::ModuleType for model id mapping + # + # "attn_qkv": 0 # compbined qkv adapter + # "attn_q": 1 # q adapter + # "attn_k": 2 # k adapter + # "attn_v": 3 # v adapter + # "attn_dense": 4 # adapter for the dense layer in attention + # "mlp_h_to_4h": 5 # for llama2 adapter for gated mlp layer after attention / RMSNorm: up projection + # "mlp_4h_to_h": 6 # for llama2 adapter for gated mlp layer after attention / RMSNorm: down projection + # "mlp_gate": 7 # for llama2 adapter for gated mlp later after attention / RMSNorm: gate + # + # last dim holds [ module_id, layer_idx, adapter_size (D aka R value) ] + { + name: "lora_config" + data_type: TYPE_INT32 + dims: [ -1, 3 ] + optional: true + allow_ragged_batch: true + } +] +output [ + { + name: "output_ids" + data_type: TYPE_INT32 + dims: [ -1, -1 ] + }, + { + name: "sequence_length" + data_type: TYPE_INT32 + dims: [ -1 ] + }, + { + name: "cum_log_probs" + data_type: TYPE_FP32 + dims: [ -1 ] + }, + { + name: "output_log_probs" + data_type: TYPE_FP32 + dims: [ -1, -1 ] + }, + { + name: "context_logits" + data_type: TYPE_FP32 + dims: [ -1, -1 ] + }, + { + name: "generation_logits" + data_type: TYPE_FP32 + dims: [ -1, -1, -1 ] + } +] +instance_group [ + { + count: 1 + kind : KIND_CPU + } +] +parameters: { + key: "max_beam_width" + value: { + string_value: "1" + } +} +parameters: { + key: "FORCE_CPU_ONLY_INPUT_TENSORS" + value: { + string_value: "no" + } +} +parameters: { + key: "gpt_model_type" + value: { + string_value: "inflight_fused_batching" + } +} +parameters: { + key: "gpt_model_path" + value: { + string_value: "${engine_dir}" + } +} +parameters: { + key: "max_tokens_in_paged_kv_cache" + value: { + string_value: "${max_tokens_in_paged_kv_cache}" + } +} +parameters: { + key: "max_attention_window_size" + value: { + string_value: "${max_attention_window_size}" + } +} +parameters: { + key: "sink_token_length" + value: { + string_value: "${sink_token_length}" + } +} +parameters: { + key: "batch_scheduler_policy" + value: { + string_value: "${batch_scheduler_policy}" + } +} +parameters: { + key: "kv_cache_free_gpu_mem_fraction" + value: { + string_value: "${kv_cache_free_gpu_mem_fraction}" + } +} +parameters: { + key: "kv_cache_host_memory_bytes" + value: { + string_value: "${kv_cache_host_memory_bytes}" + } +} +parameters: { + key: "kv_cache_onboard_blocks" + value: { + string_value: "${kv_cache_onboard_blocks}" + } +} +# enable_trt_overlap is deprecated and doesn't have any effect on the runtime +# parameters: { +# key: "enable_trt_overlap" +# value: { +# string_value: "${enable_trt_overlap}" +# } +# } +parameters: { + key: "exclude_input_in_output" + value: { + string_value: "${exclude_input_in_output}" + } +} +parameters: { + key: "cancellation_check_period_ms" + value: { + string_value: "${cancellation_check_period_ms}" + } +} +parameters: { + key: "stats_check_period_ms" + value: { + string_value: "${stats_check_period_ms}" + } +} +parameters: { + key: "iter_stats_max_iterations" + value: { + string_value: "${iter_stats_max_iterations}" + } +} +parameters: { + key: "request_stats_max_iterations" + value: { + string_value: "${request_stats_max_iterations}" + } +} +parameters: { + key: "enable_kv_cache_reuse" + value: { + string_value: "${enable_kv_cache_reuse}" + } +} +parameters: { + key: "normalize_log_probs" + value: { + string_value: "${normalize_log_probs}" + } +} +parameters: { + key: "enable_chunked_context" + value: { + string_value: "${enable_chunked_context}" + } +} +parameters: { + key: "gpu_device_ids" + value: { + string_value: "${gpu_device_ids}" + } +} +parameters: { + key: "lora_cache_optimal_adapter_size" + value: { + string_value: "${lora_cache_optimal_adapter_size}" + } +} +parameters: { + key: "lora_cache_max_adapter_size" + value: { + string_value: "${lora_cache_max_adapter_size}" + } +} +parameters: { + key: "lora_cache_gpu_memory_fraction" + value: { + string_value: "${lora_cache_gpu_memory_fraction}" + } +} +parameters: { + key: "lora_cache_host_memory_bytes" + value: { + string_value: "${lora_cache_host_memory_bytes}" + } +} +parameters: { + key: "decoding_mode" + value: { + string_value: "${decoding_mode}" + } +} +parameters: { + key: "executor_worker_path" + value: { + string_value: "/opt/tritonserver/backends/tensorrtllm/trtllmExecutorWorker" + } +} +parameters: { + key: "medusa_choices" + value: { + string_value: "${medusa_choices}" + } +} +parameters: { + key: "gpu_weights_percent" + value: { + string_value: "${gpu_weights_percent}" + } +} diff --git a/Popular_Models_Guide/Llava1.5/model_repository/vision_encoder/1/model.py b/Popular_Models_Guide/Llava1.5/model_repository/vision_encoder/1/model.py new file mode 100644 index 00000000..c58b577e --- /dev/null +++ b/Popular_Models_Guide/Llava1.5/model_repository/vision_encoder/1/model.py @@ -0,0 +1,104 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import os + +import numpy as np +import tensorrt as trt +import torch +import triton_python_backend_utils as pb_utils + + +class TritonPythonModel: + def initialize(self, args): + device = "cuda" if args["model_instance_kind"] == "GPU" else "cpu" + device_id = args["model_instance_device_id"] + self.device = f"{device}:{device_id}" + # Load TRT engine + self.logger = trt.Logger(trt.Logger.ERROR) + engine_path = os.getenv("TRT_ENGINE_LOCATION") + with open(engine_path, "rb") as f, trt.Runtime(self.logger) as runtime: + assert runtime + self.engine = runtime.deserialize_cuda_engine(f.read()) + assert self.engine + self.context = self.engine.create_execution_context() + assert self.context + + # Setup I/O bindings + self.inputs = [] + self.outputs = [] + for i in range(self.engine.num_io_tensors): + name = self.engine.get_tensor_name(i) + is_input = False + if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT: + is_input = True + + dtype = self.engine.get_tensor_dtype(name) + shape = self.engine.get_tensor_shape(name) + if shape[0] < 0: + profile_shape = self.engine.get_tensor_profile_shape(name, 0) + # Set the *min* profile as binding shape, choices [min,opt,max] + self.context.set_input_shape(name, profile_shape[0]) + shape = self.context.get_tensor_shape(name) + + binding = { + "index": i, + "name": name, + "dtype": np.dtype(trt.nptype(dtype)), + "shape": list(shape), + "allocation": None, + } + if is_input: + self.inputs.append(binding) + else: + self.outputs.append(binding) + + def execute(self, requests): + """ + This function receives a list of requests (`pb_utils.InferenceRequest`), + performs inference on every request and appends it to responses. + """ + responses = [] + for request in requests: + allocations = [] + output = torch.asarray( + np.zeros(self.outputs[0]["shape"], self.outputs[0]["dtype"]), + device=self.device, + ) + input_tensor = torch.asarray( + pb_utils.get_input_tensor_by_name(request, "image").as_numpy(), + device=self.device, + ) + self.inputs[0]["allocation"] = input_tensor.data_ptr() + allocations.append(input_tensor.data_ptr()) + self.outputs[0]["allocation"] = output.data_ptr() + allocations.append(output.data_ptr()) + self.context.execute_v2(allocations) + out_tensor = pb_utils.Tensor.from_dlpack("features", output.cpu()) + responses.append(pb_utils.InferenceResponse([out_tensor])) + self.inputs[0]["allocation"] = None + self.outputs[0]["allocation"] = None + return responses diff --git a/Popular_Models_Guide/Llava1.5/model_repository/vision_encoder/config.pbtxt b/Popular_Models_Guide/Llava1.5/model_repository/vision_encoder/config.pbtxt new file mode 100644 index 00000000..0a1e4fb0 --- /dev/null +++ b/Popular_Models_Guide/Llava1.5/model_repository/vision_encoder/config.pbtxt @@ -0,0 +1,22 @@ +input [ + { + name: "image" + data_type: TYPE_FP16 + dims: [ -1, 3, 336, 336 ] + } +] +output [ + { + name: "features" + data_type: TYPE_FP16 + dims: [ 576 , -1] + } +] + +instance_group [ + { + count: 1 + kind: KIND_GPU + gpus: [ 0 ] + } +] diff --git a/Popular_Models_Guide/Llava1.5/multi_modal_client.py b/Popular_Models_Guide/Llava1.5/multi_modal_client.py new file mode 100644 index 00000000..86bc94d9 --- /dev/null +++ b/Popular_Models_Guide/Llava1.5/multi_modal_client.py @@ -0,0 +1,198 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +import queue +import warnings +from functools import partial + +import numpy as np +import tritonclient.grpc as grpcclient +from PIL import Image +from transformers import AutoProcessor +from tritonclient.utils import * + +warnings.filterwarnings("ignore") + + +class UserData: + def __init__(self): + self._completed_requests = queue.Queue() + + +def callback(user_data, result, error): + if error: + user_data._completed_requests.put(error) + else: + user_data._completed_requests.put(result) + + +def prepare_tensor(name, input): + t = grpcclient.InferInput(name, input.shape, np_to_triton_dtype(input.dtype)) + t.set_data_from_numpy(input) + return t + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_name", + type=str, + required=False, + default="llava-1.5", + help="Model name", + ) + parser.add_argument( + "--image_url", + type=str, + required=False, + default="http://images.cocodataset.org/test2017/000000557146.jpg", + help="Image URL. Default is:\ + http://images.cocodataset.org/test2017/000000557146.jpg", + ) + parser.add_argument( + "--prompt", + type=str, + required=False, + default="What is shown on the picture?", + help="Prompt. Default is:\ + What is shown on the picture?", + ) + parser.add_argument( + "--max-tokens", + type=int, + required=False, + default=50, + help="Max amount of tokens in the output. Default is 50.", + ) + parser.add_argument( + "--temperature", + type=float, + required=False, + default=0.9, + help="Temperatue. Default is 0.9.", + ) + parser.add_argument( + "--top-k", + type=int, + required=False, + default=1, + help="Top K. Default is 1.", + ) + parser.add_argument( + "--frequency-penalty", + type=float, + required=False, + default=0.9, + help="Frequency penalty. Default is 0.9.", + ) + parser.add_argument( + "--seed", + type=int, + required=False, + default=10, + help="Random seed. Default is 10.", + ) + parser.add_argument( + "--url", + type=str, + required=False, + default="localhost:8000", + help="Inference server URL. Default is localhost:8000.", + ) + parser.add_argument( + "-v", + "--verbose", + action="store_true", + required=False, + default=False, + help="Enable verbose output", + ) + + args = parser.parse_args() + user_data = UserData() + input_image = "image" + input_prompt = "prompt" + output_name = "features" + + input_text = "USER: \nQuestion:" + args.prompt + " Answer:" + image_url = np.array([args.image_url.encode("utf-8")], dtype=np.object_) + prompt_data = np.array([input_text.encode("utf-8")], dtype=np.object_) + max_tokens = np.array([args.max_tokens], dtype=np.int32) + temperature = np.array([args.temperature], dtype=np.float32) + top_k = np.array([args.top_k], dtype=np.int32) + frequency_penalty = np.array([args.frequency_penalty], dtype=np.float32) + seed = np.array([args.seed], dtype=np.uint64) + inputs = [ + prepare_tensor("image", image_url), + prepare_tensor("prompt", prompt_data), + prepare_tensor("max_tokens", max_tokens), + prepare_tensor("temperature", temperature), + prepare_tensor("top_k", top_k), + prepare_tensor("frequency_penalty", frequency_penalty), + prepare_tensor("seed", seed), + ] + outputs = [] + for output_name in [ + "text", + "finish_reason", + "prompt_tokens", + "completion_tokens", + "total_tokens", + ]: + outputs.append(grpcclient.InferRequestedOutput(output_name)) + output_text = "" + + with grpcclient.InferenceServerClient(url="localhost:8001") as client: + client.start_stream(partial(callback, user_data)) + req = client.async_stream_infer( + args.model_name, + inputs, + outputs=outputs, + ) + expected_responses = 1 + processed_count = 0 + while processed_count < expected_responses: + try: + result = user_data._completed_requests.get() + print("Got completed request", flush=True) + except Exception: + break + + if type(result) == InferenceServerException: + if result.status() == "StatusCode.CANCELLED": + print("Request is cancelled") + else: + print("Received an error from server:") + print(result) + raise result + else: + output_text = result.as_numpy("text") + print(output_text[0].decode("utf-8")) + + processed_count = processed_count + 1 + + client.stop_stream() diff --git a/README.md b/README.md index f328ff07..bc286518 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,7 @@ The table below contains some popular models that are supported in our tutorials | [Llama-2-7B](https://huggingface.co/meta-llama/Llama-2-7b-hf/tree/main) |[TensorRT-LLM Tutorial](Popular_Models_Guide/Llama2/trtllm_guide.md) | | [Persimmon-8B](https://www.adept.ai/blog/persimmon-8b) | [HuggingFace Transformers Tutorial](https://github.com/triton-inference-server/tutorials/tree/main/Quick_Deploy/HuggingFaceTransformers) | [Falcon-7B](https://huggingface.co/tiiuae/falcon-7b) |[HuggingFace Transformers Tutorial](https://github.com/triton-inference-server/tutorials/tree/main/Quick_Deploy/HuggingFaceTransformers) | +[LLaVA-v1.5-7B](https://huggingface.co/llava-hf/llava-1.5-7b-hf) | [TensorRT-LLM Tutorial](tutorials/Popular_Models_Guide/Llava1.5/llava_trtllm_guide.md) **Note:** This is not an exhausitive list of what Triton supports, just what is included in the tutorials. From b1fd4bca8395a331e8ac47d0658717bbb4493aac Mon Sep 17 00:00:00 2001 From: oandreeva-nv Date: Mon, 17 Jun 2024 19:57:26 -0700 Subject: [PATCH 2/8] Follow up --- Popular_Models_Guide/Llava1.5/llava_trtllm_guide.md | 2 +- .../Llava1.5/model_repository/tensorrt_llm/1/.gitkeep | 0 Popular_Models_Guide/Llava1.5/multi_modal_client.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) create mode 100644 Popular_Models_Guide/Llava1.5/model_repository/tensorrt_llm/1/.gitkeep diff --git a/Popular_Models_Guide/Llava1.5/llava_trtllm_guide.md b/Popular_Models_Guide/Llava1.5/llava_trtllm_guide.md index f5aa0b75..2b49a2d7 100644 --- a/Popular_Models_Guide/Llava1.5/llava_trtllm_guide.md +++ b/Popular_Models_Guide/Llava1.5/llava_trtllm_guide.md @@ -1,5 +1,5 @@