Skip to content

Commit

Permalink
Support rocm benchmarking with text generation inference backend (#132)
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil authored Feb 22, 2024
1 parent 580251a commit fc5412b
Show file tree
Hide file tree
Showing 7 changed files with 157 additions and 198 deletions.
204 changes: 59 additions & 145 deletions optimum_benchmark/backends/text_generation_inference/backend.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
import gc
import os
import time
from concurrent.futures import ThreadPoolExecutor
from logging import getLogger
from tempfile import TemporaryDirectory
from typing import Any, Dict, List

import torch
from huggingface_hub import InferenceClient, snapshot_download
from huggingface_hub.inference._text_generation import TextGenerationResponse
from huggingface_hub import snapshot_download
from py_tgi import TGI
from safetensors.torch import save_model
from transformers import logging as transformers_logging

import docker
import docker.errors
import docker.types

from ...system_utils import is_nvidia_system, is_rocm_system
from ...task_utils import TEXT_GENERATION_TASKS
from ..base import Backend
from ..transformers_utils import randomize_weights
Expand All @@ -23,6 +19,9 @@
# bachend logger
LOGGER = getLogger("text-generation-inference")

# disable other loggers
transformers_logging.set_verbosity_error()


class TGIBackend(Backend[TGIConfig]):
NAME: str = "text-generation-inference"
Expand All @@ -31,6 +30,21 @@ def __init__(self, config: TGIConfig) -> None:
super().__init__(config)
self.validate_task()

if self.config.device == "cuda" and is_nvidia_system():
self.devices = None
self.gpus = self.config.device_ids
LOGGER.info(f"\t+ CUDA devices: {self.gpus}")
if self.config.device == "cuda" and is_rocm_system():
self.gpus = None
device_ids = list(map(int, self.config.device_ids.split(",")))
renderDs = [file for file in os.listdir("/dev/dri") if file.startswith("renderD")]
self.devices = ["/dev/kfd"] + [f"/dev/dri/{renderDs[i]}" for i in device_ids]
LOGGER.info(f"\t+ ROCm devices: {self.devices}")
else:
self.gpus = None
self.devices = None
LOGGER.info("\t+ CPU device")

LOGGER.info("\t+ Creating backend temporary directory")
self.tmp_dir = TemporaryDirectory()

Expand All @@ -44,33 +58,32 @@ def validate_task(self) -> None:
if self.config.task not in TEXT_GENERATION_TASKS:
raise NotImplementedError(f"TGI does not support task {self.config.task}")

LOGGER.info(f"Using AutoModel class {self.automodel_class.__name__}")

def download_pretrained_model(self) -> None:
LOGGER.info("\t+ Downloading pretrained model")
snapshot_download(self.config.model, **self.config.hub_kwargs)

def load_model_from_pretrained(self) -> None:
def prepare_pretrained_model(self) -> None:
LOGGER.info("\t+ Modifying pretrained generation config")
self.pretrained_generation_config.eos_token_id = -100
self.pretrained_generation_config.pad_token_id = -101
self.generation_config.eos_token_id = -100
self.generation_config.pad_token_id = -101

LOGGER.info("\t+ Saving new pretrained generation config")
model_cache_folder = f"models/{self.config.model}".replace("/", "--")
model_cache_path = f"{self.config.volume}/{model_cache_folder}"

snapshot_ref = (
open(f"{model_cache_path}/refs/{self.config.hub_kwargs.get('revision', 'main')}", "r").read().strip()
)
snapshot_file = f"{model_cache_path}/refs/{self.config.hub_kwargs.get('revision', 'main')}"
snapshot_ref = open(snapshot_file, "r").read().strip()

model_snapshot_path = f"{model_cache_path}/snapshots/{snapshot_ref}"
self.pretrained_generation_config.save_pretrained(save_directory=model_snapshot_path)
self.generation_config.save_pretrained(save_directory=model_snapshot_path)

def load_model_from_pretrained(self) -> None:
self.prepare_pretrained_model()
self.start_tgi_server()

def create_no_weights_model(self) -> None:
LOGGER.info("\t+ Creating no weights model directory")
self.no_weights_model = os.path.join(self.tmp_dir.name, "no_weights")
self.no_weights_model = os.path.join(self.config.volume, "no_weights_model")
os.makedirs(self.no_weights_model, exist_ok=True)

LOGGER.info("\t+ Saving pretrained config")
Expand Down Expand Up @@ -102,99 +115,36 @@ def create_no_weights_model(self) -> None:
self.delete_pretrained_model()

LOGGER.info("\t+ Saving generation config")
self.pretrained_generation_config.eos_token_id = -100
self.pretrained_generation_config.pad_token_id = -101
self.pretrained_generation_config.save_pretrained(save_directory=self.no_weights_model)
self.generation_config.eos_token_id = -100
self.generation_config.pad_token_id = -101
self.generation_config.save_pretrained(save_directory=self.no_weights_model)

def load_model_with_no_weights(self) -> None:
self.create_no_weights_model()
original_model = self.config.model
self.config.model = self.no_weights_model
self.config.model = "data/no_weights_model"
self.start_tgi_server()
self.config.model = original_model

def start_tgi_server(self) -> None:
LOGGER.info("\t+ Starting Python Docker client")
self.docker_client = docker.from_env()

try:
LOGGER.info("\t+ Checking if TGI image exists")
self.docker_client.images.get(self.config.image)
except docker.errors.ImageNotFound:
LOGGER.info("\t+ TGI image not found, pulling it")
self.docker_client.images.pull(self.config.image)

env = {}
if os.environ.get("HUGGING_FACE_HUB_TOKEN", None) is not None:
env["HUGGING_FACE_HUB_TOKEN"] = os.environ["HUGGING_FACE_HUB_TOKEN"]

LOGGER.info("\t+ Building TGI command")
self.command = ["--model-id", self.config.model, "--revision", self.config.hub_kwargs.get("revision", "main")]

if self.config.sharded is not None:
self.command.extend(["--sharded", str(self.config.sharded).lower()])
if self.config.num_shard is not None:
self.command.extend(["--num-shard", str(self.config.num_shard)])
if self.config.quantization_scheme is not None:
self.command.extend(["--quantize", self.config.quantization_scheme])
if self.config.torch_dtype is not None:
self.command.extend(["--dtype", self.config.torch_dtype])

if self.config.hub_kwargs.get("trust_remote_code", False):
self.command.append("--trust-remote-code")
if self.config.disable_custom_kernels:
self.command.append("--disable-custom-kernels")

if self.config.device == "cuda":
device_ids = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
LOGGER.info(f"\t+ Starting TGI container on CUDA device(s): {device_ids}")
device_requests = [docker.types.DeviceRequest(device_ids=[device_ids], capabilities=[["gpu"]])]
else:
LOGGER.info("\t+ Starting TGI container on CPU device")
device_requests = None

if self.config.no_weights:
self.volumes = {self.tmp_dir.name: {"bind": self.tmp_dir.name, "mode": "rw"}}
else:
self.volumes = {self.config.volume: {"bind": "/data", "mode": "rw"}}

ports = {"80/tcp": (self.config.address, self.config.port)}

self.tgi_container = self.docker_client.containers.run(
device_requests=device_requests,
command=self.command,
volumes=self.volumes,
shm_size=self.config.shm_size,
self.pretrained_model = TGI(
model=self.config.model,
dtype=self.config.dtype,
image=self.config.image,
environment=env,
ports=ports,
detach=True,
quantize=self.config.quantize,
port=self.config.port,
volume=self.config.volume,
address=self.config.address,
shm_size=self.config.shm_size,
gpus=self.gpus,
devices=self.devices,
sharded=self.config.sharded,
num_shard=self.config.num_shard,
disable_custom_kernels=self.config.disable_custom_kernels,
revision=self.config.hub_kwargs.get("revision", "main"),
trust_remote_code=self.config.hub_kwargs.get("trust_remote_code", False),
)

LOGGER.info("\t+ Waiting for TGI server to be ready")
for line in self.tgi_container.logs(stream=True):
tgi_log = line.decode("utf-8").strip()
if not tgi_log:
continue
elif "Connected" in tgi_log:
LOGGER.info("\t+ TGI server is ready")
break
else:
LOGGER.info(f"\t {tgi_log}")

LOGGER.info("\t+ Creating InferenceClient")
self.client = InferenceClient(model=f"http://{self.config.address}:{self.config.port}")

while True:
try:
LOGGER.info("\t+ Checking if TGI client is ready")
self.client.text_generation(prompt="test", max_new_tokens=1)
LOGGER.info("\t+ TGI client is ready")
break
except Exception as e:
LOGGER.info(f"\t+ TGI client is not ready yet: {e}")
time.sleep(0.5)

def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
if "input_ids" in inputs:
return {"prompt": self.pretrained_processor.batch_decode(inputs["input_ids"].tolist())}
Expand All @@ -203,55 +153,19 @@ def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
else:
raise ValueError("inputs must contain either input_ids or inputs")

def forward(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> List[TextGenerationResponse]:
output = []
with ThreadPoolExecutor(max_workers=len(inputs["prompt"])) as executor:
futures = [
executor.submit(
self.client.text_generation,
decoder_input_details=True,
prompt=inputs["prompt"][i],
max_new_tokens=1,
details=True,
)
for i in range(len(inputs["prompt"]))
]
for future in futures:
output.append(future.result())

return output

def generate(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> List[TextGenerationResponse]:
output = []
with ThreadPoolExecutor(max_workers=len(inputs["prompt"])) as executor:
futures = [
executor.submit(
self.client.text_generation,
max_new_tokens=kwargs["max_new_tokens"],
do_sample=kwargs["do_sample"],
prompt=inputs["prompt"][i],
details=True,
)
for i in range(len(inputs["prompt"]))
]
for i in range(len(inputs["prompt"])):
output.append(futures[i].result())

return output
def forward(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> List[str]:
return self.pretrained_model.generate(**inputs, **kwargs, max_new_tokens=1)

def generate(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> List[str]:
return self.pretrained_model.generate(
**inputs,
do_sample=kwargs.get("do_sample", False),
max_new_tokens=kwargs.get("max_new_tokens", 1),
)

def clean(self) -> None:
super().clean()

if hasattr(self, "tgi_container"):
LOGGER.info("\t+ Stopping TGI container")
self.tgi_container.stop()
LOGGER.info("\t+ Waiting for TGI container to stop")
self.tgi_container.wait()

if hasattr(self, "docker_client"):
LOGGER.info("\t+ Closing docker client")
self.docker_client.close()

if hasattr(self, "tmp_dir"):
LOGGER.info("\t+ Cleaning temporary directory")
self.tmp_dir.cleanup()
Expand Down
24 changes: 14 additions & 10 deletions optimum_benchmark/backends/text_generation_inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
from dataclasses import dataclass
from typing import Optional

from ...import_utils import py_tgi_version
from ..config import BackendConfig


@dataclass
class TGIConfig(BackendConfig):
name: str = "text-generation-inference"
version: Optional[str] = "0.0.1"
version: Optional[str] = py_tgi_version()
_target_: str = "optimum_benchmark.backends.text_generation_inference.backend.TGIBackend"

# optimum benchmark specific
Expand All @@ -21,19 +22,22 @@ class TGIConfig(BackendConfig):
shm_size: str = "1g"
port: int = 1111

# torch options
torch_dtype: Optional[str] = None # None, float32, float16, bfloat16
# optimization options
disable_custom_kernels: bool = False # True, False
# quantization options
quantization_scheme: Optional[str] = None # None, bitsandbytes-nf4, bitsandbytes-fp4
# sharding options
sharded: Optional[bool] = None # None, True, False
num_shard: Optional[int] = None # None, 1, 2, 4, 8, 16, 32, 64
# torch options
dtype: Optional[str] = None # None, float32, float16, bfloat16
quantize: Optional[str] = None # None, bitsandbytes-nf4, bitsandbytes-fp4
# optimization options
disable_custom_kernels: bool = False # True, False

def __post_init__(self):
super().__post_init__()

if self.torch_dtype is not None:
if self.torch_dtype not in ["float32", "float16", "bfloat16"]:
raise ValueError(f"Invalid value for dtype: {self.torch_dtype}")
if self.dtype is not None:
if self.dtype not in ["float32", "float16", "bfloat16"]:
raise ValueError(f"Invalid value for dtype: {self.dtype}")

if self.quantize is not None:
if self.quantize not in ["bitsandbytes-nf4", "bitsandbytes-fp4", "awq", "gptq"]:
raise ValueError(f"Invalid value for quantize: {self.quantize}")
1 change: 0 additions & 1 deletion optimum_benchmark/benchmarks/inference/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def run(self, backend: Backend[BackendConfigT][BackendConfigT]) -> None:
LOGGER.info("\t+ Generating Image Diffusion inputs")
self.call_inputs = self.input_generator()
self.call_inputs = backend.prepare_inputs(self.call_inputs)
self.call_inputs = {"prompt": self.call_inputs["prompt"]}
LOGGER.info("\t+ Updating Image Diffusion kwargs with default values")
self.config.call_kwargs = {**IMAGE_DIFFUSION_KWARGS, **self.config.call_kwargs}
LOGGER.info("\t+ Initializing Image Diffusion report")
Expand Down
5 changes: 4 additions & 1 deletion optimum_benchmark/benchmarks/inference/inputs_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
def extract_text_generation_inputs(inputs):
if "pixel_values" in inputs:
if "prompt" in inputs:
# text input
text_generation_inputs = {"prompt": inputs["prompt"]}
elif "pixel_values" in inputs:
# image input
text_generation_inputs = {"pixel_values": inputs["pixel_values"]}
elif "input_values" in inputs:
Expand Down
15 changes: 15 additions & 0 deletions optimum_benchmark/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,16 @@
_tensorrt_llm_available = importlib.util.find_spec("tensorrt_llm") is not None
_psutil_available = importlib.util.find_spec("psutil") is not None
_optimum_benchmark_available = importlib.util.find_spec("optimum_benchmark") is not None
_py_tgi_available = importlib.util.find_spec("py_tgi") is not None
_pyrsmi_available = importlib.util.find_spec("pyrsmi") is not None


def is_pyrsmi_available():
return _pyrsmi_available


def is_py_tgi_available():
return _py_tgi_available


def is_psutil_available():
Expand Down Expand Up @@ -183,6 +193,11 @@ def optimum_benchmark_version():
return importlib.metadata.version("optimum_benchmark")


def py_tgi_version():
if _py_tgi_available:
return importlib.metadata.version("py_tgi")


def get_git_revision_hash(package_name: str) -> Optional[str]:
"""
Returns the git commit SHA of a package installed from a git repository.
Expand Down
Loading

0 comments on commit fc5412b

Please sign in to comment.