diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index b56b7199..4f9c483b 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -42,7 +42,7 @@ repos:
rev: 4.0.1
hooks:
- id: flake8
- args: ['--ignore=E,F403,F405,F541,F841,W', '--select=E9,F,W6', '--per-file-ignores=__init__.py:F401,mii/grpc_related/proto/modelresponse_pb2.py:F821,F401']
+ args: ['--ignore=E,F403,F405,F541,F841,W', '--select=E9,F,W6', '--per-file-ignores=__init__.py:F401,mii/grpc_related/proto/modelresponse_pb2.py:F821,F401,mii/legacy/grpc_related/proto/legacymodelresponse_pb2.py:F821,F401']
- repo: local
hooks:
diff --git a/README.md b/README.md
index 93c2f50f..a388c989 100644
--- a/README.md
+++ b/README.md
@@ -11,6 +11,7 @@
## Latest News
+* [2023/11] [DeepSpeed-FastGen: High-throughput Text Generation for LLMs via MII and DeepSpeed-Inference](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen)
* [2022/11] [Stable Diffusion Image Generation under 1 second w. DeepSpeed MII](examples/benchmark/txt2img)
* [2022/10] [Announcing DeepSpeed Model Implementations for Inference (MII)](https://www.deepspeed.ai/2022/10/10/mii.html)
@@ -18,323 +19,219 @@
-- [DeepSpeed MII](#deepspeed-model-implementations-for-inference)
+- [DeepSpeed-MII](#deepspeed-model-implementations-for-inference)
+- [Key Technologies](#key-technologies)
- [How does MII work?](#how-does-mii-work)
-- [Supported Models and Tasks](#supported-models-and-tasks)
-- [MII-Public and MII-Azure](#mii-public-and-mii-azure)
-- [Getting started with MII](#getting-started-with-mii)
-- [Quantifying Latency and Cost Reduction](#quantifying-latency-and-cost-reduction)
-- [Community Tutorials](#community-tutorials)
+- [Supported Models](#supported-models)
+- [Getting Started](#getting-started-with-mii)
-# DeepSpeed Model Implementations for Inference
+# DeepSpeed Model Implementations for Inference (MII)
-![hero dark](docs/images/hero-dark.png#gh-dark-mode-only)
-![hero light](docs/images/hero-transparent.png#gh-light-mode-only)
+Introducing MII, an open-source Python library designed by DeepSpeed to democratize powerful model inference with a focus on high-throughput, low latency, and cost-effectiveness.
-The Deep Learning (DL) open-source community has seen tremendous growth in the last few months. Incredibly powerful text generation models such as the Bloom 176B, or image generation model such as Stable Diffusion are now available to anyone with access to a handful or even a single GPU through platforms such as Hugging Face. While open sourcing has democratized access to AI capabilities, their application is still restricted by two critical factors: inference latency and cost.
+* MII v0.1 introduces several features such as blocked KV-caching, continuous batching, Dynamic SplitFuse, tensor parallelism, and high-performance CUDA kernels to support fast high throughput text-generation for LLMs such as Llama-2-70B. MII delivers up to 2.3 times higher effective throughput compared to leading systems such as vLLM. For detailed performance results please see our [DeepSpeed-FastGen blog](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen).
-There has been significant progress in system optimizations for DL model inference that can drastically reduce both latency and cost, but those are not easily accessible. A main reason for this limited accessibility is that the DL model inference landscape is diverse with models varying in size, architecture, system performance characteristics, hardware requirements, etc. Identifying the appropriate set of system optimizations applicable to a given model and applying them correctly is often beyond the scope of most data scientists, making low latency and low-cost inference mostly inaccessible.
+
+
+
+
+
+* We first [announced MII](https://www.deepspeed.ai/2022/10/10/mii.html) in 2022, which covers all prior releases up to v0.0.9. In addition to language models, we also support accelerating [text2image models like Stable Diffusion](examples/benchmark/txt2img). For more details on our previous releases please see our [legacy APIs](mii/legacy/).
+
+# Key Technologies
+
+## MII for High-Throughput Text Generation
+
+MII provides accelerated text-generation inference through the use of four key technologies:
+
+* Blocked KV Caching
+* Continuous Batching
+* Dynamic SplitFuse
+* High Performance CUDA Kernels
+
+For a deeper dive into understanding these features please [refer to our blog](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen) which also includes a detailed performance analysis.
+
+## MII Legacy
-DeepSpeed-MII is a new open-source python library from DeepSpeed, aimed towards making low-latency, low-cost inference of powerful models not only feasible but also easily accessible.
+In the past, MII introduced several [key performance optimizations](https://www.deepspeed.ai/2022/10/10/mii.html#inference-optimizations-with-mii) for low-latency serving scenarios:
+
+* DeepFusion for Transformers
+* Multi-GPU Inference with Tensor-Slicing
+* ZeRO-Inference for Resource Constrained Systems
+* Compiler Optimizations
-* MII offers access to highly optimized implementation of thousands of widely used DL models.
-* MII supported models achieve significantly lower latency and cost compared to their original implementation. For example, MII reduces the latency of Big-Science Bloom 176B model by 5.7x, while reducing the cost by over 40x. Similarly, it reduces the latency and cost of deploying Stable Diffusion by 1.9x. See more details for [an exhaustive latency and cost analysis of MII](#quantifying-latency-and-cost-reduction).
-* To enable low latency/cost inference, MII leverages an extensive set of optimizations from DeepSpeed-Inference such as deepfusion for transformers, automated tensor-slicing for multi-GPU inference, on-the-fly quantization with ZeroQuant, and several others (see our [blog post](https://www.deepspeed.ai/2022/10/10/mii.html) for more details).
-* With state-of-the-art performance, MII supports low-cost deployment of these models both on-premises and on Azure via AML with just a few lines of codes.
# How does MII work?
-![Text Generation Models](docs/images/mii-arch.png)
+
+
+
+
-*Figure 1: MII Architecture, showing how MII automatically optimizes OSS models using DS-Inference before deploying them on-premises using GRPC, or on Microsoft Azure using AML Inference.*
-Under-the-hood MII is powered by [DeepSpeed-Inference](https://arxiv.org/abs/2207.00032). Based on model type, model size, batch size, and available hardware resources, MII automatically applies the appropriate set of system optimizations from DeepSpeed-Inference to minimize latency and maximize throughput. It does so by using one of many pre-specified model injection policies, that allows MII and DeepSpeed-Inference to identify the underlying PyTorch model architecture and replace it with an optimized implementation (see *Figure A*). In doing so, MII makes the expansive set of optimizations in DeepSpeed-Inference automatically available for thousands of popular models that it supports.
+Figure 1: MII architecture, showing how MII automatically optimizes OSS models using DS-Inference before deploying them. DeepSpeed-FastGen optimizations in the figure have been published in [our blog post](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen).
+Under-the-hood MII is powered by [DeepSpeed-Inference](https://github.com/microsoft/deepspeed). Based on the model architecture, model size, batch size, and available hardware resources, MII automatically applies the appropriate set of system optimizations to minimize latency and maximize throughput.
-# Supported Models and Tasks
-MII currently supports over 50,000 models across a range of tasks such as text-generation, question-answering, text-classification. The models accelerated by MII are available through multiple open-sourced model repositories such as Hugging Face, FairSeq, EluetherAI, etc. We support dense models based on Bert, Roberta or GPT architectures ranging from few hundred million parameters to tens of billions of parameters in size. We continue to expand the list with support for massive hundred billion plus parameter dense and sparse models coming soon.
+# Supported Models
-MII model support will continue to grow over time, check back for updates! Currently we support the following Hugging Face Transformers model families:
+MII currently supports over 13,000 models across three popular model architectures. We plan to add additional models in the near term, if there are specific model architectures you would like supported please [file an issue](https://github.com/microsoft/DeepSpeed-MII/issues) and let us know. All current models leverage Hugging Face in our backend to provide both the model weights and the model's corresponding tokenizer. For our current release we support the following model architectures:
model family | size range | ~model count
------ | ------ | ------
-[llama](https://huggingface.co/models?other=llama) | 7B - 65B | 1,500
-[bloom](https://huggingface.co/models?other=bloom) | 0.3B - 176B | 480
-[stable-diffusion](https://huggingface.co/models?other=stable-diffusion) | 1.1B | 3,700
-[opt](https://huggingface.co/models?other=opt) | 0.1B - 66B | 460
-[gpt\_neox](https://huggingface.co/models?other=gpt_neox) | 1.3B - 20B | 850
-[gptj](https://huggingface.co/models?other=gptj) | 1.4B - 6B | 420
-[gpt\_neo](https://huggingface.co/models?other=gpt_neo) | 0.1B - 2.7B | 700
-[gpt2](https://huggingface.co/models?other=gpt2) | 0.3B - 1.5B | 11,900
-[xlm-roberta](https://huggingface.co/models?other=xlm-roberta) | 0.1B - 0.3B | 4,100
-[roberta](https://huggingface.co/models?other=roberta) | 0.1B - 0.3B | 8,700
-[distilbert](https://huggingface.co/models?other=distilbert) | 0.1B - 0.3B | 4,700
-[bert](https://huggingface.co/models?other=bert) | 0.1B - 0.3B | 23,600
-
-
+[llama](https://huggingface.co/models?other=llama) | 7B - 65B | 11,000
+[llama-2](https://huggingface.co/models?other=llama-2) | 7B - 70B | 800
+[mistral](https://huggingface.co/models?other=mistral) | 7B | 1,100
+[opt](https://huggingface.co/models?other=opt) | 0.1B - 66B | 900
-
+## MII Legacy Model Support
-# MII-Public and MII-Azure
+MII Legacy APIs support over 50,000 different models including BERT, RoBERTa, Stable Diffusion, and other text-generation models like Bloom, GPT-J, etc. For a full list please see our [legacy supported models table](mii/legacy/#supported-models-and-tasks).
-MII can work with two variations of DeepSpeed-Inference. The first, referred to as ds-public, contains most of the DeepSpeed-Inference optimizations discussed here, is also available via our open-source DeepSpeed library. The second referred to as ds-azure, offers tighter integration with Azure, and is available via MII to all Microsoft Azure customers. We refer to MII running the two DeepSpeed-Inference variants as MII-Public and MII-Azure, respectively.
+# Getting Started with MII
-While both variants offers significant latency and cost reduction over the open-sourced PyTorch baseline, the latter, offers additional performance advantage for generation based workloads. The full latency and cost advantage comparison with PyTorch baseline and across these two versions is available [here](#quantifying-latency-and-cost-reduction).
+DeepSpeed-MII allows users to create non-persistent and persistent deployments for supported models in just a few lines of code.
-# Getting Started with MII
+- [Installation](#installation)
+- [Non-Persistent Pipeline](#non-persistent-pipeline)
+- [Persistent Deployment](#persistent-deployment)
## Installation
-We regularly push releases to [PyPI](https://pypi.org/project/deepspeed-mii/) and encourage users to install from there in most cases.
+The fasest way to get started is with our [PyPI release of DeepSpeed-MII](https://pypi.org/project/deepspeed-mii/) which means you can get started within minutes via:
```bash
pip install deepspeed-mii
```
-## Deploying MII-Public
+For ease of use and significant reduction in lengthy compile times that many projects require in this space we distribute a pre-compiled python wheel covering the majority of our custom kernels through a new library called [DeepSpeed-Kernels](https://github.com/microsoft/DeepSpeed-Kernels). We have found this library to be very portable across environments with NVIDIA GPUs with compute capabilities 8.0+ (Ampere+), CUDA 11.6+, and Ubuntu 20+. In most cases you shouldn't even need to know this library exists as it is a dependency of DeepSpeed-MII and will be isntalled with it. However, if for whatever reason you need to compile our kernels manually please see our [advanced installation docs](https://github.com/microsoft/DeepSpeed-Kernels#source).
-MII-Public can be deployed on-premises or on any cloud offering with just a few lines of code. MII creates a lightweight GRPC server to support this form of deployment and provides a GRPC inference endpoint for queries.
+## Non-Persistent Pipeline
-Several deployment and query examples can be found here: [examples/local](https://github.com/microsoft/DeepSpeed-MII/tree/main/examples/local)
+A non-persistent pipeline is a great way to try DeepSpeed-MII. Non-persistent pipelines are only around for the duration of the python script you are running. The full example for running a non-persistent pipeline deployment is only 4 lines. Give it a try!
-As an example here is a deployment of the [bigscience/bloom-560m](https://huggingface.co/bigscience/bloom-560m) model from Hugging Face:
-
-**Deployment**
```python
import mii
-mii_configs = {"tensor_parallel": 1, "dtype": "fp16"}
-mii.deploy(task="text-generation",
- model="bigscience/bloom-560m",
- deployment_name="bloom560m_deployment",
- mii_config=mii_configs)
+pipe = mii.pipeline("mistralai/Mistral-7B-v0.1")
+response = pipe("DeepSpeed is", max_new_tokens=128)
+print(response)
```
-This will deploy the model onto a single GPU and start the GRPC server that can later be queried.
+### Tensor parallelism
-**Query**
-```python
-import mii
-generator = mii.mii_query_handle("bloom560m_deployment")
-result = generator.query({"query": ["DeepSpeed is", "Seattle is"]}, do_sample=True, max_new_tokens=30)
-print(result)
-```
+Taking advantage of multi-GPU systems for greater performance is easy with MII. When run with the `deepspeed` launcher, tensor parallelism is automatically controlled by the `--num_gpus` flag:
-The only required key is `"query"`, all other items outside the dictionary will be passed to `generate` as kwargs. For Hugging Face provided models you can find all possible arguments in their [documentation for generate](https://huggingface.co/docs/transformers/v4.20.1/en/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate).
+```bash
+# Run on a single GPU
+deepspeed --num_gpus 1 mii-example.py
-**Shutdown Deployment**
-```python
-import mii
-mii.terminate("bloom560m_deployment")
+# Run on multiple GPUs
+deepspeed --num_gpus 2 mii-example.py
```
+### Pipeline Options
+While only the model name or path is required to stand up a non-persistent pipeline deployment, we offer customization options to our users:
-**Load balancing over multiple replicas**
-
-You can launch a load balancer and multiple replica of MII servers.
-When you specify a value for `replica_num`, `mii.deploy()` launches the load balancer server and `replica_num` number of replicas.
-Note that each replica consists of `tensor_parallel` server processes that are deployed on the same server.
+**`mii.pipeline()` Options**:
+- `model_name_or_path: str` Name or local path to a [HuggingFace](https://huggingface.co/) model.
+- `max_length: int` Sets the default maximum token length for the prompt + response.
+- `all_rank_output: bool` When enabled, all ranks return the generated text. By default, only rank 0 will return text.
-```python
-mii_configs = {
-...
- "tensor_parallel": tensor_parallel,
- "replica_num": replica_num,
- "hostfile": hostfile
-}
-mii.deploy(...
- mii_config=mii_configs,
- ...)
-```
+Users can also control the generation characteristics for individual prompts (i.e., when calling `pipe()`) with the following options:
-The client sends requests to the load balancer, which forwards them to the replicas, instead of sending requests to individual MII servers.
-Currently, the load balancer implements a simple round-robin algorithm.
-The load balancer acts as a simple proxy when `replica_num` is set to `1`.
+- `max_length: int` Sets the per-prompt maximum token length for prompt + response.
+- `max_new_tokens: int` Sets the maximum number of tokens generated in the response.
-`hostfile` is the path to hostfile used by DeepSpeed's launcher.
-When hostfile is not specified, DeepSpeed-MII uses the default path `/job/hostfile`, which is defined for DeepSpeed.
-See the [DeepSpeed's document](https://www.deepspeed.ai/getting-started/#resource-configuration-multi-node) for the details.
+## Persistent Deployment
-**RESTful API support**
-
-MII can enable users to call the inference service through RESTful APIs.
-By setting `enable_restful_api` to `True`, `mii.deploy()` launches a gateway that accepts RESTful API.
-The gateway can receive requests at `http://[HOST]:[PORT_FOR_RESTFUL_API]/mii/[DEPLOYMENT_NAME]`.
+A persistent deployment is ideal for use with long-running and production applications. The persistent model uses a lightweight GRPC server that can be queried by multiple clients at once. The full example for running a persistent model is only 5 lines. Give it a try!
```python
-mii_configs = {
-...
- "enable_restful_api": True,
- "restful_api_port": PORT_FOR_RESTFUL_API,
-...
-}
-mii.deploy(...
- deployment_name=DEPLOYMENT_NAME,
- mii_config=mii_configs)
+import mii
+client = mii.serve("mistralai/Mistral-7B-v0.1")
+response = client.generate("Deepspeed is", max_new_tokens=128)
+print(response.response)
```
-**Non-persistent Deployment**
-
-You can enable a non-persistent deployment which allows you to make queries without standing up a server. The non-persistent deployment acts as a simplified interface to DeepSpeed-inference for use cases that do not require creating a persistent model server process. Changing the `deployment_type` to `NON_PERSISTENT` in `mii.deploy(...)` will activate this option.
+If we want to generate text from other processes, we can do that too:
```python
-...
-mii.deploy(deployment_name = DEPLOYMENT_NAME,
- deployment_type=mii.constants.DeploymentType.NON_PERSISTENT
- ...
- )
-
-generator = mii.mii_query_handle(DEPLOYMENT_NAME)
-result = generator.query({"query": ["DeepSpeed is", "Seattle is"]}, do_sample=True, max_new_tokens=30})
-
+client = mii.client("mistralai/Mistral-7B-v0.1")
+response = client.generate("Deepspeed is", max_new_tokens=128)
```
-You can find a complete example [here]("https://github.com/microsoft/DeepSpeed-MII/tree/main/examples/non_persistent")
-
-Any HTTP client can be used to call the APIs. An example of using curl is:
-```bash
-# Assume deployment_name and restful_api_port are set to bloom560m_deployment and 28080 respectively:
-$ curl --header "Content-Type: application/json" --request POST -d '{"request": {"query": ["Seattle is", "Bellevue is", "Redmond is"]}, "kwargs": {"do_sample": false, "max_new_tokens": 100}}' http://localhost:28080/mii/bloom560m_deployment
-```
-
-The code below is an example using Python.
+When we no longer need a persistent deployment, we can shutdown the server from any client:
```python
-import requests
-import json
-
-# text_generation
-url = 'http://localhost:28080/mii/bloom560m_deployment'
-params = {"request": {"query": ["Seattle is", "Bellevue is", "Redmond is"]},
- "kwargs": {"do_sample": False, "max_new_tokens": 100}}
-
-json_params = json.dumps(params)
-response = requests.post(url, data=json_params, headers={
- "Content-Type": "application/json"})
-print(response.json())
+client.terminate_server()
```
-## Deploying with MII-Azure
-
-MII supports deployment on Azure via AML Inference. To enable this, MII generates AML deployment assets for a given model that can be deployed using the Azure-CLI, as shown in the code below. Furthermore, deploying on Azure, allows MII to leverage DeepSpeed-Azure as its optimization backend, which offers better latency and cost reduction than DeepSpeed-Public.
-
-This deployment process is very similar to local deployments and we will modify the code from the local deployment example with the [bigscience/bloom-560m](https://huggingface.co/bigscience/bloom-560m) model.
-
----
-📌 **Note:** MII-Azure has the benefit of supporting DeepSpeed-Azure for better latency and cost than DeepSpeed-Public for certain workloads. We are working to enable DeepSpeed-Azure automatically for all MII-Azure deployments in a near-term MII update. In the meantime, we are offering DeepSpeed-Azure as a preview release to MII-Azure users. If you have a MII-Azure deployment and would like to try DeepSpeed-Azure, please reach out to us at deepspeed-mii@microsoft.com to get access.
-
----
-
-Several other AML deployment examples can be found here: [examples/aml](https://github.com/microsoft/DeepSpeed-MII/tree/main/examples/aml)
+### Model Parallelism
-**Setup**
+Taking advantage of multi-GPU systems for better latency and throughput is also easy with the persistent deployments. Model parallelism is controlled by the `tensor_parallel` input to `mii.serve`:
-To use MII on AML resources, you must have the Azure-CLI installed with an active login associated with your Azure resources. Follow the instructions below to get your local system ready for deploying on AML resources:
-
-1. Install Azure-CLI. Follow the official [installation instructions](https://learn.microsoft.com/en-us/cli/azure/install-azure-cli#install).
-2. Run `az login` and follow the instructions to login to your Azure account. This account should be linked to the resources you plan to deploy on.
-3. Set the default subscription with `az account set --subscription `. You can find your subscription ID in the "overview" tab on your resource group page from the Azure web portal.
-4. Set the default resource group and workspace name with `az config defaults.group defaults.workspace `
-5. Install the AML plugin for Azure-CLI with `az extension add --name ml`
-
-**Deployment**
```python
-import mii
-mii_configs = {"tensor_parallel": 1, "dtype": "fp16"}
-mii.deploy(task="text-generation",
- model="bigscience/bloom-560m",
- deployment_name="bloom560m-deployment",
- deployment_type=mii.constants.DeploymentType.AML,
- mii_config=mii_configs)
+client = mii.serve("mistralai/Mistral-7B-v0.1", tensor_parallel=2)
```
----
-📌 **Note:** Running the `mii.deploy` with `deployment_type=mii.constants.DeploymentType.AML` will only generate the scripts to launch an AML deployment. You must also run the generated `deploy.sh` script to run on AML resources.
+The resulting deployment will split the model across 2 GPUs to deliver faster inference and higher throughput than a single GPU.
----
+### Model Replicas
-This will generate the scripts and configuration files necessary to deploy the model on AML using a single GPU. You can find the generated output at `./bloom560m-deployment_aml/`
+We can also take advantage of multi-GPU (and multi-node) systems by setting up multiple model replicas and taking advantage of the load-balancing that DeepSpeed-MII provides:
-When you are ready to run your deployment on AML resources, navigate to the newly created directory and run the deployment script:
-```bash
-cd ./bloom560m-deployment_aml/
-bash deploy.sh
+```python
+client = mii.serve("mistralai/Mistral-7B-v0.1", replica_num=2)
```
-This script may take several minutes to run as it does the following:
-- Downloads the model locally
-- Creates a Docker Image with MII for your deployment
-- Creates an AML online-endpoint for running queries
-- Uploads and registers the model to AML
-- Starts your deployment
-
----
-📌 **Note:** Large models (e.g., `bigscience/bloom`) may cause a timeout when trying to upload and register the model to AML. In these cases, it is required to manually upload models to Azure blob storage with [AzCopy](https://learn.microsoft.com/en-us/azure/storage/common/storage-use-azcopy-v10). Instructions and automation of this step will be added soon.
+The resulting deployment will load 2 model replicas (one per GPU) and load-balance incoming requests between the 2 model instances.
----
+Model parallelism and replicas can also be combined to take advantage of systems with many more GPUs. In the example below, we run 2 model replicas, each split across 2 GPUs on a system with 4 GPUs:
-**Query**
-Once the deployment is running on AML, you can run queries by navigating to the online-endpoint that was created for this deployment (i.e., `bloom-560m-deployment-endpoint`) from the [AML web portal](https://ml.azure.com/endpoints). Select the "Test" tab at the top of the endpoint page and type your query into the text-box:
-```
-{"query": ["DeepSpeed is", "Seattle is"], "do_sample"=True, "max_new_tokens"=30}
+```python
+client = mii.serve("mistralai/Mistral-7B-v0.1", tensor_parallel=2, replica_num=2)
```
-The only required key is `"query"`, all other items in the dictionary will be passed to `generate` as kwargs. For Hugging Face provided models you can find all possible arguments in their [documentation for generate](https://huggingface.co/docs/transformers/v4.20.1/en/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate).
+The choice between model parallelism and model replicas for maximum performance will depend on the nature of the hardware, model, and workload. For example, with small models users may find that model replicas provide the lowest average latency for requests. Meanwhile, large models may achieve greater overall throughput when using only model parallelism.
-# Quantifying Latency and Cost Reduction
-
-Inference workloads can be either latency critical, where the primary objective is to minimize latency, or cost sensitive, where the primary objective is to minimize cost. In this section, we quantify the benefits of using MII for both latency-critical and cost-sensitive scenarios.
-
-## Latency Critical Scenarios
-
-For latency-critical scenarios, where a small batch size of 1 is often used, MII can reduce the latency by up to 6x for a wide range of open-source models, across multiple tasks. More specifically, we show model latency reduction of [^overhead_details]:
-
-1. Up to 5.7x for multi-GPU inference for text generation using massive models such as Big Science Bloom, Facebook OPT, and EluetherAI NeoX (*Figure 2 (left)*)
-
-2. Up to 1.9x for image generation tasks model using Stable Diffusion (*Figure 2 (right)*)
-
-3. Up to 3x for relatively smaller text generation models (up to 7B parameters) based on OPT, BLOOM, and GPT architectures, running on a single GPU (*Figures 3 and 4*)
-
-4. Up to 9x for various text representation tasks like fill-mask, text classification, question answering, and token classification using RoBERTa- and BERT- based models (*Figures 5 and 6*).
-
-[ ![multi gpu latency](/docs/images/llm-latency-sd-latency.png) ](/docs/images/llm-latency-sd-latency-zoom.png)
-*Figure 2: (Left) Best achievable latency for large models. MII-Azure (int8) offers 5.7X lower latency compared to Baseline for Bloom-176B. (Right) Stable Diffusion text to image generation latency comparison.*
-
-[ ![OPT and BLOOM Models](/docs/images/opt-bloom.png) ](/docs/images/opt-bloom.png)
-*Figure 3: Latency comparison for OPT and BLOOM models. MII-Azure is up to 2.8x faster than baseline.*
-
-[ ![GPT Models](/docs/images/gpt.png) ](/docs/images/mii/gpt.png)
-*Figure 4: Latency comparison for GPT models. MII-Azure is up to 3x faster than baseline.*
-
-[ ![Roberta Models](/docs/images/roberta.png) ](/docs/images/roberta.png)
-*Figure 5: Latency comparison for RoBERTa models. MII offers up to 9x lower model latency and up to 3x lower end-to-end latency than baseline on several tasks and RoBERTa variants [^overhead_details].*
+
-MII can significantly reduce the inference cost of very expensive language models like Bloom, OPT, etc. To get the lowest cost, we use a large batch size that maximizes throughput for both baseline and MII. Here we look at the cost reduction from MII using two different metrics: i) tokens generated per second per GPU, and ii) dollars per million tokens generated.
+### Persistent Deployment Options
+While only the model name or path is required to stand up a persistent deployment, we offer customization options to our users.
-*Figures 7 and 8* show that MII-Public offers over 10x throughput improvement and cost reduction compared to the baseline, respectively. Furthermore, MII-Azure offers over 30x improvement in throughput and cost compared to the baseline.
+**`mii.serve()` Options**:
+- `model_name_or_path: str` Name or local path to a [HuggingFace](https://huggingface.co/) model.
+- `max_length: int` Sets the default maximum token length for the prompt + response.
+- `deployment_name: str` A unique identifying string for the persistent model. If provided, client objects should be retrieved with `client = mii.client(deployment_name)`.
+- `tensor_parallel: int` Number of GPUs to split the model across.
+- `replica_num: int` The number of model replicas to stand up.
-[ ![tput large models](/docs/images/tput-llms.png) ](/docs/images/tput-llms.png)
-*Figure 7: Throughput comparison per A100-80GB GPU for large models. MII-Public offers over 15x throughput improvement while MII-Azure offers over 40x throughput improvement.*
+**`mii.client()` Options**:
+- `model_or_deployment_name: str` Name of the model or `deployment_name` passed to `mii.serve()`
-[ ![azure cost](/docs/images/azure-cost.png) ](/docs/images/azure-cost.png)
-*Figure 8: Cost of generating 1 million tokens on Azure with different model types. MII-Azure reduces the cost of generation by over 40x.*
+Users can also control the generation characteristics for individual prompts (i.e., when calling `client.generate()`) with the following options:
-# Community Tutorials
+- `max_length: int` Sets the per-prompt maximum token length for prompt + response.
+- `max_new_tokens: int` Sets the maximum number of tokens generated in the response.
-* [DeepSpeed Deep Dive — Model Implementations for Inference (MII) (Heiko Hotz)](https://towardsdatascience.com/deepspeed-deep-dive-model-implementations-for-inference-mii-b02aa5d5e7f7)
# Contributing
diff --git a/docs/images/fast-gen-overview.png b/docs/images/fast-gen-overview.png
new file mode 100644
index 00000000..47ec26c8
Binary files /dev/null and b/docs/images/fast-gen-overview.png differ
diff --git a/docs/images/fastgen-arch-dark.png b/docs/images/fastgen-arch-dark.png
new file mode 100755
index 00000000..4d3b09ce
Binary files /dev/null and b/docs/images/fastgen-arch-dark.png differ
diff --git a/docs/images/fastgen-arch-light.png b/docs/images/fastgen-arch-light.png
new file mode 100755
index 00000000..6dcc741f
Binary files /dev/null and b/docs/images/fastgen-arch-light.png differ
diff --git a/docs/images/fastgen-hero-dark.png b/docs/images/fastgen-hero-dark.png
new file mode 100755
index 00000000..6ac1a775
Binary files /dev/null and b/docs/images/fastgen-hero-dark.png differ
diff --git a/docs/images/fastgen-hero-light.png b/docs/images/fastgen-hero-light.png
new file mode 100755
index 00000000..af8f1def
Binary files /dev/null and b/docs/images/fastgen-hero-light.png differ
diff --git a/docs/images/fastgen-hero.png b/docs/images/fastgen-hero.png
new file mode 100644
index 00000000..33b10e7d
Binary files /dev/null and b/docs/images/fastgen-hero.png differ
diff --git a/docs/images/fastgen-overview-dark.png b/docs/images/fastgen-overview-dark.png
new file mode 100755
index 00000000..dde598a9
Binary files /dev/null and b/docs/images/fastgen-overview-dark.png differ
diff --git a/docs/images/fastgen-overview-light.png b/docs/images/fastgen-overview-light.png
new file mode 100755
index 00000000..bdb5f8df
Binary files /dev/null and b/docs/images/fastgen-overview-light.png differ
diff --git a/docs/images/mii-arch-dark.png b/docs/images/mii-arch-dark.png
new file mode 100755
index 00000000..9b90357a
Binary files /dev/null and b/docs/images/mii-arch-dark.png differ
diff --git a/docs/images/mii-arch-light.png b/docs/images/mii-arch-light.png
new file mode 100755
index 00000000..9e754abd
Binary files /dev/null and b/docs/images/mii-arch-light.png differ
diff --git a/examples/README.md b/examples/README.md
new file mode 100644
index 00000000..4efb2155
--- /dev/null
+++ b/examples/README.md
@@ -0,0 +1,2 @@
+# MII Examples
+Please see [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples/tree/master/inference/mii) for a few examples on using MII.
diff --git a/mii/__init__.py b/mii/__init__.py
index 8409de97..86eb3792 100644
--- a/mii/__init__.py
+++ b/mii/__init__.py
@@ -2,19 +2,19 @@
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
-import grpc
-from .server import MIIServer
-from .client import MIIClient, mii_query_handle
-from .deployment import deploy
-from .terminate import terminate
-from .constants import DeploymentType, TaskType
-from .aml_related.utils import aml_output_path
-from .config import MIIConfig, ModelConfig
-from .utils import get_supported_models
-from .grpc_related.proto import modelresponse_pb2_grpc
+try:
+ import grpc
+ from .pipeline import pipeline
+ from .server import serve
+ from .client import client
+except ImportError as e:
+ print("Warning: DeepSpeed-FastGen could not be imported:")
+ print(e)
+ pass
+
+from .legacy import MIIServer, MIIClient, mii_query_handle, deploy, terminate, DeploymentType, TaskType, aml_output_path, MIIConfig, ModelConfig, get_supported_models
__version__ = "0.0.0"
-non_persistent_models = {}
try:
from .version import __version__
except ImportError:
diff --git a/mii/batching/__init__.py b/mii/batching/__init__.py
new file mode 100644
index 00000000..1594ba77
--- /dev/null
+++ b/mii/batching/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+from .ragged_batching import MIIAsyncPipeline, MIIPipeline
diff --git a/mii/models/providers/__init__.py b/mii/batching/generation/__init__.py
similarity index 100%
rename from mii/models/providers/__init__.py
rename to mii/batching/generation/__init__.py
diff --git a/mii/batching/generation/logit_processors.py b/mii/batching/generation/logit_processors.py
new file mode 100644
index 00000000..219d4fb1
--- /dev/null
+++ b/mii/batching/generation/logit_processors.py
@@ -0,0 +1,111 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+import abc
+from typing import List, Optional
+
+import torch
+import torch.nn.functional as F
+
+FLOAT_PAD = -float("inf")
+
+
+class BaseLogitProcessor(abc.ABC):
+ def __call__(self, logits: torch.Tensor) -> torch.Tensor:
+ return self.forward(logits)
+
+ @abc.abstractmethod
+ def forward(self, logits: torch.Tensor) -> torch.Tensor:
+ ...
+
+ def get_key(self) -> str:
+ return self.__class__.__name__
+
+
+class TopKLogitProcessor(BaseLogitProcessor):
+ def __init__(self, top_k: int) -> None:
+ self.top_k = top_k
+
+ def forward(self, logits: torch.Tensor) -> torch.Tensor:
+ # Remove all tokens with a probability less than the
+ # last token of the top-k
+ indices_to_remove = logits < torch.topk(logits, self.top_k)[0][..., -1, None]
+ logits[indices_to_remove] = FLOAT_PAD
+ return logits
+
+ def get_key(self) -> str:
+ return super().get_key() + f"_top_k={self.top_k}"
+
+
+class TopPLogitProcessor(BaseLogitProcessor):
+ def __init__(self, top_p: float) -> None:
+ assert 0.0 <= top_p <= 1.0
+ self.top_p = top_p
+
+ def forward(self, logits: torch.Tensor) -> torch.Tensor:
+ # convert to 1D
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
+
+ # Remove tokens with cumulative probability above the threshold
+ sorted_indices_to_remove = cumulative_probs > self.top_p
+ # Shift the indices to the right to keep also the first token
+ # above the threshold
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
+ sorted_indices_to_remove[..., 0] = 0
+ for i in range(sorted_indices.size(0)):
+ indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
+ logits[i][indices_to_remove] = FLOAT_PAD
+ return logits
+
+ def get_key(self) -> str:
+ return super().get_key() + f"_top_p={self.top_p}"
+
+
+class TemperatureLogitProcessor(BaseLogitProcessor):
+ def __init__(self, temperature: float) -> None:
+ self.temperature = temperature
+ assert self.temperature > 0.0
+
+ def forward(self, logits: torch.Tensor) -> torch.Tensor:
+ return logits / self.temperature
+
+ def get_key(self) -> str:
+ return super().get_key() + f"_temperature={self.temperature}"
+
+
+class PipelineLogitProcessor(BaseLogitProcessor):
+ def __init__(self, pipeline: List[BaseLogitProcessor]) -> None:
+ assert all(isinstance(step, BaseLogitProcessor) for step in pipeline)
+ self.pipeline = pipeline
+
+ def forward(self, logits: torch.Tensor) -> torch.Tensor:
+ for step in self.pipeline:
+ logits = step(logits)
+ return logits
+
+ def get_key(self) -> str:
+ return super().get_key(
+ ) + f"_{'_'.join(step.get_key() for step in self.pipeline)}"
+
+
+class NucleusSamplingLogitProcessor(BaseLogitProcessor):
+ def __init__(self,
+ top_k: Optional[int] = None,
+ top_p: Optional[float] = None) -> None:
+ assert top_k is not None or top_p is not None
+ if top_k is None:
+ self._processor = TopPLogitProcessor(top_p)
+ elif top_p is None:
+ self._processor = TopKLogitProcessor(top_k)
+ else:
+ self._processor = PipelineLogitProcessor(
+ [TopKLogitProcessor(top_k),
+ TopPLogitProcessor(top_p)])
+
+ def forward(self, logits: torch.Tensor) -> torch.Tensor:
+ return self._processor(logits)
+
+ def get_key(self) -> str:
+ return super().get_key() + f"_{self._processor.get_key()}"
diff --git a/mii/batching/generation/samplers.py b/mii/batching/generation/samplers.py
new file mode 100644
index 00000000..d0126609
--- /dev/null
+++ b/mii/batching/generation/samplers.py
@@ -0,0 +1,57 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+import abc
+from typing import Tuple
+
+import torch
+from torch.distributions import Categorical
+
+
+class BaseGenerationSampler(abc.ABC):
+ @abc.abstractmethod
+ def __call__(
+ self,
+ logits: torch.Tensor,
+ ) -> Tuple[torch.LongTensor,
+ torch.Tensor]:
+ """
+ Given the logits, return the next token to add to the sequence, as well
+ as the log probability of the token
+
+ Args:
+ logits (torch.Tensor):
+ The logits from the model. Shape: (batch_size, vocab_size)
+
+ Returns:
+ Tuple[torch.LongTensor, torch.Tensor]:
+ The next token to add to the sequence, and the log probability
+ of the token. Shape: (batch_size,) and (batch_size,)
+ """
+ ...
+
+ def get_key(self) -> str:
+ return self.__class__.__name__
+
+
+class LogitsSampler(BaseGenerationSampler):
+ def __call__(
+ self,
+ logits: torch.Tensor,
+ ) -> Tuple[torch.LongTensor,
+ torch.Tensor]:
+ logits = logits.float()
+ sampler = Categorical(logits=logits)
+ next_tokens = sampler.sample()
+ logprobs = sampler.log_prob(next_tokens)
+ return next_tokens, logprobs
+
+
+class GreedySampler(BaseGenerationSampler):
+ def __call__(self, logits: torch.Tensor) -> Tuple[torch.LongTensor, torch.Tensor]:
+ logits = logits.float()
+ sampler = Categorical(logits=logits)
+ next_tokens = logits.argmax(dim=-1)
+ logprobs = sampler.log_prob(next_tokens)
+ return next_tokens, logprobs
diff --git a/mii/batching/generation/stop_criterion.py b/mii/batching/generation/stop_criterion.py
new file mode 100644
index 00000000..7ea83608
--- /dev/null
+++ b/mii/batching/generation/stop_criterion.py
@@ -0,0 +1,97 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+import abc
+from typing import List, Union
+
+import torch
+
+# from megatron import get_tokenizer
+# from megatron.tokenizer.tokenizer import AbstractTokenizer
+
+
+class BaseGenerationStopCriterion(abc.ABC):
+ def __init__(self, tokenizer):
+ self.tokenizer = tokenizer
+
+ def __call__(self, tokens: torch.LongTensor) -> torch.BoolTensor:
+ return self.forward(tokens)
+
+ @abc.abstractmethod
+ def forward(self, tokens: torch.LongTensor) -> torch.BoolTensor:
+ ...
+
+ def get_key(self) -> str:
+ return self.__class__.__name__
+
+
+class TokenStopCriterion(BaseGenerationStopCriterion):
+ def __init__(self, token: Union[str, int], tokenizer) -> None:
+ super().__init__(tokenizer=tokenizer)
+ if isinstance(token, str):
+ token_id = self.tokenizer.tokenize(token)[0]
+ else:
+ token_id = token
+ self.stop_token_id = token_id
+
+ def forward(self, tokens: torch.LongTensor) -> torch.BoolTensor:
+ retval = torch.zeros_like(tokens, dtype=torch.bool)
+ retval |= tokens == self.stop_token_id
+ return retval
+
+ def get_key(self) -> str:
+ return self.__class__.__name__ + f"_token_id={self.stop_token_id}"
+
+
+class EosGenerationStopCriterion(BaseGenerationStopCriterion):
+ def __init__(self, tokenizer):
+ super().__init__(tokenizer=tokenizer)
+ if hasattr(self.tokenizer, "eod"):
+ self.eos_id = self.tokenizer.eod
+ elif hasattr(self.tokenizer, "eos_token_id"):
+ self.eos_id = self.tokenizer.eos_token_id
+ elif hasattr(self.tokenizer, "eos_token"):
+ self.eos_id = self.tokenizer.eos_token
+ else:
+ raise ValueError(
+ "Tokenizer must have either an `eod` or `eos_token` attribute.")
+
+ def forward(self, tokens: torch.LongTensor) -> torch.BoolTensor:
+ return tokens == self.eos_id
+
+
+class NewLineDelimitedStopCriterion(BaseGenerationStopCriterion):
+ def __init__(self, tokenizer):
+ super().__init__(tokenizer=tokenizer)
+ self.stop_token_ids = list(
+ set([self.tokenizer.tokenize(x)[0] for x in ["\n",
+ "\r\n",
+ "\n\n",
+ ".\n\n"]]))
+
+ def forward(self, tokens: torch.LongTensor) -> torch.BoolTensor:
+ retval = torch.zeros_like(tokens, dtype=torch.bool)
+ for stop_token_id in self.stop_token_ids:
+ retval |= tokens == stop_token_id
+ return retval
+
+
+class PipelinedCriterion(BaseGenerationStopCriterion):
+ def __init__(
+ self,
+ criteria: List[BaseGenerationStopCriterion],
+ tokenizer,
+ ):
+ super().__init__(tokenizer=tokenizer)
+ self.criteria = criteria
+
+ def forward(self, tokens: torch.LongTensor) -> torch.BoolTensor:
+ retval = torch.zeros_like(tokens, dtype=torch.bool)
+ for criterion in self.criteria:
+ retval |= criterion(tokens)
+ return retval
+
+ def get_key(self) -> str:
+ return super().get_key(
+ ) + f"_{'_'.join(criterion.get_key() for criterion in self.criteria)}"
diff --git a/mii/batching/postprocess.py b/mii/batching/postprocess.py
new file mode 100644
index 00000000..cbf491ec
--- /dev/null
+++ b/mii/batching/postprocess.py
@@ -0,0 +1,99 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+import itertools
+from collections import defaultdict
+from typing import Any, Dict
+
+import torch
+
+from .generation.logit_processors import (
+ TopKLogitProcessor,
+ TopPLogitProcessor,
+ TemperatureLogitProcessor,
+ NucleusSamplingLogitProcessor,
+)
+from .generation.samplers import LogitsSampler, GreedySampler
+from .generation.stop_criterion import (
+ EosGenerationStopCriterion,
+ NewLineDelimitedStopCriterion,
+)
+
+LOGITS_PROCESSORS = {
+ "TopK": TopKLogitProcessor,
+ "TopP": TopPLogitProcessor,
+ "Temperature": TemperatureLogitProcessor,
+ "NucleusSampling": NucleusSamplingLogitProcessor,
+}
+
+SAMPLERS = {"Logits": LogitsSampler, "Greedy": GreedySampler}
+
+STOP_CRITERIA = {
+ "EosGeneration": EosGenerationStopCriterion,
+ "NewLineDelimited": NewLineDelimitedStopCriterion,
+}
+
+DEFAULT_LOGITS_PROCESSOR = {"name": "TopP", "args": {"top_p": 0.9}}
+DEFAULT_SAMPLER = {"name": "Logits"}
+DEFAULT_STOP_CRITERION = {"name": "EosGeneration"}
+
+
+def _create_postprocessor(config: Dict[str,
+ Any],
+ classes: Dict[str,
+ Any],
+ default_args: Dict[str,
+ Any] = {}):
+ assert "name" in config
+
+ name = config["name"]
+ if name not in classes:
+ raise ValueError(f"Unknown postprocessor {name}")
+ args = config["args"] if "args" in config else {}
+ args.update(default_args)
+ return classes[name](**args)
+
+
+def _run_batch_postprocess(input_tensor,
+ requests,
+ get_processor_fn,
+ get_result_fn=lambda x: x):
+ processor_map = {
+ get_processor_fn(r).get_key(): get_processor_fn(r)
+ for r in requests
+ }
+ processor_indices = defaultdict(list)
+
+ for i, r in enumerate(requests):
+ key = get_processor_fn(r).get_key()
+ processor_indices[key].append(i)
+
+ indice_list = []
+ outputs_list = []
+ for key, indices in processor_map.items():
+ processor = processor_map[key]
+ indices = processor_indices[key]
+ input_filtered = input_tensor[indices]
+ output_filtered = get_result_fn(processor(input_filtered))
+ indice_list.append(indices)
+ outputs_list.append(output_filtered)
+
+ indice = list(itertools.chain.from_iterable(indice_list))
+ outputs = torch.cat(outputs_list, dim=0)
+ return outputs[torch.argsort(torch.tensor(indice))]
+
+
+def run_batch_logit_processor(input_tensor, requests):
+ return _run_batch_postprocess(input_tensor, requests, lambda r: r.logit_processor)
+
+
+def run_batch_sampler(input_tensor, requests):
+ return _run_batch_postprocess(input_tensor,
+ requests,
+ lambda r: r.sampler,
+ lambda x: x[0])
+
+
+def run_batch_stop_criterion(input_tensor, requests):
+ return _run_batch_postprocess(input_tensor, requests, lambda r: r.stop_criterion)
diff --git a/mii/batching/ragged_batching.py b/mii/batching/ragged_batching.py
new file mode 100644
index 00000000..a960a564
--- /dev/null
+++ b/mii/batching/ragged_batching.py
@@ -0,0 +1,779 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+import copy
+import queue
+import os
+import random
+import threading
+import time
+from collections import deque, defaultdict
+from dataclasses import dataclass, asdict, field
+from functools import cached_property
+from typing import Dict, Tuple, List, Any, Iterator, Union, DefaultDict, Set
+from typing_extensions import Self
+
+import torch
+import ujson
+import zmq
+from deepspeed.accelerator import get_accelerator
+from deepspeed.utils.timer import SynchronizedWallClockTimer
+
+from mii.batching.generation.logit_processors import BaseLogitProcessor
+from mii.batching.generation.samplers import BaseGenerationSampler
+from mii.batching.generation.stop_criterion import BaseGenerationStopCriterion
+from mii.batching.postprocess import (
+ _create_postprocessor,
+ run_batch_logit_processor,
+ run_batch_sampler,
+ run_batch_stop_criterion,
+ DEFAULT_LOGITS_PROCESSOR,
+ DEFAULT_SAMPLER,
+ DEFAULT_STOP_CRITERION,
+ LOGITS_PROCESSORS,
+ SAMPLERS,
+ STOP_CRITERIA,
+)
+from mii.batching.utils import sync_debug, profiler
+from mii.constants import GenerationFinishReason, ZMQ_RECV_TIMEOUT
+from mii.logging import logger
+
+
+@dataclass
+class Response:
+ generated_text: str
+ prompt_length: int
+ generated_length: int
+ finish_reason: GenerationFinishReason
+
+ @staticmethod
+ def from_msg(msg: Dict[str, Union[str, int]]) -> Self:
+ return Response(
+ generated_text=msg["generated_text"],
+ prompt_length=msg["prompt_length"],
+ generated_length=msg["generated_length"],
+ finish_reason=GenerationFinishReason(msg["finish_reason"]),
+ )
+
+ def get_msg(self) -> Dict[str, Union[str, int]]:
+ return {
+ "generated_text": self.generated_text,
+ "prompt_length": self.prompt_length,
+ "generated_length": self.generated_length,
+ "finish_reason": self.finish_reason.value
+ }
+
+ def __repr__(self) -> str:
+ return self.generated_text
+
+ def __str__(self) -> str:
+ return self.generated_text
+
+
+class ResponseBatch:
+ def __init__(self, responses: List[Response]) -> None:
+ self.responses = responses
+
+ def __iter__(self) -> Iterator[Response]:
+ return iter(self.responses)
+
+ def __repr__(self) -> str:
+ return "\n\n".join(str(r) for r in self.responses)
+
+ @property
+ def generated_texts(self) -> List[str]:
+ return [r.generated_text for r in self.responses]
+
+ @property
+ def prompt_lengths(self) -> List[int]:
+ return [r.prompt_length for r in self.responses]
+
+ @property
+ def generated_lengths(self) -> List[int]:
+ return [r.generated_length for r in self.responses]
+
+ @property
+ def finish_reasons(self) -> List[GenerationFinishReason]:
+ return [r.finish_reason for r in self.responses]
+
+ def append(self, response: Response) -> None:
+ self.responses.append(response)
+
+
+@dataclass
+class RaggedRequestMsg:
+ uid: int
+ input_tokens: Union[torch.Tensor, List[int]]
+
+ @property
+ def is_flush_request(self):
+ return self.input_tokens is None
+
+ @staticmethod
+ def from_msg(msg: Dict[str, int]) -> Self:
+ return RaggedRequestMsg(
+ uid=msg["uid"],
+ input_tokens=None
+ if msg["input_tokens"] is None else torch.tensor(msg["input_tokens"],
+ dtype=torch.int32,
+ device=torch.device("cpu")),
+ )
+
+
+@dataclass
+class RaggedRequest:
+ uid: int
+ input_tokens: torch.Tensor
+ prompt_length: int
+ seq_length: int
+ max_length: int
+ max_new_tokens: int
+ last_in_prompt: bool
+ logit_processor: BaseLogitProcessor
+ sampler: BaseGenerationSampler
+ stop_criterion: BaseGenerationStopCriterion
+ stream: bool = False
+
+ _next_token: Union[None, torch.Tensor] = None
+ _is_done: bool = False
+ _generated_tokens: List[torch.Tensor] = field(default_factory=list)
+ _finish_reason: GenerationFinishReason = GenerationFinishReason.NONE
+
+ @property
+ def next_token(self) -> Union[None, torch.Tensor]:
+ return self._next_token
+
+ @next_token.setter
+ def next_token(self, next_token: Union[None, torch.Tensor]) -> None:
+ self._next_token = next_token
+
+ @property
+ def is_done(self) -> bool:
+ return self._is_done
+
+ @is_done.setter
+ def is_done(self, is_done: bool) -> None:
+ self._is_done = is_done
+
+ @property
+ def generated_tokens(self) -> List[torch.Tensor]:
+ return self._generated_tokens
+
+ @property
+ def finish_reason(self) -> GenerationFinishReason:
+ return self._finish_reason
+
+ @property
+ def is_flush_request(self):
+ return self.input_tokens is None
+
+ @property
+ def num_generated_tokens(self) -> int:
+ # We return zero while we are processing decomposed prompts
+ return self.seq_length - self.prompt_length + 1 if self.seq_length >= self.prompt_length else 0
+
+ @property
+ def stop_generation(self) -> bool:
+ if self.is_done:
+ self._finish_reason = GenerationFinishReason.STOP
+ return True
+ if (self.seq_length >= self.max_length) or (self.num_generated_tokens >=
+ self.max_new_tokens):
+ self._finish_reason = GenerationFinishReason.LENGTH
+ return True
+ return False
+
+ def get_msg(self) -> RaggedRequestMsg:
+ return RaggedRequestMsg(
+ uid=self.uid,
+ input_tokens=None
+ if self.input_tokens is None else self.input_tokens.tolist(),
+ )
+
+ def accumulate_generated_token(self) -> None:
+ if not self.is_done:
+ self._generated_tokens.append(self.next_token)
+
+ def set_next_as_input(self) -> None:
+ if self.next_token is not None:
+ self.input_tokens = self.next_token.unsqueeze(0)
+ self.last_in_prompt = True
+ self.next_token = None
+ self.is_done = False
+
+
+class RaggedRequestBatch:
+ def __init__(self, requests: List[RaggedRequest]) -> None:
+ self.requests = requests
+
+ def __len__(self) -> int:
+ return len(self.requests)
+
+ def __contains__(self, r: RaggedRequest) -> bool:
+ return r in self.requests
+
+ def __nonzero__(self) -> bool:
+ if len(self.requests) != 0:
+ return True
+ return False
+
+ def __iter__(self) -> Iterator[RaggedRequest]:
+ return iter(self.requests)
+
+ def __repr__(self) -> str:
+ return f"RaggedRequestBatch({self.requests})"
+
+ @property
+ def requests_to_run(self) -> Self:
+ return RaggedRequestBatch([r for r in self.requests if not r.is_flush_request])
+
+ @property
+ def requests_to_flush(self) -> Self:
+ return RaggedRequestBatch([r for r in self.requests if r.is_flush_request])
+
+ @property
+ def last_in_prompt(self) -> Self:
+ return RaggedRequestBatch([r for r in self.requests if r.last_in_prompt])
+
+ @property
+ def completed(self) -> Self:
+ return RaggedRequestBatch([r for r in self.requests if r.stop_generation])
+
+ @property
+ def uids(self) -> List[int]:
+ return [r.uid for r in self.requests]
+
+ @property
+ def lengths(self) -> List[int]:
+ return [len(r.input_tokens) for r in self.requests]
+
+ @property
+ def tokens(self) -> List[torch.Tensor]:
+ return [r.input_tokens for r in self.requests]
+
+ @property
+ def next_tokens(self) -> List[torch.Tensor]:
+ return [r.next_token for r in self.requests]
+
+ @property
+ def done_tokens(self) -> List[torch.Tensor]:
+ return [r.is_done for r in self.requests]
+
+ @next_tokens.setter
+ def next_tokens(self, next_tokens: List[torch.Tensor]) -> None:
+ assert len(next_tokens) == len(self.requests)
+ for idx, r in enumerate(self.requests):
+ r.next_token = next_tokens[idx]
+
+ @done_tokens.setter
+ def done_tokens(self, done_tokens: List[torch.Tensor]) -> None:
+ assert len(done_tokens) == len(self.requests)
+ for idx, r in enumerate(self.requests):
+ r.is_done = done_tokens[idx]
+
+ def prune(self, uids: List[int]) -> None:
+ self.requests = [r for r in self.requests if r.uid not in uids]
+
+ def append(self, r: RaggedRequest) -> None:
+ self.requests.append(r)
+
+ def update_seq_length(self) -> None:
+ for r in self.requests:
+ r.seq_length += r.input_tokens.size(0)
+
+
+class RaggedBatchBase:
+ def __init__(self, inference_engine, tokenizer, model_config):
+ self.inference_engine = inference_engine
+ self.tokenizer = tokenizer
+ self.vocab_size = tokenizer.vocab_size
+ self.model_config = model_config
+ self.zmq_port = model_config.zmq_port_number
+ if model_config.max_length is not None:
+ self.max_length = model_config.max_length
+ else:
+ self.max_length = inference_engine._policy._checkpoint_engine.model_config.max_seq_length
+ self.sync_debug = model_config.sync_debug
+ self.profile_model_time = model_config.profile_model_time
+
+ self.request_queue: queue.Queue = queue.Queue()
+ self.result_queues: Dict[int, queue.Queue] = {}
+ self.scheduled_requests: RaggedRequestBatch = RaggedRequestBatch([])
+ self.buffer = deque()
+ self.scheduled_length = 0
+ self.scheduled_seq_num = 0
+ self.scheduled_req_blocks = 0
+
+ self.logit_processor = run_batch_logit_processor
+ self.sampler = run_batch_sampler
+ self.stop_criterion = run_batch_stop_criterion
+
+ self._timers: SynchronizedWallClockTimer = SynchronizedWallClockTimer()
+ self._profiled_times: DefaultDict[str, List[int]] = defaultdict(list)
+ self._iters: int = 0
+ self._num_generated_tokens: int = 0
+
+ context = zmq.Context()
+ torch.cuda.synchronize()
+ if self.is_rank_0:
+ self.socket = context.socket(zmq.PUB)
+ self.socket.bind(f"tcp://*:{self.zmq_port}")
+ time.sleep(1) # Give the subscriber a change to connect
+ else:
+ self.socket = context.socket(zmq.SUB)
+ self.socket.connect(f"tcp://localhost:{self.zmq_port}")
+ self.socket.setsockopt_string(zmq.SUBSCRIBE, "")
+ self.socket.setsockopt(zmq.RCVTIMEO, ZMQ_RECV_TIMEOUT)
+
+ @cached_property
+ def local_rank(self) -> int:
+ return get_accelerator().current_device()
+
+ @property
+ def is_rank_0(self) -> bool:
+ return self.local_rank == 0
+
+ @profiler
+ def generate(self) -> None:
+ # 1. Get a batch of requests, broadcast to all ranks
+ scheduled_requests = self._bcast_requests()
+
+ # 2. Flush for uids that are finished generating
+ self.flush(scheduled_requests.requests_to_flush.uids)
+
+ # 3. Put new tokens into inference engine
+ if scheduled_requests.requests_to_run:
+ next_token_logits = self.put(
+ scheduled_requests.requests_to_run.uids,
+ scheduled_requests.requests_to_run.tokens,
+ )
+
+ # short circuit if not rank 0, only rank 0 does scheduling and postprocessing of logits
+ if not self.is_rank_0:
+ return
+
+ # 4. Launch logit processing and token generation
+ running_requests = scheduled_requests.requests_to_run
+ running_requests.update_seq_length()
+ if running_requests:
+ next_tokens, done_tokens = self._process_logits(
+ next_token_logits, running_requests
+ )
+ running_requests.next_tokens = next_tokens
+ running_requests.done_tokens = done_tokens
+
+ # 5. Schedule requests while we wait for the forward pass to finish
+ self._reset_scheduler_bookkeeping()
+
+ # 6. Accumulate generated tokens, check completion, and generate output
+ for r in running_requests.last_in_prompt:
+ r.accumulate_generated_token()
+ self._num_generated_tokens += 1
+ if r.stop_generation or r.stream:
+ self._generate_output(r)
+ if not r.stop_generation:
+ r.set_next_as_input()
+ self.request_queue.put(r)
+
+ # 7. Update scheduled requests
+ self.scheduled_requests.prune(running_requests.completed.uids)
+ self.schedule_requests()
+
+ if self.profile_model_time:
+ self._print_profiled_times()
+
+ def _print_profiled_times(self) -> None:
+ self._iters += 1
+ if not (self._iters % 100 == 0):
+ return
+ for event, times in self._profiled_times.items():
+ mean_time = sum(times) / len(times)
+ log_msg = f"{event}: {mean_time}"
+ if event == "generate":
+ log_msg += f" ({self._num_generated_tokens / sum(times)} tokens/ms)"
+ logger.info(log_msg)
+ self._profiled_times.clear()
+ self._num_generated_tokens = 0
+
+ @sync_debug
+ def _bcast_requests(self, force=False) -> RaggedRequestBatch:
+ if self.is_rank_0:
+ if not self.scheduled_requests and not force:
+ return self.scheduled_requests
+ # Rank 0 gets batch of requests and broadcasts to other ranks
+ data_dicts = [asdict(r.get_msg()) for r in self.scheduled_requests]
+ json_data = ujson.dumps(data_dicts)
+ self.socket.send_string(json_data)
+ else:
+ try:
+ json_data = self.socket.recv_string()
+ data_dicts = ujson.loads(json_data)
+ self.scheduled_requests = RaggedRequestBatch(
+ [RaggedRequestMsg.from_msg(msg) for msg in data_dicts])
+ except zmq.Again:
+ self.scheduled_requests = RaggedRequestBatch([])
+
+ return self.scheduled_requests
+
+ def _reset_scheduler_bookkeeping(self) -> None:
+ self.scheduled_requests = RaggedRequestBatch([])
+ self.scheduled_length = 0
+ self.scheduled_seq_num = 0
+ self.scheduled_req_blocks = 0
+
+ @sync_debug
+ def _process_logits(
+ self,
+ next_token_logits: torch.Tensor,
+ running_requests: RaggedRequestBatch) -> Tuple[torch.Tensor,
+ torch.Tensor]:
+ next_token_logits = next_token_logits[:, :self.vocab_size]
+ next_token_logits = self.logit_processor(next_token_logits, running_requests)
+ next_tokens = self.sampler(next_token_logits, running_requests)
+ done_tokens = self.stop_criterion(next_tokens, running_requests)
+ next_tokens = next_tokens.to(torch.device("cpu"), non_blocking=False)
+ return next_tokens, done_tokens
+
+ @sync_debug
+ def _generate_output(self, r: RaggedRequest) -> bool:
+ outputs = []
+ if r.stream:
+ outputs.append((
+ [r.next_token],
+ r.prompt_length,
+ r.num_generated_tokens,
+ GenerationFinishReason.NONE,
+ ))
+ if r.finish_reason != GenerationFinishReason.NONE:
+ if r.stream or not r.generated_tokens:
+ output_tokens = []
+ else:
+ output_tokens = torch.cat([t.unsqueeze(0) for t in r.generated_tokens],
+ dim=0)
+ outputs.append((
+ output_tokens,
+ r.prompt_length,
+ r.num_generated_tokens,
+ r.finish_reason,
+ ))
+ for output in outputs:
+ self.result_queues[r.uid].put_nowait(output)
+
+ def _do_schedule_requests(self, requests: List[RaggedRequest]) -> None:
+
+ free_blocks = self.inference_engine._state_manager.free_blocks
+ conf_manager = self.inference_engine._config.state_manager
+ for r in requests:
+ if r.max_length <= r.seq_length:
+ continue
+
+ # Make sure that the engine has enough capacity to process the batch
+ if len(self.scheduled_requests) > conf_manager.max_ragged_sequence_count:
+ break
+
+ max_batch_size = conf_manager.max_ragged_batch_size - self.scheduled_length
+ if max_batch_size <= 0:
+ break
+
+ max_blocks = free_blocks - self.scheduled_req_blocks
+ req_tokens = min(len(r.input_tokens), max_batch_size)
+ req_tokens, req_blocks = self.inference_engine.query(r.uid, req_tokens, max_blocks)
+
+ if req_tokens <= 0:
+ continue
+
+ # Decompose the prompt to fit to the max ragged batch size
+ decomposed = req_tokens < len(r.input_tokens)
+ remaining_tokens = r.input_tokens[req_tokens:]
+ r.input_tokens = r.input_tokens[:req_tokens]
+ r.last_in_prompt = not decomposed
+
+ # Schedule the request
+ self.scheduled_requests.append(r)
+
+ self.scheduled_req_blocks += req_blocks
+ self.scheduled_length += req_tokens
+
+ if decomposed:
+ req_remaining = copy.copy(r)
+ req_remaining.input_tokens = remaining_tokens
+ req_remaining.seq_length = r.seq_length + req_tokens
+ req_remaining.last_in_prompt = True
+
+ self.buffer.appendleft(req_remaining)
+
+ def schedule_requests(self) -> None:
+ while not self.request_queue.empty():
+ r = self.request_queue.get_nowait()
+ self.buffer.append(r)
+
+ # Run next token generation first
+ next_token_gen_reqs = []
+ prompt_reqs = []
+
+ for r in self.buffer:
+ if r.is_flush_request:
+ self.scheduled_requests.append(r)
+ else:
+ if len(r.input_tokens) == 1:
+ next_token_gen_reqs.append(r)
+ else:
+ prompt_reqs.append(r)
+
+ # We want to process next token generation first
+ self._do_schedule_requests(next_token_gen_reqs)
+ self._do_schedule_requests(prompt_reqs)
+
+ scheduled_requests_ids = set(id(r) for r in self.scheduled_requests)
+ self.buffer = deque(
+ [r for r in self.buffer if id(r) not in scheduled_requests_ids])
+
+ def make_request(self,
+ uid: int,
+ input_tokens: torch.Tensor,
+ kwargs: Dict) -> List[RaggedRequest]:
+ max_length = kwargs.pop("max_length", self.max_length)
+ max_new_tokens = kwargs.pop("max_new_tokens", max_length - len(input_tokens))
+ stream = kwargs.pop("stream", False)
+ # TODO: Add back this check
+ # if self.policy.get_length(uid) + len(token_ids) >= max_length:
+ # raise ValueError(f"Session {uid} has reached max length {max_length}.")
+
+ postprocess_config = kwargs.pop("postprocess_config", {})
+ accepted_keys = ("logit_processor", "sampler", "stop_criterion")
+ for key in postprocess_config.keys():
+ if key not in accepted_keys:
+ raise ValueError(
+ f"Unknown postprocess_config keyword {key}. Accepted keywords are {accepted_keys}"
+ )
+ logit_processor = _create_postprocessor(
+ postprocess_config.get("logit_processor",
+ DEFAULT_LOGITS_PROCESSOR),
+ LOGITS_PROCESSORS,
+ )
+ sampler = _create_postprocessor(
+ postprocess_config.get("sampler",
+ DEFAULT_SAMPLER),
+ SAMPLERS)
+ stop_criterion = _create_postprocessor(
+ postprocess_config.get("stop_criterion",
+ DEFAULT_STOP_CRITERION),
+ STOP_CRITERIA,
+ {"tokenizer": self.tokenizer},
+ )
+
+ assert kwargs == {}, f"Unknown keyword arguments {kwargs}"
+
+ return [
+ RaggedRequest(
+ uid=uid,
+ input_tokens=input_tokens,
+ prompt_length=len(input_tokens),
+ seq_length=0,
+ max_length=max_length,
+ max_new_tokens=max_new_tokens,
+ last_in_prompt=True,
+ logit_processor=logit_processor,
+ sampler=sampler,
+ stop_criterion=stop_criterion,
+ stream=stream,
+ )
+ ]
+
+ def make_response(self,
+ generated_text: str,
+ prompt_length: int,
+ generated_length: int,
+ finish_reason: GenerationFinishReason) -> Response:
+ return Response(generated_text=generated_text,
+ prompt_length=prompt_length,
+ generated_length=generated_length,
+ finish_reason=finish_reason)
+
+ def put(self, uids: List[int], tokenized_input: List[torch.Tensor]) -> torch.Tensor:
+ return self.inference_engine.put(uids, tokenized_input)
+
+ def flush(self, uids: List[int]) -> None:
+ for uid in uids:
+ self.inference_engine.flush(uid)
+
+
+class MIIPipeline(RaggedBatchBase):
+ def __call__(self, inputs: Union[str, List[str]], **kwargs) -> ResponseBatch:
+ if isinstance(inputs, str):
+ inputs = [inputs]
+ outputs: ResponseBatch = ResponseBatch([])
+ uids: List[int] = list(range(len(inputs)))
+ flushed_uids: Set[int] = set()
+
+ for uid, input in zip(uids, inputs):
+ request_kwargs = kwargs.copy()
+ self._enqueue_request(uid, input, request_kwargs)
+
+ while self.scheduled_requests:
+ self.generate()
+ # Make sure we flush uids as they are done generating
+ for uid, result_queue in self.result_queues.items():
+ if (not result_queue.empty()) and uid not in flushed_uids:
+ flushed_uids.add(uid)
+ self.request_queue.put_nowait(
+ RaggedRequest(
+ uid=uid,
+ input_tokens=None,
+ prompt_length=None,
+ seq_length=None,
+ max_length=None,
+ max_new_tokens=None,
+ last_in_prompt=None,
+ logit_processor=None,
+ sampler=None,
+ stop_criterion=None,
+ stream=None,
+ ))
+
+ if self.is_rank_0:
+ # To kick ranks 1 -> n out of the while loop
+ self._bcast_requests(force=True)
+
+ for uid in range(len(inputs)):
+ outputs.append(self._dequeue_response(uid))
+
+ if self.model_config.all_rank_output:
+ outputs = self._bcast_responses(outputs)
+
+ return outputs
+
+ def _enqueue_request(self, uid: int, input: str, kwargs: Dict[str, Any]) -> None:
+ self.result_queues[uid] = queue.Queue()
+ input_tokens = self.tokenizer.encode(input)
+ for r in self.make_request(uid, input_tokens, kwargs):
+ self.request_queue.put(r)
+ self.schedule_requests()
+
+ def _dequeue_response(self, uid: int) -> Response:
+ result = self.result_queues[uid].get()
+ generated_tokens = self.tokenizer.decode(result[0])
+ response = self.make_response(generated_tokens, result[1], result[2], result[3])
+ return response
+
+ def _bcast_responses(self, responses: ResponseBatch) -> ResponseBatch:
+ if self.is_rank_0:
+ data_dicts = [r.get_msg() for r in responses]
+ json_data = ujson.dumps(data_dicts)
+ self.socket.send_string(json_data)
+ else:
+ json_data = self.socket.recv_string()
+ data_dicts = ujson.loads(json_data)
+ responses = ResponseBatch([Response.from_msg(msg) for msg in data_dicts])
+ return responses
+
+
+class MIIAsyncPipeline(RaggedBatchBase):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.uids = set()
+ self.session_to_uid: Dict[str, int] = {}
+ self.lock = threading.Lock()
+ self.thread = None
+ self.stop_thread = False
+ self._is_shutdown = False
+ self.UID_RANGE_LB = 1
+ self.UID_RANGE_UB = 10000
+
+ def __call__(self) -> None:
+ # CUDA device gets reset, must set it again to avoid problems
+ get_accelerator().set_device(int(os.getenv("LOCAL_RANK", "0")))
+ while True:
+ self.generate()
+
+ if (self.stop_thread and self.request_queue.empty()
+ and all(q.empty() for q in self.result_queues.values())):
+ break
+
+ def _get_uid(self, session_id: Union[str, None]):
+ if session_id in self.session_to_uid:
+ return self.session_to_uid[session_id]
+
+ # Create a new uid
+ with self.lock:
+ uid = random.randrange(self.UID_RANGE_LB, self.UID_RANGE_UB)
+ while uid in self.uids:
+ uid = random.randrange(self.UID_RANGE_LB, self.UID_RANGE_UB)
+ self.uids.add(uid)
+
+ if session_id is not None:
+ self.session_to_uid[session_id] = uid
+
+ return uid
+
+ def put_request(self,
+ args: Tuple,
+ kwargs: Dict,
+ session_id: Union[str,
+ None] = None) -> int:
+ if self.stop_thread:
+ raise RuntimeError("The request queue was shutdown.")
+
+ uid = self._get_uid(session_id)
+
+ with self.lock:
+ if uid not in self.result_queues:
+ self.result_queues[uid] = queue.Queue()
+
+ for input in args[0]:
+ input_tokens = self.tokenizer.encode(input)
+ for r in self.make_request(uid, input_tokens, kwargs):
+ self.request_queue.put(r)
+
+ return uid
+
+ def get_response(self, uid: int) -> List[Response]:
+ result = self.result_queues[uid].get()
+ generated_token_ids = result[0]
+ if len(generated_token_ids) == 0:
+ generated_text = ""
+ else:
+ generated_text = self.tokenizer.decode(generated_token_ids)
+ response = self.make_response(generated_text, result[1], result[2], result[3])
+ return [response]
+
+ def start(self) -> None:
+ self.thread = threading.Thread(target=self, daemon=True)
+ self.thread.start()
+
+ def shutdown(self) -> None:
+ self.stop_thread = True
+ self.thread.join()
+ self._is_shutdown = True
+
+ def is_shutdown(self) -> bool:
+ return self._is_shutdown
+
+ def destroy_session(self,
+ session_id: Union[str,
+ None],
+ uid: Union[int,
+ None] = None) -> None:
+ with self.lock:
+ if session_id in self.session_to_uid:
+ uid = self.session_to_uid[session_id]
+ del self.session_to_uid[session_id]
+ if uid in self.result_queues:
+ del self.result_queues[uid]
+ if self.is_rank_0:
+ self.request_queue.put_nowait(
+ RaggedRequest(
+ uid=uid,
+ input_tokens=None,
+ prompt_length=None,
+ seq_length=None,
+ max_length=None,
+ max_new_tokens=None,
+ last_in_prompt=None,
+ logit_processor=None,
+ sampler=None,
+ stop_criterion=None,
+ stream=None,
+ ))
+ self.uids.remove(uid)
diff --git a/mii/batching/utils.py b/mii/batching/utils.py
new file mode 100644
index 00000000..eae0c875
--- /dev/null
+++ b/mii/batching/utils.py
@@ -0,0 +1,40 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+from functools import wraps
+
+from deepspeed.accelerator import get_accelerator
+
+from mii.logging import logger
+
+
+def sync_debug(func):
+ @wraps(func)
+ def wrapper(self, *args, **kwargs):
+ if self.sync_debug:
+ get_accelerator().synchronize()
+ logger.debug(f"Calling {func.__name__} with args: {args}, kwargs: {kwargs}")
+ result = func(self, *args, **kwargs)
+ if self.sync_debug:
+ get_accelerator().synchronize()
+ logger.debug(f"Finished calling {func.__name__}")
+ return result
+
+ return wrapper
+
+
+def profiler(func):
+ @wraps(func)
+ def wrapper(self, *args, **kwargs):
+ if not self.profile_model_time:
+ return func(self, *args, **kwargs)
+
+ self._timers(func.__name__).start()
+ result = func(self, *args, **kwargs)
+ self._timers(func.__name__).stop()
+ self._profiled_times[func.__name__].append(
+ self._timers(func.__name__).elapsed(reset=True))
+ return result
+
+ return wrapper
diff --git a/mii/client.py b/mii/client.py
index 5ebcf8e5..39aa3488 100644
--- a/mii/client.py
+++ b/mii/client.py
@@ -5,40 +5,12 @@
import asyncio
import grpc
import requests
-import mii
-from .grpc_related.proto import modelresponse_pb2, modelresponse_pb2_grpc
-from .constants import GRPC_MAX_MSG_SIZE, TaskType, DeploymentType
-from .method_table import GRPC_METHOD_TABLE
-from .config import MIIConfig
-from .utils import import_score_file
+from typing import Dict, Any, Callable
-
-def _get_mii_config(deployment_name):
- mii_config = import_score_file(deployment_name, DeploymentType.LOCAL).mii_config
- return MIIConfig(**mii_config)
-
-
-def mii_query_handle(deployment_name):
- """Get a query handle for a local deployment:
-
- mii/examples/local/gpt2-query-example.py
- mii/examples/local/roberta-qa-query-example.py
-
- Arguments:
- deployment_name: Name of the deployment. Used as an identifier for posting queries for ``LOCAL`` deployment.
-
- Returns:
- query_handle: A query handle with a single method `.query(request_dictionary)` using which queries can be sent to the model.
- """
-
- if deployment_name in mii.non_persistent_models:
- inference_pipeline, task = mii.non_persistent_models[deployment_name]
- return MIINonPersistentClient(task, deployment_name)
-
- mii_config = _get_mii_config(deployment_name)
- return MIIClient(mii_config.model_config.task,
- "localhost", # TODO: This can probably be removed
- mii_config.port_number)
+from mii.config import get_mii_config, MIIConfig
+from mii.constants import GRPC_MAX_MSG_SIZE, TaskType
+from mii.grpc_related.proto import modelresponse_pb2, modelresponse_pb2_grpc
+from mii.task_methods import TASK_METHODS_DICT
def create_channel(host, port):
@@ -57,32 +29,77 @@ class MIIClient:
"""
Client to send queries to a single endpoint.
"""
- def __init__(self, task, host, port):
+ def __init__(self, mii_config: MIIConfig, host: str = "localhost") -> None:
+ self.mii_config = mii_config
+ self.task = mii_config.model_config.task
+ self.port = mii_config.port_number
self.asyncio_loop = asyncio.get_event_loop()
- channel = create_channel(host, port)
+ channel = create_channel(host, self.port)
self.stub = modelresponse_pb2_grpc.ModelResponseStub(channel)
- self.task = task
- async def _request_async_response(self, request_dict, **query_kwargs):
- if self.task not in GRPC_METHOD_TABLE:
- raise ValueError(f"unknown task: {self.task}")
+ def __call__(self, *args, **kwargs):
+ return self.generate(*args, **kwargs)
- task_methods = GRPC_METHOD_TABLE[self.task]
+ async def _request_async_response(self, request_dict, **query_kwargs):
+ task_methods = TASK_METHODS_DICT[self.task]
proto_request = task_methods.pack_request_to_proto(request_dict, **query_kwargs)
proto_response = await getattr(self.stub, task_methods.method)(proto_request)
return task_methods.unpack_response_from_proto(proto_response)
- def query(self, request_dict, **query_kwargs):
+ async def _request_async_response_stream(self, request_dict, **query_kwargs):
+ task_methods = TASK_METHODS_DICT[self.task]
+ proto_request = task_methods.pack_request_to_proto(request_dict, **query_kwargs)
+ assert hasattr(task_methods, "method_stream_out"), f"{self.task} does not support streaming response"
+ async for response in getattr(self.stub,
+ task_methods.method_stream_out)(proto_request):
+ yield task_methods.unpack_response_from_proto(response)
+
+ def generate(self,
+ prompt: str,
+ streaming_fn: Callable = None,
+ **query_kwargs: Dict[str,
+ Any]):
+ if not isinstance(prompt, str):
+ raise RuntimeError(
+ "MII client only supports a single query string, multi-string will be added soon"
+ )
+ request_dict = {"query": prompt}
+ if streaming_fn is not None:
+ return self._generate_stream(streaming_fn, request_dict, **query_kwargs)
+
return self.asyncio_loop.run_until_complete(
self._request_async_response(request_dict,
**query_kwargs))
+ def _generate_stream(self,
+ callback,
+ request_dict: Dict[str,
+ str],
+ **query_kwargs: Dict[str,
+ Any]):
+ async def put_result():
+ response_stream = self._request_async_response_stream(
+ request_dict,
+ **query_kwargs)
+
+ while True:
+ try:
+ response = await response_stream.__anext__()
+ callback(response)
+ except StopAsyncIteration:
+ break
+
+ self.asyncio_loop.run_until_complete(put_result())
+
async def terminate_async(self):
await self.stub.Terminate(
modelresponse_pb2.google_dot_protobuf_dot_empty__pb2.Empty())
- def terminate(self):
+ def terminate_server(self):
self.asyncio_loop.run_until_complete(self.terminate_async())
+ if self.mii_config.enable_restful_api:
+ requests.get(
+ f"http://localhost:{self.mii_config.restful_api_port}/terminate")
async def create_session_async(self, session_id):
return await self.stub.CreateSession(
@@ -106,44 +123,7 @@ def destroy_session(self, session_id):
self.asyncio_loop.run_until_complete(self.destroy_session_async(session_id))
-class MIINonPersistentClient:
- def __init__(self, task, deployment_name):
- self.task = task
- self.deployment_name = deployment_name
+def client(model_or_deployment_name: str) -> MIIClient:
+ mii_config = get_mii_config(model_or_deployment_name)
- def query(self, request_dict, **query_kwargs):
- assert (
- self.deployment_name in mii.non_persistent_models
- ), f"deployment: {self.deployment_name} not found"
- task_methods = GRPC_METHOD_TABLE[self.task]
- inference_pipeline = mii.non_persistent_models[self.deployment_name][0]
-
- # TODO: refactor so this code is shared between non-persistent and
- # persistent deployments in method_table.py
- if self.task == TaskType.QUESTION_ANSWERING:
- if "question" not in request_dict or "context" not in request_dict:
- raise Exception(
- "Question Answering Task requires 'question' and 'context' keys")
- args = (request_dict["question"], request_dict["context"])
- kwargs = query_kwargs
-
- elif self.task == TaskType.CONVERSATIONAL:
- conv = task_methods.create_conversation(request_dict)
- args = (conv, )
- kwargs = query_kwargs
-
- else:
- args = (request_dict["query"], )
- kwargs = query_kwargs
-
- return task_methods.run_inference(inference_pipeline, args, query_kwargs)
-
- def terminate(self):
- print(f"Terminating {self.deployment_name}...")
- del mii.non_persistent_models[self.deployment_name]
-
-
-def terminate_restful_gateway(deployment_name):
- mii_config = _get_mii_config(deployment_name)
- if mii_config.enable_restful_api:
- requests.get(f"http://localhost:{mii_config.restful_api_port}/terminate")
+ return MIIClient(mii_config)
diff --git a/mii/config.py b/mii/config.py
index 3c5806d4..e92efb2a 100644
--- a/mii/config.py
+++ b/mii/config.py
@@ -2,17 +2,19 @@
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
-import torch
import os
import string
-from typing import List, Optional, Dict, Any
-import mii
-from .constants import DeploymentType, TaskType, MII_MODEL_PATH_DEFAULT
-from .pydantic_v1 import validator, root_validator, Field
+from typing import List, Optional, Union, Dict, Any
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
-from deepspeed.inference.config import DtypeEnum
from deepspeed.launcher.runner import DLTS_HOSTFILE, fetch_hostfile
+from deepspeed.inference import RaggedInferenceEngineConfig
+
+from mii.constants import DeploymentType, TaskType, ModelProvider
+from mii.errors import DeploymentNotFoundError
+from mii.pydantic_v1 import Field, root_validator
+from mii.tokenizers import MIITokenizerWrapper
+from mii.utils import generate_deployment_name, get_default_task, import_score_file
class ReplicaConfig(DeepSpeedConfigModel):
@@ -20,66 +22,54 @@ class ReplicaConfig(DeepSpeedConfigModel):
tensor_parallel_ports: List[int] = []
torch_dist_port: int = None
gpu_indices: List[int] = []
+ zmq_port: int = None
class ModelConfig(DeepSpeedConfigModel):
- model: str
- """
- Name of a supported model for the task. Models in MII are sourced from
- multiple open-source projects such as Huggingface Transformer, FairSeq,
- EluetherAI etc. For the list of supported models for each task, please see
- here [TODO].
- """
-
- task: TaskType
+ model_name_or_path: str
"""
- Name of the machine learning task to be deployed.Currently MII supports the
- following list of tasks ``['text-generation', 'text-classification',
- 'question-answering', 'fill-mask', 'token-classification',
- 'conversational', 'text-to-image']``
+ Model name or path of the model to HuggingFace model to be deployed.
"""
- dtype: DtypeEnum = DtypeEnum.fp32
+ tokenizer: Optional[Union[str, MIITokenizerWrapper]] = None
"""
- Desired model data type, will convert model to this type. Supported target
- types: `torch.half`, `torch.float`, `torch.int8` (for BLOOM models)
+ Tokenizer wrapped with `MIITokenizerWrapper`, name or path of the
+ HuggingFace tokenizer to be used.
"""
- model_path: str = ""
+ task: Optional[TaskType] = TaskType.TEXT_GENERATION
"""
- In LOCAL deployments this is the local path where model checkpoints are
- available. In AML deployments this is an optional relative path with
- AZURE_MODEL_DIR for the deployment.
+ Name of the task to be performed by the model.
"""
- load_with_sys_mem: bool = False
+ tensor_parallel: int = int(os.getenv("WORLD_SIZE", "1"))
"""
- Loads the model onto system memory instead of GPU memory. This can help
- avoid OOM errors when sharding a model across several GPUs because MII will
- try to load a full copy of each model onto each GPU initially.
+ Tensor parallelism to use for a model (i.e., how many GPUs to shard a model
+ across). This defaults to the `WORLD_SIZE` environment variable, or a value
+ of 1 if that variable is not set. This value is also propagated to the
+ `inference_engine_config`.
"""
- meta_tensor: bool = False
+ inference_engine_config: RaggedInferenceEngineConfig = {}
"""
- Loads the initial HuggingFace model using Meta Tensors that use no memory.
- Can dramatically improve load time and reduce memory requirements on
- supported models. Supported for GPT-J, GPT-NeoX, OPT, and BLOOM when kernel
- injection is enabled. Supported for all models when kernel injection is
- disabled.
+ DeepSpeed inference engine config. This is automatically generated, but you
+ can provide a set of custom configs.
"""
- deploy_rank: Optional[List[int]] = None
+ torch_dist_port: int = 29500
"""
- GPU indices a model is deployed on. Note that CUDA_VISIBLE_DEVICES does not
- work with DeepSpeed-MII.
+ Torch distributed port to be used. This also serves as a base port when
+ multiple replicas are deployed. For example, if there are 2 replicas, the
+ first will use port 29500 and the second will use port 29600.
"""
- torch_dist_port: int = 29500
+ zmq_port_number: int = 25555
"""
- Torch distributed port.
+ Port number to use for the ZMQ communication (for broadcasting requests and
+ responses among all ranks in ragged batching).
"""
- replica_num: int = 1
+ replica_num: int = Field(1, gt=0)
"""
Number of model replicas. Enables easy data parallelism.
"""
@@ -90,219 +80,64 @@ class ModelConfig(DeepSpeedConfigModel):
generated, but you can provide a set of custom configs.
"""
- profile_model_time: bool = False
- """
- Enable profiling of model times (i.e., without communication overhead).
- """
-
- skip_model_check: bool = False
- """
- Skip validation that a model supports a given task.
- """
-
- hf_auth_token: Optional[str] = Field(
- None,
- deprecated=True,
- deprecated_msg=
- "Parameter will be removed. Please use the `pipeline_kwargs` field to pass kwargs to the HuggingFace pipeline creation.",
- )
- """
- HuggingFace authentication token for accessing models. Will be propagated
- to all ModelConfig if none are provided there.
- """
-
- trust_remote_code: bool = Field(
- False,
- deprecated=True,
- deprecated_msg=
- "Parameter will be removed. Please use the `pipeline_kwargs` field to pass kwargs to the HuggingFace pipeline creation.",
- )
- """
- HuggingFace `tranformer.pipeline` option for `trust_remote_code`.
- """
-
- pipeline_kwargs: Dict[str, Any] = {}
- """
- kwargs to be passed to HuggingFace's `transformer.pipeline`.
- """
-
- # TODO: Replace with DeepSpeedInferenceConfig
- enable_deepspeed: bool = True
- """
- Enable DeepSpeed-Inference.
- """
-
- enable_zero: bool = False
- """
- Enable Zero-Inference.
- """
-
- ds_config: Dict[str, Any] = {}
- """
- DeepSpeed config to use when Zero-Inference is enabled.
- """
-
- tensor_parallel: int = 1
- """
- Tensor parallelism to use for a model (i.e., how many GPUs to shard a model across).
+ max_length: Optional[int] = None
"""
-
- enable_cuda_graph: bool = False
- """
- Enables CUDA Graph captures with DeepSpeed-Inference.
+ The maximum number of tokens DeepSpeed-Inference can work with, including
+ the input and output tokens.
"""
- replace_with_kernel_inject: bool = True
+ all_rank_output: bool = False
"""
- Enable custom kernel injection with DeepSpeed-Inference.
+ Weather to return output on all ranks for `mii.pipeline`. Default behavior
+ is to only return on rank 0.
"""
- checkpoint_dict: Optional[Dict[str, Any]] = None
+ sync_debug: bool = False
"""
- DeepSpeed model checkpoint dict.
+ Inserts additional synchronization points for debugging purposes.
"""
- max_tokens: int = 1024
+ profile_model_time: bool = False
"""
- The maximum number of tokens DeepSpeed-Inference can work with, including
- the input and output tokens. Please consider increasing it to the required
- token-length required for your use-case.
+ Log performance information about model inference with very little overhead.
"""
- class Config:
- json_encoders = {torch.dtype: lambda x: str(x)}
-
@property
- def provider(self):
- return mii.utils.get_provider(self.model, self.task)
-
- @validator("checkpoint_dict")
- def checkpoint_dict_valid(cls, field_value, values):
- if field_value is None:
- return field_value
- for k in ["checkpoints", "version", "type", "base_dir"]:
- if not field_value.get(k, ""):
- raise ValueError(f"Missing key={k} in checkpoint_dict")
- return field_value
-
- @validator("deploy_rank", pre=True)
- def deploy_rank_to_list(cls, field_value, values):
- if field_value and not isinstance(field_value, list):
- field_value = [field_value]
- return field_value
-
- @root_validator
- def zero_or_meta(cls, values):
- if values.get("enable_zero"):
- assert not values.get(
- "meta_tensor"
- ), "ZeRO-Inference does not support meta tensors."
- return values
+ def provider(self) -> ModelProvider:
+ return ModelProvider.HUGGING_FACE
@root_validator
- def bloom_model_valid(cls, values):
- if "bigscience/bloom" in values.get("model"):
- # TODO: SHould be albe to use DtypeEnum here
- assert values.get("dtype") in [
- torch.int8,
- torch.float16,
- ], "Bloom models only support fp16/int8."
- assert not values.get(
- "enable_cuda_graph"
- ), "Bloom models do not support CUDA Graph."
+ def auto_fill_values(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ if not values.get("tokenizer"):
+ values["tokenizer"] = values.get("model_name_or_path")
+ if not values.get("task"):
+ values["task"] = get_default_task(values.get("model_name_or_path"))
return values
@root_validator
- def deploy_rank_valid(cls, values):
+ def propagate_tp_size(cls, values: Dict[str, Any]) -> Dict[str, Any]:
tensor_parallel = values.get("tensor_parallel")
- deploy_rank = values.get("deploy_rank")
-
- # if deploy rank is not given, default to align with TP value
- if deploy_rank is None:
- deploy_rank = list(range(tensor_parallel))
-
- # number of ranks provided must be equal to TP size, DP is handled outside MII currently
- assert tensor_parallel == len(
- deploy_rank
- ), f"{len(deploy_rank)} rank(s) provided in 'deploy_rank' does not align with tensor_parallel size of {tensor_parallel}"
-
- values["deploy_rank"] = deploy_rank
+ values.get("inference_engine_config").tensor_parallel.tp_size = tensor_parallel
return values
@root_validator
- def set_model_path(cls, values):
- model_path = values.get("model_path")
- if not model_path:
- if values.get("deployment_type") == DeploymentType.AML:
- model_path = "model"
- else:
- model_path = MII_MODEL_PATH_DEFAULT
- aml_model_dir = os.environ.get("AZUREML_MODEL_DIR", None)
- if aml_model_dir and not model_path.startswith(aml_model_dir):
- assert os.path.isabs(
- aml_model_dir
- ), "AZUREML_MODEL_DIR={aml_model_dir} must be an absolute path."
- assert not os.path.isabs(
- model_path
- ), f"model_path={model_path} must be relative to append w/ AML path."
- model_path = os.path.join(aml_model_dir, model_path)
-
- values["model_path"] = model_path
- return values
-
- @root_validator
- def validate_model_and_task(cls, values):
- task = values.get("task")
- model = values.get("model")
- if not values.get("skip_model_check"):
- mii.utils.check_if_task_and_model_is_valid(task, model)
- if values.get("enable_deepspeed"):
- mii.utils.check_if_task_and_model_is_supported(task, model)
- # Skip any future checks
- values["skip_model_check"] = True
- return values
-
- @root_validator
- def meta_tensor_or_sys_mem(cls, values):
- if values.get("meta_tensor") and values.get("load_with_sys_mem"):
- raise ValueError(
- "`meta_tensor` and `load_with_sys_mem` cannot be active at the same time."
- )
- return values
-
- @root_validator
- def zero_dtype_valid(cls, values):
- if values.get("enable_zero"):
- if values.get("ds_config").get("fp16", {}).get("enabled", False):
- # TODO: We should be able to use DtypeEnum instead of torch.float
- assert (
- values.get("dtype") == torch.float16
- ), "ZeRO FP16 enabled, `dtype` must be set to `torch.float16`"
- else:
- assert (
- values.get("dtype") == torch.float32
- ), "ZeRO FP16 disabled, `dtype` must be set to `torch.float32`"
- return values
-
- @root_validator
- def deepspeed_or_zero(cls, values):
- assert not (
- values.get("enable_deepspeed") and values.get("enable_zero")
- ), "DeepSpeed and ZeRO cannot both be enabled, select only one"
+ def check_replica_config(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ num_replica_config = len(values.get("replica_configs"))
+ if num_replica_config > 0:
+ assert num_replica_config == values.get("replica_num"), "Number of replica configs must match replica_num"
return values
class MIIConfig(DeepSpeedConfigModel):
- deployment_name: str
+ deployment_name: str = ""
"""
Name of the deployment. Used as an identifier for obtaining a inference
- server client and posting queries.
+ server client and posting queries. Automatically generated if it is not provided.
"""
deployment_type: DeploymentType = DeploymentType.LOCAL
"""
- One of the `enum mii.DeploymentTypes: [LOCAL]`.
+ One of the `enum mii.DeploymentTypes:`
* `LOCAL` uses a grpc server to create a local deployment.
- * `NON_PERSISTENT` creates a local deployment that will end when the process exits.
* `AML` will generate the assets necessary to deploy on AML resources.
"""
@@ -342,7 +177,7 @@ class MIIConfig(DeepSpeedConfigModel):
AML instance type to use when create AML deployment assets.
"""
@root_validator(skip_on_failure=True)
- def AML_name_valid(cls, values):
+ def AML_name_valid(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if values.get("deployment_type") == DeploymentType.AML:
allowed_chars = set(string.ascii_lowercase + string.ascii_uppercase +
string.digits + "-")
@@ -351,12 +186,25 @@ def AML_name_valid(cls, values):
), "AML deployment names can only contain a-z, A-Z, 0-9, and '-'."
return values
- def generate_replica_configs(self):
+ @root_validator(skip_on_failure=True)
+ def check_deployment_name(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ deployment_name = values.get("deployment_name")
+ if not deployment_name:
+ model_name_or_path = values.get("model_config").model_name_or_path
+ deployment_name = generate_deployment_name(
+ model_name_or_path=model_name_or_path)
+ values["deployment_name"] = deployment_name
+ return values
+
+ def generate_replica_configs(self) -> None:
+ if self.model_config.replica_configs:
+ return
# TODO: refactor this function
hostfile = self.hostfile
port_number = self.port_number
torch_dist_port = self.model_config.torch_dist_port
tensor_parallel = self.model_config.tensor_parallel
+ zmq_port = self.model_config.zmq_port_number
replica_num = self.model_config.replica_num
replica_pool = _allocate_processes(hostfile, tensor_parallel, replica_num)
replica_configs = []
@@ -372,12 +220,13 @@ def generate_replica_configs(self):
tensor_parallel_ports=tensor_parallel_ports,
torch_dist_port=replica_torch_dist_port,
gpu_indices=gpu_indices,
+ zmq_port=zmq_port + i,
))
self.model_config.replica_configs = replica_configs
-def _allocate_processes(hostfile_path, tensor_parallel, replica_num):
+def _allocate_processes(hostfile_path: str, tensor_parallel: int, replica_num: int):
resource_pool = fetch_hostfile(hostfile_path)
assert (
resource_pool is not None and len(resource_pool) > 0
@@ -415,3 +264,20 @@ def _allocate_processes(hostfile_path, tensor_parallel, replica_num):
)
return replica_pool
+
+
+def get_mii_config(model_or_deployment_name: str) -> MIIConfig:
+ try:
+ deployment_name = model_or_deployment_name
+ mii_config = import_score_file(deployment_name, DeploymentType.LOCAL).mii_config
+ except:
+ try:
+ deployment_name = generate_deployment_name(
+ model_name_or_path=model_or_deployment_name)
+ mii_config = import_score_file(deployment_name,
+ DeploymentType.LOCAL).mii_config
+ except:
+ raise DeploymentNotFoundError(
+ f"Could not find a deployment named {model_or_deployment_name} or {deployment_name}"
+ )
+ return MIIConfig(**mii_config)
diff --git a/mii/constants.py b/mii/constants.py
index ea90b87a..729bba2b 100644
--- a/mii/constants.py
+++ b/mii/constants.py
@@ -13,49 +13,25 @@ class DeploymentType(str, Enum):
class TaskType(str, Enum):
TEXT_GENERATION = "text-generation"
- TEXT_CLASSIFICATION = "text-classification"
- QUESTION_ANSWERING = "question-answering"
- FILL_MASK = "fill-mask"
- TOKEN_CLASSIFICATION = "token-classification"
- CONVERSATIONAL = "conversational"
- TEXT2IMG = "text-to-image"
class ModelProvider(str, Enum):
HUGGING_FACE = "hugging-face"
- ELEUTHER_AI = "eleuther-ai"
- DIFFUSERS = "diffusers"
+
+
+class GenerationFinishReason(str, Enum):
+ STOP = "stop"
+ LENGTH = "length"
+ NONE = "none"
SUPPORTED_MODEL_TYPES = {
- 'roberta': ModelProvider.HUGGING_FACE,
- 'xlm-roberta': ModelProvider.HUGGING_FACE,
- 'gpt2': ModelProvider.HUGGING_FACE,
- 'distilbert': ModelProvider.HUGGING_FACE,
- 'bert': ModelProvider.HUGGING_FACE,
- 'gpt_neo': ModelProvider.HUGGING_FACE,
- 'gptj': ModelProvider.HUGGING_FACE,
'opt': ModelProvider.HUGGING_FACE,
- 'bloom': ModelProvider.HUGGING_FACE,
- 'gpt-neox': ModelProvider.ELEUTHER_AI,
- 'stable-diffusion': ModelProvider.DIFFUSERS,
'llama': ModelProvider.HUGGING_FACE
}
REQUIRED_KEYS_PER_TASK = {
TaskType.TEXT_GENERATION: ["query"],
- TaskType.TEXT_CLASSIFICATION: ["query"],
- TaskType.QUESTION_ANSWERING: ["context",
- "question"],
- TaskType.FILL_MASK: ["query"],
- TaskType.TOKEN_CLASSIFICATION: ["query"],
- TaskType.CONVERSATIONAL: [
- "text",
- "conversation_id",
- "past_user_inputs",
- "generated_responses",
- ],
- TaskType.TEXT2IMG: ["query"],
}
MII_CACHE_PATH = "MII_CACHE_PATH"
@@ -80,9 +56,12 @@ class ModelProvider(str, Enum):
CREATE_SESSION_METHOD = "CreateSession"
DESTROY_SESSION_METHOD = "DestroySession"
-LB_MAX_WORKER_THREADS = 32
+LB_MAX_WORKER_THREADS = 256
SERVER_SHUTDOWN_TIMEOUT = 10
RESTFUL_GATEWAY_SHUTDOWN_TIMEOUT = 1
RESTFUL_API_PATH = "mii"
+
+STREAM_RESPONSE_QUEUE_TIMEOUT = 600
+ZMQ_RECV_TIMEOUT = 5 * 1000
diff --git a/mii/errors.py b/mii/errors.py
new file mode 100644
index 00000000..43050c53
--- /dev/null
+++ b/mii/errors.py
@@ -0,0 +1,8 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+
+class DeploymentNotFoundError(Exception):
+ pass
diff --git a/mii/grpc_related/modelresponse_server.py b/mii/grpc_related/modelresponse_server.py
index 0ffd4610..03719d53 100644
--- a/mii/grpc_related/modelresponse_server.py
+++ b/mii/grpc_related/modelresponse_server.py
@@ -7,11 +7,13 @@
import logging
import grpc
+
from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2
from .proto import modelresponse_pb2_grpc
import sys
import threading
import time
+import queue
from mii.constants import (
GRPC_MAX_MSG_SIZE,
@@ -20,12 +22,15 @@
TERMINATE_METHOD,
LB_MAX_WORKER_THREADS,
SERVER_SHUTDOWN_TIMEOUT,
+ STREAM_RESPONSE_QUEUE_TIMEOUT,
TaskType,
)
-from mii.method_table import GRPC_METHOD_TABLE
+from mii.task_methods import TASK_METHODS_DICT
from mii.client import create_channel
from mii.utils import unpack_proto_query_kwargs
+from mii.constants import GenerationFinishReason
+
class ServiceBase(modelresponse_pb2_grpc.ModelResponseServicer):
"""
@@ -46,37 +51,19 @@ class ModelResponse(ServiceBase):
"""
Implementation class of an MII inference server
"""
- def __init__(self, inference_pipeline):
+ def __init__(self, async_pipeline=None):
super().__init__()
- self.inference_pipeline = inference_pipeline
- self.method_name_to_task = {m.method: t for t, m in GRPC_METHOD_TABLE.items()}
+ self.inference_pipeline = async_pipeline
+ self.method_name_to_task = {m.method: t for t, m in TASK_METHODS_DICT.items()}
self.lock = threading.Lock()
- def _get_model_time(self, model, sum_times=False):
- model_times = []
- # Only grab model times if profiling was enabled/exists
- if getattr(model, "model_profile_enabled", False):
- model_times = model.model_times()
-
- if len(model_times) > 0:
- if sum_times:
- model_time = sum(model_times)
- else:
- # Unclear how to combine values, so just grab the most recent one
- model_time = model_times[-1]
- else:
- # no model times were captured
- model_time = -1
- return model_time
-
def CreateSession(self, request, context):
- task_methods = GRPC_METHOD_TABLE[TaskType.TEXT_GENERATION]
+ task_methods = TASK_METHODS_DICT[TaskType.TEXT_GENERATION]
task_methods.create_session(request.session_id)
return google_dot_protobuf_dot_empty__pb2.Empty()
def DestroySession(self, request, context):
- task_methods = GRPC_METHOD_TABLE[TaskType.TEXT_GENERATION]
- task_methods.destroy_session(request.session_id)
+ self.inference_pipeline.destroy_session(request.session_id)
return google_dot_protobuf_dot_empty__pb2.Empty()
def _run_inference(self, method_name, request_proto):
@@ -84,44 +71,54 @@ def _run_inference(self, method_name, request_proto):
raise ValueError(f"unknown method: {method_name}")
task = self.method_name_to_task[method_name]
- if task not in GRPC_METHOD_TABLE:
+ if task not in TASK_METHODS_DICT:
raise ValueError(f"unknown task: {task}")
- task_methods = GRPC_METHOD_TABLE[task]
+ task_methods = TASK_METHODS_DICT[task]
args, kwargs = task_methods.unpack_request_from_proto(request_proto)
+ session_id = kwargs.pop("session_id", None)
+
start = time.time()
- with self.lock:
- response = task_methods.run_inference(self.inference_pipeline, args, kwargs)
+ uid = self.inference_pipeline.put_request(args, kwargs, session_id)
+ response = self.inference_pipeline.get_response(uid)
end = time.time()
- model_time = (self._get_model_time(self.inference_pipeline.model,
- sum_times=True) if hasattr(
- self.inference_pipeline,
- "model") else -1)
+ if session_id is None:
+ self.inference_pipeline.destroy_session(session_id, uid)
- return task_methods.pack_response_to_proto(response, end - start, model_time)
+ return task_methods.pack_response_to_proto(response, end - start, -1)
def GeneratorReply(self, request, context):
return self._run_inference("GeneratorReply", request)
- def Txt2ImgReply(self, request, context):
- return self._run_inference("Txt2ImgReply", request)
-
- def ClassificationReply(self, request, context):
- return self._run_inference("ClassificationReply", request)
+ def _run_inference_stream(self, method_name, request_proto) -> int:
+ task = self.method_name_to_task[method_name]
+ task_methods = TASK_METHODS_DICT[task]
+ args, kwargs = task_methods.unpack_request_from_proto(request_proto)
- def QuestionAndAnswerReply(self, request, context):
- return self._run_inference("QuestionAndAnswerReply", request)
+ session_id = kwargs.pop("session_id", None)
+ kwargs["stream"] = True
+ return self.inference_pipeline.put_request(args, kwargs, session_id)
- def FillMaskReply(self, request, context):
- return self._run_inference("FillMaskReply", request)
+ def GeneratorReplyStream(self, request, context):
+ method_name = "GeneratorReply"
+ task = self.method_name_to_task[method_name]
+ task_methods = TASK_METHODS_DICT[task]
+ _, kwargs = task_methods.unpack_request_from_proto(request)
+ session_id = kwargs.pop("session_id", None)
- def TokenClassificationReply(self, request, context):
- return self._run_inference("TokenClassificationReply", request)
+ uid = self._run_inference_stream(method_name, request)
+ while True:
+ r = self.inference_pipeline.get_response(uid)
+ done = r[0].finish_reason != GenerationFinishReason.NONE
+ response = task_methods.pack_response_to_proto(r, 0.0, 0.0)
+ yield response
+ if done:
+ break
- def ConversationalReply(self, request, context):
- return self._run_inference("ConversationalReply", request)
+ if session_id is None:
+ self.inference_pipeline.destroy_session(session_id, uid)
class AtomicCounter:
@@ -135,6 +132,10 @@ def get_and_increment(self):
self.value += 1
return current_value
+ def get(self):
+ with self.lock:
+ return self.value
+
def _get_grpc_method_name(method):
return method.split("/")[-1]
@@ -170,6 +171,18 @@ def invoke(self, method_name, proto_request):
proto_request),
self.asyncio_loop).result()
+ def invoke_stream(self, method_name, proto_request, result_queue):
+ async def invoke_stream_async():
+ stub = self.stubs[0] # Only the first stub is used for streaming
+ method = getattr(stub, method_name)
+ response = method(proto_request)
+
+ async for r in response:
+ result_queue.put(r)
+
+ return asyncio.run_coroutine_threadsafe(invoke_stream_async(),
+ self.asyncio_loop).result()
+
class LoadBalancingInterceptor(grpc.ServerInterceptor):
def __init__(self, model_config):
@@ -197,7 +210,9 @@ def choose_stub(self, call_count):
def intercept_service(self, continuation, handler_call_details):
next_handler = continuation(handler_call_details)
- assert next_handler.unary_unary is not None
+
+ call_count = self.counter.get_and_increment()
+ replica_index = call_count % len(self.stubs)
def invoke_intercept_method(request_proto, context):
method_name = _get_grpc_method_name(handler_call_details.method)
@@ -209,7 +224,7 @@ def invoke_intercept_method(request_proto, context):
self.asyncio_loop.call_soon_threadsafe(self.asyncio_loop.stop)
return next_handler.unary_unary(request_proto, context)
- call_count = self.counter.get_and_increment()
+ call_count = self.counter.get()
replica_index = call_count % len(self.stubs)
if method_name == CREATE_SESSION_METHOD:
@@ -235,25 +250,53 @@ def invoke_intercept_method(request_proto, context):
ret = self.stubs[replica_index].invoke(method_name, request_proto)
return ret
- return grpc.unary_unary_rpc_method_handler(
- invoke_intercept_method,
- request_deserializer=next_handler.request_deserializer,
- response_serializer=next_handler.response_serializer,
- )
+ if next_handler.unary_unary is not None:
+ return grpc.unary_unary_rpc_method_handler(
+ invoke_intercept_method,
+ request_deserializer=next_handler.request_deserializer,
+ response_serializer=next_handler.response_serializer)
+ else:
+ method_name = _get_grpc_method_name(handler_call_details.method)
+ result_queue = queue.Queue()
+
+ def call_invoker(request_proto, context):
+ self.stubs[replica_index].invoke_stream(method_name,
+ request_proto,
+ result_queue)
+
+ def invoke_intercept_method_stream(request_proto, context):
+ threading.Thread(target=call_invoker,
+ args=(request_proto,
+ context)).start()
+ while True:
+ try:
+ response_proto = result_queue.get(
+ timeout=STREAM_RESPONSE_QUEUE_TIMEOUT)
+ yield response_proto
+ if response_proto.details[0].finish_reason != str(
+ GenerationFinishReason.NONE):
+ break
+ except queue.Empty:
+ print(
+ f"Haven't received a streaming response in {STREAM_RESPONSE_QUEUE_TIMEOUT} second(s)"
+ )
+ break
+
+ return grpc.unary_stream_rpc_method_handler(
+ invoke_intercept_method_stream,
+ request_deserializer=next_handler.request_deserializer,
+ response_serializer=next_handler.response_serializer)
def _do_serve(service_impl, port, interceptors=[]):
stop_event = service_impl.get_stop_event()
- server = grpc.server(
- futures.ThreadPoolExecutor(max_workers=LB_MAX_WORKER_THREADS),
- interceptors=interceptors,
- options=[
- ("grpc.max_send_message_length",
- GRPC_MAX_MSG_SIZE),
- ("grpc.max_receive_message_length",
- GRPC_MAX_MSG_SIZE),
- ],
- )
+
+ server = grpc.server(futures.ThreadPoolExecutor(max_workers=LB_MAX_WORKER_THREADS),
+ interceptors=interceptors,
+ options=[("grpc.max_send_message_length",
+ GRPC_MAX_MSG_SIZE),
+ ("grpc.max_receive_message_length",
+ GRPC_MAX_MSG_SIZE)])
modelresponse_pb2_grpc.add_ModelResponseServicer_to_server(service_impl, server)
server.add_insecure_port(f"[::]:{port}")
print(f"About to start server")
@@ -263,8 +306,10 @@ def _do_serve(service_impl, port, interceptors=[]):
server.stop(SERVER_SHUTDOWN_TIMEOUT)
-def serve_inference(inference_pipeline, port):
- _do_serve(ModelResponse(inference_pipeline), port)
+def serve_inference(async_pipeline, port):
+ async_pipeline.start()
+ _do_serve(ModelResponse(async_pipeline=async_pipeline), port)
+ async_pipeline.shutdown()
def serve_load_balancing(model_config, lb_port):
diff --git a/mii/grpc_related/proto/modelresponse.proto b/mii/grpc_related/proto/modelresponse.proto
index c8de6d14..c2d0899f 100644
--- a/mii/grpc_related/proto/modelresponse.proto
+++ b/mii/grpc_related/proto/modelresponse.proto
@@ -34,6 +34,12 @@ service ModelResponse {
rpc TokenClassificationReply(SingleStringRequest) returns (SingleStringReply) {}
rpc ConversationalReply(ConversationRequest) returns (ConversationReply) {}
rpc Txt2ImgReply(MultiStringRequest) returns (ImageReply) {}
+
+ rpc GeneratorReplyStream (MultiStringRequest) returns (stream GenerationReply) {}
+}
+
+message Dictionary {
+ map values = 1;
}
message Value {
@@ -42,6 +48,7 @@ message Value {
int64 ivalue = 2;
float fvalue = 3;
bool bvalue = 4;
+ Dictionary mvalue = 5;
}
}
@@ -71,6 +78,23 @@ message MultiStringReply {
float model_time_taken = 3;
}
+message GenerationDetails {
+ string finish_reason = 1;
+ int64 prompt_tokens = 2;
+ int64 generated_tokens = 3;
+}
+
+message GenerationReply {
+ repeated string response = 1;
+ // A request may contain multiple prompts and they produce different number of tokens.
+ // When streaming output is enabled, a response may contain generated tokens only for some prompts.
+ // `indices` represents the indices of prompts for which `response` and `details` are provided.
+ repeated int64 indices = 2;
+ repeated GenerationDetails details = 3;
+ float time_taken = 4;
+ float model_time_taken = 5;
+}
+
message QARequest {
string question = 1;
string context = 2;
diff --git a/mii/grpc_related/proto/modelresponse_pb2.py b/mii/grpc_related/proto/modelresponse_pb2.py
index 4c442c49..6b5294f7 100644
--- a/mii/grpc_related/proto/modelresponse_pb2.py
+++ b/mii/grpc_related/proto/modelresponse_pb2.py
@@ -2,15 +2,14 @@
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
-
+# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: modelresponse.proto
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
-from google.protobuf import message as _message
-from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
+from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
@@ -18,192 +17,16 @@
from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b'\n\x13modelresponse.proto\x12\rmodelresponse\x1a\x1bgoogle/protobuf/empty.proto\"_\n\x05Value\x12\x10\n\x06svalue\x18\x01 \x01(\tH\x00\x12\x10\n\x06ivalue\x18\x02 \x01(\x03H\x00\x12\x10\n\x06\x66value\x18\x03 \x01(\x02H\x00\x12\x10\n\x06\x62value\x18\x04 \x01(\x08H\x00\x42\x0e\n\x0coneof_values\"\x1f\n\tSessionID\x12\x12\n\nsession_id\x18\x01 \x01(\t\"\xbb\x01\n\x13SingleStringRequest\x12\x0f\n\x07request\x18\x01 \x01(\t\x12I\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x33.modelresponse.SingleStringRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\xb9\x01\n\x12MultiStringRequest\x12\x0f\n\x07request\x18\x01 \x03(\t\x12H\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x32.modelresponse.MultiStringRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"S\n\x11SingleStringReply\x12\x10\n\x08response\x18\x01 \x01(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"R\n\x10MultiStringReply\x12\x10\n\x08response\x18\x01 \x03(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"\xb9\x01\n\tQARequest\x12\x10\n\x08question\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontext\x18\x02 \x01(\t\x12?\n\x0cquery_kwargs\x18\x03 \x03(\x0b\x32).modelresponse.QARequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\x88\x02\n\x13\x43onversationRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x17\n\x0f\x63onversation_id\x18\x02 \x01(\t\x12\x18\n\x10past_user_inputs\x18\x03 \x03(\t\x12\x1b\n\x13generated_responses\x18\x04 \x03(\t\x12I\n\x0cquery_kwargs\x18\x05 \x03(\x0b\x32\x33.modelresponse.ConversationRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\x91\x01\n\x11\x43onversationReply\x12\x17\n\x0f\x63onversation_id\x18\x01 \x01(\t\x12\x18\n\x10past_user_inputs\x18\x02 \x03(\t\x12\x1b\n\x13generated_responses\x18\x03 \x03(\t\x12\x12\n\ntime_taken\x18\x04 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x05 \x01(\x02\"}\n\nImageReply\x12\x0e\n\x06images\x18\x01 \x03(\x0c\x12\x1d\n\x15nsfw_content_detected\x18\x02 \x03(\x08\x12\x0c\n\x04mode\x18\x03 \x01(\t\x12\x0e\n\x06size_w\x18\x04 \x01(\x03\x12\x0e\n\x06size_h\x18\x05 \x01(\x03\x12\x12\n\ntime_taken\x18\x06 \x01(\x02\x32\xd4\x06\n\rModelResponse\x12=\n\tTerminate\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12\x43\n\rCreateSession\x12\x18.modelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12\x44\n\x0e\x44\x65stroySession\x12\x18.modelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12V\n\x0eGeneratorReply\x12!.modelresponse.MultiStringRequest\x1a\x1f.modelresponse.MultiStringReply\"\x00\x12]\n\x13\x43lassificationReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12V\n\x16QuestionAndAnswerReply\x12\x18.modelresponse.QARequest\x1a .modelresponse.SingleStringReply\"\x00\x12W\n\rFillMaskReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12\x62\n\x18TokenClassificationReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12]\n\x13\x43onversationalReply\x12\".modelresponse.ConversationRequest\x1a .modelresponse.ConversationReply\"\x00\x12N\n\x0cTxt2ImgReply\x12!.modelresponse.MultiStringRequest\x1a\x19.modelresponse.ImageReply\"\x00\x62\x06proto3'
+ b'\n\x13modelresponse.proto\x12\rmodelresponse\x1a\x1bgoogle/protobuf/empty.proto\"\x88\x01\n\nDictionary\x12\x35\n\x06values\x18\x01 \x03(\x0b\x32%.modelresponse.Dictionary.ValuesEntry\x1a\x43\n\x0bValuesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\x8c\x01\n\x05Value\x12\x10\n\x06svalue\x18\x01 \x01(\tH\x00\x12\x10\n\x06ivalue\x18\x02 \x01(\x03H\x00\x12\x10\n\x06\x66value\x18\x03 \x01(\x02H\x00\x12\x10\n\x06\x62value\x18\x04 \x01(\x08H\x00\x12+\n\x06mvalue\x18\x05 \x01(\x0b\x32\x19.modelresponse.DictionaryH\x00\x42\x0e\n\x0coneof_values\"\x1f\n\tSessionID\x12\x12\n\nsession_id\x18\x01 \x01(\t\"\xbb\x01\n\x13SingleStringRequest\x12\x0f\n\x07request\x18\x01 \x01(\t\x12I\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x33.modelresponse.SingleStringRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\xb9\x01\n\x12MultiStringRequest\x12\x0f\n\x07request\x18\x01 \x03(\t\x12H\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x32.modelresponse.MultiStringRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"S\n\x11SingleStringReply\x12\x10\n\x08response\x18\x01 \x01(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"R\n\x10MultiStringReply\x12\x10\n\x08response\x18\x01 \x03(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"[\n\x11GenerationDetails\x12\x15\n\rfinish_reason\x18\x01 \x01(\t\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x03\x12\x18\n\x10generated_tokens\x18\x03 \x01(\x03\"\x95\x01\n\x0fGenerationReply\x12\x10\n\x08response\x18\x01 \x03(\t\x12\x0f\n\x07indices\x18\x02 \x03(\x03\x12\x31\n\x07\x64\x65tails\x18\x03 \x03(\x0b\x32 .modelresponse.GenerationDetails\x12\x12\n\ntime_taken\x18\x04 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x05 \x01(\x02\"\xb9\x01\n\tQARequest\x12\x10\n\x08question\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontext\x18\x02 \x01(\t\x12?\n\x0cquery_kwargs\x18\x03 \x03(\x0b\x32).modelresponse.QARequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\x88\x02\n\x13\x43onversationRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x17\n\x0f\x63onversation_id\x18\x02 \x01(\t\x12\x18\n\x10past_user_inputs\x18\x03 \x03(\t\x12\x1b\n\x13generated_responses\x18\x04 \x03(\t\x12I\n\x0cquery_kwargs\x18\x05 \x03(\x0b\x32\x33.modelresponse.ConversationRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\x91\x01\n\x11\x43onversationReply\x12\x17\n\x0f\x63onversation_id\x18\x01 \x01(\t\x12\x18\n\x10past_user_inputs\x18\x02 \x03(\t\x12\x1b\n\x13generated_responses\x18\x03 \x03(\t\x12\x12\n\ntime_taken\x18\x04 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x05 \x01(\x02\"}\n\nImageReply\x12\x0e\n\x06images\x18\x01 \x03(\x0c\x12\x1d\n\x15nsfw_content_detected\x18\x02 \x03(\x08\x12\x0c\n\x04mode\x18\x03 \x01(\t\x12\x0e\n\x06size_w\x18\x04 \x01(\x03\x12\x0e\n\x06size_h\x18\x05 \x01(\x03\x12\x12\n\ntime_taken\x18\x06 \x01(\x02\x32\xb3\x07\n\rModelResponse\x12=\n\tTerminate\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12\x43\n\rCreateSession\x12\x18.modelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12\x44\n\x0e\x44\x65stroySession\x12\x18.modelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12V\n\x0eGeneratorReply\x12!.modelresponse.MultiStringRequest\x1a\x1f.modelresponse.MultiStringReply\"\x00\x12]\n\x13\x43lassificationReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12V\n\x16QuestionAndAnswerReply\x12\x18.modelresponse.QARequest\x1a .modelresponse.SingleStringReply\"\x00\x12W\n\rFillMaskReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12\x62\n\x18TokenClassificationReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12]\n\x13\x43onversationalReply\x12\".modelresponse.ConversationRequest\x1a .modelresponse.ConversationReply\"\x00\x12N\n\x0cTxt2ImgReply\x12!.modelresponse.MultiStringRequest\x1a\x19.modelresponse.ImageReply\"\x00\x12]\n\x14GeneratorReplyStream\x12!.modelresponse.MultiStringRequest\x1a\x1e.modelresponse.GenerationReply\"\x00\x30\x01\x62\x06proto3'
)
-_VALUE = DESCRIPTOR.message_types_by_name['Value']
-_SESSIONID = DESCRIPTOR.message_types_by_name['SessionID']
-_SINGLESTRINGREQUEST = DESCRIPTOR.message_types_by_name['SingleStringRequest']
-_SINGLESTRINGREQUEST_QUERYKWARGSENTRY = _SINGLESTRINGREQUEST.nested_types_by_name[
- 'QueryKwargsEntry']
-_MULTISTRINGREQUEST = DESCRIPTOR.message_types_by_name['MultiStringRequest']
-_MULTISTRINGREQUEST_QUERYKWARGSENTRY = _MULTISTRINGREQUEST.nested_types_by_name[
- 'QueryKwargsEntry']
-_SINGLESTRINGREPLY = DESCRIPTOR.message_types_by_name['SingleStringReply']
-_MULTISTRINGREPLY = DESCRIPTOR.message_types_by_name['MultiStringReply']
-_QAREQUEST = DESCRIPTOR.message_types_by_name['QARequest']
-_QAREQUEST_QUERYKWARGSENTRY = _QAREQUEST.nested_types_by_name['QueryKwargsEntry']
-_CONVERSATIONREQUEST = DESCRIPTOR.message_types_by_name['ConversationRequest']
-_CONVERSATIONREQUEST_QUERYKWARGSENTRY = _CONVERSATIONREQUEST.nested_types_by_name[
- 'QueryKwargsEntry']
-_CONVERSATIONREPLY = DESCRIPTOR.message_types_by_name['ConversationReply']
-_IMAGEREPLY = DESCRIPTOR.message_types_by_name['ImageReply']
-Value = _reflection.GeneratedProtocolMessageType(
- 'Value',
- (_message.Message,
- ),
- {
- 'DESCRIPTOR': _VALUE,
- '__module__': 'modelresponse_pb2'
- # @@protoc_insertion_point(class_scope:modelresponse.Value)
- })
-_sym_db.RegisterMessage(Value)
-
-SessionID = _reflection.GeneratedProtocolMessageType(
- 'SessionID',
- (_message.Message,
- ),
- {
- 'DESCRIPTOR': _SESSIONID,
- '__module__': 'modelresponse_pb2'
- # @@protoc_insertion_point(class_scope:modelresponse.SessionID)
- })
-_sym_db.RegisterMessage(SessionID)
-
-SingleStringRequest = _reflection.GeneratedProtocolMessageType(
- 'SingleStringRequest',
- (_message.Message,
- ),
- {
- 'QueryKwargsEntry':
- _reflection.GeneratedProtocolMessageType(
- 'QueryKwargsEntry',
- (_message.Message,
- ),
- {
- 'DESCRIPTOR': _SINGLESTRINGREQUEST_QUERYKWARGSENTRY,
- '__module__': 'modelresponse_pb2'
- # @@protoc_insertion_point(class_scope:modelresponse.SingleStringRequest.QueryKwargsEntry)
- }),
- 'DESCRIPTOR':
- _SINGLESTRINGREQUEST,
- '__module__':
- 'modelresponse_pb2'
- # @@protoc_insertion_point(class_scope:modelresponse.SingleStringRequest)
- })
-_sym_db.RegisterMessage(SingleStringRequest)
-_sym_db.RegisterMessage(SingleStringRequest.QueryKwargsEntry)
-
-MultiStringRequest = _reflection.GeneratedProtocolMessageType(
- 'MultiStringRequest',
- (_message.Message,
- ),
- {
- 'QueryKwargsEntry':
- _reflection.GeneratedProtocolMessageType(
- 'QueryKwargsEntry',
- (_message.Message,
- ),
- {
- 'DESCRIPTOR': _MULTISTRINGREQUEST_QUERYKWARGSENTRY,
- '__module__': 'modelresponse_pb2'
- # @@protoc_insertion_point(class_scope:modelresponse.MultiStringRequest.QueryKwargsEntry)
- }),
- 'DESCRIPTOR':
- _MULTISTRINGREQUEST,
- '__module__':
- 'modelresponse_pb2'
- # @@protoc_insertion_point(class_scope:modelresponse.MultiStringRequest)
- })
-_sym_db.RegisterMessage(MultiStringRequest)
-_sym_db.RegisterMessage(MultiStringRequest.QueryKwargsEntry)
-
-SingleStringReply = _reflection.GeneratedProtocolMessageType(
- 'SingleStringReply',
- (_message.Message,
- ),
- {
- 'DESCRIPTOR': _SINGLESTRINGREPLY,
- '__module__': 'modelresponse_pb2'
- # @@protoc_insertion_point(class_scope:modelresponse.SingleStringReply)
- })
-_sym_db.RegisterMessage(SingleStringReply)
-
-MultiStringReply = _reflection.GeneratedProtocolMessageType(
- 'MultiStringReply',
- (_message.Message,
- ),
- {
- 'DESCRIPTOR': _MULTISTRINGREPLY,
- '__module__': 'modelresponse_pb2'
- # @@protoc_insertion_point(class_scope:modelresponse.MultiStringReply)
- })
-_sym_db.RegisterMessage(MultiStringReply)
-
-QARequest = _reflection.GeneratedProtocolMessageType(
- 'QARequest',
- (_message.Message,
- ),
- {
- 'QueryKwargsEntry':
- _reflection.GeneratedProtocolMessageType(
- 'QueryKwargsEntry',
- (_message.Message,
- ),
- {
- 'DESCRIPTOR': _QAREQUEST_QUERYKWARGSENTRY,
- '__module__': 'modelresponse_pb2'
- # @@protoc_insertion_point(class_scope:modelresponse.QARequest.QueryKwargsEntry)
- }),
- 'DESCRIPTOR':
- _QAREQUEST,
- '__module__':
- 'modelresponse_pb2'
- # @@protoc_insertion_point(class_scope:modelresponse.QARequest)
- })
-_sym_db.RegisterMessage(QARequest)
-_sym_db.RegisterMessage(QARequest.QueryKwargsEntry)
-
-ConversationRequest = _reflection.GeneratedProtocolMessageType(
- 'ConversationRequest',
- (_message.Message,
- ),
- {
- 'QueryKwargsEntry':
- _reflection.GeneratedProtocolMessageType(
- 'QueryKwargsEntry',
- (_message.Message,
- ),
- {
- 'DESCRIPTOR': _CONVERSATIONREQUEST_QUERYKWARGSENTRY,
- '__module__': 'modelresponse_pb2'
- # @@protoc_insertion_point(class_scope:modelresponse.ConversationRequest.QueryKwargsEntry)
- }),
- 'DESCRIPTOR':
- _CONVERSATIONREQUEST,
- '__module__':
- 'modelresponse_pb2'
- # @@protoc_insertion_point(class_scope:modelresponse.ConversationRequest)
- })
-_sym_db.RegisterMessage(ConversationRequest)
-_sym_db.RegisterMessage(ConversationRequest.QueryKwargsEntry)
-
-ConversationReply = _reflection.GeneratedProtocolMessageType(
- 'ConversationReply',
- (_message.Message,
- ),
- {
- 'DESCRIPTOR': _CONVERSATIONREPLY,
- '__module__': 'modelresponse_pb2'
- # @@protoc_insertion_point(class_scope:modelresponse.ConversationReply)
- })
-_sym_db.RegisterMessage(ConversationReply)
-
-ImageReply = _reflection.GeneratedProtocolMessageType(
- 'ImageReply',
- (_message.Message,
- ),
- {
- 'DESCRIPTOR': _IMAGEREPLY,
- '__module__': 'modelresponse_pb2'
- # @@protoc_insertion_point(class_scope:modelresponse.ImageReply)
- })
-_sym_db.RegisterMessage(ImageReply)
-
-_MODELRESPONSE = DESCRIPTOR.services_by_name['ModelResponse']
+_globals = globals()
+_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
+_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'modelresponse_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
-
DESCRIPTOR._options = None
+ _DICTIONARY_VALUESENTRY._options = None
+ _DICTIONARY_VALUESENTRY._serialized_options = b'8\001'
_SINGLESTRINGREQUEST_QUERYKWARGSENTRY._options = None
_SINGLESTRINGREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001'
_MULTISTRINGREQUEST_QUERYKWARGSENTRY._options = None
@@ -212,34 +35,42 @@
_QAREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001'
_CONVERSATIONREQUEST_QUERYKWARGSENTRY._options = None
_CONVERSATIONREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001'
- _VALUE._serialized_start = 67
- _VALUE._serialized_end = 162
- _SESSIONID._serialized_start = 164
- _SESSIONID._serialized_end = 195
- _SINGLESTRINGREQUEST._serialized_start = 198
- _SINGLESTRINGREQUEST._serialized_end = 385
- _SINGLESTRINGREQUEST_QUERYKWARGSENTRY._serialized_start = 313
- _SINGLESTRINGREQUEST_QUERYKWARGSENTRY._serialized_end = 385
- _MULTISTRINGREQUEST._serialized_start = 388
- _MULTISTRINGREQUEST._serialized_end = 573
- _MULTISTRINGREQUEST_QUERYKWARGSENTRY._serialized_start = 313
- _MULTISTRINGREQUEST_QUERYKWARGSENTRY._serialized_end = 385
- _SINGLESTRINGREPLY._serialized_start = 575
- _SINGLESTRINGREPLY._serialized_end = 658
- _MULTISTRINGREPLY._serialized_start = 660
- _MULTISTRINGREPLY._serialized_end = 742
- _QAREQUEST._serialized_start = 745
- _QAREQUEST._serialized_end = 930
- _QAREQUEST_QUERYKWARGSENTRY._serialized_start = 313
- _QAREQUEST_QUERYKWARGSENTRY._serialized_end = 385
- _CONVERSATIONREQUEST._serialized_start = 933
- _CONVERSATIONREQUEST._serialized_end = 1197
- _CONVERSATIONREQUEST_QUERYKWARGSENTRY._serialized_start = 313
- _CONVERSATIONREQUEST_QUERYKWARGSENTRY._serialized_end = 385
- _CONVERSATIONREPLY._serialized_start = 1200
- _CONVERSATIONREPLY._serialized_end = 1345
- _IMAGEREPLY._serialized_start = 1347
- _IMAGEREPLY._serialized_end = 1472
- _MODELRESPONSE._serialized_start = 1475
- _MODELRESPONSE._serialized_end = 2327
+ _globals['_DICTIONARY']._serialized_start = 68
+ _globals['_DICTIONARY']._serialized_end = 204
+ _globals['_DICTIONARY_VALUESENTRY']._serialized_start = 137
+ _globals['_DICTIONARY_VALUESENTRY']._serialized_end = 204
+ _globals['_VALUE']._serialized_start = 207
+ _globals['_VALUE']._serialized_end = 347
+ _globals['_SESSIONID']._serialized_start = 349
+ _globals['_SESSIONID']._serialized_end = 380
+ _globals['_SINGLESTRINGREQUEST']._serialized_start = 383
+ _globals['_SINGLESTRINGREQUEST']._serialized_end = 570
+ _globals['_SINGLESTRINGREQUEST_QUERYKWARGSENTRY']._serialized_start = 498
+ _globals['_SINGLESTRINGREQUEST_QUERYKWARGSENTRY']._serialized_end = 570
+ _globals['_MULTISTRINGREQUEST']._serialized_start = 573
+ _globals['_MULTISTRINGREQUEST']._serialized_end = 758
+ _globals['_MULTISTRINGREQUEST_QUERYKWARGSENTRY']._serialized_start = 498
+ _globals['_MULTISTRINGREQUEST_QUERYKWARGSENTRY']._serialized_end = 570
+ _globals['_SINGLESTRINGREPLY']._serialized_start = 760
+ _globals['_SINGLESTRINGREPLY']._serialized_end = 843
+ _globals['_MULTISTRINGREPLY']._serialized_start = 845
+ _globals['_MULTISTRINGREPLY']._serialized_end = 927
+ _globals['_GENERATIONDETAILS']._serialized_start = 929
+ _globals['_GENERATIONDETAILS']._serialized_end = 1020
+ _globals['_GENERATIONREPLY']._serialized_start = 1023
+ _globals['_GENERATIONREPLY']._serialized_end = 1172
+ _globals['_QAREQUEST']._serialized_start = 1175
+ _globals['_QAREQUEST']._serialized_end = 1360
+ _globals['_QAREQUEST_QUERYKWARGSENTRY']._serialized_start = 498
+ _globals['_QAREQUEST_QUERYKWARGSENTRY']._serialized_end = 570
+ _globals['_CONVERSATIONREQUEST']._serialized_start = 1363
+ _globals['_CONVERSATIONREQUEST']._serialized_end = 1627
+ _globals['_CONVERSATIONREQUEST_QUERYKWARGSENTRY']._serialized_start = 498
+ _globals['_CONVERSATIONREQUEST_QUERYKWARGSENTRY']._serialized_end = 570
+ _globals['_CONVERSATIONREPLY']._serialized_start = 1630
+ _globals['_CONVERSATIONREPLY']._serialized_end = 1775
+ _globals['_IMAGEREPLY']._serialized_start = 1777
+ _globals['_IMAGEREPLY']._serialized_end = 1902
+ _globals['_MODELRESPONSE']._serialized_start = 1905
+ _globals['_MODELRESPONSE']._serialized_end = 2852
# @@protoc_insertion_point(module_scope)
diff --git a/mii/grpc_related/proto/modelresponse_pb2_grpc.py b/mii/grpc_related/proto/modelresponse_pb2_grpc.py
index 2735b034..4f16a368 100644
--- a/mii/grpc_related/proto/modelresponse_pb2_grpc.py
+++ b/mii/grpc_related/proto/modelresponse_pb2_grpc.py
@@ -2,7 +2,6 @@
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
-
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
@@ -70,6 +69,11 @@ def __init__(self, channel):
request_serializer=modelresponse__pb2.MultiStringRequest.SerializeToString,
response_deserializer=modelresponse__pb2.ImageReply.FromString,
)
+ self.GeneratorReplyStream = channel.unary_stream(
+ '/modelresponse.ModelResponse/GeneratorReplyStream',
+ request_serializer=modelresponse__pb2.MultiStringRequest.SerializeToString,
+ response_deserializer=modelresponse__pb2.GenerationReply.FromString,
+ )
class ModelResponseServicer(object):
@@ -136,6 +140,12 @@ def Txt2ImgReply(self, request, context):
context.set_details(self.ERROR_MSG)
raise NotImplementedError(self.ERROR_MSG)
+ def GeneratorReplyStream(self, request, context):
+ """Missing associated documentation comment in .proto file."""
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
def add_ModelResponseServicer_to_server(servicer, server):
rpc_method_handlers = {
@@ -202,6 +212,12 @@ def add_ModelResponseServicer_to_server(servicer, server):
request_deserializer=modelresponse__pb2.MultiStringRequest.FromString,
response_serializer=modelresponse__pb2.ImageReply.SerializeToString,
),
+ 'GeneratorReplyStream':
+ grpc.unary_stream_rpc_method_handler(
+ servicer.GeneratorReplyStream,
+ request_deserializer=modelresponse__pb2.MultiStringRequest.FromString,
+ response_serializer=modelresponse__pb2.GenerationReply.SerializeToString,
+ ),
}
generic_handler = grpc.method_handlers_generic_handler('modelresponse.ModelResponse',
rpc_method_handlers)
@@ -470,3 +486,29 @@ def Txt2ImgReply(request,
wait_for_ready,
timeout,
metadata)
+
+ @staticmethod
+ def GeneratorReplyStream(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_stream(
+ request,
+ target,
+ '/modelresponse.ModelResponse/GeneratorReplyStream',
+ modelresponse__pb2.MultiStringRequest.SerializeToString,
+ modelresponse__pb2.GenerationReply.FromString,
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata)
diff --git a/mii/launch/multi_gpu_server.py b/mii/launch/multi_gpu_server.py
index 590a4ed0..15814b07 100644
--- a/mii/launch/multi_gpu_server.py
+++ b/mii/launch/multi_gpu_server.py
@@ -2,18 +2,18 @@
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
-import os
import argparse
import base64
import json
+import os
from mii.config import ModelConfig
-from mii.models.load_models import load_models
from mii.grpc_related.modelresponse_server import serve_inference, serve_load_balancing
from mii.grpc_related.restful_gateway import RestfulGatewayThread
+from mii.pipeline import async_pipeline
-def b64_encoded_config(config_str):
+def b64_encoded_config(config_str: str) -> ModelConfig:
# str -> bytes
b64_bytes = config_str.encode()
# decode b64 bytes -> json bytes
@@ -24,7 +24,7 @@ def b64_encoded_config(config_str):
return ModelConfig(**config_dict)
-def main():
+def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--deployment-name", type=str, help="Name of deployment")
parser.add_argument(
@@ -38,6 +38,7 @@ def main():
default=0,
help="Port to user for DeepSpeed inference server.",
)
+ parser.add_argument("--zmq-port", type=int, default=0, help="Port to use for ZMQ.")
parser.add_argument("--load-balancer",
action="store_true",
help="Launch load balancer process.")
@@ -85,9 +86,8 @@ def main():
assert args.server_port, "--server-port must be provided."
local_rank = int(os.getenv("LOCAL_RANK", "0"))
port = args.server_port + local_rank
-
- inference_pipeline = load_models(args.model_config)
-
+ args.model_config.zmq_port_number = args.zmq_port
+ inference_pipeline = async_pipeline(args.model_config)
print(f"Starting server on port: {port}")
serve_inference(inference_pipeline, port)
diff --git a/mii/legacy/README.md b/mii/legacy/README.md
new file mode 100644
index 00000000..ed949a1c
--- /dev/null
+++ b/mii/legacy/README.md
@@ -0,0 +1,359 @@
+
+[![Formatting](https://github.com/microsoft/DeepSpeed-MII/actions/workflows/formatting.yml/badge.svg)](https://github.com/microsoft/DeepSpeed-MII/actions/workflows/formatting.yml)
+[![License Apache 2.0](https://badgen.net/badge/license/apache2.0/blue)](https://github.com/Microsoft/DeepSpeed/blob/master/LICENSE)
+[![PyPI version](https://badge.fury.io/py/deepspeed-mii.svg)](https://pypi.org/project/deepspeed-mii/)
+
+
+
+
+
+
+
+## Latest News
+
+* [2022/11] [Stable Diffusion Image Generation under 1 second w. DeepSpeed MII](examples/benchmark/txt2img)
+* [2022/10] [Announcing DeepSpeed Model Implementations for Inference (MII)](https://www.deepspeed.ai/2022/10/10/mii.html)
+
+# Contents
+
+
+
+- [DeepSpeed MII](#deepspeed-model-implementations-for-inference)
+- [How does MII work?](#how-does-mii-work)
+- [Supported Models and Tasks](#supported-models-and-tasks)
+- [MII-Public and MII-Azure](#mii-public-and-mii-azure)
+- [Getting started with MII](#getting-started-with-mii)
+- [Quantifying Latency and Cost Reduction](#quantifying-latency-and-cost-reduction)
+- [Community Tutorials](#community-tutorials)
+
+
+
+# DeepSpeed Model Implementations for Inference
+
+![hero dark](docs/images/hero-dark.png#gh-dark-mode-only)
+![hero light](docs/images/hero-transparent.png#gh-light-mode-only)
+
+The Deep Learning (DL) open-source community has seen tremendous growth in the last few months. Incredibly powerful text generation models such as the Bloom 176B, or image generation model such as Stable Diffusion are now available to anyone with access to a handful or even a single GPU through platforms such as Hugging Face. While open sourcing has democratized access to AI capabilities, their application is still restricted by two critical factors: inference latency and cost.
+
+There has been significant progress in system optimizations for DL model inference that can drastically reduce both latency and cost, but those are not easily accessible. A main reason for this limited accessibility is that the DL model inference landscape is diverse with models varying in size, architecture, system performance characteristics, hardware requirements, etc. Identifying the appropriate set of system optimizations applicable to a given model and applying them correctly is often beyond the scope of most data scientists, making low latency and low-cost inference mostly inaccessible.
+
+DeepSpeed-MII is a new open-source python library from DeepSpeed, aimed towards making low-latency, low-cost inference of powerful models not only feasible but also easily accessible.
+
+* MII offers access to highly optimized implementation of thousands of widely used DL models.
+* MII supported models achieve significantly lower latency and cost compared to their original implementation. For example, MII reduces the latency of Big-Science Bloom 176B model by 5.7x, while reducing the cost by over 40x. Similarly, it reduces the latency and cost of deploying Stable Diffusion by 1.9x. See more details for [an exhaustive latency and cost analysis of MII](#quantifying-latency-and-cost-reduction).
+* To enable low latency/cost inference, MII leverages an extensive set of optimizations from DeepSpeed-Inference such as deepfusion for transformers, automated tensor-slicing for multi-GPU inference, on-the-fly quantization with ZeroQuant, and several others (see our [blog post](https://www.deepspeed.ai/2022/10/10/mii.html) for more details).
+* With state-of-the-art performance, MII supports low-cost deployment of these models both on-premises and on Azure via AML with just a few lines of codes.
+
+# How does MII work?
+
+![Text Generation Models](docs/images/mii-arch.png)
+
+*Figure 1: MII Architecture, showing how MII automatically optimizes OSS models using DS-Inference before deploying them on-premises using GRPC, or on Microsoft Azure using AML Inference.*
+
+Under-the-hood MII is powered by [DeepSpeed-Inference](https://arxiv.org/abs/2207.00032). Based on model type, model size, batch size, and available hardware resources, MII automatically applies the appropriate set of system optimizations from DeepSpeed-Inference to minimize latency and maximize throughput. It does so by using one of many pre-specified model injection policies, that allows MII and DeepSpeed-Inference to identify the underlying PyTorch model architecture and replace it with an optimized implementation (see *Figure A*). In doing so, MII makes the expansive set of optimizations in DeepSpeed-Inference automatically available for thousands of popular models that it supports.
+
+
+# Supported Models and Tasks
+
+MII currently supports over 50,000 models across a range of tasks such as text-generation, question-answering, text-classification. The models accelerated by MII are available through multiple open-sourced model repositories such as Hugging Face, FairSeq, EluetherAI, etc. We support dense models based on Bert, Roberta or GPT architectures ranging from few hundred million parameters to tens of billions of parameters in size. We continue to expand the list with support for massive hundred billion plus parameter dense and sparse models coming soon.
+
+MII model support will continue to grow over time, check back for updates! Currently we support the following Hugging Face Transformers model families:
+
+model family | size range | ~model count
+------ | ------ | ------
+[llama](https://huggingface.co/models?other=llama) | 7B - 65B | 1,500
+[bloom](https://huggingface.co/models?other=bloom) | 0.3B - 176B | 480
+[stable-diffusion](https://huggingface.co/models?other=stable-diffusion) | 1.1B | 3,700
+[opt](https://huggingface.co/models?other=opt) | 0.1B - 66B | 460
+[gpt\_neox](https://huggingface.co/models?other=gpt_neox) | 1.3B - 20B | 850
+[gptj](https://huggingface.co/models?other=gptj) | 1.4B - 6B | 420
+[gpt\_neo](https://huggingface.co/models?other=gpt_neo) | 0.1B - 2.7B | 700
+[gpt2](https://huggingface.co/models?other=gpt2) | 0.3B - 1.5B | 11,900
+[xlm-roberta](https://huggingface.co/models?other=xlm-roberta) | 0.1B - 0.3B | 4,100
+[roberta](https://huggingface.co/models?other=roberta) | 0.1B - 0.3B | 8,700
+[distilbert](https://huggingface.co/models?other=distilbert) | 0.1B - 0.3B | 4,700
+[bert](https://huggingface.co/models?other=bert) | 0.1B - 0.3B | 23,600
+
+
+
+
+
+# MII-Public and MII-Azure
+
+MII can work with two variations of DeepSpeed-Inference. The first, referred to as ds-public, contains most of the DeepSpeed-Inference optimizations discussed here, is also available via our open-source DeepSpeed library. The second referred to as ds-azure, offers tighter integration with Azure, and is available via MII to all Microsoft Azure customers. We refer to MII running the two DeepSpeed-Inference variants as MII-Public and MII-Azure, respectively.
+
+While both variants offers significant latency and cost reduction over the open-sourced PyTorch baseline, the latter, offers additional performance advantage for generation based workloads. The full latency and cost advantage comparison with PyTorch baseline and across these two versions is available [here](#quantifying-latency-and-cost-reduction).
+
+# Getting Started with MII
+
+## Installation
+
+We regularly push releases to [PyPI](https://pypi.org/project/deepspeed-mii/) and encourage users to install from there in most cases.
+
+```bash
+pip install deepspeed-mii
+```
+
+## Deploying MII-Public
+
+MII-Public can be deployed on-premises or on any cloud offering with just a few lines of code. MII creates a lightweight GRPC server to support this form of deployment and provides a GRPC inference endpoint for queries.
+
+Several deployment and query examples can be found here: [examples/local](examples/local)
+
+As an example here is a deployment of the [bigscience/bloom-560m](https://huggingface.co/bigscience/bloom-560m) model from Hugging Face:
+
+**Deployment**
+```python
+import mii
+mii_configs = {"tensor_parallel": 1, "dtype": "fp16"}
+mii.deploy(task="text-generation",
+ model="bigscience/bloom-560m",
+ deployment_name="bloom560m_deployment",
+ mii_config=mii_configs)
+```
+
+This will deploy the model onto a single GPU and start the GRPC server that can later be queried.
+
+**Query**
+```python
+import mii
+generator = mii.mii_query_handle("bloom560m_deployment")
+result = generator.query({"query": ["DeepSpeed is", "Seattle is"]}, do_sample=True, max_new_tokens=30)
+print(result)
+```
+
+The only required key is `"query"`, all other items outside the dictionary will be passed to `generate` as kwargs. For Hugging Face provided models you can find all possible arguments in their [documentation for generate](https://huggingface.co/docs/transformers/v4.20.1/en/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate).
+
+**Shutdown Deployment**
+```python
+import mii
+mii.terminate("bloom560m_deployment")
+```
+
+**Load balancing over multiple replicas**
+
+You can launch a load balancer and multiple replica of MII servers.
+When you specify a value for `replica_num`, `mii.deploy()` launches the load balancer server and `replica_num` number of replicas.
+Note that each replica consists of `tensor_parallel` server processes that are deployed on the same server.
+
+```python
+mii_configs = {
+...
+ "tensor_parallel": tensor_parallel,
+ "replica_num": replica_num,
+ "hostfile": hostfile
+}
+mii.deploy(...
+ mii_config=mii_configs,
+ ...)
+```
+
+The client sends requests to the load balancer, which forwards them to the replicas, instead of sending requests to individual MII servers.
+Currently, the load balancer implements a simple round-robin algorithm.
+The load balancer acts as a simple proxy when `replica_num` is set to `1`.
+
+`hostfile` is the path to hostfile used by DeepSpeed's launcher.
+When hostfile is not specified, DeepSpeed-MII uses the default path `/job/hostfile`, which is defined for DeepSpeed.
+See the [DeepSpeed's document](https://www.deepspeed.ai/getting-started/#resource-configuration-multi-node) for the details.
+
+**RESTful API support**
+
+MII can enable users to call the inference service through RESTful APIs.
+By setting `enable_restful_api` to `True`, `mii.deploy()` launches a gateway that accepts RESTful API.
+The gateway can receive requests at `http://[HOST]:[PORT_FOR_RESTFUL_API]/mii/[DEPLOYMENT_NAME]`.
+
+```python
+mii_configs = {
+...
+ "enable_restful_api": True,
+ "restful_api_port": PORT_FOR_RESTFUL_API,
+...
+}
+mii.deploy(...
+ deployment_name=DEPLOYMENT_NAME,
+ mii_config=mii_configs)
+```
+
+**Non-persistent Deployment**
+
+You can enable a non-persistent deployment which allows you to make queries without standing up a server. The non-persistent deployment acts as a simplified interface to DeepSpeed-inference for use cases that do not require creating a persistent model server process. Changing the `deployment_type` to `NON_PERSISTENT` in `mii.deploy(...)` will activate this option.
+
+```python
+...
+mii.deploy(deployment_name = DEPLOYMENT_NAME,
+ deployment_type=mii.constants.DeploymentType.NON_PERSISTENT
+ ...
+ )
+
+generator = mii.mii_query_handle(DEPLOYMENT_NAME)
+result = generator.query({"query": ["DeepSpeed is", "Seattle is"]}, do_sample=True, max_new_tokens=30})
+
+```
+
+You can find a complete example [here]("https://github.com/microsoft/DeepSpeed-MII/tree/main/examples/non_persistent")
+
+Any HTTP client can be used to call the APIs. An example of using curl is:
+```bash
+# Assume deployment_name and restful_api_port are set to bloom560m_deployment and 28080 respectively:
+$ curl --header "Content-Type: application/json" --request POST -d '{"request": {"query": ["Seattle is", "Bellevue is", "Redmond is"]}, "kwargs": {"do_sample": false, "max_new_tokens": 100}}' http://localhost:28080/mii/bloom560m_deployment
+```
+
+The code below is an example using Python.
+
+```python
+import requests
+import json
+
+# text_generation
+url = 'http://localhost:28080/mii/bloom560m_deployment'
+params = {"request": {"query": ["Seattle is", "Bellevue is", "Redmond is"]},
+ "kwargs": {"do_sample": False, "max_new_tokens": 100}}
+
+json_params = json.dumps(params)
+response = requests.post(url, data=json_params, headers={
+ "Content-Type": "application/json"})
+print(response.json())
+```
+
+## Deploying with MII-Azure
+
+MII supports deployment on Azure via AML Inference. To enable this, MII generates AML deployment assets for a given model that can be deployed using the Azure-CLI, as shown in the code below. Furthermore, deploying on Azure, allows MII to leverage DeepSpeed-Azure as its optimization backend, which offers better latency and cost reduction than DeepSpeed-Public.
+
+This deployment process is very similar to local deployments and we will modify the code from the local deployment example with the [bigscience/bloom-560m](https://huggingface.co/bigscience/bloom-560m) model.
+
+---
+📌 **Note:** MII-Azure has the benefit of supporting DeepSpeed-Azure for better latency and cost than DeepSpeed-Public for certain workloads. We are working to enable DeepSpeed-Azure automatically for all MII-Azure deployments in a near-term MII update. In the meantime, we are offering DeepSpeed-Azure as a preview release to MII-Azure users. If you have a MII-Azure deployment and would like to try DeepSpeed-Azure, please reach out to us at deepspeed-mii@microsoft.com to get access.
+
+---
+
+Several other AML deployment examples can be found here: [examples/aml](examples/aml)
+
+**Setup**
+
+To use MII on AML resources, you must have the Azure-CLI installed with an active login associated with your Azure resources. Follow the instructions below to get your local system ready for deploying on AML resources:
+
+1. Install Azure-CLI. Follow the official [installation instructions](https://learn.microsoft.com/en-us/cli/azure/install-azure-cli#install).
+2. Run `az login` and follow the instructions to login to your Azure account. This account should be linked to the resources you plan to deploy on.
+3. Set the default subscription with `az account set --subscription `. You can find your subscription ID in the "overview" tab on your resource group page from the Azure web portal.
+4. Set the default resource group and workspace name with `az config defaults.group defaults.workspace `
+5. Install the AML plugin for Azure-CLI with `az extension add --name ml`
+
+**Deployment**
+```python
+import mii
+mii_configs = {"tensor_parallel": 1, "dtype": "fp16"}
+mii.deploy(task="text-generation",
+ model="bigscience/bloom-560m",
+ deployment_name="bloom560m-deployment",
+ deployment_type=mii.constants.DeploymentType.AML,
+ mii_config=mii_configs)
+```
+
+---
+📌 **Note:** Running the `mii.deploy` with `deployment_type=mii.constants.DeploymentType.AML` will only generate the scripts to launch an AML deployment. You must also run the generated `deploy.sh` script to run on AML resources.
+
+---
+
+This will generate the scripts and configuration files necessary to deploy the model on AML using a single GPU. You can find the generated output at `./bloom560m-deployment_aml/`
+
+When you are ready to run your deployment on AML resources, navigate to the newly created directory and run the deployment script:
+```bash
+cd ./bloom560m-deployment_aml/
+bash deploy.sh
+```
+
+This script may take several minutes to run as it does the following:
+- Downloads the model locally
+- Creates a Docker Image with MII for your deployment
+- Creates an AML online-endpoint for running queries
+- Uploads and registers the model to AML
+- Starts your deployment
+
+---
+📌 **Note:** Large models (e.g., `bigscience/bloom`) may cause a timeout when trying to upload and register the model to AML. In these cases, it is required to manually upload models to Azure blob storage with [AzCopy](https://learn.microsoft.com/en-us/azure/storage/common/storage-use-azcopy-v10). Instructions and automation of this step will be added soon.
+
+---
+
+**Query**
+Once the deployment is running on AML, you can run queries by navigating to the online-endpoint that was created for this deployment (i.e., `bloom-560m-deployment-endpoint`) from the [AML web portal](https://ml.azure.com/endpoints). Select the "Test" tab at the top of the endpoint page and type your query into the text-box:
+```
+{"query": ["DeepSpeed is", "Seattle is"], "do_sample"=True, "max_new_tokens"=30}
+```
+
+The only required key is `"query"`, all other items in the dictionary will be passed to `generate` as kwargs. For Hugging Face provided models you can find all possible arguments in their [documentation for generate](https://huggingface.co/docs/transformers/v4.20.1/en/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate).
+
+# Quantifying Latency and Cost Reduction
+
+Inference workloads can be either latency critical, where the primary objective is to minimize latency, or cost sensitive, where the primary objective is to minimize cost. In this section, we quantify the benefits of using MII for both latency-critical and cost-sensitive scenarios.
+
+## Latency Critical Scenarios
+
+For latency-critical scenarios, where a small batch size of 1 is often used, MII can reduce the latency by up to 6x for a wide range of open-source models, across multiple tasks. More specifically, we show model latency reduction of [^overhead_details]:
+
+1. Up to 5.7x for multi-GPU inference for text generation using massive models such as Big Science Bloom, Facebook OPT, and EluetherAI NeoX (*Figure 2 (left)*)
+
+2. Up to 1.9x for image generation tasks model using Stable Diffusion (*Figure 2 (right)*)
+
+3. Up to 3x for relatively smaller text generation models (up to 7B parameters) based on OPT, BLOOM, and GPT architectures, running on a single GPU (*Figures 3 and 4*)
+
+4. Up to 9x for various text representation tasks like fill-mask, text classification, question answering, and token classification using RoBERTa- and BERT- based models (*Figures 5 and 6*).
+
+[ ![multi gpu latency](docs/images/llm-latency-sd-latency.png) ](docs/images/llm-latency-sd-latency-zoom.png)
+*Figure 2: (Left) Best achievable latency for large models. MII-Azure (int8) offers 5.7X lower latency compared to Baseline for Bloom-176B. (Right) Stable Diffusion text to image generation latency comparison.*
+
+[ ![OPT and BLOOM Models](docs/images/opt-bloom.png) ](docs/images/opt-bloom.png)
+*Figure 3: Latency comparison for OPT and BLOOM models. MII-Azure is up to 2.8x faster than baseline.*
+
+[ ![GPT Models](docs/images/gpt.png) ](docs/images/mii/gpt.png)
+*Figure 4: Latency comparison for GPT models. MII-Azure is up to 3x faster than baseline.*
+
+[ ![Roberta Models](docs/images/roberta.png) ](docs/images/roberta.png)
+*Figure 5: Latency comparison for RoBERTa models. MII offers up to 9x lower model latency and up to 3x lower end-to-end latency than baseline on several tasks and RoBERTa variants [^overhead_details].*
+
+[ ![Bert Models](docs/images/bert.png) ](docs/images/bert.png)
+*Figure 6: Latency comparison for BERT models. MII offers up to 8.9x lower model latency and up to 4.5x end-to-end latency across several tasks and BERT variants[^overhead_details].*
+
+[^overhead_details]: The end-to-end latency of an inference workload is comprised of two components: i) actual model execution, and ii) pre-/post-processing before and after the model execution. MII optimizes the actual model execution but leaves the pre-/post-processing pipeline for future optimizations. We notice that text representation tasks have significant pre-/post-processing overhead (*Figures G and H*). We plan to address those in a future update.
+
+## Cost Sensitive Scenarios
+
+MII can significantly reduce the inference cost of very expensive language models like Bloom, OPT, etc. To get the lowest cost, we use a large batch size that maximizes throughput for both baseline and MII. Here we look at the cost reduction from MII using two different metrics: i) tokens generated per second per GPU, and ii) dollars per million tokens generated.
+
+*Figures 7 and 8* show that MII-Public offers over 10x throughput improvement and cost reduction compared to the baseline, respectively. Furthermore, MII-Azure offers over 30x improvement in throughput and cost compared to the baseline.
+
+[ ![tput large models](docs/images/tput-llms.png) ](docs/images/tput-llms.png)
+*Figure 7: Throughput comparison per A100-80GB GPU for large models. MII-Public offers over 15x throughput improvement while MII-Azure offers over 40x throughput improvement.*
+
+[ ![azure cost](docs/images/azure-cost.png) ](docs/images/azure-cost.png)
+*Figure 8: Cost of generating 1 million tokens on Azure with different model types. MII-Azure reduces the cost of generation by over 40x.*
+
+# Community Tutorials
+
+* [DeepSpeed Deep Dive — Model Implementations for Inference (MII) (Heiko Hotz)](https://towardsdatascience.com/deepspeed-deep-dive-model-implementations-for-inference-mii-b02aa5d5e7f7)
+
+# Contributing
+
+This project welcomes contributions and suggestions. Most contributions require you to agree to a
+Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
+the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
+
+When you submit a pull request, a CLA bot will automatically determine whether you need to provide
+a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
+provided by the bot. You will only need to do this once across all repos using our CLA.
+
+This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
+For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
+contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
+
+# Trademarks
+
+This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft
+trademarks or logos is subject to and must follow
+[Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
+Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
+Any use of third-party trademarks or logos are subject to those third-party's policies.
diff --git a/mii/legacy/__init__.py b/mii/legacy/__init__.py
new file mode 100644
index 00000000..9f7ba998
--- /dev/null
+++ b/mii/legacy/__init__.py
@@ -0,0 +1,21 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+import grpc
+from .server import MIIServer
+from .client import MIIClient, mii_query_handle
+from .deployment import deploy
+from .terminate import terminate
+from .constants import DeploymentType, TaskType
+from .aml_related.utils import aml_output_path
+from .config import MIIConfig, ModelConfig
+from .utils import get_supported_models
+from .grpc_related.proto import legacymodelresponse_pb2_grpc as modelresponse_pb2_grpc
+
+__version__ = "0.0.0"
+non_persistent_models = {}
+try:
+ from .version import __version__
+except ImportError:
+ pass
diff --git a/mii/legacy/aml_related/__init__.py b/mii/legacy/aml_related/__init__.py
new file mode 100644
index 00000000..fa2121e0
--- /dev/null
+++ b/mii/legacy/aml_related/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+from .templates import *
+from .utils import get_acr_name, generate_aml_scripts, aml_output_path
diff --git a/mii/legacy/aml_related/templates.py b/mii/legacy/aml_related/templates.py
new file mode 100644
index 00000000..66de070c
--- /dev/null
+++ b/mii/legacy/aml_related/templates.py
@@ -0,0 +1,398 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+deployment = \
+"""$schema: https://azuremlschemas.azureedge.net/latest/managedOnlineDeployment.schema.json
+name:
+endpoint_name:
+model:
+ path:
+model_mount_path: /var/azureml-model
+code_configuration:
+ code:
+ scoring_script: score.py
+environment: azureml::
+environment_variables:
+ AML_APP_ROOT: /var/azureml-model/code
+ WORKER_TIMEOUT: 2400
+ WORKER_COUNT:
+ AZUREML_LOG_LEVEL: DEBUG
+ LOG_IO: 1
+instance_type:
+request_settings:
+ request_timeout_ms: 90000
+ max_concurrent_requests_per_instance:
+liveness_probe:
+ initial_delay: 300
+ timeout: 1
+ period: 60
+ success_threshold: 1
+ failure_threshold: 40
+readiness_probe:
+ initial_delay: 300
+ timeout: 1
+ period: 60
+ success_threshold: 1
+ failure_threshold: 40
+instance_count: 1
+"""
+
+endpoint = \
+"""$schema: https://azuremlschemas.azureedge.net/latest/managedOnlineEndpoint.schema.json
+name:
+auth_mode: key
+"""
+
+environment = \
+"""$schema: https://azuremlschemas.azureedge.net/latest/environment.schema.json
+name:
+version:
+image: .azurecr.io/:
+inference_config:
+ liveness_route:
+ path: /
+ port: 5001
+ readiness_route:
+ path: /
+ port: 5001
+ scoring_route:
+ path: /score
+ port: 5001
+"""
+
+model_download = \
+"""import os
+import glob
+import shutil
+
+# Path and model params
+model_path = ""
+tmp_download_path = "./tmp/"
+snapshot_rel_path = "*/snapshots/*/*"
+model = ""
+task = ""
+
+# Must set cache location before loading transformers
+os.environ["TRANSFORMERS_CACHE"] = tmp_download_path
+
+from transformers import pipeline
+from huggingface_hub import snapshot_download
+
+# Download model
+try:
+ _ = pipeline(task=task, model=model)
+except OSError:
+ # Sometimes the model cannot be downloaded and we need to grab the snapshot
+ snapshot_download(model, cache_dir=tmp_download_path)
+
+# We need to resolve symlinks and move files to model_path dir
+os.mkdir(model_path)
+for f_path in glob.glob(os.path.join(tmp_download_path, snapshot_rel_path)):
+ f_name = os.path.basename(f_path)
+ real_file = os.path.realpath(f_path)
+ new_file = os.path.join(model_path, f_name)
+ os.rename(real_file, new_file)
+
+shutil.rmtree(tmp_download_path)
+"""
+
+deploy = \
+"""set -e
+python3 model_download.py
+az acr build -r --build-arg no-cache=True -t ":" build
+az ml environment create -f environment.yml
+az ml online-endpoint create -n "" -f endpoint.yml
+az ml online-deployment create -n "" -f deployment.yml
+"""
+
+dockerfile = \
+ """FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04
+
+ENV AML_APP_ROOT=/var/azureml-model/code \
+ BUILD_DIR=/tmp/build \
+ LANG=C.UTF-8 \
+ LC_ALL=C.UTF-8 \
+ DEBIAN_FRONTEND=noninteractive \
+ AZUREML_MODEL_DIR=/var/azureml-model \
+ MII_MODEL_DIR=/var/azureml-model \
+ AZUREML_ENTRY_SCRIPT=score.py \
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
+ MII_CACHE_PATH=/tmp/mii_cache
+
+COPY . $BUILD_DIR
+
+RUN mkdir -p $BUILD_DIR && \
+ apt-get update && \
+ apt-get install -y --no-install-recommends nginx-light wget sudo runit rsyslog libcurl4 unzip git-all && \
+ apt-get autoremove -y && \
+ apt-get clean -y && \
+ rm -rf /usr/share/man/* /var/lib/apt/lists/* && \
+ mv "$BUILD_DIR/gunicorn_app" /etc/nginx/sites-available/ && \
+ rm /etc/nginx/sites-enabled/default && \
+ ln -s /etc/nginx/sites-available/gunicorn_app /etc/nginx/sites-enabled/ && \
+ useradd --create-home dockeruser && \
+ usermod -aG sudo dockeruser && \
+ echo "dockeruser ALL=(ALL:ALL) NOPASSWD:/usr/sbin/service nginx start" >> /etc/sudoers.d/dockeruser && \
+ mkdir -p /opt/miniconda /var/azureml-logger /var/azureml-util && \
+ chown -R dockeruser:root /opt/miniconda && \
+ cp -r "$BUILD_DIR/runit" /var && \
+ chown -R dockeruser:root /var/runit && \
+ mkdir -p {$AZUREML_MODEL_DIR,$MII_CACHE_PATH} && chmod 775 {$AZUREML_MODEL_DIR,$MII_CACHE_PATH} && chown -R dockeruser:root {$AZUREML_MODEL_DIR,$MII_CACHE_PATH}
+
+ENV PATH=/opt/miniconda/envs/amlenv/bin:/opt/miniconda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin \
+ AZUREML_CONDA_ENVIRONMENT_PATH=/opt/miniconda/envs/amlenv \
+ LD_LIBRARY_PATH=/usr/local/:/usr/local/lib:/usr/local/cuda:/usr/local/nvidia/lib:$LD_LIBRARY_PATH \
+ SVDIR=/var/runit \
+ AZUREML_INFERENCE_SERVER_HTTP_ENABLED=True
+
+USER dockeruser
+
+SHELL ["/bin/bash", "-c"]
+
+RUN cd ~ && \
+ wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
+ chmod +x Miniconda3-latest-Linux-x86_64.sh && \
+ bash ./Miniconda3-latest-Linux-x86_64.sh -bf -p /opt/miniconda && \
+ conda create -n amlenv python=3.10 -y
+
+ENV PATH="/opt/miniconda/envs/amlenv/bin:$AML_APP_ROOT:$PATH" \
+ CUDA_HOME=/usr/local/cuda \
+ LD_LIBRARY_PATH="/opt/miniconda/envs/amlenv/lib:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" \
+ CONDA_DEFAULT_ENV=amlenv \
+ PATH=$PATH:/usr/local/cuda/bin
+
+RUN /opt/miniconda/envs/amlenv/bin/pip install -r "$BUILD_DIR/requirements.txt" --extra-index-url https://download.pytorch.org/whl/cu113 && \
+ /opt/miniconda/envs/amlenv/bin/pip install azureml-inference-server-http && \
+ /opt/miniconda/envs/amlenv/bin/pip install git+https://github.com/microsoft/DeepSpeed.git && \
+ /opt/miniconda/envs/amlenv/bin/pip install git+https://github.com/microsoft/DeepSpeed-MII.git && \
+ /opt/miniconda/envs/amlenv/bin/pip install git+https://github.com/huggingface/transformers.git
+
+
+EXPOSE 5001
+
+WORKDIR $AZUREML_MODEL_DIR/code
+
+CMD sudo service nginx start && cd $AZUREML_MODEL_DIR/code && azmlinfsrv --model_dir $AZUREML_MODEL_DIR --entry_script $AZUREML_MODEL_DIR/code/score.py --port 31311
+"""
+
+gunicorn = \
+"""upstream gunicorn {
+ server 127.0.0.1:31311;
+}
+
+server {
+listen *:5001;
+ location / {
+ proxy_set_header Host $http_host;
+ proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
+ proxy_set_header X-Forwarded-Proto $scheme;
+ proxy_read_timeout 300;
+ proxy_connect_timeout 300;
+ proxy_send_timeout 300;
+ proxy_redirect off;
+ proxy_buffering off;
+ proxy_pass http://gunicorn;
+ }
+}
+
+ map $http_upgrade $connection_upgrade {
+ default upgrade;
+ '' close;
+ }
+"""
+
+gunicorn_run = \
+"""#!/bin/bash
+
+
+SCRIPT_PATH=$(dirname $(realpath -s "$0"))
+
+# Error handling that sleeps so logs are properly sent
+handle_error () {
+ echo "Error occurred. Sleeping to send error logs."
+ # Sleep 45 seconds
+ sleep 45
+ exit 95
+}
+
+format_print () {
+ echo "$(date -uIns) | gunicorn/run | $1"
+}
+
+echo "`date -uIns` - gunicorn/run $@"
+
+format_print ""
+format_print "###############################################"
+format_print "AzureML Container Runtime Information"
+format_print "###############################################"
+format_print ""
+
+
+if [[ -z "${AZUREML_CONDA_ENVIRONMENT_PATH}" ]]; then
+ # If AZUREML_CONDA_ENVIRONMENT_PATH exists, add to the front of the LD_LIBRARY_PATH
+ export LD_LIBRARY_PATH="$(conda info --root)/lib:$LD_LIBRARY_PATH"
+else
+ # Otherwise, take the conda root and add that to the front of the LD_LIBRARY_PATH
+ export LD_LIBRARY_PATH="$AZUREML_CONDA_ENVIRONMENT_PATH/lib:$LD_LIBRARY_PATH"
+fi
+
+if [[ -f "/IMAGE_INFORMATION" ]]; then
+ format_print ""
+ format_print "AzureML image information: $(cat /IMAGE_INFORMATION)"
+ format_print ""
+fi
+
+format_print ""
+format_print "PATH environment variable: $PATH"
+format_print "PYTHONPATH environment variable: $PYTHONPATH"
+format_print ""
+format_print "Pip Dependencies (before dynamic installation)"
+echo
+pip freeze
+echo
+
+if [[ -n "$AZUREML_INFERENCE_SERVER_HTTP_ENABLED" ]]; then
+ # Currently locking this feature to inference images.
+
+ if [[ -n "$AZUREML_ENTRY_SCRIPT" ]]; then
+ # Remove leading forward slash if it exists and then append the directory to the AML_APP_ROOT
+ export ENTRY_SCRIPT_DIR="${AML_APP_ROOT:-/var/azureml-app}/$(dirname "${AZUREML_ENTRY_SCRIPT#/}")"
+ else
+ export ENTRY_SCRIPT_DIR=${AML_APP_ROOT:-/var/azureml-app}
+ fi
+
+ format_print ""
+ format_print "Entry script directory: $ENTRY_SCRIPT_DIR"
+ format_print ""
+ format_print "###############################################"
+ format_print "Dynamic Python Package Installation"
+ format_print "###############################################"
+ format_print ""
+
+
+ if [[ -n "$AZUREML_EXTRA_PYTHON_LIB_PATH" ]]; then
+ # Pre-installed mounted dependencies, check for the variable and if the folder exists.
+
+ export EXTRA_PYTHON_LIB_FULL_PATH="${ENTRY_SCRIPT_DIR}/${AZUREML_EXTRA_PYTHON_LIB_PATH}"
+
+ if [[ -d $EXTRA_PYTHON_LIB_FULL_PATH ]]; then
+ format_print "Adding ${EXTRA_PYTHON_LIB_FULL_PATH} in PYTHONPATH"
+ export PYTHONPATH="${EXTRA_PYTHON_LIB_FULL_PATH}:$PYTHONPATH"
+ else
+ format_print "Expected folder with pre-installed packages not found: ${EXTRA_PYTHON_LIB_FULL_PATH}. Exiting with error ..."
+ exit 97
+ fi
+ elif [[ -n "$AZUREML_EXTRA_CONDA_YAML_ABS_PATH" || -n "$AZUREML_EXTRA_CONDA_YAML" ]]; then
+ # Dynamic installation conda.yml, check for the variable and if the file exists for relative and absolute paths.
+ # Need the absolute path for the MLFlow scenario where yaml could exist outside of azureml-app folder.
+
+ if [[ -n "$AZUREML_EXTRA_CONDA_YAML_ABS_PATH" ]]; then
+ export CONDA_FULL_PATH="$AZUREML_EXTRA_CONDA_YAML_ABS_PATH"
+ else
+ export CONDA_FULL_PATH="${ENTRY_SCRIPT_DIR}/${AZUREML_EXTRA_CONDA_YAML}"
+ fi
+
+ # NOTE: This may take a very long time if existing dependencies are added!
+ # Source: https://stackoverflow.com/questions/53250933/conda-takes-20-minutes-to-solve-environment-when-package-is-already-installed
+ if [[ -f $CONDA_FULL_PATH ]]; then
+ format_print "Updating conda environment from ${CONDA_FULL_PATH} !"
+
+ # Extract version from amlenv
+ # If this is not installed, the value is empty. There will be a Warning output that states that the package is not installed.
+ SERVER_VERSION="$(pip show azureml-inference-server-http | grep Version | sed -e 's/.*: //')"
+
+ if [ -z "$SERVER_VERSION" ]; then
+ format_print "azureml-inference-server-http not installed"
+ exit 96
+ fi
+
+ # Copy user conda.yml to tmp folder since we don't have write access to user folder
+ # Write access to folder is required for conda env create, and tmp folder has write access
+ export CONDA_FILENAME="${TMPDIR:=/tmp}/copied_env_$(date +%s%N).yaml"
+
+ cp "${CONDA_FULL_PATH}" "${CONDA_FILENAME}"
+
+ # Create a userenv from the conda yaml that replaces the existing amlenv
+ conda env create -n userenv -f "${CONDA_FILENAME}" || { handle_error ; }
+
+ export AZUREML_CONDA_ENVIRONMENT_PATH="/opt/miniconda/envs/userenv"
+ export PATH="/opt/miniconda/envs/userenv/bin:$PATH"
+ export LD_LIBRARY_PATH="$AZUREML_CONDA_ENVIRONMENT_PATH/lib:$LD_LIBRARY_PATH"
+
+ # Install the same version of the http server
+ pip install azureml-inference-server-http=="$SERVER_VERSION" || { handle_error ; }
+
+ else
+ format_print "Dynamic Python packages installation is enabled but expected conda yaml file not found: ${CONDA_FULL_PATH}. Exiting with error ..."
+ exit 98
+ fi
+ elif [[ -n "$AZUREML_EXTRA_REQUIREMENTS_TXT" ]]; then
+ # Dynamic installation requirements.txt, check for the variable and if the file exists for relative and absolute paths.
+
+ export REQUIREMENTS_TXT_FULL_PATH="${ENTRY_SCRIPT_DIR}/${AZUREML_EXTRA_REQUIREMENTS_TXT}"
+
+ if [[ -f $REQUIREMENTS_TXT_FULL_PATH ]]; then
+ format_print "Installing Python packages from ${REQUIREMENTS_TXT_FULL_PATH} !"
+ pip install -r "$REQUIREMENTS_TXT_FULL_PATH" || { handle_error ; }
+ else
+ format_print "Dynamic Python packages installation is enabled but expected requirements file not found: ${REQUIREMENTS_TXT_FULL_PATH}. Exiting with error ..."
+ exit 99
+ fi
+ else
+ format_print "Dynamic Python package installation is disabled."
+ fi
+fi
+
+format_print ""
+format_print "###############################################"
+format_print "AzureML Inference Server"
+format_print "###############################################"
+format_print ""
+
+cd "${AML_APP_ROOT:-/var/azureml-app}"
+
+# Check the result of $(pip show ...) instead of $(which azmlinfsrv). If we launch azmlinfsrv we need to make sure it is
+# from the active python environment. $(which azmlinfsrv) may point to the azmlinfsrv in a different virtual env.
+if [[ -n "$AZUREML_INFERENCE_SERVER_HTTP_ENABLED" || -n "$(pip show azureml-inference-server-http 2>/dev/null)" ]]; then
+ format_print "Starting AzureML Inference Server HTTP."
+
+ # Ensure the presence of debugpy if the user enabled local debugging. See ensure_debugpy.py for more details.
+ if [[ -n $AZUREML_DEBUG_PORT ]]; then
+ python $SCRIPT_PATH/ensure_debugpy.py
+ if [[ $? -ne 0 ]]; then
+ format_print "Exiting because debugpy cannot be not injected into entry.py."
+ exit 94
+ fi
+ fi
+
+ exec azmlinfsrv --entry_script "${AZUREML_ENTRY_SCRIPT:-main.py}" --port 31311
+else
+ format_print ""
+ format_print "Starting HTTP server"
+ format_print ""
+
+ export PYTHONPATH="${AML_SERVER_ROOT:-/var/azureml-server}:$PYTHONPATH"
+ exec gunicorn -c "${AML_SERVER_ROOT:-/var/azureml-server}/gunicorn_conf.py" "entry:app"
+fi
+"""
+
+gunicorn_finish = \
+"""#!/bin/bash
+
+exit_code="$1" # The exit code from gunicorn
+signal="$2" # The signal which caused gunicorn to exit (or 0)
+
+echo "`date -uIns` - gunicorn/finish $@"
+echo "`date -uIns` - Exit code $exit_code is not normal. Killing image."
+
+killall -SIGHUP runsvdir
+"""
+
+requirements = \
+"""torch>=2.0.0
+grpcio
+grpcio-tools
+pydantic
+asyncio
+"""
diff --git a/mii/legacy/aml_related/utils.py b/mii/legacy/aml_related/utils.py
new file mode 100644
index 00000000..24e1bd62
--- /dev/null
+++ b/mii/legacy/aml_related/utils.py
@@ -0,0 +1,151 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+import os
+import subprocess
+import yaml
+import mii.legacy as mii
+
+
+def get_acr_name():
+ try:
+ acr_name = subprocess.check_output(
+ ["az",
+ "ml",
+ "workspace",
+ "show",
+ "--query",
+ "container_registry"],
+ text=True)
+ return acr_name.strip().replace('"', '').rsplit('/', 1)[-1]
+ except subprocess.CalledProcessError as e:
+ print("\n", "-" * 30, "\n")
+ print("Unable to obtain ACR name from Azure-CLI. Please verify that you:")
+ print(
+ "\t- Have Azure-CLI installed (https://learn.microsoft.com/en-us/cli/azure/install-azure-cli)"
+ )
+ print("\t- Are logged in to an active account on Azure-CLI ($az login)")
+ print("\t- Have Azure-CLI ML plugin installed ($az extension add --name ml)")
+ print("\t- You have the default subscription, resource group, and workspace set")
+ print("\t\t- az account set --subscription YOUR_SUBSCRIPTION")
+ print("\t\t- az config set defaults.group=YOUR_GROUP")
+ print("\t\t- az config set defaults.workspace=YOUR_WORKSPACE")
+ print("\n", "-" * 30, "\n")
+ raise (e)
+
+
+def aml_output_path(deployment_name):
+ output_path = os.path.join(os.getcwd(), f"{deployment_name}_aml")
+ os.makedirs(output_path, exist_ok=True)
+ return output_path
+
+
+def fill_template(template, replace_dict):
+ for var, val in replace_dict.items():
+ template = template.replace(var, val)
+ return template
+
+
+def write_out_script(output_file, script):
+ dir_path = os.path.dirname(output_file)
+ os.makedirs(dir_path, exist_ok=True)
+ with open(output_file, "w") as f:
+ f.write(script)
+
+
+def write_out_yaml(output_file, yaml_data):
+ dir_path = os.path.dirname(output_file)
+ os.makedirs(dir_path, exist_ok=True)
+ with open(output_file, "w") as f:
+ yaml.dump(yaml.safe_load(yaml_data), f)
+
+
+def generate_aml_scripts(acr_name,
+ deployment_name,
+ model_name,
+ task_name,
+ replica_num,
+ instance_type,
+ version):
+ output_dir = aml_output_path(deployment_name)
+ code_path = os.path.join(output_dir, "code")
+ model_path = os.path.join(output_dir, "model")
+ endpoint_name = deployment_name + "-endpoint"
+ environment_name = deployment_name + "-environment"
+ image_name = deployment_name + "-image"
+
+ # Dictionary to fill template values
+ replace_dict = {
+ "": acr_name,
+ "": deployment_name,
+ "": model_name,
+ "": task_name,
+ "": str(replica_num),
+ "": instance_type,
+ "": str(version),
+ "": code_path,
+ "": model_path,
+ "": endpoint_name,
+ "": environment_name,
+ "": image_name,
+ }
+
+ # Docker files
+ write_out_script(os.path.join(output_dir,
+ "build",
+ "Dockerfile"),
+ fill_template(mii.aml_related.templates.dockerfile,
+ replace_dict))
+ write_out_script(os.path.join(output_dir,
+ "build",
+ "gunicorn_app"),
+ fill_template(mii.aml_related.templates.gunicorn,
+ replace_dict))
+ write_out_script(os.path.join(output_dir,
+ "build",
+ "runit",
+ "gunicorn",
+ "run"),
+ fill_template(mii.aml_related.templates.gunicorn_run,
+ replace_dict))
+ write_out_script(
+ os.path.join(output_dir,
+ "build",
+ "runit",
+ "gunicorn",
+ "finish"),
+ fill_template(mii.aml_related.templates.gunicorn_finish,
+ replace_dict))
+ write_out_script(os.path.join(output_dir,
+ "build",
+ "requirements.txt"),
+ fill_template(mii.aml_related.templates.requirements,
+ replace_dict))
+
+ # Model download script
+ write_out_script(
+ os.path.join(output_dir,
+ "model_download.py"),
+ fill_template(mii.aml_related.templates.model_download,
+ replace_dict))
+
+ # Deployment script
+ write_out_script(os.path.join(output_dir,
+ "deploy.sh"),
+ fill_template(mii.aml_related.templates.deploy,
+ replace_dict))
+
+ # Yaml configs
+ write_out_yaml(os.path.join(output_dir,
+ "deployment.yml"),
+ fill_template(mii.aml_related.templates.deployment,
+ replace_dict))
+ write_out_yaml(os.path.join(output_dir,
+ "endpoint.yml"),
+ fill_template(mii.aml_related.templates.endpoint,
+ replace_dict))
+ write_out_yaml(os.path.join(output_dir,
+ "environment.yml"),
+ fill_template(mii.aml_related.templates.environment,
+ replace_dict))
diff --git a/mii/legacy/client.py b/mii/legacy/client.py
new file mode 100644
index 00000000..2d8ca81d
--- /dev/null
+++ b/mii/legacy/client.py
@@ -0,0 +1,150 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+import asyncio
+import grpc
+import requests
+import mii.legacy as mii
+from .grpc_related.proto import legacymodelresponse_pb2 as modelresponse_pb2
+from .grpc_related.proto import legacymodelresponse_pb2_grpc as modelresponse_pb2_grpc
+from .constants import GRPC_MAX_MSG_SIZE, TaskType, DeploymentType
+from .method_table import GRPC_METHOD_TABLE
+from .config import MIIConfig
+from .utils import import_score_file
+
+
+def _get_mii_config(deployment_name):
+ mii_config = import_score_file(deployment_name, DeploymentType.LOCAL).mii_config
+ return MIIConfig(**mii_config)
+
+
+def mii_query_handle(deployment_name):
+ """Get a query handle for a local deployment:
+
+ mii/examples/local/gpt2-query-example.py
+ mii/examples/local/roberta-qa-query-example.py
+
+ Arguments:
+ deployment_name: Name of the deployment. Used as an identifier for posting queries for ``LOCAL`` deployment.
+
+ Returns:
+ query_handle: A query handle with a single method `.query(request_dictionary)` using which queries can be sent to the model.
+ """
+
+ if deployment_name in mii.non_persistent_models:
+ inference_pipeline, task = mii.non_persistent_models[deployment_name]
+ return MIINonPersistentClient(task, deployment_name)
+
+ mii_config = _get_mii_config(deployment_name)
+ return MIIClient(mii_config.model_config.task,
+ "localhost", # TODO: This can probably be removed
+ mii_config.port_number)
+
+
+def create_channel(host, port):
+ return grpc.aio.insecure_channel(
+ f"{host}:{port}",
+ options=[
+ ("grpc.max_send_message_length",
+ GRPC_MAX_MSG_SIZE),
+ ("grpc.max_receive_message_length",
+ GRPC_MAX_MSG_SIZE),
+ ],
+ )
+
+
+class MIIClient:
+ """
+ Client to send queries to a single endpoint.
+ """
+ def __init__(self, task, host, port):
+ self.asyncio_loop = asyncio.get_event_loop()
+ channel = create_channel(host, port)
+ self.stub = modelresponse_pb2_grpc.ModelResponseStub(channel)
+ self.task = task
+
+ async def _request_async_response(self, request_dict, **query_kwargs):
+ if self.task not in GRPC_METHOD_TABLE:
+ raise ValueError(f"unknown task: {self.task}")
+
+ task_methods = GRPC_METHOD_TABLE[self.task]
+ proto_request = task_methods.pack_request_to_proto(request_dict, **query_kwargs)
+ proto_response = await getattr(self.stub, task_methods.method)(proto_request)
+ return task_methods.unpack_response_from_proto(proto_response)
+
+ def query(self, request_dict, **query_kwargs):
+ return self.asyncio_loop.run_until_complete(
+ self._request_async_response(request_dict,
+ **query_kwargs))
+
+ async def terminate_async(self):
+ await self.stub.Terminate(
+ modelresponse_pb2.google_dot_protobuf_dot_empty__pb2.Empty())
+
+ def terminate(self):
+ self.asyncio_loop.run_until_complete(self.terminate_async())
+
+ async def create_session_async(self, session_id):
+ return await self.stub.CreateSession(
+ modelresponse_pb2.SessionID(session_id=session_id))
+
+ def create_session(self, session_id):
+ assert (
+ self.task == TaskType.TEXT_GENERATION
+ ), f"Session creation only available for task '{TaskType.TEXT_GENERATION}'."
+ return self.asyncio_loop.run_until_complete(
+ self.create_session_async(session_id))
+
+ async def destroy_session_async(self, session_id):
+ await self.stub.DestroySession(modelresponse_pb2.SessionID(session_id=session_id)
+ )
+
+ def destroy_session(self, session_id):
+ assert (
+ self.task == TaskType.TEXT_GENERATION
+ ), f"Session deletion only available for task '{TaskType.TEXT_GENERATION}'."
+ self.asyncio_loop.run_until_complete(self.destroy_session_async(session_id))
+
+
+class MIINonPersistentClient:
+ def __init__(self, task, deployment_name):
+ self.task = task
+ self.deployment_name = deployment_name
+
+ def query(self, request_dict, **query_kwargs):
+ assert (
+ self.deployment_name in mii.non_persistent_models
+ ), f"deployment: {self.deployment_name} not found"
+ task_methods = GRPC_METHOD_TABLE[self.task]
+ inference_pipeline = mii.non_persistent_models[self.deployment_name][0]
+
+ # TODO: refactor so this code is shared between non-persistent and
+ # persistent deployments in method_table.py
+ if self.task == TaskType.QUESTION_ANSWERING:
+ if "question" not in request_dict or "context" not in request_dict:
+ raise Exception(
+ "Question Answering Task requires 'question' and 'context' keys")
+ args = (request_dict["question"], request_dict["context"])
+ kwargs = query_kwargs
+
+ elif self.task == TaskType.CONVERSATIONAL:
+ conv = task_methods.create_conversation(request_dict)
+ args = (conv, )
+ kwargs = query_kwargs
+
+ else:
+ args = (request_dict["query"], )
+ kwargs = query_kwargs
+
+ return task_methods.run_inference(inference_pipeline, args, query_kwargs)
+
+ def terminate(self):
+ print(f"Terminating {self.deployment_name}...")
+ del mii.non_persistent_models[self.deployment_name]
+
+
+def terminate_restful_gateway(deployment_name):
+ mii_config = _get_mii_config(deployment_name)
+ if mii_config.enable_restful_api:
+ requests.get(f"http://localhost:{mii_config.restful_api_port}/terminate")
diff --git a/mii/legacy/config.py b/mii/legacy/config.py
new file mode 100644
index 00000000..26b10487
--- /dev/null
+++ b/mii/legacy/config.py
@@ -0,0 +1,417 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+import torch
+import os
+import string
+from typing import List, Optional, Dict, Any
+import mii.legacy as mii
+from .constants import DeploymentType, TaskType, MII_MODEL_PATH_DEFAULT
+from .pydantic_v1 import validator, root_validator, Field
+
+from deepspeed.runtime.config_utils import DeepSpeedConfigModel
+from deepspeed.inference.config import DtypeEnum
+from deepspeed.launcher.runner import DLTS_HOSTFILE, fetch_hostfile
+
+
+class ReplicaConfig(DeepSpeedConfigModel):
+ hostname: str = ""
+ tensor_parallel_ports: List[int] = []
+ torch_dist_port: int = None
+ gpu_indices: List[int] = []
+
+
+class ModelConfig(DeepSpeedConfigModel):
+ model: str
+ """
+ Name of a supported model for the task. Models in MII are sourced from
+ multiple open-source projects such as Huggingface Transformer, FairSeq,
+ EluetherAI etc. For the list of supported models for each task, please see
+ here [TODO].
+ """
+
+ task: TaskType
+ """
+ Name of the machine learning task to be deployed.Currently MII supports the
+ following list of tasks ``['text-generation', 'text-classification',
+ 'question-answering', 'fill-mask', 'token-classification',
+ 'conversational', 'text-to-image']``
+ """
+
+ dtype: DtypeEnum = DtypeEnum.fp32
+ """
+ Desired model data type, will convert model to this type. Supported target
+ types: `torch.half`, `torch.float`, `torch.int8` (for BLOOM models)
+ """
+
+ model_path: str = ""
+ """
+ In LOCAL deployments this is the local path where model checkpoints are
+ available. In AML deployments this is an optional relative path with
+ AZURE_MODEL_DIR for the deployment.
+ """
+
+ load_with_sys_mem: bool = False
+ """
+ Loads the model onto system memory instead of GPU memory. This can help
+ avoid OOM errors when sharding a model across several GPUs because MII will
+ try to load a full copy of each model onto each GPU initially.
+ """
+
+ meta_tensor: bool = False
+ """
+ Loads the initial HuggingFace model using Meta Tensors that use no memory.
+ Can dramatically improve load time and reduce memory requirements on
+ supported models. Supported for GPT-J, GPT-NeoX, OPT, and BLOOM when kernel
+ injection is enabled. Supported for all models when kernel injection is
+ disabled.
+ """
+
+ deploy_rank: Optional[List[int]] = None
+ """
+ GPU indices a model is deployed on. Note that CUDA_VISIBLE_DEVICES does not
+ work with DeepSpeed-MII.
+ """
+
+ torch_dist_port: int = 29500
+ """
+ Torch distributed port.
+ """
+
+ replica_num: int = 1
+ """
+ Number of model replicas. Enables easy data parallelism.
+ """
+
+ replica_configs: List[ReplicaConfig] = []
+ """
+ Configuration details for each replica. This will be automatically
+ generated, but you can provide a set of custom configs.
+ """
+
+ profile_model_time: bool = False
+ """
+ Enable profiling of model times (i.e., without communication overhead).
+ """
+
+ skip_model_check: bool = False
+ """
+ Skip validation that a model supports a given task.
+ """
+
+ hf_auth_token: Optional[str] = Field(
+ None,
+ deprecated=True,
+ deprecated_msg=
+ "Parameter will be removed. Please use the `pipeline_kwargs` field to pass kwargs to the HuggingFace pipeline creation.",
+ )
+ """
+ HuggingFace authentication token for accessing models. Will be propagated
+ to all ModelConfig if none are provided there.
+ """
+
+ trust_remote_code: bool = Field(
+ False,
+ deprecated=True,
+ deprecated_msg=
+ "Parameter will be removed. Please use the `pipeline_kwargs` field to pass kwargs to the HuggingFace pipeline creation.",
+ )
+ """
+ HuggingFace `tranformer.pipeline` option for `trust_remote_code`.
+ """
+
+ pipeline_kwargs: Dict[str, Any] = {}
+ """
+ kwargs to be passed to HuggingFace's `transformer.pipeline`.
+ """
+
+ # TODO: Replace with DeepSpeedInferenceConfig
+ enable_deepspeed: bool = True
+ """
+ Enable DeepSpeed-Inference.
+ """
+
+ enable_zero: bool = False
+ """
+ Enable Zero-Inference.
+ """
+
+ ds_config: Dict[str, Any] = {}
+ """
+ DeepSpeed config to use when Zero-Inference is enabled.
+ """
+
+ tensor_parallel: int = 1
+ """
+ Tensor parallelism to use for a model (i.e., how many GPUs to shard a model across).
+ """
+
+ enable_cuda_graph: bool = False
+ """
+ Enables CUDA Graph captures with DeepSpeed-Inference.
+ """
+
+ replace_with_kernel_inject: bool = True
+ """
+ Enable custom kernel injection with DeepSpeed-Inference.
+ """
+
+ checkpoint_dict: Optional[Dict[str, Any]] = None
+ """
+ DeepSpeed model checkpoint dict.
+ """
+
+ max_tokens: int = 1024
+ """
+ The maximum number of tokens DeepSpeed-Inference can work with, including
+ the input and output tokens. Please consider increasing it to the required
+ token-length required for your use-case.
+ """
+ class Config:
+ json_encoders = {torch.dtype: lambda x: str(x)}
+
+ @property
+ def provider(self):
+ return mii.utils.get_provider(self.model, self.task)
+
+ @validator("checkpoint_dict")
+ def checkpoint_dict_valid(cls, field_value, values):
+ if field_value is None:
+ return field_value
+ for k in ["checkpoints", "version", "type", "base_dir"]:
+ if not field_value.get(k, ""):
+ raise ValueError(f"Missing key={k} in checkpoint_dict")
+ return field_value
+
+ @validator("deploy_rank", pre=True)
+ def deploy_rank_to_list(cls, field_value, values):
+ if field_value and not isinstance(field_value, list):
+ field_value = [field_value]
+ return field_value
+
+ @root_validator
+ def zero_or_meta(cls, values):
+ if values.get("enable_zero"):
+ assert not values.get(
+ "meta_tensor"
+ ), "ZeRO-Inference does not support meta tensors."
+ return values
+
+ @root_validator
+ def bloom_model_valid(cls, values):
+ if "bigscience/bloom" in values.get("model"):
+ # TODO: SHould be albe to use DtypeEnum here
+ assert values.get("dtype") in [
+ torch.int8,
+ torch.float16,
+ ], "Bloom models only support fp16/int8."
+ assert not values.get(
+ "enable_cuda_graph"
+ ), "Bloom models do not support CUDA Graph."
+ return values
+
+ @root_validator
+ def deploy_rank_valid(cls, values):
+ tensor_parallel = values.get("tensor_parallel")
+ deploy_rank = values.get("deploy_rank")
+
+ # if deploy rank is not given, default to align with TP value
+ if deploy_rank is None:
+ deploy_rank = list(range(tensor_parallel))
+
+ # number of ranks provided must be equal to TP size, DP is handled outside MII currently
+ assert tensor_parallel == len(
+ deploy_rank
+ ), f"{len(deploy_rank)} rank(s) provided in 'deploy_rank' does not align with tensor_parallel size of {tensor_parallel}"
+
+ values["deploy_rank"] = deploy_rank
+ return values
+
+ @root_validator
+ def set_model_path(cls, values):
+ model_path = values.get("model_path")
+ if not model_path:
+ if values.get("deployment_type") == DeploymentType.AML:
+ model_path = "model"
+ else:
+ model_path = MII_MODEL_PATH_DEFAULT
+ aml_model_dir = os.environ.get("AZUREML_MODEL_DIR", None)
+ if aml_model_dir and not model_path.startswith(aml_model_dir):
+ assert os.path.isabs(
+ aml_model_dir
+ ), "AZUREML_MODEL_DIR={aml_model_dir} must be an absolute path."
+ assert not os.path.isabs(
+ model_path
+ ), f"model_path={model_path} must be relative to append w/ AML path."
+ model_path = os.path.join(aml_model_dir, model_path)
+
+ values["model_path"] = model_path
+ return values
+
+ @root_validator
+ def validate_model_and_task(cls, values):
+ task = values.get("task")
+ model = values.get("model")
+ if not values.get("skip_model_check"):
+ mii.utils.check_if_task_and_model_is_valid(task, model)
+ if values.get("enable_deepspeed"):
+ mii.utils.check_if_task_and_model_is_supported(task, model)
+ # Skip any future checks
+ values["skip_model_check"] = True
+ return values
+
+ @root_validator
+ def meta_tensor_or_sys_mem(cls, values):
+ if values.get("meta_tensor") and values.get("load_with_sys_mem"):
+ raise ValueError(
+ "`meta_tensor` and `load_with_sys_mem` cannot be active at the same time."
+ )
+ return values
+
+ @root_validator
+ def zero_dtype_valid(cls, values):
+ if values.get("enable_zero"):
+ if values.get("ds_config").get("fp16", {}).get("enabled", False):
+ # TODO: We should be able to use DtypeEnum instead of torch.float
+ assert (
+ values.get("dtype") == torch.float16
+ ), "ZeRO FP16 enabled, `dtype` must be set to `torch.float16`"
+ else:
+ assert (
+ values.get("dtype") == torch.float32
+ ), "ZeRO FP16 disabled, `dtype` must be set to `torch.float32`"
+ return values
+
+ @root_validator
+ def deepspeed_or_zero(cls, values):
+ assert not (
+ values.get("enable_deepspeed") and values.get("enable_zero")
+ ), "DeepSpeed and ZeRO cannot both be enabled, select only one"
+ return values
+
+
+class MIIConfig(DeepSpeedConfigModel):
+ deployment_name: str
+ """
+ Name of the deployment. Used as an identifier for obtaining a inference
+ server client and posting queries.
+ """
+
+ deployment_type: DeploymentType = DeploymentType.LOCAL
+ """
+ One of the `enum mii.DeploymentTypes: [LOCAL]`.
+ * `LOCAL` uses a grpc server to create a local deployment.
+ * `NON_PERSISTENT` creates a local deployment that will end when the process exits.
+ * `AML` will generate the assets necessary to deploy on AML resources.
+ """
+
+ model_config: ModelConfig
+ """
+ Configuration for the deployed model(s).
+ """
+
+ port_number: int = 50050
+ """
+ Port number to use for the load balancer process.
+ """
+
+ enable_restful_api: bool = False
+ """
+ Enables a RESTful API that can be queries with via http POST method.
+ """
+
+ restful_api_port: int = 51080
+ """
+ Port number to use for the RESTful API.
+ """
+
+ hostfile: str = DLTS_HOSTFILE
+ """
+ DeepSpeed hostfile. Will be autogenerated if None is provided.
+ """
+
+ # TODO: Place AML-related configs in subconfig
+ version: int = 1
+ """
+ Version number to pass to AML deployments.
+ """
+
+ instance_type: str = "Standard_NC12s_v3"
+ """
+ AML instance type to use when create AML deployment assets.
+ """
+ @root_validator(skip_on_failure=True)
+ def AML_name_valid(cls, values):
+ if values.get("deployment_type") == DeploymentType.AML:
+ allowed_chars = set(string.ascii_lowercase + string.ascii_uppercase +
+ string.digits + "-")
+ assert (
+ set(values.get("deployment_name")) <= allowed_chars
+ ), "AML deployment names can only contain a-z, A-Z, 0-9, and '-'."
+ return values
+
+ def generate_replica_configs(self):
+ # TODO: refactor this function
+ hostfile = self.hostfile
+ port_number = self.port_number
+ torch_dist_port = self.model_config.torch_dist_port
+ tensor_parallel = self.model_config.tensor_parallel
+ replica_num = self.model_config.replica_num
+ replica_pool = _allocate_processes(hostfile, tensor_parallel, replica_num)
+ replica_configs = []
+ for i, (hostname, gpu_indices) in enumerate(replica_pool):
+ # Reserver port for a LB proxy when replication is enabled
+ port_offset = 1
+ base_port = port_number + i * tensor_parallel + port_offset
+ tensor_parallel_ports = list(range(base_port, base_port + tensor_parallel))
+ replica_torch_dist_port = torch_dist_port + (100 * i)
+ replica_configs.append(
+ ReplicaConfig(
+ hostname=hostname,
+ tensor_parallel_ports=tensor_parallel_ports,
+ torch_dist_port=replica_torch_dist_port,
+ gpu_indices=gpu_indices,
+ ))
+
+ self.model_config.replica_configs = replica_configs
+
+
+def _allocate_processes(hostfile_path, tensor_parallel, replica_num):
+ resource_pool = fetch_hostfile(hostfile_path)
+ assert (
+ resource_pool is not None and len(resource_pool) > 0
+ ), f"No hosts found in {hostfile_path}"
+
+ replica_pool = []
+ allocated_num = 0
+ for host, slots in resource_pool.items():
+ available_on_host = slots
+ while available_on_host >= tensor_parallel:
+ if allocated_num >= replica_num:
+ break
+ if slots < tensor_parallel:
+ raise ValueError(
+ f"Host {host} has {slots} slot(s), but {tensor_parallel} slot(s) are required"
+ )
+
+ allocated_num_on_host = slots - available_on_host
+ replica_pool.append((
+ host,
+ [
+ i for i in range(
+ allocated_num_on_host,
+ allocated_num_on_host + tensor_parallel,
+ )
+ ],
+ ))
+ allocated_num += 1
+
+ available_on_host -= tensor_parallel
+
+ if allocated_num < replica_num:
+ raise ValueError(
+ f"Not sufficient GPUs for {replica_num} replica(s), only {allocated_num} replica(s) can be deployed"
+ )
+
+ return replica_pool
diff --git a/mii/legacy/constants.py b/mii/legacy/constants.py
new file mode 100644
index 00000000..ea90b87a
--- /dev/null
+++ b/mii/legacy/constants.py
@@ -0,0 +1,88 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+from enum import Enum
+
+
+class DeploymentType(str, Enum):
+ LOCAL = "local"
+ AML = "aml"
+ NON_PERSISTENT = "non-persistent"
+
+
+class TaskType(str, Enum):
+ TEXT_GENERATION = "text-generation"
+ TEXT_CLASSIFICATION = "text-classification"
+ QUESTION_ANSWERING = "question-answering"
+ FILL_MASK = "fill-mask"
+ TOKEN_CLASSIFICATION = "token-classification"
+ CONVERSATIONAL = "conversational"
+ TEXT2IMG = "text-to-image"
+
+
+class ModelProvider(str, Enum):
+ HUGGING_FACE = "hugging-face"
+ ELEUTHER_AI = "eleuther-ai"
+ DIFFUSERS = "diffusers"
+
+
+SUPPORTED_MODEL_TYPES = {
+ 'roberta': ModelProvider.HUGGING_FACE,
+ 'xlm-roberta': ModelProvider.HUGGING_FACE,
+ 'gpt2': ModelProvider.HUGGING_FACE,
+ 'distilbert': ModelProvider.HUGGING_FACE,
+ 'bert': ModelProvider.HUGGING_FACE,
+ 'gpt_neo': ModelProvider.HUGGING_FACE,
+ 'gptj': ModelProvider.HUGGING_FACE,
+ 'opt': ModelProvider.HUGGING_FACE,
+ 'bloom': ModelProvider.HUGGING_FACE,
+ 'gpt-neox': ModelProvider.ELEUTHER_AI,
+ 'stable-diffusion': ModelProvider.DIFFUSERS,
+ 'llama': ModelProvider.HUGGING_FACE
+}
+
+REQUIRED_KEYS_PER_TASK = {
+ TaskType.TEXT_GENERATION: ["query"],
+ TaskType.TEXT_CLASSIFICATION: ["query"],
+ TaskType.QUESTION_ANSWERING: ["context",
+ "question"],
+ TaskType.FILL_MASK: ["query"],
+ TaskType.TOKEN_CLASSIFICATION: ["query"],
+ TaskType.CONVERSATIONAL: [
+ "text",
+ "conversation_id",
+ "past_user_inputs",
+ "generated_responses",
+ ],
+ TaskType.TEXT2IMG: ["query"],
+}
+
+MII_CACHE_PATH = "MII_CACHE_PATH"
+MII_CACHE_PATH_DEFAULT = "/tmp/mii_cache"
+
+MII_HF_CACHE_EXPIRATION = "MII_HF_CACHE_EXPIRATION"
+MII_HF_CACHE_EXPIRATION_DEFAULT = 60 * 60 # 1 hour
+
+MII_DEBUG_MODE = "MII_DEBUG_MODE"
+MII_DEBUG_MODE_DEFAULT = "0"
+
+MII_DEBUG_DEPLOY_KEY = "MII_DEBUG_DEPLOY_KEY"
+
+MII_DEBUG_BRANCH = "MII_DEBUG_BRANCH"
+MII_DEBUG_BRANCH_DEFAULT = "main"
+
+MII_MODEL_PATH_DEFAULT = "/tmp/mii_models"
+
+GRPC_MAX_MSG_SIZE = 2**27 # ~100MB
+
+TERMINATE_METHOD = "Terminate"
+CREATE_SESSION_METHOD = "CreateSession"
+DESTROY_SESSION_METHOD = "DestroySession"
+
+LB_MAX_WORKER_THREADS = 32
+
+SERVER_SHUTDOWN_TIMEOUT = 10
+
+RESTFUL_GATEWAY_SHUTDOWN_TIMEOUT = 1
+RESTFUL_API_PATH = "mii"
diff --git a/mii/deployment.py b/mii/legacy/deployment.py
similarity index 99%
rename from mii/deployment.py
rename to mii/legacy/deployment.py
index 15509aec..59954901 100644
--- a/mii/deployment.py
+++ b/mii/legacy/deployment.py
@@ -3,7 +3,7 @@
# DeepSpeed Team
import os
-import mii
+import mii.legacy as mii
from .logging import logger
from .models.score import create_score_file
diff --git a/mii/legacy/docs/CNAME b/mii/legacy/docs/CNAME
new file mode 100644
index 00000000..e69de29b
diff --git a/docs/GPT-NeoX.md b/mii/legacy/docs/GPT-NeoX.md
similarity index 100%
rename from docs/GPT-NeoX.md
rename to mii/legacy/docs/GPT-NeoX.md
diff --git a/docs/images/azure-cost.png b/mii/legacy/docs/images/azure-cost.png
similarity index 100%
rename from docs/images/azure-cost.png
rename to mii/legacy/docs/images/azure-cost.png
diff --git a/docs/images/bert.png b/mii/legacy/docs/images/bert.png
similarity index 100%
rename from docs/images/bert.png
rename to mii/legacy/docs/images/bert.png
diff --git a/docs/images/bloom.png b/mii/legacy/docs/images/bloom.png
similarity index 100%
rename from docs/images/bloom.png
rename to mii/legacy/docs/images/bloom.png
diff --git a/docs/images/gpt.png b/mii/legacy/docs/images/gpt.png
similarity index 100%
rename from docs/images/gpt.png
rename to mii/legacy/docs/images/gpt.png
diff --git a/docs/images/hero-dark.png b/mii/legacy/docs/images/hero-dark.png
similarity index 100%
rename from docs/images/hero-dark.png
rename to mii/legacy/docs/images/hero-dark.png
diff --git a/docs/images/hero-transparent.png b/mii/legacy/docs/images/hero-transparent.png
similarity index 100%
rename from docs/images/hero-transparent.png
rename to mii/legacy/docs/images/hero-transparent.png
diff --git a/docs/images/hero.png b/mii/legacy/docs/images/hero.png
similarity index 100%
rename from docs/images/hero.png
rename to mii/legacy/docs/images/hero.png
diff --git a/docs/images/llm-latency-sd-latency.png b/mii/legacy/docs/images/llm-latency-sd-latency.png
similarity index 100%
rename from docs/images/llm-latency-sd-latency.png
rename to mii/legacy/docs/images/llm-latency-sd-latency.png
diff --git a/docs/images/mii-arch.png b/mii/legacy/docs/images/mii-arch.png
similarity index 100%
rename from docs/images/mii-arch.png
rename to mii/legacy/docs/images/mii-arch.png
diff --git a/mii/legacy/docs/images/mii-dark.svg b/mii/legacy/docs/images/mii-dark.svg
new file mode 100644
index 00000000..d34cd6fb
--- /dev/null
+++ b/mii/legacy/docs/images/mii-dark.svg
@@ -0,0 +1,19 @@
+
diff --git a/mii/legacy/docs/images/mii-white.svg b/mii/legacy/docs/images/mii-white.svg
new file mode 100644
index 00000000..70b40f63
--- /dev/null
+++ b/mii/legacy/docs/images/mii-white.svg
@@ -0,0 +1,19 @@
+
diff --git a/docs/images/multi-gpu-latency.png b/mii/legacy/docs/images/multi-gpu-latency.png
similarity index 100%
rename from docs/images/multi-gpu-latency.png
rename to mii/legacy/docs/images/multi-gpu-latency.png
diff --git a/docs/images/opt-bloom.png b/mii/legacy/docs/images/opt-bloom.png
similarity index 100%
rename from docs/images/opt-bloom.png
rename to mii/legacy/docs/images/opt-bloom.png
diff --git a/docs/images/opt.png b/mii/legacy/docs/images/opt.png
similarity index 100%
rename from docs/images/opt.png
rename to mii/legacy/docs/images/opt.png
diff --git a/docs/images/roberta.png b/mii/legacy/docs/images/roberta.png
similarity index 100%
rename from docs/images/roberta.png
rename to mii/legacy/docs/images/roberta.png
diff --git a/docs/images/sd-hero-dark.png b/mii/legacy/docs/images/sd-hero-dark.png
similarity index 100%
rename from docs/images/sd-hero-dark.png
rename to mii/legacy/docs/images/sd-hero-dark.png
diff --git a/docs/images/sd-hero-light.png b/mii/legacy/docs/images/sd-hero-light.png
similarity index 100%
rename from docs/images/sd-hero-light.png
rename to mii/legacy/docs/images/sd-hero-light.png
diff --git a/docs/images/sd-latency.png b/mii/legacy/docs/images/sd-latency.png
similarity index 100%
rename from docs/images/sd-latency.png
rename to mii/legacy/docs/images/sd-latency.png
diff --git a/docs/images/tput-llms.png b/mii/legacy/docs/images/tput-llms.png
similarity index 100%
rename from docs/images/tput-llms.png
rename to mii/legacy/docs/images/tput-llms.png
diff --git a/examples/aml/fill-mask-example.py b/mii/legacy/examples/aml/fill-mask-example.py
similarity index 100%
rename from examples/aml/fill-mask-example.py
rename to mii/legacy/examples/aml/fill-mask-example.py
diff --git a/examples/aml/text-generation-bloom.py b/mii/legacy/examples/aml/text-generation-bloom.py
similarity index 100%
rename from examples/aml/text-generation-bloom.py
rename to mii/legacy/examples/aml/text-generation-bloom.py
diff --git a/examples/aml/text-generation-bloom560m-example.py b/mii/legacy/examples/aml/text-generation-bloom560m-example.py
similarity index 100%
rename from examples/aml/text-generation-bloom560m-example.py
rename to mii/legacy/examples/aml/text-generation-bloom560m-example.py
diff --git a/examples/benchmark/txt2img/README.md b/mii/legacy/examples/benchmark/txt2img/README.md
similarity index 100%
rename from examples/benchmark/txt2img/README.md
rename to mii/legacy/examples/benchmark/txt2img/README.md
diff --git a/examples/benchmark/txt2img/baseline-sd.py b/mii/legacy/examples/benchmark/txt2img/baseline-sd.py
similarity index 100%
rename from examples/benchmark/txt2img/baseline-sd.py
rename to mii/legacy/examples/benchmark/txt2img/baseline-sd.py
diff --git a/examples/benchmark/txt2img/mii-sd.py b/mii/legacy/examples/benchmark/txt2img/mii-sd.py
similarity index 100%
rename from examples/benchmark/txt2img/mii-sd.py
rename to mii/legacy/examples/benchmark/txt2img/mii-sd.py
diff --git a/examples/benchmark/txt2img/requirements.txt b/mii/legacy/examples/benchmark/txt2img/requirements.txt
similarity index 100%
rename from examples/benchmark/txt2img/requirements.txt
rename to mii/legacy/examples/benchmark/txt2img/requirements.txt
diff --git a/examples/benchmark/txt2img/utils.py b/mii/legacy/examples/benchmark/txt2img/utils.py
similarity index 100%
rename from examples/benchmark/txt2img/utils.py
rename to mii/legacy/examples/benchmark/txt2img/utils.py
diff --git a/examples/local/chat/README.md b/mii/legacy/examples/local/chat/README.md
similarity index 100%
rename from examples/local/chat/README.md
rename to mii/legacy/examples/local/chat/README.md
diff --git a/examples/local/chat/chat-client-example.py b/mii/legacy/examples/local/chat/chat-client-example.py
similarity index 100%
rename from examples/local/chat/chat-client-example.py
rename to mii/legacy/examples/local/chat/chat-client-example.py
diff --git a/examples/local/chat/chat-server-example.py b/mii/legacy/examples/local/chat/chat-server-example.py
similarity index 100%
rename from examples/local/chat/chat-server-example.py
rename to mii/legacy/examples/local/chat/chat-server-example.py
diff --git a/examples/local/conversational-example.py b/mii/legacy/examples/local/conversational-example.py
similarity index 100%
rename from examples/local/conversational-example.py
rename to mii/legacy/examples/local/conversational-example.py
diff --git a/examples/local/conversational-query-example.py b/mii/legacy/examples/local/conversational-query-example.py
similarity index 100%
rename from examples/local/conversational-query-example.py
rename to mii/legacy/examples/local/conversational-query-example.py
diff --git a/examples/local/fill-mask-example.py b/mii/legacy/examples/local/fill-mask-example.py
similarity index 100%
rename from examples/local/fill-mask-example.py
rename to mii/legacy/examples/local/fill-mask-example.py
diff --git a/examples/local/question-answering-example.py b/mii/legacy/examples/local/question-answering-example.py
similarity index 100%
rename from examples/local/question-answering-example.py
rename to mii/legacy/examples/local/question-answering-example.py
diff --git a/examples/local/question-answering-query-example.py b/mii/legacy/examples/local/question-answering-query-example.py
similarity index 100%
rename from examples/local/question-answering-query-example.py
rename to mii/legacy/examples/local/question-answering-query-example.py
diff --git a/examples/local/text-classification-example.py b/mii/legacy/examples/local/text-classification-example.py
similarity index 100%
rename from examples/local/text-classification-example.py
rename to mii/legacy/examples/local/text-classification-example.py
diff --git a/examples/local/text-classification-query-example.py b/mii/legacy/examples/local/text-classification-query-example.py
similarity index 100%
rename from examples/local/text-classification-query-example.py
rename to mii/legacy/examples/local/text-classification-query-example.py
diff --git a/examples/local/text-generation-bloom-example.py b/mii/legacy/examples/local/text-generation-bloom-example.py
similarity index 100%
rename from examples/local/text-generation-bloom-example.py
rename to mii/legacy/examples/local/text-generation-bloom-example.py
diff --git a/examples/local/text-generation-bloom560m-example.py b/mii/legacy/examples/local/text-generation-bloom560m-example.py
similarity index 100%
rename from examples/local/text-generation-bloom560m-example.py
rename to mii/legacy/examples/local/text-generation-bloom560m-example.py
diff --git a/examples/local/text-generation-fbopt-example.py b/mii/legacy/examples/local/text-generation-fbopt-example.py
similarity index 100%
rename from examples/local/text-generation-fbopt-example.py
rename to mii/legacy/examples/local/text-generation-fbopt-example.py
diff --git a/examples/local/text-generation-query-example.py b/mii/legacy/examples/local/text-generation-query-example.py
similarity index 100%
rename from examples/local/text-generation-query-example.py
rename to mii/legacy/examples/local/text-generation-query-example.py
diff --git a/examples/local/text-generation-zero-example.py b/mii/legacy/examples/local/text-generation-zero-example.py
similarity index 100%
rename from examples/local/text-generation-zero-example.py
rename to mii/legacy/examples/local/text-generation-zero-example.py
diff --git a/examples/local/token-classification-example.py b/mii/legacy/examples/local/token-classification-example.py
similarity index 100%
rename from examples/local/token-classification-example.py
rename to mii/legacy/examples/local/token-classification-example.py
diff --git a/examples/local/token-classification-query-example.py b/mii/legacy/examples/local/token-classification-query-example.py
similarity index 100%
rename from examples/local/token-classification-query-example.py
rename to mii/legacy/examples/local/token-classification-query-example.py
diff --git a/examples/local/txt2img-example.py b/mii/legacy/examples/local/txt2img-example.py
similarity index 100%
rename from examples/local/txt2img-example.py
rename to mii/legacy/examples/local/txt2img-example.py
diff --git a/examples/non_persistent/text-generation-bloom560-example.py b/mii/legacy/examples/non_persistent/text-generation-bloom560-example.py
similarity index 100%
rename from examples/non_persistent/text-generation-bloom560-example.py
rename to mii/legacy/examples/non_persistent/text-generation-bloom560-example.py
diff --git a/mii/models/providers/eleutherai.py b/mii/legacy/grpc_related/__init__.py
similarity index 100%
rename from mii/models/providers/eleutherai.py
rename to mii/legacy/grpc_related/__init__.py
diff --git a/mii/legacy/grpc_related/modelresponse_server.py b/mii/legacy/grpc_related/modelresponse_server.py
new file mode 100644
index 00000000..880bed0b
--- /dev/null
+++ b/mii/legacy/grpc_related/modelresponse_server.py
@@ -0,0 +1,276 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+import asyncio
+from concurrent import futures
+import logging
+
+import grpc
+from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2
+from .proto import legacymodelresponse_pb2_grpc as modelresponse_pb2_grpc
+import sys
+import threading
+import time
+
+from mii.legacy.constants import (
+ GRPC_MAX_MSG_SIZE,
+ CREATE_SESSION_METHOD,
+ DESTROY_SESSION_METHOD,
+ TERMINATE_METHOD,
+ LB_MAX_WORKER_THREADS,
+ SERVER_SHUTDOWN_TIMEOUT,
+ TaskType,
+)
+from mii.legacy.method_table import GRPC_METHOD_TABLE
+from mii.legacy.client import create_channel
+from mii.legacy.utils import unpack_proto_query_kwargs
+
+
+class ServiceBase(modelresponse_pb2_grpc.ModelResponseServicer):
+ """
+ Base class to provide common features of an inference server
+ """
+ def __init__(self):
+ self._stop_event = threading.Event()
+
+ def Terminate(self, request, context):
+ self._stop_event.set()
+ return google_dot_protobuf_dot_empty__pb2.Empty()
+
+ def get_stop_event(self):
+ return self._stop_event
+
+
+class ModelResponse(ServiceBase):
+ """
+ Implementation class of an MII inference server
+ """
+ def __init__(self, inference_pipeline):
+ super().__init__()
+ self.inference_pipeline = inference_pipeline
+ self.method_name_to_task = {m.method: t for t, m in GRPC_METHOD_TABLE.items()}
+ self.lock = threading.Lock()
+
+ def _get_model_time(self, model, sum_times=False):
+ model_times = []
+ # Only grab model times if profiling was enabled/exists
+ if getattr(model, "model_profile_enabled", False):
+ model_times = model.model_times()
+
+ if len(model_times) > 0:
+ if sum_times:
+ model_time = sum(model_times)
+ else:
+ # Unclear how to combine values, so just grab the most recent one
+ model_time = model_times[-1]
+ else:
+ # no model times were captured
+ model_time = -1
+ return model_time
+
+ def CreateSession(self, request, context):
+ task_methods = GRPC_METHOD_TABLE[TaskType.TEXT_GENERATION]
+ task_methods.create_session(request.session_id)
+ return google_dot_protobuf_dot_empty__pb2.Empty()
+
+ def DestroySession(self, request, context):
+ task_methods = GRPC_METHOD_TABLE[TaskType.TEXT_GENERATION]
+ task_methods.destroy_session(request.session_id)
+ return google_dot_protobuf_dot_empty__pb2.Empty()
+
+ def _run_inference(self, method_name, request_proto):
+ if method_name not in self.method_name_to_task:
+ raise ValueError(f"unknown method: {method_name}")
+
+ task = self.method_name_to_task[method_name]
+ if task not in GRPC_METHOD_TABLE:
+ raise ValueError(f"unknown task: {task}")
+
+ task_methods = GRPC_METHOD_TABLE[task]
+ args, kwargs = task_methods.unpack_request_from_proto(request_proto)
+
+ start = time.time()
+ with self.lock:
+ response = task_methods.run_inference(self.inference_pipeline, args, kwargs)
+ end = time.time()
+
+ model_time = (self._get_model_time(self.inference_pipeline.model,
+ sum_times=True) if hasattr(
+ self.inference_pipeline,
+ "model") else -1)
+
+ return task_methods.pack_response_to_proto(response, end - start, model_time)
+
+ def GeneratorReply(self, request, context):
+ return self._run_inference("GeneratorReply", request)
+
+ def Txt2ImgReply(self, request, context):
+ return self._run_inference("Txt2ImgReply", request)
+
+ def ClassificationReply(self, request, context):
+ return self._run_inference("ClassificationReply", request)
+
+ def QuestionAndAnswerReply(self, request, context):
+ return self._run_inference("QuestionAndAnswerReply", request)
+
+ def FillMaskReply(self, request, context):
+ return self._run_inference("FillMaskReply", request)
+
+ def TokenClassificationReply(self, request, context):
+ return self._run_inference("TokenClassificationReply", request)
+
+ def ConversationalReply(self, request, context):
+ return self._run_inference("ConversationalReply", request)
+
+
+class AtomicCounter:
+ def __init__(self, initial_value=0):
+ self.value = initial_value
+ self.lock = threading.Lock()
+
+ def get_and_increment(self):
+ with self.lock:
+ current_value = self.value
+ self.value += 1
+ return current_value
+
+
+def _get_grpc_method_name(method):
+ return method.split("/")[-1]
+
+
+class ParallelStubInvoker:
+ """
+ Invokes a gRPC method on multiple endpoints in parallel.
+ This class aims to call gRPC methods without conversions between proto and python object.
+ TensorParallelClient can be used for invocation with the conversions.
+ """
+ def __init__(self, host, ports):
+ # Assumption: target services are all on the same host
+ self.stubs = []
+ for port in ports:
+ channel = create_channel(host, port)
+ stub = modelresponse_pb2_grpc.ModelResponseStub(channel)
+ self.stubs.append(stub)
+
+ self.asyncio_loop = asyncio.get_event_loop()
+
+ async def _invoke_async(self, method_name, proto_request):
+ responses = []
+ for stub in self.stubs:
+ method = getattr(stub, method_name)
+ responses.append(method(proto_request))
+ return await responses[0]
+
+ def invoke(self, method_name, proto_request):
+ # This is needed because gRPC calls from interceptor are launched from
+ return asyncio.run_coroutine_threadsafe(
+ self._invoke_async(method_name,
+ proto_request),
+ self.asyncio_loop).result()
+
+
+class LoadBalancingInterceptor(grpc.ServerInterceptor):
+ def __init__(self, model_config):
+ super().__init__()
+ self.asyncio_loop = asyncio.get_event_loop()
+
+ self.stubs = [
+ ParallelStubInvoker(replica.hostname,
+ replica.tensor_parallel_ports)
+ for replica in model_config.replica_configs
+ ]
+ self.counter = AtomicCounter()
+ self.task = model_config.task
+ self.replica_sessions = {}
+
+ # Start the asyncio loop in a separate thread
+ def run_asyncio_loop(loop):
+ asyncio.set_event_loop(loop)
+ loop.run_forever()
+
+ threading.Thread(target=run_asyncio_loop, args=(self.asyncio_loop, )).start()
+
+ def choose_stub(self, call_count):
+ return self.stubs[call_count % len(self.stubs)]
+
+ def intercept_service(self, continuation, handler_call_details):
+ next_handler = continuation(handler_call_details)
+ assert next_handler.unary_unary is not None
+
+ def invoke_intercept_method(request_proto, context):
+ method_name = _get_grpc_method_name(handler_call_details.method)
+
+ if method_name == TERMINATE_METHOD:
+ for stub in self.stubs:
+ stub.invoke(TERMINATE_METHOD,
+ google_dot_protobuf_dot_empty__pb2.Empty())
+ self.asyncio_loop.call_soon_threadsafe(self.asyncio_loop.stop)
+ return next_handler.unary_unary(request_proto, context)
+
+ call_count = self.counter.get_and_increment()
+ replica_index = call_count % len(self.stubs)
+
+ if method_name == CREATE_SESSION_METHOD:
+ if request_proto.session_id in self.replica_sessions:
+ raise ValueError(
+ f"session {request_proto.session_id} already exists")
+ self.replica_sessions[request_proto.session_id] = replica_index
+ self.stubs[replica_index].invoke(CREATE_SESSION_METHOD, request_proto)
+ return google_dot_protobuf_dot_empty__pb2.Empty()
+
+ if method_name == DESTROY_SESSION_METHOD:
+ replica_index = self.replica_sessions.pop(request_proto.session_id)
+ self.stubs[replica_index].invoke(DESTROY_SESSION_METHOD, request_proto)
+ return google_dot_protobuf_dot_empty__pb2.Empty()
+
+ kwargs = unpack_proto_query_kwargs(request_proto.query_kwargs)
+ if "session_id" in kwargs:
+ session_id = kwargs["session_id"]
+ if session_id not in self.replica_sessions:
+ raise ValueError(f"session not found")
+ replica_index = self.replica_sessions[session_id]
+
+ ret = self.stubs[replica_index].invoke(method_name, request_proto)
+ return ret
+
+ return grpc.unary_unary_rpc_method_handler(
+ invoke_intercept_method,
+ request_deserializer=next_handler.request_deserializer,
+ response_serializer=next_handler.response_serializer,
+ )
+
+
+def _do_serve(service_impl, port, interceptors=[]):
+ stop_event = service_impl.get_stop_event()
+ server = grpc.server(
+ futures.ThreadPoolExecutor(max_workers=LB_MAX_WORKER_THREADS),
+ interceptors=interceptors,
+ options=[
+ ("grpc.max_send_message_length",
+ GRPC_MAX_MSG_SIZE),
+ ("grpc.max_receive_message_length",
+ GRPC_MAX_MSG_SIZE),
+ ],
+ )
+ modelresponse_pb2_grpc.add_ModelResponseServicer_to_server(service_impl, server)
+ server.add_insecure_port(f"[::]:{port}")
+ print(f"About to start server")
+ server.start()
+ print(f"Started")
+ stop_event.wait()
+ server.stop(SERVER_SHUTDOWN_TIMEOUT)
+
+
+def serve_inference(inference_pipeline, port):
+ _do_serve(ModelResponse(inference_pipeline), port)
+
+
+def serve_load_balancing(model_config, lb_port):
+ _do_serve(ServiceBase(), lb_port, [LoadBalancingInterceptor(model_config)])
+
+
+if __name__ == "__main__":
+ logging.basicConfig()
+ serve_inference(None, sys.argv[1])
diff --git a/mii/legacy/grpc_related/proto/__init__.py b/mii/legacy/grpc_related/proto/__init__.py
new file mode 100644
index 00000000..208299fb
--- /dev/null
+++ b/mii/legacy/grpc_related/proto/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
diff --git a/mii/legacy/grpc_related/proto/build_script.sh b/mii/legacy/grpc_related/proto/build_script.sh
new file mode 100644
index 00000000..0a709f23
--- /dev/null
+++ b/mii/legacy/grpc_related/proto/build_script.sh
@@ -0,0 +1,9 @@
+#!/bin/bash
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+python3 -m grpc_tools.protoc -I./ --python_out=. --grpc_python_out=. ./legacymodelresponse.proto
+
+# update import to be global wrt mii
+sed -i 's/legacymodelresponse_pb2/mii.legacy.grpc_related.proto.legacymodelresponse_pb2/g' legacymodelresponse_pb2_grpc.py
diff --git a/mii/legacy/grpc_related/proto/legacymodelresponse.proto b/mii/legacy/grpc_related/proto/legacymodelresponse.proto
new file mode 100644
index 00000000..8ad1611a
--- /dev/null
+++ b/mii/legacy/grpc_related/proto/legacymodelresponse.proto
@@ -0,0 +1,103 @@
+// Copyright 2015 gRPC authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+syntax = "proto3";
+
+/*option java_multiple_files = true;
+option java_package = "io.grpc.examples.helloworld";
+option java_outer_classname = "HelloWorldProto";
+option objc_class_prefix = "HLW";*/
+
+import "google/protobuf/empty.proto";
+
+package legacymodelresponse;
+
+service ModelResponse {
+ rpc Terminate (google.protobuf.Empty) returns (google.protobuf.Empty) {}
+ rpc CreateSession (SessionID) returns (google.protobuf.Empty) {}
+ rpc DestroySession (SessionID) returns (google.protobuf.Empty) {}
+ rpc GeneratorReply (MultiStringRequest) returns (MultiStringReply) {}
+ rpc ClassificationReply (SingleStringRequest) returns (SingleStringReply) {}
+ rpc QuestionAndAnswerReply(QARequest) returns (SingleStringReply) {}
+ rpc FillMaskReply(SingleStringRequest) returns (SingleStringReply) {}
+ rpc TokenClassificationReply(SingleStringRequest) returns (SingleStringReply) {}
+ rpc ConversationalReply(ConversationRequest) returns (ConversationReply) {}
+ rpc Txt2ImgReply(MultiStringRequest) returns (ImageReply) {}
+}
+
+message Value {
+ oneof oneof_values {
+ string svalue = 1;
+ int64 ivalue = 2;
+ float fvalue = 3;
+ bool bvalue = 4;
+ }
+}
+
+message SessionID {
+ string session_id = 1;
+}
+
+message SingleStringRequest {
+ string request = 1;
+ map query_kwargs = 2;
+}
+
+message MultiStringRequest {
+ repeated string request = 1;
+ map query_kwargs = 2;
+}
+
+message SingleStringReply {
+ string response = 1;
+ float time_taken = 2;
+ float model_time_taken = 3;
+}
+
+message MultiStringReply {
+ repeated string response = 1;
+ float time_taken = 2;
+ float model_time_taken = 3;
+}
+
+message QARequest {
+ string question = 1;
+ string context = 2;
+ map query_kwargs = 3;
+}
+
+message ConversationRequest {
+ string text = 1;
+ string conversation_id = 2;
+ repeated string past_user_inputs = 3;
+ repeated string generated_responses = 4;
+ map query_kwargs = 5;
+}
+
+message ConversationReply {
+ string conversation_id = 1;
+ repeated string past_user_inputs = 2;
+ repeated string generated_responses = 3;
+ float time_taken = 4;
+ float model_time_taken = 5;
+}
+
+message ImageReply {
+ repeated bytes images = 1;
+ repeated bool nsfw_content_detected = 2;
+ string mode = 3;
+ int64 size_w = 4;
+ int64 size_h = 5;
+ float time_taken = 6;
+}
diff --git a/mii/legacy/grpc_related/proto/legacymodelresponse_pb2.py b/mii/legacy/grpc_related/proto/legacymodelresponse_pb2.py
new file mode 100644
index 00000000..21989cb2
--- /dev/null
+++ b/mii/legacy/grpc_related/proto/legacymodelresponse_pb2.py
@@ -0,0 +1,65 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+# Generated by the protocol buffer compiler. DO NOT EDIT!
+# source: legacymodelresponse.proto
+"""Generated protocol buffer code."""
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import descriptor_pool as _descriptor_pool
+from google.protobuf import symbol_database as _symbol_database
+from google.protobuf.internal import builder as _builder
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2
+
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
+ b'\n\x19legacymodelresponse.proto\x12\x13legacymodelresponse\x1a\x1bgoogle/protobuf/empty.proto\"_\n\x05Value\x12\x10\n\x06svalue\x18\x01 \x01(\tH\x00\x12\x10\n\x06ivalue\x18\x02 \x01(\x03H\x00\x12\x10\n\x06\x66value\x18\x03 \x01(\x02H\x00\x12\x10\n\x06\x62value\x18\x04 \x01(\x08H\x00\x42\x0e\n\x0coneof_values\"\x1f\n\tSessionID\x12\x12\n\nsession_id\x18\x01 \x01(\t\"\xc7\x01\n\x13SingleStringRequest\x12\x0f\n\x07request\x18\x01 \x01(\t\x12O\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x39.legacymodelresponse.SingleStringRequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\"\xc5\x01\n\x12MultiStringRequest\x12\x0f\n\x07request\x18\x01 \x03(\t\x12N\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x38.legacymodelresponse.MultiStringRequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\"S\n\x11SingleStringReply\x12\x10\n\x08response\x18\x01 \x01(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"R\n\x10MultiStringReply\x12\x10\n\x08response\x18\x01 \x03(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"\xc5\x01\n\tQARequest\x12\x10\n\x08question\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontext\x18\x02 \x01(\t\x12\x45\n\x0cquery_kwargs\x18\x03 \x03(\x0b\x32/.legacymodelresponse.QARequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\"\x94\x02\n\x13\x43onversationRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x17\n\x0f\x63onversation_id\x18\x02 \x01(\t\x12\x18\n\x10past_user_inputs\x18\x03 \x03(\t\x12\x1b\n\x13generated_responses\x18\x04 \x03(\t\x12O\n\x0cquery_kwargs\x18\x05 \x03(\x0b\x32\x39.legacymodelresponse.ConversationRequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\"\x91\x01\n\x11\x43onversationReply\x12\x17\n\x0f\x63onversation_id\x18\x01 \x01(\t\x12\x18\n\x10past_user_inputs\x18\x02 \x03(\t\x12\x1b\n\x13generated_responses\x18\x03 \x03(\t\x12\x12\n\ntime_taken\x18\x04 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x05 \x01(\x02\"}\n\nImageReply\x12\x0e\n\x06images\x18\x01 \x03(\x0c\x12\x1d\n\x15nsfw_content_detected\x18\x02 \x03(\x08\x12\x0c\n\x04mode\x18\x03 \x01(\t\x12\x0e\n\x06size_w\x18\x04 \x01(\x03\x12\x0e\n\x06size_h\x18\x05 \x01(\x03\x12\x12\n\ntime_taken\x18\x06 \x01(\x02\x32\xb4\x07\n\rModelResponse\x12=\n\tTerminate\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12I\n\rCreateSession\x12\x1e.legacymodelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12J\n\x0e\x44\x65stroySession\x12\x1e.legacymodelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12\x62\n\x0eGeneratorReply\x12\'.legacymodelresponse.MultiStringRequest\x1a%.legacymodelresponse.MultiStringReply\"\x00\x12i\n\x13\x43lassificationReply\x12(.legacymodelresponse.SingleStringRequest\x1a&.legacymodelresponse.SingleStringReply\"\x00\x12\x62\n\x16QuestionAndAnswerReply\x12\x1e.legacymodelresponse.QARequest\x1a&.legacymodelresponse.SingleStringReply\"\x00\x12\x63\n\rFillMaskReply\x12(.legacymodelresponse.SingleStringRequest\x1a&.legacymodelresponse.SingleStringReply\"\x00\x12n\n\x18TokenClassificationReply\x12(.legacymodelresponse.SingleStringRequest\x1a&.legacymodelresponse.SingleStringReply\"\x00\x12i\n\x13\x43onversationalReply\x12(.legacymodelresponse.ConversationRequest\x1a&.legacymodelresponse.ConversationReply\"\x00\x12Z\n\x0cTxt2ImgReply\x12\'.legacymodelresponse.MultiStringRequest\x1a\x1f.legacymodelresponse.ImageReply\"\x00\x62\x06proto3'
+)
+
+_globals = globals()
+_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
+_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'legacymodelresponse_pb2', _globals)
+if _descriptor._USE_C_DESCRIPTORS == False:
+ DESCRIPTOR._options = None
+ _SINGLESTRINGREQUEST_QUERYKWARGSENTRY._options = None
+ _SINGLESTRINGREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001'
+ _MULTISTRINGREQUEST_QUERYKWARGSENTRY._options = None
+ _MULTISTRINGREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001'
+ _QAREQUEST_QUERYKWARGSENTRY._options = None
+ _QAREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001'
+ _CONVERSATIONREQUEST_QUERYKWARGSENTRY._options = None
+ _CONVERSATIONREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001'
+ _globals['_VALUE']._serialized_start = 79
+ _globals['_VALUE']._serialized_end = 174
+ _globals['_SESSIONID']._serialized_start = 176
+ _globals['_SESSIONID']._serialized_end = 207
+ _globals['_SINGLESTRINGREQUEST']._serialized_start = 210
+ _globals['_SINGLESTRINGREQUEST']._serialized_end = 409
+ _globals['_SINGLESTRINGREQUEST_QUERYKWARGSENTRY']._serialized_start = 331
+ _globals['_SINGLESTRINGREQUEST_QUERYKWARGSENTRY']._serialized_end = 409
+ _globals['_MULTISTRINGREQUEST']._serialized_start = 412
+ _globals['_MULTISTRINGREQUEST']._serialized_end = 609
+ _globals['_MULTISTRINGREQUEST_QUERYKWARGSENTRY']._serialized_start = 331
+ _globals['_MULTISTRINGREQUEST_QUERYKWARGSENTRY']._serialized_end = 409
+ _globals['_SINGLESTRINGREPLY']._serialized_start = 611
+ _globals['_SINGLESTRINGREPLY']._serialized_end = 694
+ _globals['_MULTISTRINGREPLY']._serialized_start = 696
+ _globals['_MULTISTRINGREPLY']._serialized_end = 778
+ _globals['_QAREQUEST']._serialized_start = 781
+ _globals['_QAREQUEST']._serialized_end = 978
+ _globals['_QAREQUEST_QUERYKWARGSENTRY']._serialized_start = 331
+ _globals['_QAREQUEST_QUERYKWARGSENTRY']._serialized_end = 409
+ _globals['_CONVERSATIONREQUEST']._serialized_start = 981
+ _globals['_CONVERSATIONREQUEST']._serialized_end = 1257
+ _globals['_CONVERSATIONREQUEST_QUERYKWARGSENTRY']._serialized_start = 331
+ _globals['_CONVERSATIONREQUEST_QUERYKWARGSENTRY']._serialized_end = 409
+ _globals['_CONVERSATIONREPLY']._serialized_start = 1260
+ _globals['_CONVERSATIONREPLY']._serialized_end = 1405
+ _globals['_IMAGEREPLY']._serialized_start = 1407
+ _globals['_IMAGEREPLY']._serialized_end = 1532
+ _globals['_MODELRESPONSE']._serialized_start = 1535
+ _globals['_MODELRESPONSE']._serialized_end = 2483
+# @@protoc_insertion_point(module_scope)
diff --git a/mii/legacy/grpc_related/proto/legacymodelresponse_pb2_grpc.py b/mii/legacy/grpc_related/proto/legacymodelresponse_pb2_grpc.py
new file mode 100644
index 00000000..2f364663
--- /dev/null
+++ b/mii/legacy/grpc_related/proto/legacymodelresponse_pb2_grpc.py
@@ -0,0 +1,482 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
+"""Client and server classes corresponding to protobuf-defined services."""
+import grpc
+
+from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2
+import mii.legacy.grpc_related.proto.legacymodelresponse_pb2 as legacymodelresponse__pb2
+
+
+class ModelResponseStub(object):
+ """Missing associated documentation comment in .proto file."""
+ def __init__(self, channel):
+ """Constructor.
+
+ Args:
+ channel: A grpc.Channel.
+ """
+ self.Terminate = channel.unary_unary(
+ '/legacymodelresponse.ModelResponse/Terminate',
+ request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.
+ SerializeToString,
+ response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
+ )
+ self.CreateSession = channel.unary_unary(
+ '/legacymodelresponse.ModelResponse/CreateSession',
+ request_serializer=legacymodelresponse__pb2.SessionID.SerializeToString,
+ response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
+ )
+ self.DestroySession = channel.unary_unary(
+ '/legacymodelresponse.ModelResponse/DestroySession',
+ request_serializer=legacymodelresponse__pb2.SessionID.SerializeToString,
+ response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
+ )
+ self.GeneratorReply = channel.unary_unary(
+ '/legacymodelresponse.ModelResponse/GeneratorReply',
+ request_serializer=legacymodelresponse__pb2.MultiStringRequest.
+ SerializeToString,
+ response_deserializer=legacymodelresponse__pb2.MultiStringReply.FromString,
+ )
+ self.ClassificationReply = channel.unary_unary(
+ '/legacymodelresponse.ModelResponse/ClassificationReply',
+ request_serializer=legacymodelresponse__pb2.SingleStringRequest.
+ SerializeToString,
+ response_deserializer=legacymodelresponse__pb2.SingleStringReply.FromString,
+ )
+ self.QuestionAndAnswerReply = channel.unary_unary(
+ '/legacymodelresponse.ModelResponse/QuestionAndAnswerReply',
+ request_serializer=legacymodelresponse__pb2.QARequest.SerializeToString,
+ response_deserializer=legacymodelresponse__pb2.SingleStringReply.FromString,
+ )
+ self.FillMaskReply = channel.unary_unary(
+ '/legacymodelresponse.ModelResponse/FillMaskReply',
+ request_serializer=legacymodelresponse__pb2.SingleStringRequest.
+ SerializeToString,
+ response_deserializer=legacymodelresponse__pb2.SingleStringReply.FromString,
+ )
+ self.TokenClassificationReply = channel.unary_unary(
+ '/legacymodelresponse.ModelResponse/TokenClassificationReply',
+ request_serializer=legacymodelresponse__pb2.SingleStringRequest.
+ SerializeToString,
+ response_deserializer=legacymodelresponse__pb2.SingleStringReply.FromString,
+ )
+ self.ConversationalReply = channel.unary_unary(
+ '/legacymodelresponse.ModelResponse/ConversationalReply',
+ request_serializer=legacymodelresponse__pb2.ConversationRequest.
+ SerializeToString,
+ response_deserializer=legacymodelresponse__pb2.ConversationReply.FromString,
+ )
+ self.Txt2ImgReply = channel.unary_unary(
+ '/legacymodelresponse.ModelResponse/Txt2ImgReply',
+ request_serializer=legacymodelresponse__pb2.MultiStringRequest.
+ SerializeToString,
+ response_deserializer=legacymodelresponse__pb2.ImageReply.FromString,
+ )
+
+
+class ModelResponseServicer(object):
+ """Missing associated documentation comment in .proto file."""
+ def Terminate(self, request, context):
+ """Missing associated documentation comment in .proto file."""
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+ def CreateSession(self, request, context):
+ """Missing associated documentation comment in .proto file."""
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+ def DestroySession(self, request, context):
+ """Missing associated documentation comment in .proto file."""
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+ def GeneratorReply(self, request, context):
+ """Missing associated documentation comment in .proto file."""
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+ def ClassificationReply(self, request, context):
+ """Missing associated documentation comment in .proto file."""
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+ def QuestionAndAnswerReply(self, request, context):
+ """Missing associated documentation comment in .proto file."""
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+ def FillMaskReply(self, request, context):
+ """Missing associated documentation comment in .proto file."""
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+ def TokenClassificationReply(self, request, context):
+ """Missing associated documentation comment in .proto file."""
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+ def ConversationalReply(self, request, context):
+ """Missing associated documentation comment in .proto file."""
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+ def Txt2ImgReply(self, request, context):
+ """Missing associated documentation comment in .proto file."""
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+
+def add_ModelResponseServicer_to_server(servicer, server):
+ rpc_method_handlers = {
+ 'Terminate':
+ grpc.unary_unary_rpc_method_handler(
+ servicer.Terminate,
+ request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
+ response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.
+ SerializeToString,
+ ),
+ 'CreateSession':
+ grpc.unary_unary_rpc_method_handler(
+ servicer.CreateSession,
+ request_deserializer=legacymodelresponse__pb2.SessionID.FromString,
+ response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.
+ SerializeToString,
+ ),
+ 'DestroySession':
+ grpc.unary_unary_rpc_method_handler(
+ servicer.DestroySession,
+ request_deserializer=legacymodelresponse__pb2.SessionID.FromString,
+ response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.
+ SerializeToString,
+ ),
+ 'GeneratorReply':
+ grpc.unary_unary_rpc_method_handler(
+ servicer.GeneratorReply,
+ request_deserializer=legacymodelresponse__pb2.MultiStringRequest.FromString,
+ response_serializer=legacymodelresponse__pb2.MultiStringReply.
+ SerializeToString,
+ ),
+ 'ClassificationReply':
+ grpc.unary_unary_rpc_method_handler(
+ servicer.ClassificationReply,
+ request_deserializer=legacymodelresponse__pb2.SingleStringRequest.FromString,
+ response_serializer=legacymodelresponse__pb2.SingleStringReply.
+ SerializeToString,
+ ),
+ 'QuestionAndAnswerReply':
+ grpc.unary_unary_rpc_method_handler(
+ servicer.QuestionAndAnswerReply,
+ request_deserializer=legacymodelresponse__pb2.QARequest.FromString,
+ response_serializer=legacymodelresponse__pb2.SingleStringReply.
+ SerializeToString,
+ ),
+ 'FillMaskReply':
+ grpc.unary_unary_rpc_method_handler(
+ servicer.FillMaskReply,
+ request_deserializer=legacymodelresponse__pb2.SingleStringRequest.FromString,
+ response_serializer=legacymodelresponse__pb2.SingleStringReply.
+ SerializeToString,
+ ),
+ 'TokenClassificationReply':
+ grpc.unary_unary_rpc_method_handler(
+ servicer.TokenClassificationReply,
+ request_deserializer=legacymodelresponse__pb2.SingleStringRequest.FromString,
+ response_serializer=legacymodelresponse__pb2.SingleStringReply.
+ SerializeToString,
+ ),
+ 'ConversationalReply':
+ grpc.unary_unary_rpc_method_handler(
+ servicer.ConversationalReply,
+ request_deserializer=legacymodelresponse__pb2.ConversationRequest.FromString,
+ response_serializer=legacymodelresponse__pb2.ConversationReply.
+ SerializeToString,
+ ),
+ 'Txt2ImgReply':
+ grpc.unary_unary_rpc_method_handler(
+ servicer.Txt2ImgReply,
+ request_deserializer=legacymodelresponse__pb2.MultiStringRequest.FromString,
+ response_serializer=legacymodelresponse__pb2.ImageReply.SerializeToString,
+ ),
+ }
+ generic_handler = grpc.method_handlers_generic_handler(
+ 'legacymodelresponse.ModelResponse',
+ rpc_method_handlers)
+ server.add_generic_rpc_handlers((generic_handler, ))
+
+
+# This class is part of an EXPERIMENTAL API.
+class ModelResponse(object):
+ """Missing associated documentation comment in .proto file."""
+ @staticmethod
+ def Terminate(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_unary(
+ request,
+ target,
+ '/legacymodelresponse.ModelResponse/Terminate',
+ google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
+ google_dot_protobuf_dot_empty__pb2.Empty.FromString,
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata)
+
+ @staticmethod
+ def CreateSession(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_unary(
+ request,
+ target,
+ '/legacymodelresponse.ModelResponse/CreateSession',
+ legacymodelresponse__pb2.SessionID.SerializeToString,
+ google_dot_protobuf_dot_empty__pb2.Empty.FromString,
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata)
+
+ @staticmethod
+ def DestroySession(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_unary(
+ request,
+ target,
+ '/legacymodelresponse.ModelResponse/DestroySession',
+ legacymodelresponse__pb2.SessionID.SerializeToString,
+ google_dot_protobuf_dot_empty__pb2.Empty.FromString,
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata)
+
+ @staticmethod
+ def GeneratorReply(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_unary(
+ request,
+ target,
+ '/legacymodelresponse.ModelResponse/GeneratorReply',
+ legacymodelresponse__pb2.MultiStringRequest.SerializeToString,
+ legacymodelresponse__pb2.MultiStringReply.FromString,
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata)
+
+ @staticmethod
+ def ClassificationReply(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_unary(
+ request,
+ target,
+ '/legacymodelresponse.ModelResponse/ClassificationReply',
+ legacymodelresponse__pb2.SingleStringRequest.SerializeToString,
+ legacymodelresponse__pb2.SingleStringReply.FromString,
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata)
+
+ @staticmethod
+ def QuestionAndAnswerReply(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_unary(
+ request,
+ target,
+ '/legacymodelresponse.ModelResponse/QuestionAndAnswerReply',
+ legacymodelresponse__pb2.QARequest.SerializeToString,
+ legacymodelresponse__pb2.SingleStringReply.FromString,
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata)
+
+ @staticmethod
+ def FillMaskReply(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_unary(
+ request,
+ target,
+ '/legacymodelresponse.ModelResponse/FillMaskReply',
+ legacymodelresponse__pb2.SingleStringRequest.SerializeToString,
+ legacymodelresponse__pb2.SingleStringReply.FromString,
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata)
+
+ @staticmethod
+ def TokenClassificationReply(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_unary(
+ request,
+ target,
+ '/legacymodelresponse.ModelResponse/TokenClassificationReply',
+ legacymodelresponse__pb2.SingleStringRequest.SerializeToString,
+ legacymodelresponse__pb2.SingleStringReply.FromString,
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata)
+
+ @staticmethod
+ def ConversationalReply(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_unary(
+ request,
+ target,
+ '/legacymodelresponse.ModelResponse/ConversationalReply',
+ legacymodelresponse__pb2.ConversationRequest.SerializeToString,
+ legacymodelresponse__pb2.ConversationReply.FromString,
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata)
+
+ @staticmethod
+ def Txt2ImgReply(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_unary(
+ request,
+ target,
+ '/legacymodelresponse.ModelResponse/Txt2ImgReply',
+ legacymodelresponse__pb2.MultiStringRequest.SerializeToString,
+ legacymodelresponse__pb2.ImageReply.FromString,
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata)
diff --git a/mii/legacy/grpc_related/restful_gateway.py b/mii/legacy/grpc_related/restful_gateway.py
new file mode 100644
index 00000000..ca577e83
--- /dev/null
+++ b/mii/legacy/grpc_related/restful_gateway.py
@@ -0,0 +1,65 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+import time
+import threading
+import mii.legacy as mii
+from flask import Flask, request
+from flask_restful import Resource, Api
+from werkzeug.serving import make_server
+from mii.legacy.constants import RESTFUL_GATEWAY_SHUTDOWN_TIMEOUT, RESTFUL_API_PATH
+from google.protobuf.json_format import MessageToJson
+
+
+def shutdown(thread):
+ time.sleep(RESTFUL_GATEWAY_SHUTDOWN_TIMEOUT)
+ thread.server.shutdown()
+
+
+def createRestfulGatewayApp(deployment_name, task, lb_port, server_thread):
+ # client must be thread-safe
+ client = mii.MIIClient(task, "localhost", lb_port)
+
+ class RestfulGatewayService(Resource):
+ def __init__(self):
+ super().__init__()
+
+ def post(self):
+ data = request.get_json()
+ kwargs = data["kwargs"] if "kwargs" in data else {}
+ result = client.query(data["request"], **kwargs)
+ return MessageToJson(result)
+
+ app = Flask("RestfulGateway")
+
+ @app.route("/terminate", methods=["GET"])
+ def terminate():
+ # Need to shutdown *after* completing the request
+ threading.Thread(target=shutdown, args=(server_thread, )).start()
+ return "Shutting down RESTful API gateway server"
+
+ api = Api(app)
+ path = "/{}/{}".format(RESTFUL_API_PATH, deployment_name)
+ api.add_resource(RestfulGatewayService, path)
+
+ return app
+
+
+class RestfulGatewayThread(threading.Thread):
+ def __init__(self, deployment_name, task, lb_port, rest_port):
+ threading.Thread.__init__(self)
+
+ app = createRestfulGatewayApp(deployment_name, task, lb_port, self)
+ self.server = make_server("127.0.0.1", rest_port, app)
+ self.ctx = app.app_context()
+ self.ctx.push()
+
+ self._stop_event = threading.Event()
+
+ def run(self):
+ self.server.serve_forever()
+ self._stop_event.set()
+
+ def get_stop_event(self):
+ return self._stop_event
diff --git a/mii/legacy/launch/__init__.py b/mii/legacy/launch/__init__.py
new file mode 100644
index 00000000..208299fb
--- /dev/null
+++ b/mii/legacy/launch/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
diff --git a/mii/legacy/launch/multi_gpu_server.py b/mii/legacy/launch/multi_gpu_server.py
new file mode 100644
index 00000000..a9ee7498
--- /dev/null
+++ b/mii/legacy/launch/multi_gpu_server.py
@@ -0,0 +1,97 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+import os
+import argparse
+import base64
+import json
+
+from mii.legacy.config import ModelConfig
+from mii.legacy.models.load_models import load_models
+from mii.legacy.grpc_related.modelresponse_server import serve_inference, serve_load_balancing
+from mii.legacy.grpc_related.restful_gateway import RestfulGatewayThread
+
+
+def b64_encoded_config(config_str):
+ # str -> bytes
+ b64_bytes = config_str.encode()
+ # decode b64 bytes -> json bytes
+ config_bytes = base64.urlsafe_b64decode(b64_bytes)
+ # convert json bytes -> str -> dict
+ config_dict = json.loads(config_bytes.decode())
+ # return mii.ModelConfig object
+ return ModelConfig(**config_dict)
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--deployment-name", type=str, help="Name of deployment")
+ parser.add_argument(
+ "--model-config",
+ type=b64_encoded_config,
+ help="base64 encoded model config",
+ )
+ parser.add_argument(
+ "--server-port",
+ type=int,
+ default=0,
+ help="Port to user for DeepSpeed inference server.",
+ )
+ parser.add_argument("--load-balancer",
+ action="store_true",
+ help="Launch load balancer process.")
+ parser.add_argument(
+ "--load-balancer-port",
+ type=int,
+ default=0,
+ help="Port to use for load balancer.",
+ )
+ parser.add_argument(
+ "--restful-gateway",
+ action="store_true",
+ help="Launches restful gateway process.",
+ )
+ parser.add_argument(
+ "--restful-gateway-port",
+ type=int,
+ default=0,
+ help="Port to use for restful gateway.",
+ )
+ args = parser.parse_args()
+ assert not (
+ args.load_balancer and args.restful_gateway
+ ), "Select only load-balancer OR restful-gateway."
+
+ if args.restful_gateway:
+ assert args.restful_gateway_port, "--restful-gateway-port must be provided."
+ print(f"Starting RESTful API gateway on port: {args.restful_gateway_port}")
+ gateway_thread = RestfulGatewayThread(
+ deployment_name=args.deployment_name,
+ task=args.model_config.task,
+ lb_port=args.load_balancer_port,
+ rest_port=args.restful_gateway_port,
+ )
+ stop_event = gateway_thread.get_stop_event()
+ gateway_thread.start()
+ stop_event.wait()
+
+ elif args.load_balancer:
+ assert args.load_balancer_port, "--load-balancer-port must be provided."
+ print(f"Starting load balancer on port: {args.load_balancer_port}")
+ serve_load_balancing(args.model_config, args.load_balancer_port)
+
+ else:
+ assert args.server_port, "--server-port must be provided."
+ local_rank = int(os.getenv("LOCAL_RANK", "0"))
+ port = args.server_port + local_rank
+
+ inference_pipeline = load_models(args.model_config)
+
+ print(f"Starting server on port: {port}")
+ serve_inference(inference_pipeline, port)
+
+
+if __name__ == "__main__":
+ # python -m mii.launch.multi_gpu_server
+ main()
diff --git a/mii/legacy/logging.py b/mii/legacy/logging.py
new file mode 100644
index 00000000..1fcf2ac9
--- /dev/null
+++ b/mii/legacy/logging.py
@@ -0,0 +1,45 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+import sys
+import logging
+
+log_levels = {
+ "debug": logging.DEBUG,
+ "info": logging.INFO,
+ "warning": logging.WARNING,
+ "error": logging.ERROR,
+ "critical": logging.CRITICAL,
+}
+
+
+class LoggerFactory:
+ @staticmethod
+ def create_logger(name=None, level=logging.INFO):
+ """create a logger
+ Args:
+ name (str): name of the logger
+ level: level of logger
+ Raises:
+ ValueError is name is None
+ """
+
+ if name is None:
+ raise ValueError("name for logger cannot be None")
+
+ formatter = logging.Formatter(
+ "[%(asctime)s] [%(levelname)s] "
+ "[%(filename)s:%(lineno)d:%(funcName)s] %(message)s")
+
+ logger_ = logging.getLogger(name)
+ logger_.setLevel(level)
+ logger_.propagate = False
+ ch = logging.StreamHandler(stream=sys.stdout)
+ ch.setLevel(level)
+ ch.setFormatter(formatter)
+ logger_.addHandler(ch)
+ return logger_
+
+
+logger = LoggerFactory.create_logger(name="MII", level=logging.INFO)
diff --git a/mii/method_table.py b/mii/legacy/method_table.py
similarity index 97%
rename from mii/method_table.py
rename to mii/legacy/method_table.py
index 6b95da91..e9bb9b4f 100644
--- a/mii/method_table.py
+++ b/mii/legacy/method_table.py
@@ -6,10 +6,10 @@
from abc import ABC, abstractmethod
from transformers import Conversation
-from mii.constants import TaskType
-from mii.grpc_related.proto import modelresponse_pb2
-from mii.utils import kwarg_dict_to_proto, unpack_proto_query_kwargs
-from mii.models.utils import ImageResponse
+from mii.legacy.constants import TaskType
+from mii.legacy.grpc_related.proto import legacymodelresponse_pb2 as modelresponse_pb2
+from mii.legacy.utils import kwarg_dict_to_proto, unpack_proto_query_kwargs
+from mii.legacy.models.utils import ImageResponse
def single_string_request_to_proto(self, request_dict, **query_kwargs):
diff --git a/mii/models/__init__.py b/mii/legacy/models/__init__.py
similarity index 100%
rename from mii/models/__init__.py
rename to mii/legacy/models/__init__.py
diff --git a/mii/models/load_models.py b/mii/legacy/models/load_models.py
similarity index 95%
rename from mii/models/load_models.py
rename to mii/legacy/models/load_models.py
index 97510bf1..cfbf455f 100644
--- a/mii/models/load_models.py
+++ b/mii/legacy/models/load_models.py
@@ -3,7 +3,7 @@
# DeepSpeed Team
import os
-import mii
+import mii.legacy as mii
import torch
import inspect
import deepspeed
@@ -33,7 +33,7 @@ def load_models(model_config):
provider = model_config.provider
if provider == mii.constants.ModelProvider.HUGGING_FACE:
- from mii.models.providers.huggingface import hf_provider
+ from mii.legacy.models.providers.huggingface import hf_provider
inference_pipeline = hf_provider(model_config)
if model_config.meta_tensor:
@@ -60,7 +60,7 @@ def load_models(model_config):
inf_config["config"] = inference_pipeline.neox_args
"""
elif provider == mii.constants.ModelProvider.DIFFUSERS:
- from mii.models.providers.diffusers import diffusers_provider
+ from mii.legacy.models.providers.diffusers import diffusers_provider
inference_pipeline = diffusers_provider(model_config)
else:
raise ValueError(f"Unknown model provider {provider}")
diff --git a/mii/legacy/models/providers/__init__.py b/mii/legacy/models/providers/__init__.py
new file mode 100644
index 00000000..208299fb
--- /dev/null
+++ b/mii/legacy/models/providers/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
diff --git a/mii/models/providers/diffusers.py b/mii/legacy/models/providers/diffusers.py
similarity index 100%
rename from mii/models/providers/diffusers.py
rename to mii/legacy/models/providers/diffusers.py
diff --git a/mii/legacy/models/providers/eleutherai.py b/mii/legacy/models/providers/eleutherai.py
new file mode 100644
index 00000000..208299fb
--- /dev/null
+++ b/mii/legacy/models/providers/eleutherai.py
@@ -0,0 +1,4 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
diff --git a/mii/models/providers/huggingface.py b/mii/legacy/models/providers/huggingface.py
similarity index 98%
rename from mii/models/providers/huggingface.py
rename to mii/legacy/models/providers/huggingface.py
index 9c798d48..291ddeaa 100644
--- a/mii/models/providers/huggingface.py
+++ b/mii/legacy/models/providers/huggingface.py
@@ -12,7 +12,7 @@
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, AutoConfig
from huggingface_hub import snapshot_download
-from mii.utils import mii_cache_path, is_aml
+from mii.legacy.utils import mii_cache_path, is_aml
class MetaTensorPipeline(object):
diff --git a/mii/models/score/__init__.py b/mii/legacy/models/score/__init__.py
similarity index 100%
rename from mii/models/score/__init__.py
rename to mii/legacy/models/score/__init__.py
diff --git a/mii/models/score/generate.py b/mii/legacy/models/score/generate.py
similarity index 92%
rename from mii/models/score/generate.py
rename to mii/legacy/models/score/generate.py
index b1787d65..05aabe57 100644
--- a/mii/models/score/generate.py
+++ b/mii/legacy/models/score/generate.py
@@ -3,10 +3,10 @@
# DeepSpeed Team
import os
-import mii
+import mii.legacy as mii
import pprint
-from mii.logging import logger
-from mii.constants import DeploymentType
+from mii.legacy.logging import logger
+from mii.legacy.constants import DeploymentType
def create_score_file(mii_config):
diff --git a/mii/models/score/score_template.py b/mii/legacy/models/score/score_template.py
similarity index 98%
rename from mii/models/score/score_template.py
rename to mii/legacy/models/score/score_template.py
index 6b10ef23..7d2ad2aa 100644
--- a/mii/models/score/score_template.py
+++ b/mii/legacy/models/score/score_template.py
@@ -9,7 +9,7 @@
import time
import torch
-import mii
+import mii.legacy as mii
model = None
diff --git a/mii/models/utils.py b/mii/legacy/models/utils.py
similarity index 97%
rename from mii/models/utils.py
rename to mii/legacy/models/utils.py
index d44b2871..5298745e 100644
--- a/mii/models/utils.py
+++ b/mii/legacy/models/utils.py
@@ -3,7 +3,7 @@
# DeepSpeed Team
import os
-from mii.utils import mii_cache_path
+from mii.legacy.utils import mii_cache_path
def supported_models_from_huggingface():
diff --git a/mii/legacy/pydantic_v1.py b/mii/legacy/pydantic_v1.py
new file mode 100644
index 00000000..6aba072a
--- /dev/null
+++ b/mii/legacy/pydantic_v1.py
@@ -0,0 +1,16 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+"""Pydantic v1 compatibility module.
+
+Pydantic v2 introduced breaking changes that hinder its adoption:
+https://docs.pydantic.dev/latest/migration/. To provide deepspeed users the option to
+migrate to pydantic v2 on their own timeline, deepspeed uses this compatibility module
+as a pydantic-version-agnostic alias for pydantic's v1 API.
+"""
+
+try:
+ from pydantic.v1 import * # noqa: F401
+except ImportError:
+ from pydantic import * # noqa: F401
diff --git a/mii/legacy/server.py b/mii/legacy/server.py
new file mode 100644
index 00000000..8a66f3ec
--- /dev/null
+++ b/mii/legacy/server.py
@@ -0,0 +1,169 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+import base64
+import os
+import subprocess
+import sys
+import tempfile
+import time
+from collections import defaultdict
+from deepspeed.accelerator import get_accelerator
+
+from mii.legacy.utils import get_num_gpus
+from mii.legacy.logging import logger
+
+
+def config_to_b64_str(config):
+ # convert json str -> bytes
+ json_bytes = config.json().encode()
+ # base64 encoded bytes
+ b64_config_bytes = base64.urlsafe_b64encode(json_bytes)
+ # bytes -> str
+ return b64_config_bytes.decode()
+
+
+class MIIServer:
+ """Initialize the model, setup the server for the model under model_path"""
+ def __init__(self, mii_config):
+
+ self.task = mii_config.model_config.task
+ self.num_gpus = get_num_gpus(mii_config)
+ assert self.num_gpus > 0, "GPU count must be greater than 0"
+
+ self.port_number = mii_config.port_number
+
+ if not os.path.isfile(mii_config.hostfile):
+ logger.info(f"Hostfile {mii_config.hostfile} not found, creating hostfile.")
+ num_gpu = get_accelerator().device_count()
+ with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_file:
+ temp_file.write(f"localhost slots={num_gpu}")
+ mii_config.hostfile = temp_file.name
+
+ mii_config.generate_replica_configs()
+
+ processes = self._initialize_service(mii_config)
+ self._wait_until_server_is_live(processes,
+ mii_config.model_config.replica_configs)
+
+ def _wait_until_server_is_live(self, processes, deployment):
+ for process, repl_config in zip(processes, deployment):
+ sockets_open = False
+ while not sockets_open:
+ sockets_open = all(
+ self._is_socket_open(repl_config.hostname,
+ port)
+ for port in repl_config.tensor_parallel_ports)
+ process_alive = self._is_server_process_alive(process)
+ if not process_alive:
+ raise RuntimeError(
+ "server crashed for some reason, unable to proceed")
+ time.sleep(4)
+ logger.info("waiting for server to start...")
+ logger.info(
+ f"server has started on ports {repl_config.tensor_parallel_ports}")
+
+ def _is_socket_open(self, host, port):
+ import socket
+
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ result = sock.connect_ex((host, port))
+ sock.close()
+ return result == 0
+
+ def _is_server_process_alive(self, process):
+ if process is None:
+ return True
+ try:
+ process.wait(1)
+ except subprocess.TimeoutExpired as err:
+ # timeout means we're still running and all (probably) okay
+ is_alive = True
+ else:
+ # no exception case
+ is_alive = False
+ return is_alive
+
+ def _launch_server_process(self,
+ model_config,
+ msg_server_type,
+ ds_launch_str="",
+ server_args=None):
+ launch_str = f"{sys.executable} -m mii.legacy.launch.multi_gpu_server"
+ b64_config_str = config_to_b64_str(model_config)
+ if server_args is None:
+ server_args = []
+ server_args.append(f"--model-config {b64_config_str}")
+ server_args_str = " ".join(server_args)
+ cmd = f"{ds_launch_str} {launch_str} {server_args_str}".strip().split(" ")
+
+ mii_env = os.environ.copy()
+ mii_env["TRANSFORMERS_CACHE"] = model_config.model_path
+ logger.info(f"{msg_server_type} server launch: {cmd}")
+ return subprocess.Popen(cmd, env=mii_env)
+
+ def _generate_ds_launch_str(self, replica_config, hostfile):
+ # use different hostfiles for replica instances
+ # pass /dev/null when no replica is used
+ worker_str = f"-H {hostfile} "
+ # pin deepspeed launch to specific gpu id(s)
+ included_gpus = f"{replica_config.hostname}:{','.join(map(str, replica_config.gpu_indices))}"
+ worker_str += f"-i {included_gpus} "
+
+ # adjust torch dist port depending on rank, otherwise multi-replica deployments will conflict
+ # assign different ports to replicas because they could be on the same host
+ worker_str += f"--master_port {replica_config.torch_dist_port}"
+
+ ds_launch_str = f"deepspeed {worker_str} --master_addr localhost --no_ssh_check --no_local_rank --no_python"
+
+ return ds_launch_str
+
+ def _initialize_service(self, mii_config):
+ processes = []
+ server_args = [
+ f"--deployment-name {mii_config.deployment_name}",
+ f"--load-balancer-port {mii_config.port_number}",
+ f"--restful-gateway-port {mii_config.restful_api_port}",
+ ]
+
+ host_gpus = defaultdict(list)
+ for repl_config in mii_config.model_config.replica_configs:
+ host_gpus[repl_config.hostname].extend(repl_config.gpu_indices)
+
+ # Start replica instances
+ for repl_config in mii_config.model_config.replica_configs:
+ hostfile = tempfile.NamedTemporaryFile(delete=False)
+ hostfile.write(
+ f"{repl_config.hostname} slots={max(host_gpus[repl_config.hostname])+1}\n"
+ .encode())
+ ds_launch_str = self._generate_ds_launch_str(repl_config, hostfile.name)
+ processes.append(
+ self._launch_server_process(
+ mii_config.model_config,
+ "MII server",
+ ds_launch_str=ds_launch_str,
+ server_args=server_args +
+ [f"--server-port {repl_config.tensor_parallel_ports[0]}"],
+ ))
+ # start load balancer here. We don't use deepspeed launcher for the
+ # load balancer because it does not need a GPU. The deepspeed
+ # launcher determines the number of processes to launch based on
+ # GPUs available on the host or CUDA_VISIBLE_DEVICES, and it is
+ # expected to assign one GPU to one process.
+ processes.append(
+ self._launch_server_process(
+ mii_config.model_config,
+ "load balancer",
+ server_args=server_args + ["--load-balancer"],
+ ))
+
+ if mii_config.enable_restful_api:
+ processes.append(
+ self._launch_server_process(
+ mii_config.model_config,
+ "restful api gateway",
+ server_args=server_args + ["--restful-gateway"],
+ ))
+
+ return processes
diff --git a/mii/terminate.py b/mii/legacy/terminate.py
similarity index 92%
rename from mii/terminate.py
rename to mii/legacy/terminate.py
index ae1e38f2..03d9680f 100644
--- a/mii/terminate.py
+++ b/mii/legacy/terminate.py
@@ -4,8 +4,8 @@
# DeepSpeed Team
import grpc
-import mii
-from mii.logging import logger
+import mii.legacy as mii
+from mii.legacy.logging import logger
def terminate(deployment_name):
diff --git a/mii/legacy/utils.py b/mii/legacy/utils.py
new file mode 100644
index 00000000..64b7f16c
--- /dev/null
+++ b/mii/legacy/utils.py
@@ -0,0 +1,197 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+import os
+import pickle
+import time
+import importlib
+import torch
+import mii.legacy as mii
+from types import SimpleNamespace
+from huggingface_hub import HfApi
+
+from mii.legacy.models.score.generate import generated_score_path
+from mii.legacy.constants import (
+ MII_CACHE_PATH,
+ MII_CACHE_PATH_DEFAULT,
+ ModelProvider,
+ SUPPORTED_MODEL_TYPES,
+ REQUIRED_KEYS_PER_TASK,
+ MII_HF_CACHE_EXPIRATION,
+ MII_HF_CACHE_EXPIRATION_DEFAULT,
+)
+
+from mii.legacy.config import TaskType
+
+
+def _get_hf_models_by_type(model_type=None, task=None):
+ cache_file_path = os.path.join(mii_cache_path(), "HF_model_cache.pkl")
+ cache_expiration_seconds = os.getenv(MII_HF_CACHE_EXPIRATION,
+ MII_HF_CACHE_EXPIRATION_DEFAULT)
+
+ # Load or initialize the cache
+ model_data = {"cache_time": 0, "model_list": []}
+ if os.path.isfile(cache_file_path):
+ with open(cache_file_path, 'rb') as f:
+ model_data = pickle.load(f)
+
+ current_time = time.time()
+
+ # Update the cache if it has expired
+ if (model_data["cache_time"] + cache_expiration_seconds) < current_time:
+ api = HfApi()
+ model_data["model_list"] = [
+ SimpleNamespace(modelId=m.modelId,
+ pipeline_tag=m.pipeline_tag,
+ tags=m.tags) for m in api.list_models()
+ ]
+ model_data["cache_time"] = current_time
+
+ # Save the updated cache
+ with open(cache_file_path, 'wb') as f:
+ pickle.dump(model_data, f)
+
+ # Filter the model list
+ models = model_data["model_list"]
+ if model_type is not None:
+ models = [m for m in models if model_type in m.tags]
+ if task is not None:
+ models = [m for m in models if m.pipeline_tag == task]
+
+ # Extract model IDs
+ model_ids = [m.modelId for m in models]
+
+ if task == TaskType.TEXT_GENERATION:
+ # TODO: this is a temp solution to get around some HF models not having the correct tags
+ model_ids.extend([
+ "microsoft/bloom-deepspeed-inference-fp16",
+ "microsoft/bloom-deepspeed-inference-int8",
+ "EleutherAI/gpt-neox-20b"
+ ])
+
+ return model_ids
+
+
+def get_supported_models(task):
+ supported_models = []
+
+ for model_type, provider in SUPPORTED_MODEL_TYPES.items():
+ if provider == ModelProvider.HUGGING_FACE:
+ models = _get_hf_models_by_type(model_type, task)
+ elif provider == ModelProvider.ELEUTHER_AI:
+ if task == TaskType.TEXT_GENERATION:
+ models = [model_type]
+ elif provider == ModelProvider.DIFFUSERS:
+ models = _get_hf_models_by_type(model_type, task)
+ supported_models.extend(models)
+ if not supported_models:
+ raise ValueError(f"Task {task} not supported")
+
+ return supported_models
+
+
+def check_if_task_and_model_is_supported(task, model_name):
+ supported_models = get_supported_models(task)
+ assert (
+ model_name in supported_models
+ ), f"{task} is not supported by {model_name}. This task is supported by {len(supported_models)} other models. See which models with `mii.get_supported_models(mii.{task})`."
+
+
+def check_if_task_and_model_is_valid(task, model_name):
+ valid_task_models = _get_hf_models_by_type(None, task)
+ assert (
+ model_name in valid_task_models
+ ), f"{task} is not supported by {model_name}. This task is supported by {len(valid_task_models)} other models. See which models with `mii.get_supported_models(mii.{task})`."
+
+
+def full_model_path(model_path):
+ aml_model_dir = os.environ.get('AZUREML_MODEL_DIR', None)
+ if aml_model_dir:
+ # (potentially) append relative model_path w. aml path
+ assert os.path.isabs(aml_model_dir), f"AZUREML_MODEL_DIR={aml_model_dir} must be an absolute path"
+ if model_path:
+ assert not os.path.isabs(model_path), f"model_path={model_path} must be relative to append w. AML path"
+ return os.path.join(aml_model_dir, model_path)
+ else:
+ return aml_model_dir
+ elif model_path:
+ return model_path
+ else:
+ return mii.constants.MII_MODEL_PATH_DEFAULT
+
+
+def is_aml():
+ return os.getenv("AZUREML_MODEL_DIR") is not None
+
+
+def mii_cache_path():
+ cache_path = os.environ.get(MII_CACHE_PATH, MII_CACHE_PATH_DEFAULT)
+ if not os.path.isdir(cache_path):
+ os.makedirs(cache_path)
+ return cache_path
+
+
+def import_score_file(deployment_name, deployment_type):
+ score_path = generated_score_path(deployment_name, deployment_type)
+ spec = importlib.util.spec_from_file_location("score", score_path)
+ score = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(score)
+ return score
+
+
+dtype_proto_field = {
+ str: "svalue",
+ int: "ivalue",
+ float: "fvalue",
+ bool: "bvalue",
+}
+
+
+def kwarg_dict_to_proto(kwarg_dict):
+ def get_proto_value(value):
+ proto_value = mii.grpc_related.proto.legacymodelresponse_pb2.Value()
+ setattr(proto_value, dtype_proto_field[type(value)], value)
+ return proto_value
+
+ return {k: get_proto_value(v) for k, v in kwarg_dict.items()}
+
+
+def unpack_proto_query_kwargs(query_kwargs):
+ query_kwargs = {
+ k: getattr(v,
+ v.WhichOneof("oneof_values"))
+ for k,
+ v in query_kwargs.items()
+ }
+ return query_kwargs
+
+
+def extract_query_dict(task, request_dict):
+ required_keys = REQUIRED_KEYS_PER_TASK[task]
+ query_dict = {}
+ for key in required_keys:
+ value = request_dict.pop(key, None)
+ if value is None:
+ raise ValueError("Request for task: {task} is missing required key: {key}.")
+ query_dict[key] = value
+ return query_dict
+
+
+def get_num_gpus(mii_config):
+ num_gpus = mii_config.model_config.tensor_parallel
+
+ assert (
+ torch.cuda.device_count() >= num_gpus
+ ), f"Available GPU count: {torch.cuda.device_count()} does not meet the required gpu count: {num_gpus}"
+ return num_gpus
+
+
+def get_provider(model_name, task):
+ if model_name == "gpt-neox":
+ provider = ModelProvider.ELEUTHER_AI
+ elif task == TaskType.TEXT2IMG:
+ provider = ModelProvider.DIFFUSERS
+ else:
+ provider = ModelProvider.HUGGING_FACE
+ return provider
diff --git a/mii/logging.py b/mii/logging.py
index 1fcf2ac9..e46654a2 100644
--- a/mii/logging.py
+++ b/mii/logging.py
@@ -2,8 +2,8 @@
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
-import sys
import logging
+import sys
log_levels = {
"debug": logging.DEBUG,
diff --git a/mii/models.py b/mii/models.py
new file mode 100644
index 00000000..7dbb5d73
--- /dev/null
+++ b/mii/models.py
@@ -0,0 +1,23 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+from deepspeed.inference import build_hf_engine, InferenceEngineV2
+
+from mii.config import ModelConfig
+from mii.constants import ModelProvider
+from mii.utils import init_distributed
+
+
+def load_model(model_config: ModelConfig) -> InferenceEngineV2:
+ init_distributed(model_config)
+ provider = model_config.provider
+ if provider == ModelProvider.HUGGING_FACE:
+ inference_engine = build_hf_engine(
+ path=model_config.model_name_or_path,
+ engine_config=model_config.inference_engine_config)
+ else:
+ raise ValueError(f"Unknown model provider {provider}")
+
+ return inference_engine
diff --git a/mii/pipeline.py b/mii/pipeline.py
new file mode 100644
index 00000000..3db0511f
--- /dev/null
+++ b/mii/pipeline.py
@@ -0,0 +1,46 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+from typing import Optional, Any, Dict
+
+from mii.batching import MIIPipeline, MIIAsyncPipeline
+from mii.config import ModelConfig
+from mii.models import load_model
+from mii.tokenizers import load_tokenizer
+
+
+def pipeline(model_name_or_path: str = "",
+ model_config: Optional[Dict[str,
+ Any]] = None,
+ **kwargs) -> MIIPipeline:
+ if model_config is None:
+ model_config = {}
+ if model_name_or_path:
+ if "model_name_or_path" in model_config:
+ assert model_config.get("model_name_or_path") == model_name_or_path, "model_name_or_path in model_config must match model_name_or_path"
+ model_config["model_name_or_path"] = model_name_or_path
+ for key, val in kwargs.items():
+ if key in ModelConfig.__dict__["__fields__"]:
+ if key in model_config:
+ assert model_config.get(key) == val, f"{key} in model_config must match {key}"
+ model_config[key] = val
+ else:
+ raise ValueError(f"Invalid keyword argument {key}")
+ model_config = ModelConfig(**model_config)
+
+ inference_engine = load_model(model_config)
+ tokenizer = load_tokenizer(model_config)
+ inference_pipeline = MIIPipeline(inference_engine=inference_engine,
+ tokenizer=tokenizer,
+ model_config=model_config)
+ return inference_pipeline
+
+
+def async_pipeline(model_config: ModelConfig) -> MIIAsyncPipeline:
+ inference_engine = load_model(model_config)
+ tokenizer = load_tokenizer(model_config)
+ inference_pipeline = MIIAsyncPipeline(inference_engine=inference_engine,
+ tokenizer=tokenizer,
+ model_config=model_config)
+ return inference_pipeline
diff --git a/mii/score/__init__.py b/mii/score/__init__.py
new file mode 100644
index 00000000..0b533d8f
--- /dev/null
+++ b/mii/score/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+from .generate import create_score_file, generated_score_path
diff --git a/mii/score/generate.py b/mii/score/generate.py
new file mode 100644
index 00000000..a34a96c6
--- /dev/null
+++ b/mii/score/generate.py
@@ -0,0 +1,42 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+import os
+import mii
+import pprint
+from mii.logging import logger
+from mii.constants import DeploymentType
+
+
+def create_score_file(mii_config):
+ if len(mii.__path__) > 1:
+ logger.warning(
+ f"Detected mii path as multiple sources: {mii.__path__}, might cause unknown behavior"
+ )
+
+ with open(os.path.join(mii.__path__[0], "score/score_template.py"), "r") as fd:
+ score_src = fd.read()
+
+ # update score file w. global config dict
+ config_dict = mii_config.dict()
+ source_with_config = f"{score_src}\n"
+ source_with_config += f"mii_config = {pprint.pformat(config_dict, indent=4)}"
+
+ with open(
+ generated_score_path(mii_config.deployment_name,
+ mii_config.deployment_type),
+ "w") as fd:
+ fd.write(source_with_config)
+ fd.write("\n")
+
+
+def generated_score_path(deployment_name, deployment_type):
+ if deployment_type == DeploymentType.LOCAL:
+ score_path = os.path.join(mii.utils.mii_cache_path(), deployment_name)
+ elif deployment_type == DeploymentType.AML:
+ score_path = os.path.join(mii.aml_related.utils.aml_output_path(deployment_name),
+ "code")
+ if not os.path.isdir(score_path):
+ os.makedirs(score_path)
+ return os.path.join(score_path, "score.py")
diff --git a/mii/score/score_template.py b/mii/score/score_template.py
new file mode 100644
index 00000000..a68d20d8
--- /dev/null
+++ b/mii/score/score_template.py
@@ -0,0 +1,61 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+# flake8: noqa
+import os
+import json
+import time
+import torch
+
+import mii
+
+model = None
+
+
+def init():
+ global mii_config
+ mii_config = mii.config.MIIConfig(**mii_config)
+
+ # For AML deployments, we stand up multiple nginx server workers, one for
+ # each replica. This is so that we can properly run multiple requests in
+ # parallel on the different replicas. However, each worker will run this
+ # generated score.py and try to stand up an entire MII deployment
+ # (load-balancer, replicas, etc.). We want only one worker to spawn the
+ # load-balancer and replicas. We take advantage of the nginx worker PIDs
+ # being consecutive to achieve that here.
+ start_server = True
+ if mii.utils.is_aml() and (int(os.getpid()) % mii_config.replica_num != 0):
+ start_server = False
+
+ if start_server:
+ mii.server.MIIServer(mii_config)
+
+ global model
+ model = None
+
+ # In AML deployments both the GRPC client and server are used in the same process
+ if mii.utils.is_aml():
+ model = mii.client.MIIClient(mii_config=mii_config)
+
+
+def run(request):
+ global mii_config, model
+ assert (
+ model is not None
+ ), "grpc client has not been setup when this model was created"
+
+ request_dict = json.loads(request)
+
+ query_dict = mii.utils.extract_query_dict(mii_config.task, request_dict)
+
+ response = model.query(query_dict, **request_dict)
+
+ time_taken = response.time_taken
+ if not isinstance(response.response, str):
+ response = [r for r in response.response]
+ return json.dumps({"responses": response, "time": time_taken})
+
+
+### Auto-generated config will be appended below at run-time
diff --git a/mii/server.py b/mii/server.py
index 2a4f7a8e..dfd376bf 100644
--- a/mii/server.py
+++ b/mii/server.py
@@ -9,13 +9,75 @@
import tempfile
import time
from collections import defaultdict
+from typing import Optional, Any, Dict, Union, List
+
from deepspeed.accelerator import get_accelerator
+from deepspeed.runtime.config_utils import DeepSpeedConfigModel
-from mii.utils import get_num_gpus
+import mii
+from mii.client import MIIClient
+from mii.config import ModelConfig, MIIConfig, ReplicaConfig
+from mii.constants import DeploymentType
from mii.logging import logger
-
-
-def config_to_b64_str(config):
+from mii.score import create_score_file
+from mii.utils import import_score_file
+
+
+def serve(model_name_or_path: str = "",
+ model_config: Optional[Dict[str,
+ Any]] = None,
+ mii_config: Optional[Dict[str,
+ Any]] = None,
+ **kwargs) -> Union[None,
+ MIIClient]:
+ if model_config is None:
+ model_config = {}
+ if mii_config is None:
+ mii_config = {}
+ if model_name_or_path:
+ if "model_name_or_path" in model_config:
+ assert model_config.get("model_name_or_path") == model_name_or_path, "model_name_or_path in model_config must match model_name_or_path"
+ model_config["model_name_or_path"] = model_name_or_path
+ for key, val in kwargs.items():
+ if key in ModelConfig.__dict__["__fields__"]:
+ if key in model_config:
+ assert model_config.get(key) == val, f"{key} in model_config must match {key}"
+ model_config[key] = val
+ elif key in MIIConfig.__dict__["__fields__"]:
+ if key in mii_config:
+ assert mii_config.get(key) == val, f"{key} in mii_config must match {key}"
+ mii_config[key] = val
+ else:
+ raise ValueError(f"Invalid keyword argument {key}")
+ if "model_config" in mii_config:
+ assert mii_config.get("model_config") == model_config, "model_config in mii_config must match model_config"
+ mii_config["model_config"] = model_config
+ mii_config = MIIConfig(**mii_config)
+
+ #MIIServer(mii_config)
+ create_score_file(mii_config)
+
+ if mii_config.deployment_type == DeploymentType.LOCAL:
+ import_score_file(mii_config.deployment_name, DeploymentType.LOCAL).init()
+ return MIIClient(mii_config=mii_config)
+ if mii_config.deployment_type == DeploymentType.AML:
+ acr_name = mii.aml_related.utils.get_acr_name()
+ mii.aml_related.utils.generate_aml_scripts(
+ acr_name=acr_name,
+ deployment_name=mii_config.deployment_name,
+ model_name=mii_config.model_config.model,
+ task_name=mii_config.model_config.task,
+ replica_num=mii_config.model_config.replica_num,
+ instance_type=mii_config.instance_type,
+ version=mii_config.version,
+ )
+ print(
+ f"AML deployment assets at {mii.aml_related.utils.aml_output_path(mii_config.deployment_name)}"
+ )
+ print("Please run 'deploy.sh' to bring your deployment online")
+
+
+def config_to_b64_str(config: DeepSpeedConfigModel) -> str:
# convert json str -> bytes
json_bytes = config.json().encode()
# base64 encoded bytes
@@ -25,13 +87,10 @@ def config_to_b64_str(config):
class MIIServer:
- """Initialize the model, setup the server for the model under model_path"""
- def __init__(self, mii_config):
+ """Initialize the model, setup the server for the model"""
+ def __init__(self, mii_config: MIIConfig) -> None:
self.task = mii_config.model_config.task
- self.num_gpus = get_num_gpus(mii_config)
- assert self.num_gpus > 0, "GPU count must be greater than 0"
-
self.port_number = mii_config.port_number
if not os.path.isfile(mii_config.hostfile):
@@ -47,7 +106,9 @@ def __init__(self, mii_config):
self._wait_until_server_is_live(processes,
mii_config.model_config.replica_configs)
- def _wait_until_server_is_live(self, processes, deployment):
+ def _wait_until_server_is_live(self,
+ processes: List[subprocess.Popen],
+ deployment: List[ReplicaConfig]):
for process, repl_config in zip(processes, deployment):
sockets_open = False
while not sockets_open:
@@ -61,10 +122,14 @@ def _wait_until_server_is_live(self, processes, deployment):
"server crashed for some reason, unable to proceed")
time.sleep(4)
logger.info("waiting for server to start...")
+ # TODO: Fix displaying outputs from logger
+ # When we launch processes on multiple nodes using " --force_multi",
+ # all the outputs from logger to stdout is displayed when the process is stopped.
+ # This is confusing because you see the message "server has started ..." when you stop the process.
logger.info(
f"server has started on ports {repl_config.tensor_parallel_ports}")
- def _is_socket_open(self, host, port):
+ def _is_socket_open(self, host: str, port: int) -> bool:
import socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
@@ -72,7 +137,7 @@ def _is_socket_open(self, host, port):
sock.close()
return result == 0
- def _is_server_process_alive(self, process):
+ def _is_server_process_alive(self, process: subprocess.Popen) -> bool:
if process is None:
return True
try:
@@ -86,10 +151,10 @@ def _is_server_process_alive(self, process):
return is_alive
def _launch_server_process(self,
- model_config,
- msg_server_type,
- ds_launch_str="",
- server_args=None):
+ model_config: ModelConfig,
+ msg_server_type: str,
+ ds_launch_str: str = "",
+ server_args: List[str] = None) -> subprocess.Popen:
launch_str = f"{sys.executable} -m mii.launch.multi_gpu_server"
b64_config_str = config_to_b64_str(model_config)
if server_args is None:
@@ -98,15 +163,17 @@ def _launch_server_process(self,
server_args_str = " ".join(server_args)
cmd = f"{ds_launch_str} {launch_str} {server_args_str}".strip().split(" ")
- mii_env = os.environ.copy()
- mii_env["TRANSFORMERS_CACHE"] = model_config.model_path
- logger.info(f"{msg_server_type} server launch: {cmd}")
- return subprocess.Popen(cmd, env=mii_env)
+ logger.info(f"msg_server launch: {cmd}")
+ return subprocess.Popen(cmd)
- def _generate_ds_launch_str(self, replica_config, hostfile):
+ def _generate_ds_launch_str(self,
+ replica_config: ReplicaConfig,
+ hostfile: str,
+ use_multiple_hosts) -> str:
# use different hostfiles for replica instances
# pass /dev/null when no replica is used
- worker_str = f"-H {hostfile} "
+ #worker_str = f"-H {hostfile} "
+ worker_str = ""
# pin deepspeed launch to specific gpu id(s)
included_gpus = f"{replica_config.hostname}:{','.join(map(str, replica_config.gpu_indices))}"
worker_str += f"-i {included_gpus} "
@@ -116,10 +183,12 @@ def _generate_ds_launch_str(self, replica_config, hostfile):
worker_str += f"--master_port {replica_config.torch_dist_port}"
ds_launch_str = f"deepspeed {worker_str} --master_addr localhost --no_ssh_check --no_local_rank --no_python"
+ if use_multiple_hosts:
+ ds_launch_str += f" --force_multi"
return ds_launch_str
- def _initialize_service(self, mii_config):
+ def _initialize_service(self, mii_config: MIIConfig) -> List[subprocess.Popen]:
processes = []
server_args = [
f"--deployment-name {mii_config.deployment_name}",
@@ -131,26 +200,33 @@ def _initialize_service(self, mii_config):
for repl_config in mii_config.model_config.replica_configs:
host_gpus[repl_config.hostname].extend(repl_config.gpu_indices)
+ use_multiple_hosts = len(
+ set(repl_config.hostname
+ for repl_config in mii_config.model_config.replica_configs)) > 1
+
# Start replica instances
for repl_config in mii_config.model_config.replica_configs:
hostfile = tempfile.NamedTemporaryFile(delete=False)
hostfile.write(
f"{repl_config.hostname} slots={max(host_gpus[repl_config.hostname])+1}\n"
.encode())
- ds_launch_str = self._generate_ds_launch_str(repl_config, hostfile.name)
+ ds_launch_str = self._generate_ds_launch_str(repl_config,
+ hostfile.name,
+ use_multiple_hosts)
processes.append(
self._launch_server_process(
mii_config.model_config,
"MII server",
ds_launch_str=ds_launch_str,
- server_args=server_args +
- [f"--server-port {repl_config.tensor_parallel_ports[0]}"],
+ server_args=server_args + [
+ f"--server-port {repl_config.tensor_parallel_ports[0]} --zmq-port {repl_config.zmq_port}"
+ ],
))
- # start load balancer here. We don't use deepspeed launcher for the
- # load balancer because it does not need a GPU. The deepspeed
- # launcher determines the number of processes to launch based on
- # GPUs available on the host or CUDA_VISIBLE_DEVICES, and it is
- # expected to assign one GPU to one process.
+ # start load balancer here. We don't use deepspeed launcher for the
+ # load balancer because it does not need a GPU. The deepspeed
+ # launcher determines the number of processes to launch based on
+ # GPUs available on the host or CUDA_VISIBLE_DEVICES, and it is
+ # expected to assign one GPU to one process.
processes.append(
self._launch_server_process(
mii_config.model_config,
diff --git a/mii/task_methods.py b/mii/task_methods.py
new file mode 100644
index 00000000..718284b9
--- /dev/null
+++ b/mii/task_methods.py
@@ -0,0 +1,114 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+from abc import ABC, abstractmethod
+
+from mii.constants import TaskType
+from mii.grpc_related.proto import modelresponse_pb2
+from mii.utils import kwarg_dict_to_proto, unpack_proto_query_kwargs
+
+
+def single_string_request_to_proto(self, request_dict, **query_kwargs):
+ return modelresponse_pb2.SingleStringRequest(
+ request=request_dict["query"],
+ query_kwargs=kwarg_dict_to_proto(query_kwargs))
+
+
+def single_string_response_to_proto(self, response, time_taken, model_time_taken):
+ return modelresponse_pb2.SingleStringReply(response=f"{response}",
+ time_taken=time_taken,
+ model_time_taken=model_time_taken)
+
+
+def multi_string_request_to_proto(self, request_dict, **query_kwargs):
+ return modelresponse_pb2.MultiStringRequest(
+ request=request_dict["query"] if isinstance(request_dict["query"],
+ list) else [request_dict["query"]],
+ query_kwargs=kwarg_dict_to_proto(query_kwargs),
+ )
+
+
+def proto_request_to_single_input(self, request):
+ args = (request.request, )
+ kwargs = unpack_proto_query_kwargs(request.query_kwargs)
+ return args, kwargs
+
+
+def proto_request_to_list(self, request):
+ args = ([r for r in request.request], )
+ kwargs = unpack_proto_query_kwargs(request.query_kwargs)
+ return args, kwargs
+
+
+class TaskMethods(ABC):
+ @property
+ @abstractmethod
+ def method(self):
+ ...
+
+ def pack_request_to_proto(self, request_dict, **query_kwargs):
+ return request_dict, query_kwargs
+
+ def unpack_request_from_proto(self, request):
+ return request
+
+ def pack_response_to_proto(self, response, time_taken, model_time_taken):
+ return response, time_taken, model_time_taken
+
+ def unpack_response_from_proto(self, response):
+ return response
+
+
+class TextGenerationMethods(TaskMethods):
+ session_context = {}
+
+ @property
+ def method(self):
+ return "GeneratorReply"
+
+ @property
+ def method_stream_out(self):
+ return "GeneratorReplyStream"
+
+ pack_request_to_proto = multi_string_request_to_proto
+ unpack_request_from_proto = proto_request_to_list
+
+ def create_session(self, session_id):
+ if session_id in self.session_context:
+ raise ValueError(f"session {session_id} already exists")
+ self.session_context[session_id] = None
+
+ def destroy_session(self, session_id):
+ if session_id not in self.session_context:
+ raise ValueError(f"session {session_id} does not exist")
+ del self.session_context[session_id]
+
+ def pack_response_to_proto(self, responses, time_taken, model_time_taken):
+ text_responses = []
+ details = []
+
+ # Response a nested list of dicts
+ # [Sample, 1, Dict]
+ for response in responses:
+ text = response.generated_text
+ text_responses.append(text)
+ details.append(
+ modelresponse_pb2.GenerationDetails(
+ finish_reason=str(response.finish_reason),
+ prompt_tokens=response.prompt_length,
+ generated_tokens=response.generated_length))
+
+ return modelresponse_pb2.GenerationReply(
+ response=text_responses,
+ indices=[0],
+ details=details,
+ time_taken=time_taken,
+ model_time_taken=model_time_taken,
+ )
+
+
+TASK_METHODS_DICT = {
+ TaskType.TEXT_GENERATION: TextGenerationMethods(),
+}
diff --git a/mii/tokenizers.py b/mii/tokenizers.py
new file mode 100644
index 00000000..527caec2
--- /dev/null
+++ b/mii/tokenizers.py
@@ -0,0 +1,70 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+from abc import ABC, abstractmethod
+from typing import TYPE_CHECKING, Union
+
+import torch
+from transformers import AutoTokenizer
+
+from mii.constants import ModelProvider
+
+if TYPE_CHECKING:
+ from mii.config import ModelConfig
+
+
+class MIITokenizerWrapper(ABC):
+ def __init__(self, tokenizer: object) -> None:
+ self.tokenizer = tokenizer
+
+ @property
+ @abstractmethod
+ def vocab_size(self) -> int:
+ ...
+
+ @property
+ @abstractmethod
+ def eos_token_id(self) -> int:
+ ...
+
+ @abstractmethod
+ def encode(self, input: str) -> torch.Tensor:
+ ...
+
+ @abstractmethod
+ def decode(self, tokens: torch.Tensor) -> str:
+ ...
+
+
+class HFTokenizer(MIITokenizerWrapper):
+ def __init__(self, tokenizer: Union[str, object]) -> None:
+ if isinstance(tokenizer, str):
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer)
+ tokenizer.pad_token = tokenizer.eos_token
+ super().__init__(tokenizer)
+
+ @property
+ def vocab_size(self) -> int:
+ return self.tokenizer.vocab_size
+
+ @property
+ def eos_token_id(self) -> int:
+ return self.tokenizer.eos_token_id
+
+ def encode(self, input: str) -> torch.Tensor:
+ return self.tokenizer.encode(input, return_tensors="pt").flatten()
+
+ def decode(self, tokens: torch.Tensor) -> str:
+ return self.tokenizer.decode(tokens)
+
+
+def load_tokenizer(model_config: "ModelConfig") -> MIITokenizerWrapper:
+ provider = model_config.provider
+ if provider == ModelProvider.HUGGING_FACE:
+ tokenizer = HFTokenizer(model_config.tokenizer)
+ else:
+ raise ValueError(f"Unknown model provider {provider}")
+
+ return tokenizer
diff --git a/mii/utils.py b/mii/utils.py
index cc6db265..3958b09b 100644
--- a/mii/utils.py
+++ b/mii/utils.py
@@ -2,31 +2,42 @@
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
-import os
+import importlib
import pickle
+import os
import time
-import importlib
-import torch
-import mii
-from types import SimpleNamespace
+import deepspeed
+
+from dataclasses import dataclass
+from typing import List, TYPE_CHECKING
+from datetime import timedelta
from huggingface_hub import HfApi
+from transformers import AutoConfig
-from mii.models.score.generate import generated_score_path
+import mii
from mii.constants import (
MII_CACHE_PATH,
MII_CACHE_PATH_DEFAULT,
- ModelProvider,
- SUPPORTED_MODEL_TYPES,
REQUIRED_KEYS_PER_TASK,
MII_HF_CACHE_EXPIRATION,
MII_HF_CACHE_EXPIRATION_DEFAULT,
)
+from mii.logging import logger
+from mii.score.generate import generated_score_path
+
+if TYPE_CHECKING:
+ from mii.config import ModelConfig
+
-from mii.config import TaskType
+@dataclass
+class ModelInfo:
+ modelId: str
+ pipeline_tag: str
+ tags: List[str]
-def _get_hf_models_by_type(model_type=None, task=None):
- cache_file_path = os.path.join(mii_cache_path(), "HF_model_cache.pkl")
+def _hf_model_list() -> List[ModelInfo]:
+ cache_file_path = os.path.join(mii_cache_path(), "MII_model_cache.pkl")
cache_expiration_seconds = os.getenv(MII_HF_CACHE_EXPIRATION,
MII_HF_CACHE_EXPIRATION_DEFAULT)
@@ -42,9 +53,9 @@ def _get_hf_models_by_type(model_type=None, task=None):
if (model_data["cache_time"] + cache_expiration_seconds) < current_time:
api = HfApi()
model_data["model_list"] = [
- SimpleNamespace(modelId=m.modelId,
- pipeline_tag=m.pipeline_tag,
- tags=m.tags) for m in api.list_models()
+ ModelInfo(modelId=m.modelId,
+ pipeline_tag=m.pipeline_tag,
+ tags=m.tags) for m in api.list_models()
]
model_data["cache_time"] = current_time
@@ -52,76 +63,37 @@ def _get_hf_models_by_type(model_type=None, task=None):
with open(cache_file_path, 'wb') as f:
pickle.dump(model_data, f)
- # Filter the model list
- models = model_data["model_list"]
- if model_type is not None:
- models = [m for m in models if model_type in m.tags]
- if task is not None:
- models = [m for m in models if m.pipeline_tag == task]
-
- # Extract model IDs
- model_ids = [m.modelId for m in models]
-
- if task == TaskType.TEXT_GENERATION:
- # TODO: this is a temp solution to get around some HF models not having the correct tags
- model_ids.extend([
- "microsoft/bloom-deepspeed-inference-fp16",
- "microsoft/bloom-deepspeed-inference-int8",
- "EleutherAI/gpt-neox-20b"
- ])
-
- return model_ids
-
-
-def get_supported_models(task):
- supported_models = []
-
- for model_type, provider in SUPPORTED_MODEL_TYPES.items():
- if provider == ModelProvider.HUGGING_FACE:
- models = _get_hf_models_by_type(model_type, task)
- elif provider == ModelProvider.ELEUTHER_AI:
- if task == TaskType.TEXT_GENERATION:
- models = [model_type]
- elif provider == ModelProvider.DIFFUSERS:
- models = _get_hf_models_by_type(model_type, task)
- supported_models.extend(models)
- if not supported_models:
- raise ValueError(f"Task {task} not supported")
-
- return supported_models
-
-
-def check_if_task_and_model_is_supported(task, model_name):
- supported_models = get_supported_models(task)
- assert (
- model_name in supported_models
- ), f"{task} is not supported by {model_name}. This task is supported by {len(supported_models)} other models. See which models with `mii.get_supported_models(mii.{task})`."
-
-
-def check_if_task_and_model_is_valid(task, model_name):
- valid_task_models = _get_hf_models_by_type(None, task)
- assert (
- model_name in valid_task_models
- ), f"{task} is not supported by {model_name}. This task is supported by {len(valid_task_models)} other models. See which models with `mii.get_supported_models(mii.{task})`."
-
-
-def full_model_path(model_path):
- aml_model_dir = os.environ.get('AZUREML_MODEL_DIR', None)
- if aml_model_dir:
- # (potentially) append relative model_path w. aml path
- assert os.path.isabs(aml_model_dir), f"AZUREML_MODEL_DIR={aml_model_dir} must be an absolute path"
- if model_path:
- assert not os.path.isabs(model_path), f"model_path={model_path} must be relative to append w. AML path"
- return os.path.join(aml_model_dir, model_path)
- else:
- return aml_model_dir
- elif model_path:
- return model_path
+ return model_data["model_list"]
+
+
+def get_default_task(model_name_or_path: str) -> str:
+ model_name = get_model_name(model_name_or_path)
+ models = _hf_model_list()
+ for m in models:
+ if m.modelId == model_name:
+ task = m.pipeline_tag
+ logger.info(f"Detected default task as '{task}' for model '{model_name}'")
+ return task
+ else:
+ raise ValueError(f"Model {model_name} not found")
+
+
+def get_model_name(model_name_or_path: str) -> str:
+ model_name = None
+ if os.path.exists(model_name_or_path):
+ try:
+ model_name = AutoConfig.from_pretrained(model_name_or_path)._name_or_path
+ except:
+ model_name = os.path.basename(model_name_or_path)
+ logger.warning(
+ f"Could not deduce model name from {model_name_or_path}. Trying with {model_name=} instead."
+ )
else:
- return mii.constants.MII_MODEL_PATH_DEFAULT
+ model_name = model_name_or_path
+ return model_name
-def is_aml():
+def is_aml() -> bool:
return os.getenv("AZUREML_MODEL_DIR") is not None
@@ -145,26 +117,41 @@ def import_score_file(deployment_name, deployment_type):
int: "ivalue",
float: "fvalue",
bool: "bvalue",
+ dict: "mvalue",
}
def kwarg_dict_to_proto(kwarg_dict):
def get_proto_value(value):
proto_value = mii.grpc_related.proto.modelresponse_pb2.Value()
- setattr(proto_value, dtype_proto_field[type(value)], value)
+
+ if isinstance(value, dict):
+ nested_dict = mii.grpc_related.proto.modelresponse_pb2.Dictionary()
+ for k, v in value.items():
+ nested_dict.values[k].CopyFrom(get_proto_value(v))
+ proto_value.mvalue.CopyFrom(nested_dict)
+ else:
+ setattr(proto_value, dtype_proto_field[type(value)], value)
+
return proto_value
return {k: get_proto_value(v) for k, v in kwarg_dict.items()}
def unpack_proto_query_kwargs(query_kwargs):
- query_kwargs = {
- k: getattr(v,
- v.WhichOneof("oneof_values"))
- for k,
- v in query_kwargs.items()
- }
- return query_kwargs
+ def extract_proto_value(proto_value):
+ field_name = proto_value.WhichOneof("oneof_values")
+
+ if field_name == "mvalue":
+ return {
+ k: extract_proto_value(v)
+ for k,
+ v in proto_value.mvalue.values.items()
+ }
+ else:
+ return getattr(proto_value, field_name)
+
+ return {k: extract_proto_value(v) for k, v in query_kwargs.items()}
def extract_query_dict(task, request_dict):
@@ -178,20 +165,23 @@ def extract_query_dict(task, request_dict):
return query_dict
-def get_num_gpus(mii_config):
- num_gpus = mii_config.model_config.tensor_parallel
-
- assert (
- torch.cuda.device_count() >= num_gpus
- ), f"Available GPU count: {torch.cuda.device_count()} does not meet the required gpu count: {num_gpus}"
- return num_gpus
-
-
-def get_provider(model_name, task):
- if model_name == "gpt-neox":
- provider = ModelProvider.ELEUTHER_AI
- elif task == TaskType.TEXT2IMG:
- provider = ModelProvider.DIFFUSERS
+def generate_deployment_name(model_name_or_path: str):
+ if os.path.exists(model_name_or_path):
+ model_name = os.path.basename(model_name_or_path)
else:
- provider = ModelProvider.HUGGING_FACE
- return provider
+ model_name = model_name_or_path
+ return f"{model_name}-mii-deployment"
+
+
+def init_distributed(model_config: "ModelConfig"):
+ # If not running with a distributed launcher (e.g., deepspeed, torch) set some default environment variables
+ required_env = ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
+ if not all([e in os.environ for e in required_env]):
+ assert model_config.tensor_parallel == 1, "Attempting to run with TP > 1 and not using a distributed launcher like deepspeed or torch.distributed"
+ os.environ["RANK"] = "0"
+ os.environ["LOCAL_RANK"] = "0"
+ os.environ["WORLD_SIZE"] = "1"
+ os.environ["MASTER_ADDR"] = "localhost"
+ os.environ["MASTER_PORT"] = str(model_config.torch_dist_port)
+
+ deepspeed.init_distributed(dist_backend="nccl", timeout=timedelta(seconds=1e9))
diff --git a/requirements/requirements.txt b/requirements/requirements.txt
index 243d5aee..81473546 100644
--- a/requirements/requirements.txt
+++ b/requirements/requirements.txt
@@ -1,9 +1,12 @@
asyncio
-deepspeed>=0.7.6
+deepspeed>=0.12.0
+deepspeed-kernels
Flask-RESTful
grpcio
grpcio-tools
pydantic
torch
transformers
+ujson
Werkzeug
+zmq
diff --git a/setup.py b/setup.py
index cf8af9ab..415e6df7 100644
--- a/setup.py
+++ b/setup.py
@@ -75,7 +75,7 @@ def command_exists(cmd):
thisdir = os.path.abspath(os.path.dirname(__file__))
with open(os.path.join(thisdir, 'README.md'), encoding='utf-8') as fin:
readme_text = fin.read()
-
+print("PACKAGES", find_packages())
setup(name="deepspeed-mii",
version=version_str,
long_description=readme_text,
@@ -90,7 +90,8 @@ def command_exists(cmd):
},
install_requires=install_requires,
extras_require=extras_require,
- packages=find_packages(),
+ packages=find_packages(exclude=("tests",
+ )),
classifiers=[
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
diff --git a/tests/conftest.py b/tests/conftest.py
index 5af93d9d..a332f638 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -5,7 +5,7 @@
import pytest
import os
-import mii
+import mii.legacy as mii
from types import SimpleNamespace
diff --git a/tests/test_config.py b/tests/test_config.py
index 09d2b58e..bc2ca1fd 100644
--- a/tests/test_config.py
+++ b/tests/test_config.py
@@ -5,8 +5,8 @@
import pytest
-import mii
-from mii import pydantic_v1
+import mii.legacy as mii
+from mii.legacy import pydantic_v1
@pytest.mark.parametrize("port_number", [12345])
diff --git a/tests/test_deployment_options.py b/tests/test_deployment_options.py
index f0566ed7..e60ebcd7 100644
--- a/tests/test_deployment_options.py
+++ b/tests/test_deployment_options.py
@@ -6,8 +6,8 @@
import pytest
import json
import requests
-import mii
-from mii import pydantic_v1
+import mii.legacy as mii
+from mii.legacy import pydantic_v1
@pytest.mark.deepspeed
diff --git a/tests/test_local_deployment.py b/tests/test_local_deployment.py
index 5a575680..0dea3209 100644
--- a/tests/test_local_deployment.py
+++ b/tests/test_local_deployment.py
@@ -3,7 +3,7 @@
# DeepSpeed Team
import pytest
-import mii
+import mii.legacy as mii
@pytest.mark.parametrize(
diff --git a/tests/test_non_persistent_deployment.py b/tests/test_non_persistent_deployment.py
index 71234201..b2e1f1e3 100644
--- a/tests/test_non_persistent_deployment.py
+++ b/tests/test_non_persistent_deployment.py
@@ -4,7 +4,7 @@
# DeepSpeed Team
import pytest
-import mii
+import mii.legacy as mii
@pytest.mark.parametrize("deployment_type", [mii.DeploymentType.NON_PERSISTENT])
diff --git a/version.txt b/version.txt
index c5d54ec3..6e8bf73a 100644
--- a/version.txt
+++ b/version.txt
@@ -1 +1 @@
-0.0.9
+0.1.0