Skip to content

Commit

Permalink
simplify py-tgi
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Feb 16, 2024
1 parent c8b999e commit d75eac9
Show file tree
Hide file tree
Showing 9 changed files with 272 additions and 280 deletions.
10 changes: 5 additions & 5 deletions .github/workflows/cpu_workflow.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,17 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
uses: actions/checkout@v3

- name: Set up Python 3.8
uses: actions/setup-python@v2
- name: Set up Python 3.10
uses: actions/setup-python@v3
with:
python-version: 3.8
python-version: '3.10'

- name: Install requirements
run: |
pip install --upgrade pip
pip install -e .
- name: Run sanity check
run: python sanity_test.py
run: python tests/sanity.py
47 changes: 8 additions & 39 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
# py-tgi

This repo constains two wrappers:

- `TGIServer`: a python wrapper around HuggingFace's TGI (text-generation-inference) using `docker-py`.
- `BatchedInferenceClient`: a python wrapper around HuggingFace's `InferenceClient` using threading to simulate batched inference.

Practical for running/managing TGI servers and benchmarking against other inference servers.
This presents a Python wrapper for the using `docker-py` to manage a TGI server and a client for single and batched inference.

## Installation

Expand All @@ -18,49 +13,23 @@ python -m pip install git+https://github.com/IlyasMoutawwakil/py-tgi.git
Running a TGI server with a batched inference client:

```python
from py_tgi import TGIServer, BatchedInferenceClient
# from logging import basicConfig, INFO
# basicConfig(level=INFO)
from py_tgi import TGI

tgi_server = TGIServer(model="gpt2", sharded=False)
llm = TGI(model="TheBloke/Mistral-7B-Instruct-v0.1-AWQ", quantization="awq")

try:
client = BatchedInferenceClient(url=tgi_server.url)
output = client.generate(["Hi, I'm an example 1", "Hi, I'm an example 2"])
output = llm.generate(["Hi, I'm an example 1", "Hi, I'm an example 2"])
print("Output:", output)
except Exception as e:
print(e)
finally:
tgi_server.close()
llm.close()
```

Output:

```bash
INFO:tgi-server: + Starting Docker client
INFO:tgi-server: + Checking if TGI image exists
INFO:tgi-server: + Building TGI command
INFO:tgi-server: + Checking if GPU is available
INFO:tgi-server: + Using GPU(s) from nvidia-smi
INFO:tgi-server: + Using GPU(s): 0,1,2,4
INFO:tgi-server: + Waiting for TGI server to be ready
INFO:tgi-server: 2024-01-15T08:47:15.882960Z INFO text_generation_launcher: Args { model_id: "gpt2", revision: Some("main"), validation_workers: 2, sharded: Some(false), num_shard: None, quantize: None, speculate: None, dtype: None, trust_remote_code: false, max_concurrent_requests: 128, max_best_of: 2, max_stop_sequences: 4, max_top_n_tokens: 5, max_input_length: 1024, max_total_tokens: 2048, waiting_served_ratio: 1.2, max_batch_prefill_tokens: 4096, max_batch_total_tokens: None, max_waiting_tokens: 20, hostname: "ec83247f21ab", port: 80, shard_uds_path: "/tmp/text-generation-server", master_addr: "localhost", master_port: 29500, huggingface_hub_cache: Some("/data"), weights_cache_override: None, disable_custom_kernels: false, cuda_memory_fraction: 1.0, rope_scaling: None, rope_factor: None, json_output: false, otlp_endpoint: None, cors_allow_origin: [], watermark_gamma: None, watermark_delta: None, ngrok: false, ngrok_authtoken: None, ngrok_edge: None, env: false }
INFO:tgi-server: 2024-01-15T08:47:15.883089Z INFO download: text_generation_launcher: Starting download process.
INFO:tgi-server: 2024-01-15T08:47:19.764449Z INFO text_generation_launcher: Files are already present on the host. Skipping download.
INFO:tgi-server: 2024-01-15T08:47:20.387759Z INFO download: text_generation_launcher: Successfully downloaded weights.
INFO:tgi-server: 2024-01-15T08:47:20.388064Z INFO shard-manager: text_generation_launcher: Starting shard rank=0
INFO:tgi-server: 2024-01-15T08:47:26.062519Z INFO text_generation_launcher: Server started at unix:///tmp/text-generation-server-0
INFO:tgi-server: 2024-01-15T08:47:26.095249Z INFO shard-manager: text_generation_launcher: Shard ready in 5.70626412s rank=0
INFO:tgi-server: 2024-01-15T08:47:26.193466Z INFO text_generation_launcher: Starting Webserver
INFO:tgi-server: 2024-01-15T08:47:26.204835Z INFO hf_hub: /usr/local/cargo/registry/src/index.crates.io-6f17d22bba15001f/hf-hub-0.3.2/src/lib.rs:55: Token file not found "/root/.cache/huggingface/token"
INFO:tgi-server: 2024-01-15T08:47:26.536395Z INFO text_generation_router: router/src/main.rs:368: Serving revision 11c5a3d5811f50298f278a704980280950aedb10 of model gpt2
INFO:tgi-server: 2024-01-15T08:47:26.593914Z INFO text_generation_router: router/src/main.rs:230: Warming up model
INFO:tgi-server: 2024-01-15T08:47:27.545238Z WARN text_generation_router: router/src/main.rs:244: Model does not support automatic max batch total tokens
INFO:tgi-server: 2024-01-15T08:47:27.545255Z INFO text_generation_router: router/src/main.rs:266: Setting max batch total tokens to 16000
INFO:tgi-server: + Couldn't connect to TGI server
INFO:tgi-server: + Retrying in 0.1s
INFO:tgi-server: + TGI server ready at http://127.0.0.1:1111
INFO:tgi-llm-client: + Creating InferenceClient
Output: [".0.0.0. I'm a programmer, I'm a programmer, I'm a", ".0.0. I'm a programmer, I'm a programmer, I'm a programmer,"]
INFO:tgi-server: + Stoping TGI container
INFO:tgi-server: + Waiting for TGI container to stop
INFO:tgi-server: + Closing docker client
Output: [".\n\nHi, I'm an example 2.\n\nHi, I'm", ".\n\nI'm a simple example of a class that has a method that returns a value"]
```
12 changes: 7 additions & 5 deletions example.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from py_tgi import TGIServer, BatchedInferenceClient
from logging import basicConfig, INFO
from py_tgi import TGI

tgi_server = TGIServer(model="gpt2", sharded=False)
basicConfig(level=INFO) # to stream tgi container logs

llm = TGI(model="TheBloke/Mistral-7B-Instruct-v0.1-AWQ", quantization="awq")

try:
client = BatchedInferenceClient(url=tgi_server.url)
output = client.generate(["Hi, I'm an example 1", "Hi, I'm an example 2"])
output = llm.generate(["Hi, I'm an example 1", "Hi, I'm an example 2"])
print("Output:", output)
except Exception as e:
print(e)
finally:
tgi_server.close()
llm.close()
233 changes: 231 additions & 2 deletions py_tgi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,231 @@
from .server import TGIServer
from .client import BatchedInferenceClient
import os
import time
import subprocess
from logging import getLogger
from concurrent.futures import ThreadPoolExecutor
from typing import Optional, Literal, List, Union
from contextlib import contextmanager
import signal


import docker
import docker.types
import docker.errors
from huggingface_hub import InferenceClient
from huggingface_hub.inference._text_generation import TextGenerationResponse

LOGGER = getLogger("tgi")
HF_CACHE_DIR = f"{os.path.expanduser('~')}/.cache/huggingface/hub"

Quantization_Literal = Literal["bitsandbytes-nf4", "bitsandbytes-fp4", "gptq"]
Torch_Dtype_Literal = Literal["float32", "float16", "bfloat16"]


class TGI:
def __init__(
self,
# model options
model: str,
revision: str = "main",
# image options
image: str = "ghcr.io/huggingface/text-generation-inference",
version: str = "latest",
# docker options
volume: str = HF_CACHE_DIR,
shm_size: str = "1g",
address: str = "127.0.0.1",
port: int = 1111,
# tgi launcher options
sharded: Optional[bool] = None,
num_shard: Optional[int] = None,
torch_dtype: Optional[Torch_Dtype_Literal] = None,
quantization: Optional[Quantization_Literal] = None,
trust_remote_code: Optional[bool] = False,
disable_custom_kernels: Optional[bool] = False,
) -> None:
# model options
self.model = model
self.revision = revision
# image options
self.image = image
self.version = version
# docker options
self.port = port
self.volume = volume
self.address = address
self.shm_size = shm_size
# tgi launcher options
self.sharded = sharded
self.num_shard = num_shard
self.torch_dtype = torch_dtype
self.quantization = quantization
self.trust_remote_code = trust_remote_code
self.disable_custom_kernels = disable_custom_kernels

LOGGER.info("\t+ Starting Docker client")
self.docker_client = docker.from_env()

try:
LOGGER.info("\t+ Checking if TGI image exists")
self.docker_client.images.get(f"{self.image}:{self.version}")
except docker.errors.ImageNotFound:
LOGGER.info(
"\t+ TGI image not found, downloading it (this may take a while)"
)
self.docker_client.images.pull(f"{self.image}:{self.version}")

LOGGER.info("\t+ Building TGI command")
self.command = ["--model-id", self.model, "--revision", self.revision]

if self.torch_dtype is not None:
self.command.extend(["--torch-dtype", self.torch_dtype])
if self.quantization is not None:
self.command.extend(["--quantize", self.quantization])
if self.sharded is not None:
self.command.extend(["--sharded", str(self.sharded).lower()])
if self.num_shard is not None:
self.command.extend(["--num-shard", str(self.num_shard)])
if self.trust_remote_code:
self.command.append("--trust-remote-code")
if self.disable_custom_kernels:
self.command.append("--disable-custom-kernels")

try:
LOGGER.info("\t+ Checking if GPU is available")
if os.environ.get("CUDA_VISIBLE_DEVICES") is not None:
LOGGER.info(
"\t+ `CUDA_VISIBLE_DEVICES` is set, using the specified GPU(s)"
)
device_ids = os.environ.get("CUDA_VISIBLE_DEVICES")
else:
LOGGER.info(
"\t+ `CUDA_VISIBLE_DEVICES` is not set, using nvidia-smi to detect GPU(s)"
)
device_ids = ",".join([str(device) for device in get_gpu_devices()])
LOGGER.info("\t+ Using GPU(s) from nvidia-smi")

LOGGER.info(f"\t+ Using GPU(s): {device_ids}")
self.device_requests = [
docker.types.DeviceRequest(
driver="nvidia",
device_ids=[str(device_ids)],
capabilities=[["gpu"]],
)
]
except Exception:
LOGGER.info("\t+ No GPU detected")
self.device_requests = None

self.tgi_container = self.docker_client.containers.run(
image=f"{self.image}:{self.version}",
command=self.command,
shm_size=self.shm_size,
volumes={self.volume: {"bind": "/data", "mode": "rw"}},
ports={"80/tcp": (self.address, self.port)},
device_requests=self.device_requests,
detach=True,
)

LOGGER.info("\t+ Waiting for TGI server to be ready")
with timeout(60):
for line in self.tgi_container.logs(stream=True):
tgi_log = line.decode("utf-8").strip()
if "Connected" in tgi_log:
break
elif "Error" in tgi_log:
raise Exception(f"\t {tgi_log}")

LOGGER.info(f"\t {tgi_log}")

LOGGER.info("\t+ Conecting to TGI server")
self.url = f"http://{self.address}:{self.port}"
with timeout(60):
while True:
try:
self.tgi_client = InferenceClient(model=self.url)
self.tgi_client.text_generation("Hello world!")
LOGGER.info(f"\t+ Connected to TGI server at {self.url}")
break
except Exception:
LOGGER.info("\t+ TGI server not ready, retrying in 1 second")
time.sleep(1)

def close(self) -> None:
if hasattr(self, "tgi_container"):
LOGGER.info("\t+ Stoping TGI container")
self.tgi_container.stop()
LOGGER.info("\t+ Waiting for TGI container to stop")
self.tgi_container.wait()

if hasattr(self, "docker_client"):
LOGGER.info("\t+ Closing docker client")
self.docker_client.close()

def __call__(
self, prompt: Union[str, List[str]], **kwargs
) -> Union[TextGenerationResponse, List[TextGenerationResponse]]:
return self.generate(prompt, **kwargs)

def generate(
self, prompt: Union[str, List[str]], **kwargs
) -> Union[TextGenerationResponse, List[TextGenerationResponse]]:
if isinstance(prompt, str):
return self.tgi_client.text_generation(prompt=prompt, **kwargs)

elif isinstance(prompt, list):
with ThreadPoolExecutor(max_workers=len(prompt)) as executor:
futures = [
executor.submit(
self.tgi_client.text_generation, prompt=prompt[i], **kwargs
)
for i in range(len(prompt))
]

output = []
for i in range(len(prompt)):
output.append(futures[i].result())
return output


def get_gpu_devices():
nvidia_smi = (
subprocess.check_output(
[
"nvidia-smi",
"--query-gpu=index,gpu_name,compute_cap",
"--format=csv",
],
)
.decode("utf-8")
.strip()
.split("\n")[1:]
)
device = [
{
"id": int(gpu.split(", ")[0]),
"name": gpu.split(", ")[1],
"compute_cap": gpu.split(", ")[2],
}
for gpu in nvidia_smi
]
device_ids = [gpu["id"] for gpu in device if "Display" not in gpu["name"]]

return device_ids


@contextmanager
def timeout(time: int):
"""
Timeout context manager. Raises TimeoutError if the code inside the context manager takes longer than `time` seconds to execute.
"""

def signal_handler(signum, frame):
raise TimeoutError("Timed out")

signal.signal(signal.SIGALRM, signal_handler)
signal.alarm(time)

try:
yield
finally:
signal.alarm(0)
36 changes: 0 additions & 36 deletions py_tgi/client.py

This file was deleted.

Loading

0 comments on commit d75eac9

Please sign in to comment.