From d75eac90b05483bf5e9ed30ee3262874ed66d189 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 16 Feb 2024 14:37:50 +0100 Subject: [PATCH] simplify py-tgi --- .github/workflows/cpu_workflow.yaml | 10 +- README.md | 47 +----- example.py | 12 +- py_tgi/__init__.py | 233 +++++++++++++++++++++++++++- py_tgi/client.py | 36 ----- py_tgi/server.py | 176 --------------------- sanity_test.py | 15 -- setup.py | 3 +- tests/sanity.py | 20 +++ 9 files changed, 272 insertions(+), 280 deletions(-) delete mode 100644 py_tgi/client.py delete mode 100644 py_tgi/server.py delete mode 100644 sanity_test.py create mode 100644 tests/sanity.py diff --git a/.github/workflows/cpu_workflow.yaml b/.github/workflows/cpu_workflow.yaml index 21a3e98..4c77339 100644 --- a/.github/workflows/cpu_workflow.yaml +++ b/.github/workflows/cpu_workflow.yaml @@ -15,12 +15,12 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout code - uses: actions/checkout@v2 + uses: actions/checkout@v3 - - name: Set up Python 3.8 - uses: actions/setup-python@v2 + - name: Set up Python 3.10 + uses: actions/setup-python@v3 with: - python-version: 3.8 + python-version: '3.10' - name: Install requirements run: | @@ -28,4 +28,4 @@ jobs: pip install -e . - name: Run sanity check - run: python sanity_test.py + run: python tests/sanity.py diff --git a/README.md b/README.md index 47e6717..163585d 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,6 @@ # py-tgi -This repo constains two wrappers: - -- `TGIServer`: a python wrapper around HuggingFace's TGI (text-generation-inference) using `docker-py`. -- `BatchedInferenceClient`: a python wrapper around HuggingFace's `InferenceClient` using threading to simulate batched inference. - -Practical for running/managing TGI servers and benchmarking against other inference servers. +This presents a Python wrapper for the using `docker-py` to manage a TGI server and a client for single and batched inference. ## Installation @@ -18,49 +13,23 @@ python -m pip install git+https://github.com/IlyasMoutawwakil/py-tgi.git Running a TGI server with a batched inference client: ```python -from py_tgi import TGIServer, BatchedInferenceClient +# from logging import basicConfig, INFO +# basicConfig(level=INFO) +from py_tgi import TGI -tgi_server = TGIServer(model="gpt2", sharded=False) +llm = TGI(model="TheBloke/Mistral-7B-Instruct-v0.1-AWQ", quantization="awq") try: - client = BatchedInferenceClient(url=tgi_server.url) - output = client.generate(["Hi, I'm an example 1", "Hi, I'm an example 2"]) + output = llm.generate(["Hi, I'm an example 1", "Hi, I'm an example 2"]) print("Output:", output) except Exception as e: print(e) finally: - tgi_server.close() + llm.close() ``` Output: ```bash -INFO:tgi-server: + Starting Docker client -INFO:tgi-server: + Checking if TGI image exists -INFO:tgi-server: + Building TGI command -INFO:tgi-server: + Checking if GPU is available -INFO:tgi-server: + Using GPU(s) from nvidia-smi -INFO:tgi-server: + Using GPU(s): 0,1,2,4 -INFO:tgi-server: + Waiting for TGI server to be ready -INFO:tgi-server: 2024-01-15T08:47:15.882960Z INFO text_generation_launcher: Args { model_id: "gpt2", revision: Some("main"), validation_workers: 2, sharded: Some(false), num_shard: None, quantize: None, speculate: None, dtype: None, trust_remote_code: false, max_concurrent_requests: 128, max_best_of: 2, max_stop_sequences: 4, max_top_n_tokens: 5, max_input_length: 1024, max_total_tokens: 2048, waiting_served_ratio: 1.2, max_batch_prefill_tokens: 4096, max_batch_total_tokens: None, max_waiting_tokens: 20, hostname: "ec83247f21ab", port: 80, shard_uds_path: "/tmp/text-generation-server", master_addr: "localhost", master_port: 29500, huggingface_hub_cache: Some("/data"), weights_cache_override: None, disable_custom_kernels: false, cuda_memory_fraction: 1.0, rope_scaling: None, rope_factor: None, json_output: false, otlp_endpoint: None, cors_allow_origin: [], watermark_gamma: None, watermark_delta: None, ngrok: false, ngrok_authtoken: None, ngrok_edge: None, env: false } -INFO:tgi-server: 2024-01-15T08:47:15.883089Z INFO download: text_generation_launcher: Starting download process. -INFO:tgi-server: 2024-01-15T08:47:19.764449Z INFO text_generation_launcher: Files are already present on the host. Skipping download. -INFO:tgi-server: 2024-01-15T08:47:20.387759Z INFO download: text_generation_launcher: Successfully downloaded weights. -INFO:tgi-server: 2024-01-15T08:47:20.388064Z INFO shard-manager: text_generation_launcher: Starting shard rank=0 -INFO:tgi-server: 2024-01-15T08:47:26.062519Z INFO text_generation_launcher: Server started at unix:///tmp/text-generation-server-0 -INFO:tgi-server: 2024-01-15T08:47:26.095249Z INFO shard-manager: text_generation_launcher: Shard ready in 5.70626412s rank=0 -INFO:tgi-server: 2024-01-15T08:47:26.193466Z INFO text_generation_launcher: Starting Webserver -INFO:tgi-server: 2024-01-15T08:47:26.204835Z INFO hf_hub: /usr/local/cargo/registry/src/index.crates.io-6f17d22bba15001f/hf-hub-0.3.2/src/lib.rs:55: Token file not found "/root/.cache/huggingface/token" -INFO:tgi-server: 2024-01-15T08:47:26.536395Z INFO text_generation_router: router/src/main.rs:368: Serving revision 11c5a3d5811f50298f278a704980280950aedb10 of model gpt2 -INFO:tgi-server: 2024-01-15T08:47:26.593914Z INFO text_generation_router: router/src/main.rs:230: Warming up model -INFO:tgi-server: 2024-01-15T08:47:27.545238Z WARN text_generation_router: router/src/main.rs:244: Model does not support automatic max batch total tokens -INFO:tgi-server: 2024-01-15T08:47:27.545255Z INFO text_generation_router: router/src/main.rs:266: Setting max batch total tokens to 16000 -INFO:tgi-server: + Couldn't connect to TGI server -INFO:tgi-server: + Retrying in 0.1s -INFO:tgi-server: + TGI server ready at http://127.0.0.1:1111 -INFO:tgi-llm-client: + Creating InferenceClient -Output: [".0.0.0. I'm a programmer, I'm a programmer, I'm a", ".0.0. I'm a programmer, I'm a programmer, I'm a programmer,"] -INFO:tgi-server: + Stoping TGI container -INFO:tgi-server: + Waiting for TGI container to stop -INFO:tgi-server: + Closing docker client +Output: [".\n\nHi, I'm an example 2.\n\nHi, I'm", ".\n\nI'm a simple example of a class that has a method that returns a value"] ``` diff --git a/example.py b/example.py index c5f9d23..38a5019 100644 --- a/example.py +++ b/example.py @@ -1,12 +1,14 @@ -from py_tgi import TGIServer, BatchedInferenceClient +from logging import basicConfig, INFO +from py_tgi import TGI -tgi_server = TGIServer(model="gpt2", sharded=False) +basicConfig(level=INFO) # to stream tgi container logs + +llm = TGI(model="TheBloke/Mistral-7B-Instruct-v0.1-AWQ", quantization="awq") try: - client = BatchedInferenceClient(url=tgi_server.url) - output = client.generate(["Hi, I'm an example 1", "Hi, I'm an example 2"]) + output = llm.generate(["Hi, I'm an example 1", "Hi, I'm an example 2"]) print("Output:", output) except Exception as e: print(e) finally: - tgi_server.close() \ No newline at end of file + llm.close() diff --git a/py_tgi/__init__.py b/py_tgi/__init__.py index f349135..500a958 100644 --- a/py_tgi/__init__.py +++ b/py_tgi/__init__.py @@ -1,2 +1,231 @@ -from .server import TGIServer -from .client import BatchedInferenceClient \ No newline at end of file +import os +import time +import subprocess +from logging import getLogger +from concurrent.futures import ThreadPoolExecutor +from typing import Optional, Literal, List, Union +from contextlib import contextmanager +import signal + + +import docker +import docker.types +import docker.errors +from huggingface_hub import InferenceClient +from huggingface_hub.inference._text_generation import TextGenerationResponse + +LOGGER = getLogger("tgi") +HF_CACHE_DIR = f"{os.path.expanduser('~')}/.cache/huggingface/hub" + +Quantization_Literal = Literal["bitsandbytes-nf4", "bitsandbytes-fp4", "gptq"] +Torch_Dtype_Literal = Literal["float32", "float16", "bfloat16"] + + +class TGI: + def __init__( + self, + # model options + model: str, + revision: str = "main", + # image options + image: str = "ghcr.io/huggingface/text-generation-inference", + version: str = "latest", + # docker options + volume: str = HF_CACHE_DIR, + shm_size: str = "1g", + address: str = "127.0.0.1", + port: int = 1111, + # tgi launcher options + sharded: Optional[bool] = None, + num_shard: Optional[int] = None, + torch_dtype: Optional[Torch_Dtype_Literal] = None, + quantization: Optional[Quantization_Literal] = None, + trust_remote_code: Optional[bool] = False, + disable_custom_kernels: Optional[bool] = False, + ) -> None: + # model options + self.model = model + self.revision = revision + # image options + self.image = image + self.version = version + # docker options + self.port = port + self.volume = volume + self.address = address + self.shm_size = shm_size + # tgi launcher options + self.sharded = sharded + self.num_shard = num_shard + self.torch_dtype = torch_dtype + self.quantization = quantization + self.trust_remote_code = trust_remote_code + self.disable_custom_kernels = disable_custom_kernels + + LOGGER.info("\t+ Starting Docker client") + self.docker_client = docker.from_env() + + try: + LOGGER.info("\t+ Checking if TGI image exists") + self.docker_client.images.get(f"{self.image}:{self.version}") + except docker.errors.ImageNotFound: + LOGGER.info( + "\t+ TGI image not found, downloading it (this may take a while)" + ) + self.docker_client.images.pull(f"{self.image}:{self.version}") + + LOGGER.info("\t+ Building TGI command") + self.command = ["--model-id", self.model, "--revision", self.revision] + + if self.torch_dtype is not None: + self.command.extend(["--torch-dtype", self.torch_dtype]) + if self.quantization is not None: + self.command.extend(["--quantize", self.quantization]) + if self.sharded is not None: + self.command.extend(["--sharded", str(self.sharded).lower()]) + if self.num_shard is not None: + self.command.extend(["--num-shard", str(self.num_shard)]) + if self.trust_remote_code: + self.command.append("--trust-remote-code") + if self.disable_custom_kernels: + self.command.append("--disable-custom-kernels") + + try: + LOGGER.info("\t+ Checking if GPU is available") + if os.environ.get("CUDA_VISIBLE_DEVICES") is not None: + LOGGER.info( + "\t+ `CUDA_VISIBLE_DEVICES` is set, using the specified GPU(s)" + ) + device_ids = os.environ.get("CUDA_VISIBLE_DEVICES") + else: + LOGGER.info( + "\t+ `CUDA_VISIBLE_DEVICES` is not set, using nvidia-smi to detect GPU(s)" + ) + device_ids = ",".join([str(device) for device in get_gpu_devices()]) + LOGGER.info("\t+ Using GPU(s) from nvidia-smi") + + LOGGER.info(f"\t+ Using GPU(s): {device_ids}") + self.device_requests = [ + docker.types.DeviceRequest( + driver="nvidia", + device_ids=[str(device_ids)], + capabilities=[["gpu"]], + ) + ] + except Exception: + LOGGER.info("\t+ No GPU detected") + self.device_requests = None + + self.tgi_container = self.docker_client.containers.run( + image=f"{self.image}:{self.version}", + command=self.command, + shm_size=self.shm_size, + volumes={self.volume: {"bind": "/data", "mode": "rw"}}, + ports={"80/tcp": (self.address, self.port)}, + device_requests=self.device_requests, + detach=True, + ) + + LOGGER.info("\t+ Waiting for TGI server to be ready") + with timeout(60): + for line in self.tgi_container.logs(stream=True): + tgi_log = line.decode("utf-8").strip() + if "Connected" in tgi_log: + break + elif "Error" in tgi_log: + raise Exception(f"\t {tgi_log}") + + LOGGER.info(f"\t {tgi_log}") + + LOGGER.info("\t+ Conecting to TGI server") + self.url = f"http://{self.address}:{self.port}" + with timeout(60): + while True: + try: + self.tgi_client = InferenceClient(model=self.url) + self.tgi_client.text_generation("Hello world!") + LOGGER.info(f"\t+ Connected to TGI server at {self.url}") + break + except Exception: + LOGGER.info("\t+ TGI server not ready, retrying in 1 second") + time.sleep(1) + + def close(self) -> None: + if hasattr(self, "tgi_container"): + LOGGER.info("\t+ Stoping TGI container") + self.tgi_container.stop() + LOGGER.info("\t+ Waiting for TGI container to stop") + self.tgi_container.wait() + + if hasattr(self, "docker_client"): + LOGGER.info("\t+ Closing docker client") + self.docker_client.close() + + def __call__( + self, prompt: Union[str, List[str]], **kwargs + ) -> Union[TextGenerationResponse, List[TextGenerationResponse]]: + return self.generate(prompt, **kwargs) + + def generate( + self, prompt: Union[str, List[str]], **kwargs + ) -> Union[TextGenerationResponse, List[TextGenerationResponse]]: + if isinstance(prompt, str): + return self.tgi_client.text_generation(prompt=prompt, **kwargs) + + elif isinstance(prompt, list): + with ThreadPoolExecutor(max_workers=len(prompt)) as executor: + futures = [ + executor.submit( + self.tgi_client.text_generation, prompt=prompt[i], **kwargs + ) + for i in range(len(prompt)) + ] + + output = [] + for i in range(len(prompt)): + output.append(futures[i].result()) + return output + + +def get_gpu_devices(): + nvidia_smi = ( + subprocess.check_output( + [ + "nvidia-smi", + "--query-gpu=index,gpu_name,compute_cap", + "--format=csv", + ], + ) + .decode("utf-8") + .strip() + .split("\n")[1:] + ) + device = [ + { + "id": int(gpu.split(", ")[0]), + "name": gpu.split(", ")[1], + "compute_cap": gpu.split(", ")[2], + } + for gpu in nvidia_smi + ] + device_ids = [gpu["id"] for gpu in device if "Display" not in gpu["name"]] + + return device_ids + + +@contextmanager +def timeout(time: int): + """ + Timeout context manager. Raises TimeoutError if the code inside the context manager takes longer than `time` seconds to execute. + """ + + def signal_handler(signum, frame): + raise TimeoutError("Timed out") + + signal.signal(signal.SIGALRM, signal_handler) + signal.alarm(time) + + try: + yield + finally: + signal.alarm(0) diff --git a/py_tgi/client.py b/py_tgi/client.py deleted file mode 100644 index b0e560d..0000000 --- a/py_tgi/client.py +++ /dev/null @@ -1,36 +0,0 @@ -from huggingface_hub.inference._text_generation import TextGenerationResponse -from concurrent.futures import ThreadPoolExecutor -from huggingface_hub import InferenceClient -from typing import Any, Dict, List, Union -from logging import getLogger - - -LOGGER = getLogger("tgi-llm-client") - -ClientOutput = Union[TextGenerationResponse, List[TextGenerationResponse]] - - -class BatchedInferenceClient: - def __init__(self, url: str) -> None: - LOGGER.info("\t+ Creating InferenceClient") - self.tgi_client = InferenceClient(model=url) - - def generate( - self, prompt: Union[str, List[str]], **kwargs: Dict[str, Any] - ) -> ClientOutput: - if isinstance(prompt, str): - return self.tgi_client.text_generation(prompt=prompt, **kwargs) - - elif isinstance(prompt, list): - with ThreadPoolExecutor(max_workers=len(prompt)) as executor: - futures = [ - executor.submit( - self.tgi_client.text_generation, prompt=prompt[i], **kwargs - ) - for i in range(len(prompt)) - ] - - output = [] - for i in range(len(prompt)): - output.append(futures[i].result()) - return output diff --git a/py_tgi/server.py b/py_tgi/server.py deleted file mode 100644 index 0a989d7..0000000 --- a/py_tgi/server.py +++ /dev/null @@ -1,176 +0,0 @@ -import os -import time -import subprocess -from logging import getLogger, basicConfig, INFO -from typing import Optional - -import docker -import docker.errors -import docker.types - -from huggingface_hub import InferenceClient - -basicConfig(level=INFO) -LOGGER = getLogger("tgi-server") -HF_CACHE_DIR = f"{os.path.expanduser('~')}/.cache/huggingface/hub" - - -class TGIServer: - def __init__( - self, - model: str, - revision: str = "main", - version: str = "latest", - image: str = "ghcr.io/huggingface/text-generation-inference", - volume: str = HF_CACHE_DIR, - shm_size: str = "1g", - address: str = "127.0.0.1", - port: int = 1111, - trust_remote_code: bool = False, - disable_custom_kernels: bool = False, - sharded: Optional[bool] = None, - num_shard: Optional[int] = None, - torch_dtype: Optional[str] = None, # float32, float16, bfloat16 - quantization: Optional[str] = None, # bitsandbytes-nf4, bitsandbytes-fp4, gptq - ) -> None: - # model options - self.model = model - self.revision = revision - # image options - self.image = image - self.version = version - # docker options - self.port = port - self.volume = volume - self.address = address - self.shm_size = shm_size - # tgi launcher options - self.sharded = sharded - self.num_shard = num_shard - self.torch_dtype = torch_dtype - self.quantization = quantization - self.trust_remote_code = trust_remote_code - self.disable_custom_kernels = disable_custom_kernels - self.url = f"http://{self.address}:{self.port}" - - LOGGER.info("\t+ Starting Docker client") - self.docker_client = docker.from_env() - - try: - LOGGER.info("\t+ Checking if TGI image exists") - self.docker_client.images.get(f"{self.image}:{self.version}") - except docker.errors.ImageNotFound: - LOGGER.info("\t+ TGI image not found, pulling it") - self.docker_client.images.pull(f"{self.image}:{self.version}") - - LOGGER.info("\t+ Building TGI command") - self.command = [ - "--model-id", - self.model, - "--revision", - self.revision, - ] - - if self.torch_dtype is not None: - self.command.extend(["--torch-dtype", self.torch_dtype]) - if self.quantization is not None: - self.command.extend(["--quantize", self.quantization]) - if self.sharded is not None: - self.command.extend(["--sharded", str(self.sharded).lower()]) - if self.num_shard is not None: - self.command.extend(["--num-shard", str(self.num_shard)]) - if self.trust_remote_code: - self.command.append("--trust-remote-code") - if self.disable_custom_kernels: - self.command.append("--disable-custom-kernels") - - try: - LOGGER.info("\t+ Checking if GPU is available") - if os.environ.get("CUDA_VISIBLE_DEVICES") is not None: - LOGGER.info("\t+ Using GPU(s) from CUDA_VISIBLE_DEVICES") - device_ids = os.environ.get("CUDA_VISIBLE_DEVICES") - else: - device_ids = ",".join([str(device) for device in get_gpu_devices()]) - LOGGER.info("\t+ Using GPU(s) from nvidia-smi") - - LOGGER.info(f"\t+ Using GPU(s): {device_ids}") - self.device_requests = [ - docker.types.DeviceRequest( - driver="nvidia", - device_ids=[str(device_ids)], - capabilities=[["gpu"]], - ) - ] - except Exception: - LOGGER.info("\t+ No GPU detected") - self.device_requests = None - - self.tgi_container = self.docker_client.containers.run( - image=f"{self.image}:{self.version}", - command=self.command, - shm_size=self.shm_size, - volumes={self.volume: {"bind": "/data", "mode": "rw"}}, - ports={"80/tcp": (self.address, self.port)}, - device_requests=self.device_requests, - detach=True, - ) - - LOGGER.info("\t+ Waiting for TGI server to be ready") - for line in self.tgi_container.logs(stream=True): - tgi_log = line.decode("utf-8").strip() - if not tgi_log: - continue - elif "Connected" in tgi_log: - break - else: - LOGGER.info(f"\t {tgi_log}") - - while True: - try: - dummy_client = InferenceClient(model=self.url) - dummy_client.text_generation("Hello world!") - del dummy_client - break - except Exception: - LOGGER.info("\t+ Couldn't connect to TGI server") - LOGGER.info("\t+ Retrying in 0.1s") - time.sleep(0.1) - - LOGGER.info(f"\t+ TGI server ready at {self.url}") - - def close(self) -> None: - if hasattr(self, "tgi_container"): - LOGGER.info("\t+ Stoping TGI container") - self.tgi_container.stop() - LOGGER.info("\t+ Waiting for TGI container to stop") - self.tgi_container.wait() - - if hasattr(self, "docker_client"): - LOGGER.info("\t+ Closing docker client") - self.docker_client.close() - - -def get_gpu_devices(): - nvidia_smi = ( - subprocess.check_output( - [ - "nvidia-smi", - "--query-gpu=index,gpu_name,compute_cap", - "--format=csv", - ], - ) - .decode("utf-8") - .strip() - .split("\n")[1:] - ) - device = [ - { - "id": int(gpu.split(", ")[0]), - "name": gpu.split(", ")[1], - "compute_cap": gpu.split(", ")[2], - } - for gpu in nvidia_smi - ] - device_ids = [gpu["id"] for gpu in device if "Display" not in gpu["name"]] - - return device_ids diff --git a/sanity_test.py b/sanity_test.py deleted file mode 100644 index 8efdd6d..0000000 --- a/sanity_test.py +++ /dev/null @@ -1,15 +0,0 @@ -from py_tgi import TGIServer, BatchedInferenceClient - -tgi_server = TGIServer("gpt2", sharded=False) - -try: - client = BatchedInferenceClient(url=tgi_server.url) - output = client.generate("Hi, I'm a sanity test") - print("Output:", output) - tgi_server.close() - assert isinstance(output, str) - -# catch Exception and InterruptedError -except (Exception, InterruptedError, KeyboardInterrupt) as e: - tgi_server.close() - raise e diff --git a/setup.py b/setup.py index 839345e..4cd328c 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,6 @@ version="0.1", packages=find_packages(), install_requires=[ - "docker", - "huggingface-hub", + "docker" "huggingface-hub", ], ) diff --git a/tests/sanity.py b/tests/sanity.py new file mode 100644 index 0000000..8d5be12 --- /dev/null +++ b/tests/sanity.py @@ -0,0 +1,20 @@ +from logging import basicConfig, INFO +from py_tgi import TGI + +basicConfig(level=INFO) + +llm = TGI("gpt2", sharded=False) + +try: + output = llm.generate("Hi, I'm a sanity test") + assert isinstance(output, str) + + output = llm.generate(["Hi, I'm a sanity test", "I'm a second sentence"]) + assert isinstance(output, list) + + llm.close() + +# catch Exception and InterruptedError +except (Exception, InterruptedError, KeyboardInterrupt) as e: + llm.close() + raise e