Skip to content

Commit

Permalink
cpu fixes and updates
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed May 10, 2024
1 parent 65ae0d7 commit 6f496cf
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 16 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Py-TXI (previously Py-TGI)
# Py-TXI

[![PyPI version](https://badge.fury.io/py/py-txi.svg)](https://badge.fury.io/py/py-txi)
[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/py-txi)](https://pypi.org/project/py-txi/)
Expand Down
20 changes: 13 additions & 7 deletions py_txi/inference_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import docker.errors
import docker.types
from huggingface_hub import AsyncInferenceClient
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE

from .utils import get_free_port, styled_logs

Expand All @@ -23,7 +24,7 @@
@dataclass
class InferenceServerConfig:
# Common options
model_id: str
model_id: Optional[str] = None
revision: Optional[str] = "main"
# Image to use for the container
image: Optional[str] = None
Expand All @@ -39,7 +40,7 @@ class InferenceServerConfig:
metadata={"help": "Dictionary of ports to expose from the container."},
)
volumes: Dict[str, Any] = field(
default_factory=lambda: {os.path.expanduser("~/.cache/huggingface/hub"): {"bind": "/data", "mode": "rw"}},
default_factory=lambda: {HUGGINGFACE_HUB_CACHE: {"bind": "/data", "mode": "rw"}},
metadata={"help": "Dictionary of volumes to mount inside the container."},
)
environment: List[str] = field(
Expand Down Expand Up @@ -94,16 +95,18 @@ def __init__(self, config: InferenceServerConfig) -> None:
self.device_requests = None

LOGGER.info(f"\t+ Building {self.NAME} command")
self.command = ["--model-id", self.config.model_id]
self.command = []

if self.config.model_id is not None:
self.command = ["--model-id", self.config.model_id]
if self.config.revision is not None:
self.command.extend(["--revision", self.config.revision])

for k, v in asdict(self.config).items():
if k in InferenceServerConfig.__annotations__:
continue
elif v is not None:
if isinstance(v, bool):
if isinstance(v, bool) and not k == "sharded":
self.command.append(f"--{k.replace('_', '-')}")
else:
self.command.append(f"--{k.replace('_', '-')}={str(v).lower()}")
Expand Down Expand Up @@ -179,16 +182,19 @@ async def batch_client_call(self, *args, **kwargs) -> Any:
def close(self) -> None:
if hasattr(self, "container"):
LOGGER.info("\t+ Stoping Docker container")
self.container.stop()
self.container.wait()
if self.container.status == "running":
self.container.stop()
self.container.wait()
LOGGER.info("\t+ Docker container stopped")
del self.container

if hasattr(self, "semaphore"):
self.semaphore
if self.semaphore.locked():
self.semaphore.release()
del self.semaphore

if hasattr(self, "client"):
self.client
del self.client

def __del__(self) -> None:
Expand Down
6 changes: 2 additions & 4 deletions py_txi/text_generation_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,8 @@ def __post_init__(self) -> None:
LOGGER.info("\t+ Using the latest ROCm AMD GPU image for Text-Generation-Inference")
self.image = "ghcr.io/huggingface/text-generation-inference:latest-rocm"
else:
raise ValueError(
"Unsupported system. Please either provide the image to use explicitly "
"or use a supported system (NVIDIA/ROCm) while specifying gpus/devices."
)
LOGGER.info("\t+ Using the version 1.4 since it's the last image supporting CPU")
self.image = "ghcr.io/huggingface/text-generation-inference:1.4"

if is_rocm_system() and "rocm" not in self.image:
LOGGER.warning("\t+ You are running on a ROCm AMD GPU system but using a non-ROCM image.")
Expand Down
5 changes: 4 additions & 1 deletion py_txi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ def color_text(text: str, color: str) -> str:


def styled_logs(log: str) -> str:
dict_log = loads(log)
try:
dict_log = loads(log)
except Exception:
return log

fields = dict_log.get("fields", {})
level = dict_log.get("level", "could not parse level")
Expand Down
5 changes: 2 additions & 3 deletions tests/test_txi.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@ def test_cpu_tei():
embed.close()


# tested locally with gpu
def test_gpu_tgi():
llm = TGI(config=TGIConfig(model_id="bigscience/bloom-560m", gpus="0"))
def test_cpu_tgi():
llm = TGI(config=TGIConfig(model_id="bigscience/bloom-560m"))
output = llm.generate("Hi, I'm a sanity test")
assert isinstance(output, str)
output = llm.generate(["Hi, I'm a sanity test", "I'm a second sentence"])
Expand Down

0 comments on commit 6f496cf

Please sign in to comment.