diff --git a/Quick_Deploy/vLLM/Dockerfile b/Quick_Deploy/vLLM/Dockerfile deleted file mode 100644 index 1876189a..00000000 --- a/Quick_Deploy/vLLM/Dockerfile +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * Neither the name of NVIDIA CORPORATION nor the names of its -# contributors may be used to endorse or promote products derived -# from this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY -# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR -# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY -# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -FROM nvcr.io/nvidia/tritonserver:23.09-py3 -RUN pip install vllm==0.2.1.post1 diff --git a/Quick_Deploy/vLLM/README.md b/Quick_Deploy/vLLM/README.md index a0fb6b26..ee48f2af 100644 --- a/Quick_Deploy/vLLM/README.md +++ b/Quick_Deploy/vLLM/README.md @@ -31,38 +31,43 @@ The following tutorial demonstrates how to deploy a simple [facebook/opt-125m](https://huggingface.co/facebook/opt-125m) model on -Triton Inference Server using Triton's [Python backend](https://github.com/triton-inference-server/python_backend) and the -[vLLM](https://github.com/vllm-project/vllm) library. +Triton Inference Server using the Triton's +[Python-based](https://github.com/triton-inference-server/backend/blob/main/docs/python_based_backends.md#python-based-backends) +[vLLM](https://github.com/triton-inference-server/vllm_backend/tree/main) +backend. *NOTE*: The tutorial is intended to be a reference example only and has [known limitations](#limitations). -## Step 1: Build a Triton Container Image with vLLM +## Step 1: Prepare your model repository -We will build a new container image derived from tritonserver:23.08-py3 with vLLM. +To use Triton, we need to build a model repository. For this tutorial we will +use the model repository, provided in the [samples](https://github.com/triton-inference-server/vllm_backend/tree/main/samples) +folder of the [vllm_backend](https://github.com/triton-inference-server/vllm_backend/tree/main) +repository. +The following set of commands will create a `model_repository/vllm_model/1` +directory and copy 2 files: +[`model.json`](https://github.com/triton-inference-server/vllm_backend/blob/main/samples/model_repository/vllm_model/1/model.json) +and +[`config.pbtxt`](https://github.com/triton-inference-server/vllm_backend/blob/main/samples/model_repository/vllm_model/config.pbtxt), +required to serve the [facebook/opt-125m](https://huggingface.co/facebook/opt-125m) model. ``` -docker build -t tritonserver_vllm . +mkdir -p model_repository/vllm_model/1 +wget -P model_repository/vllm_model/1 https://raw.githubusercontent.com/triton-inference-server/vllm_backend/main/samples/model_repository/vllm_model/1/model.json +wget -P model_repository/vllm_model/ https://raw.githubusercontent.com/triton-inference-server/vllm_backend/main/samples/model_repository/vllm_model/config.pbtxt ``` -The above command should create the tritonserver_vllm image with vLLM and all of its dependencies. - - -## Step 2: Start Triton Inference Server - -A sample model repository for deploying `facebook/opt-125m` using vLLM in Triton is -included with this demo as `model_repository` directory. The model repository should look like this: ``` model_repository/ -`-- vllm - |-- 1 - | `-- model.py - |-- config.pbtxt - |-- vllm_engine_args.json +└── vllm_model + ├── 1 + │   └── model.json + └── config.pbtxt ``` -The content of `vllm_engine_args.json` is: +The content of `model.json` is: ```json { @@ -71,53 +76,116 @@ The content of `vllm_engine_args.json` is: "gpu_memory_utilization": 0.5 } ``` + This file can be modified to provide further settings to the vLLM engine. See vLLM [AsyncEngineArgs](https://github.com/vllm-project/vllm/blob/32b6816e556f69f1672085a6267e8516bcb8e622/vllm/engine/arg_utils.py#L165) and [EngineArgs](https://github.com/vllm-project/vllm/blob/32b6816e556f69f1672085a6267e8516bcb8e622/vllm/engine/arg_utils.py#L11) -for supported key-value pairs. +for supported key-value pairs. Inflight batching and paged attention is handled +by the vLLM engine. -For multi-GPU support, EngineArgs like `tensor_parallel_size` can be specified in [`vllm_engine_args.json`](model_repository/vllm/vllm_engine_args.json). +For multi-GPU support, EngineArgs like `tensor_parallel_size` can be specified +in [`model.json`](https://github.com/triton-inference-server/vllm_backend/blob/main/samples/model_repository/vllm_model/1/model.json). *Note*: vLLM greedily consume up to 90% of the GPU's memory under default settings. This tutorial updates this behavior by setting `gpu_memory_utilization` to 50%. You can tweak this behavior using fields like `gpu_memory_utilization` and other settings -in [`vllm_engine_args.json`](model_repository/vllm/vllm_engine_args.json). +in [`model.json`](https://github.com/triton-inference-server/vllm_backend/blob/main/samples/model_repository/vllm_model/1/model.json). -Read through the documentation in [`model.py`](model_repository/vllm/1/model.py) to understand how -to configure this sample for your use-case. +Read through the documentation in [`model.py`](https://github.com/triton-inference-server/vllm_backend/blob/main/src/model.py) +to understand how to configure this sample for your use-case. -Run the following commands to start the server container: +## Step 2: Launch Triton Inference Server +Once you have the model repository setup, it is time to launch the triton server. +Starting with 23.10 release, a dedicated container with vLLM pre-installed +is available on [NGC.](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver) +To use this container to launch Triton, you can use the docker command below. ``` -docker run --gpus all -it --rm -p 8001:8001 --shm-size=1G --ulimit memlock=-1 --ulimit stack=67108864 -v ${PWD}:/work -w /work tritonserver_vllm tritonserver --model-store ./model_repository +docker run --gpus all -it --net=host --rm -p 8001:8001 --shm-size=1G --ulimit memlock=-1 --ulimit stack=67108864 -v ${PWD}:/work -w /work nvcr.io/nvidia/tritonserver:-vllm-python-py3 tritonserver --model-store ./model_repository ``` +Throughout the tutorial, \ is the version of Triton +that you want to use. Please note, that Triton's vLLM +container was first published in 23.10 release, so any prior version +will not work. -Upon successful start of the server, you should see the following at the end of the output. +After you start Triton you will see output on the console showing +the server starting up and loading the model. When you see output +like the following, Triton is ready to accept inference requests. ``` -I0901 23:39:08.729123 1 grpc_server.cc:2451] Started GRPCInferenceService at 0.0.0.0:8001 -I0901 23:39:08.729640 1 http_server.cc:3558] Started HTTPService at 0.0.0.0:8000 -I0901 23:39:08.772522 1 http_server.cc:187] Started Metrics Service at 0.0.0.0:8002 +I1030 22:33:28.291908 1 grpc_server.cc:2513] Started GRPCInferenceService at 0.0.0.0:8001 +I1030 22:33:28.292879 1 http_server.cc:4497] Started HTTPService at 0.0.0.0:8000 +I1030 22:33:28.335154 1 http_server.cc:270] Started Metrics Service at 0.0.0.0:8002 ``` -## Step 3: Use a Triton Client to Query the Server +## Step 3: Use a Triton Client to Send Your First Inference Request -We will run the client within Triton's SDK container to issue multiple async requests using the +In this tutorial, we will show how to send an inference request to the +[facebook/opt-125m](https://huggingface.co/facebook/opt-125m) model in 2 ways: + +* [Using the generate endpoint](#using-generate-endpoint) +* [Using the gRPC asyncio client](#using-grpc-asyncio-client) + +### Using the Generate Endpoint +After you start Triton with the sample model_repository, +you can quickly run your first inference request with the +[generate](https://github.com/triton-inference-server/server/blob/main/docs/protocol/extension_generate.md) +endpoint. + +Start Triton's SDK container with the following command: +``` +docker run -it --net=host -v ${PWD}:/workspace/ nvcr.io/nvidia/tritonserver:-py3-sdk bash +``` + +Now, let's send an inference request: +``` +curl -X POST localhost:8000/v2/models/vllm_model/generate -d '{"text_input": "What is Triton Inference Server?", "parameters": {"stream": false, "temperature": 0}}' +``` + +Upon success, you should see a response from the server like this one: +``` +{"model_name":"vllm_model","model_version":"1","text_output":"What is Triton Inference Server?\n\nTriton Inference Server is a server that is used by many"} +``` + +### Using the gRPC Asyncio Client +Now, we will see how to run the client within Triton's SDK container +to issue multiple async requests using the [gRPC asyncio client](https://github.com/triton-inference-server/client/blob/main/src/python/library/tritonclient/grpc/aio/__init__.py) library. +This method requires a +[client.py](https://github.com/triton-inference-server/vllm_backend/blob/main/samples/client.py) +script and a set of +[prompts](https://github.com/triton-inference-server/vllm_backend/blob/main/samples/prompts.txt), +which are provided in the +[samples](https://github.com/triton-inference-server/vllm_backend/tree/main/samples) +folder of +[vllm_backend](https://github.com/triton-inference-server/vllm_backend/tree/main) +repository. + +Use the following command to download `client.py` and `prompts.txt` to your +current directory: ``` -docker run -it --net=host -v ${PWD}:/workspace/ nvcr.io/nvidia/tritonserver:23.08-py3-sdk bash +wget https://raw.githubusercontent.com/triton-inference-server/vllm_backend/main/samples/client.py +wget https://raw.githubusercontent.com/triton-inference-server/vllm_backend/main/samples/prompts.txt ``` -Within the container, run [`client.py`](client.py) with: +Now, we are ready to start Triton's SDK container: +``` +docker run -it --net=host -v ${PWD}:/workspace/ nvcr.io/nvidia/tritonserver:-py3-sdk bash +``` +Within the container, run +[`client.py`](https://github.com/triton-inference-server/vllm_backend/blob/main/samples/client.py) +with: ``` python3 client.py ``` -The client reads prompts from the [prompts.txt](prompts.txt) file, sends them to Triton server for +The client reads prompts from the +[prompts.txt](https://github.com/triton-inference-server/vllm_backend/blob/main/samples/prompts.txt) +file, sends them to Triton server for inference, and stores the results into a file named `results.txt` by default. The output of the client should look like below: @@ -128,15 +196,22 @@ Storing results into `results.txt`... PASS: vLLM example ``` -You can inspect the contents of the `results.txt` for the response from the server. The `--iterations` -flag can be used with the client to increase the load on the server by looping through the list of -provided prompts in [`prompts.txt`](prompts.txt). +You can inspect the contents of the `results.txt` for the response +from the server. The `--iterations` flag can be used with the client +to increase the load on the server by looping through the list of +provided prompts in +[prompts.txt](https://github.com/triton-inference-server/vllm_backend/blob/main/samples/prompts.txt). -When you run the client in verbose mode with the `--verbose` flag, the client will print more details -about the request/response transactions. +When you run the client in verbose mode with the `--verbose` flag, +the client will print more details about the request/response transactions. ## Limitations - We use decoupled streaming protocol even if there is exactly 1 response for each request. - The asyncio implementation is exposed to model.py. - Does not support providing specific subset of GPUs to be used. +- If you are running multiple instances of Triton server with +a Python-based vLLM backend, you need to specify a different +`shm-region-prefix-name` for each server. See +[here](https://github.com/triton-inference-server/python_backend#running-multiple-instances-of-triton-server) +for more information. diff --git a/Quick_Deploy/vLLM/client.py b/Quick_Deploy/vLLM/client.py deleted file mode 100644 index db1aa2db..00000000 --- a/Quick_Deploy/vLLM/client.py +++ /dev/null @@ -1,208 +0,0 @@ -# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * Neither the name of NVIDIA CORPORATION nor the names of its -# contributors may be used to endorse or promote products derived -# from this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY -# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR -# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY -# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -import argparse -import asyncio -import json -import queue -import sys -from os import system - -import numpy as np -import tritonclient.grpc.aio as grpcclient -from tritonclient.utils import * - - -def create_request( - prompt, - stream, - request_id, - sampling_parameters, - model_name, - send_parameters_as_tensor=True, -): - inputs = [] - prompt_data = np.array([prompt.encode("utf-8")], dtype=np.object_) - try: - inputs.append(grpcclient.InferInput("PROMPT", [1], "BYTES")) - inputs[-1].set_data_from_numpy(prompt_data) - except Exception as e: - print(f"Encountered an error {e}") - - stream_data = np.array([stream], dtype=bool) - inputs.append(grpcclient.InferInput("STREAM", [1], "BOOL")) - inputs[-1].set_data_from_numpy(stream_data) - - # Request parameters are not yet supported via BLS. Provide an - # optional mechanism to send serialized parameters as an input - # tensor until support is added - - if send_parameters_as_tensor: - sampling_parameters_data = np.array( - [json.dumps(sampling_parameters).encode("utf-8")], dtype=np.object_ - ) - inputs.append(grpcclient.InferInput("SAMPLING_PARAMETERS", [1], "BYTES")) - inputs[-1].set_data_from_numpy(sampling_parameters_data) - - # Add requested outputs - outputs = [] - outputs.append(grpcclient.InferRequestedOutput("TEXT")) - - # Issue the asynchronous sequence inference. - return { - "model_name": model_name, - "inputs": inputs, - "outputs": outputs, - "request_id": str(request_id), - "parameters": sampling_parameters, - } - - -async def main(FLAGS): - model_name = "vllm" - sampling_parameters = {"temperature": "0.1", "top_p": "0.95"} - stream = FLAGS.streaming_mode - with open(FLAGS.input_prompts, "r") as file: - print(f"Loading inputs from `{FLAGS.input_prompts}`...") - prompts = file.readlines() - - results_dict = {} - - async with grpcclient.InferenceServerClient( - url=FLAGS.url, verbose=FLAGS.verbose - ) as triton_client: - # Request iterator that yields the next request - async def async_request_iterator(): - try: - for iter in range(FLAGS.iterations): - for i, prompt in enumerate(prompts): - prompt_id = FLAGS.offset + (len(prompts) * iter) + i - results_dict[str(prompt_id)] = [] - yield create_request( - prompt, stream, prompt_id, sampling_parameters, model_name - ) - except Exception as error: - print(f"caught error in request iterator: {error}") - - try: - # Start streaming - response_iterator = triton_client.stream_infer( - inputs_iterator=async_request_iterator(), - stream_timeout=FLAGS.stream_timeout, - ) - # Read response from the stream - async for response in response_iterator: - result, error = response - if error: - print(f"Encountered error while processing: {error}") - else: - output = result.as_numpy("TEXT") - for i in output: - results_dict[result.get_response().id].append(i) - - except InferenceServerException as error: - print(error) - sys.exit(1) - - with open(FLAGS.results_file, "w") as file: - for id in results_dict.keys(): - for result in results_dict[id]: - file.write(result.decode("utf-8")) - file.write("\n") - file.write("\n=========\n\n") - print(f"Storing results into `{FLAGS.results_file}`...") - - if FLAGS.verbose: - print(f"\nContents of `{FLAGS.results_file}` ===>") - system(f"cat {FLAGS.results_file}") - - print("PASS: vLLM example") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "-v", - "--verbose", - action="store_true", - required=False, - default=False, - help="Enable verbose output", - ) - parser.add_argument( - "-u", - "--url", - type=str, - required=False, - default="localhost:8001", - help="Inference server URL and it gRPC port. Default is localhost:8001.", - ) - parser.add_argument( - "-t", - "--stream-timeout", - type=float, - required=False, - default=None, - help="Stream timeout in seconds. Default is None.", - ) - parser.add_argument( - "--offset", - type=int, - required=False, - default=0, - help="Add offset to request IDs used", - ) - parser.add_argument( - "--input-prompts", - type=str, - required=False, - default="prompts.txt", - help="Text file with input prompts", - ) - parser.add_argument( - "--results-file", - type=str, - required=False, - default="results.txt", - help="The file with output results", - ) - parser.add_argument( - "--iterations", - type=int, - required=False, - default=1, - help="Number of iterations through the prompts file", - ) - parser.add_argument( - "-s", - "--streaming-mode", - action="store_true", - required=False, - default=False, - help="Enable streaming mode", - ) - FLAGS = parser.parse_args() - asyncio.run(main(FLAGS)) diff --git a/Quick_Deploy/vLLM/model_repository/vllm/1/model.py b/Quick_Deploy/vLLM/model_repository/vllm/1/model.py deleted file mode 100644 index d70cad57..00000000 --- a/Quick_Deploy/vLLM/model_repository/vllm/1/model.py +++ /dev/null @@ -1,249 +0,0 @@ -# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * Neither the name of NVIDIA CORPORATION nor the names of its -# contributors may be used to endorse or promote products derived -# from this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY -# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR -# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY -# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -import asyncio -import json -import os -import threading -from typing import AsyncGenerator - -import numpy as np -import triton_python_backend_utils as pb_utils -from vllm import SamplingParams -from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.utils import random_uuid - -_VLLM_ENGINE_ARGS_FILENAME = "vllm_engine_args.json" - - -class TritonPythonModel: - def initialize(self, args): - self.logger = pb_utils.Logger - self.model_config = json.loads(args["model_config"]) - - # assert are in decoupled mode. Currently, Triton needs to use - # decoupled policy for asynchronously forwarding requests to - # vLLM engine. - self.using_decoupled = pb_utils.using_decoupled_model_transaction_policy( - self.model_config - ) - assert ( - self.using_decoupled - ), "vLLM Triton backend must be configured to use decoupled model transaction policy" - - engine_args_filepath = os.path.join( - args["model_repository"], _VLLM_ENGINE_ARGS_FILENAME - ) - assert os.path.isfile( - engine_args_filepath - ), f"'{_VLLM_ENGINE_ARGS_FILENAME}' containing vllm engine args must be provided in '{args['model_repository']}'" - with open(engine_args_filepath) as file: - vllm_engine_config = json.load(file) - - # Create an AsyncLLMEngine from the config from JSON - self.llm_engine = AsyncLLMEngine.from_engine_args( - AsyncEngineArgs(**vllm_engine_config) - ) - - output_config = pb_utils.get_output_config_by_name(self.model_config, "TEXT") - self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"]) - - # Counter to keep track of ongoing request counts - self.ongoing_request_count = 0 - - # Starting asyncio event loop to process the received requests asynchronously. - self._loop = asyncio.get_event_loop() - self._loop_thread = threading.Thread( - target=self.engine_loop, args=(self._loop,) - ) - self._shutdown_event = asyncio.Event() - self._loop_thread.start() - - def create_task(self, coro): - """ - Creates a task on the engine's event loop which is running on a separate thread. - """ - assert ( - self._shutdown_event.is_set() is False - ), "Cannot create tasks after shutdown has been requested" - - return asyncio.run_coroutine_threadsafe(coro, self._loop) - - def engine_loop(self, loop): - """ - Runs the engine's event loop on a separate thread. - """ - asyncio.set_event_loop(loop) - self._loop.run_until_complete(self.await_shutdown()) - - async def await_shutdown(self): - """ - Primary coroutine running on the engine event loop. This coroutine is responsible for - keeping the engine alive until a shutdown is requested. - """ - # first await the shutdown signal - while self._shutdown_event.is_set() is False: - await asyncio.sleep(5) - - # Wait for the ongoing_requests - while self.ongoing_request_count > 0: - self.logger.log_info( - "Awaiting remaining {} requests".format(self.ongoing_request_count) - ) - await asyncio.sleep(5) - - self.logger.log_info("Shutdown complete") - - def get_sampling_params_dict(self, params_json): - """ - This functions parses the dictionary values into their - expected format. - """ - - params_dict = json.loads(params_json) - - # Special parsing for the supported sampling parameters - bool_keys = ["ignore_eos", "skip_special_tokens", "use_beam_search"] - for k in bool_keys: - if k in params_dict: - params_dict[k] = bool(params_dict[k]) - - float_keys = [ - "frequency_penalty", - "length_penalty", - "presence_penalty", - "temperature", - "top_p", - ] - for k in float_keys: - if k in params_dict: - params_dict[k] = float(params_dict[k]) - - int_keys = ["best_of", "max_tokens", "n", "top_k"] - for k in int_keys: - if k in params_dict: - params_dict[k] = int(params_dict[k]) - - return params_dict - - def create_response(self, vllm_output): - """ - Parses the output from the vLLM engine into Triton - response. - """ - prompt = vllm_output.prompt - text_outputs = [ - (prompt + output.text).encode("utf-8") for output in vllm_output.outputs - ] - triton_output_tensor = pb_utils.Tensor( - "TEXT", np.asarray(text_outputs, dtype=self.output_dtype) - ) - return pb_utils.InferenceResponse(output_tensors=[triton_output_tensor]) - - async def generate(self, request): - """ - Forwards single request to LLM engine and returns responses. - """ - response_sender = request.get_response_sender() - self.ongoing_request_count += 1 - try: - request_id = random_uuid() - - prompt = pb_utils.get_input_tensor_by_name(request, "PROMPT").as_numpy()[0] - if isinstance(prompt, bytes): - prompt = prompt.decode("utf-8") - - # stream is an optional input - stream = False - stream_input_tensor = pb_utils.get_input_tensor_by_name(request, "STREAM") - if stream_input_tensor: - stream = stream_input_tensor.as_numpy()[0] - - # Request parameters are not yet supported via - # BLS. Provide an optional mechanism to receive serialized - # parameters as an input tensor until support is added - parameters_input_tensor = pb_utils.get_input_tensor_by_name( - request, "SAMPLING_PARAMETERS" - ) - if parameters_input_tensor: - parameters = parameters_input_tensor.as_numpy()[0].decode("utf-8") - else: - parameters = request.parameters() - - sampling_params_dict = self.get_sampling_params_dict(parameters) - sampling_params = SamplingParams(**sampling_params_dict) - - last_output = None - async for output in self.llm_engine.generate( - prompt, sampling_params, request_id - ): - if stream: - response_sender.send(self.create_response(output)) - else: - last_output = output - - if not stream: - response_sender.send(self.create_response(last_output)) - - except Exception as e: - self.logger.log_info(f"Error generating stream: {e}") - error = pb_utils.TritonError(f"Error generating stream: {e}") - triton_output_tensor = pb_utils.Tensor( - "TEXT", np.asarray(["N/A"], dtype=self.output_dtype) - ) - response = pb_utils.InferenceResponse( - output_tensors=[triton_output_tensor], error=error - ) - response_sender.send(response) - raise e - finally: - response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) - self.ongoing_request_count -= 1 - - def execute(self, requests): - """ - Triton core issues requests to the backend via this method. - - When this method returns, new requests can be issued to the backend. Blocking - this function would prevent the backend from pulling additional requests from - Triton into the vLLM engine. This can be done if the kv cache within vLLM engine - is too loaded. - We are pushing all the requests on vllm and let it handle the full traffic. - """ - for request in requests: - self.create_task(self.generate(request)) - return None - - def finalize(self): - """ - Triton virtual method; called when the model is unloaded. - """ - self.logger.log_info("Issuing finalize to vllm backend") - self._shutdown_event.set() - if self._loop_thread is not None: - self._loop_thread.join() - self._loop_thread = None diff --git a/Quick_Deploy/vLLM/model_repository/vllm/config.pbtxt b/Quick_Deploy/vLLM/model_repository/vllm/config.pbtxt deleted file mode 100644 index 243491a6..00000000 --- a/Quick_Deploy/vLLM/model_repository/vllm/config.pbtxt +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * Neither the name of NVIDIA CORPORATION nor the names of its -# contributors may be used to endorse or promote products derived -# from this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY -# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR -# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY -# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -name: "vllm" -backend: "python" - -# Disabling batching in Triton, let vLLM handle the batching on its own. -max_batch_size: 0 - -# We need to use decoupled transaction policy for saturating -# vLLM engine for max throughtput. -# TODO [DLIS:5233]: Allow asynchronous execution to lift this -# restriction for cases there is exactly a single response to -# a single request. -model_transaction_policy { - decoupled: True -} - -input [ - { - name: "PROMPT" - data_type: TYPE_STRING - dims: [ 1 ] - }, - { - name: "STREAM" - data_type: TYPE_BOOL - dims: [ 1 ] - optional: true - }, - { - name: "SAMPLING_PARAMETERS" - data_type: TYPE_STRING - dims: [ 1 ] - optional: true - } -] - -output [ - { - name: "TEXT" - data_type: TYPE_STRING - dims: [ -1 ] - } -] - -# The usage of device is deferred to the vLLM engine -instance_group [ - { - count: 1 - kind: KIND_MODEL - } -] diff --git a/Quick_Deploy/vLLM/model_repository/vllm/vllm_engine_args.json b/Quick_Deploy/vLLM/model_repository/vllm/vllm_engine_args.json deleted file mode 100644 index e610c3cb..00000000 --- a/Quick_Deploy/vLLM/model_repository/vllm/vllm_engine_args.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "model":"facebook/opt-125m", - "disable_log_requests": "true", - "gpu_memory_utilization": 0.5 -} diff --git a/Quick_Deploy/vLLM/prompts.txt b/Quick_Deploy/vLLM/prompts.txt deleted file mode 100644 index 133800ec..00000000 --- a/Quick_Deploy/vLLM/prompts.txt +++ /dev/null @@ -1,4 +0,0 @@ -Hello, my name is -The most dangerous animal is -The capital of France is -The future of AI is