diff --git a/README.md b/README.md index cee9c34..885c834 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Py-TXI (previously Py-TGI) +# Py-TXI [![PyPI version](https://badge.fury.io/py/py-txi.svg)](https://badge.fury.io/py/py-txi) [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/py-txi)](https://pypi.org/project/py-txi/) diff --git a/py_txi/inference_server.py b/py_txi/inference_server.py index 0b3ab4a..c954654 100644 --- a/py_txi/inference_server.py +++ b/py_txi/inference_server.py @@ -12,6 +12,7 @@ import docker.errors import docker.types from huggingface_hub import AsyncInferenceClient +from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from .utils import get_free_port, styled_logs @@ -23,7 +24,7 @@ @dataclass class InferenceServerConfig: # Common options - model_id: str + model_id: Optional[str] = None revision: Optional[str] = "main" # Image to use for the container image: Optional[str] = None @@ -39,7 +40,7 @@ class InferenceServerConfig: metadata={"help": "Dictionary of ports to expose from the container."}, ) volumes: Dict[str, Any] = field( - default_factory=lambda: {os.path.expanduser("~/.cache/huggingface/hub"): {"bind": "/data", "mode": "rw"}}, + default_factory=lambda: {HUGGINGFACE_HUB_CACHE: {"bind": "/data", "mode": "rw"}}, metadata={"help": "Dictionary of volumes to mount inside the container."}, ) environment: List[str] = field( @@ -94,8 +95,10 @@ def __init__(self, config: InferenceServerConfig) -> None: self.device_requests = None LOGGER.info(f"\t+ Building {self.NAME} command") - self.command = ["--model-id", self.config.model_id] + self.command = [] + if self.config.model_id is not None: + self.command = ["--model-id", self.config.model_id] if self.config.revision is not None: self.command.extend(["--revision", self.config.revision]) @@ -103,7 +106,7 @@ def __init__(self, config: InferenceServerConfig) -> None: if k in InferenceServerConfig.__annotations__: continue elif v is not None: - if isinstance(v, bool): + if isinstance(v, bool) and not k == "sharded": self.command.append(f"--{k.replace('_', '-')}") else: self.command.append(f"--{k.replace('_', '-')}={str(v).lower()}") @@ -179,16 +182,19 @@ async def batch_client_call(self, *args, **kwargs) -> Any: def close(self) -> None: if hasattr(self, "container"): LOGGER.info("\t+ Stoping Docker container") - self.container.stop() - self.container.wait() + if self.container.status == "running": + self.container.stop() + self.container.wait() LOGGER.info("\t+ Docker container stopped") del self.container if hasattr(self, "semaphore"): - self.semaphore + if self.semaphore.locked(): + self.semaphore.release() del self.semaphore if hasattr(self, "client"): + self.client del self.client def __del__(self) -> None: diff --git a/py_txi/text_generation_inference.py b/py_txi/text_generation_inference.py index 498e0e2..c05da6a 100644 --- a/py_txi/text_generation_inference.py +++ b/py_txi/text_generation_inference.py @@ -38,10 +38,8 @@ def __post_init__(self) -> None: LOGGER.info("\t+ Using the latest ROCm AMD GPU image for Text-Generation-Inference") self.image = "ghcr.io/huggingface/text-generation-inference:latest-rocm" else: - raise ValueError( - "Unsupported system. Please either provide the image to use explicitly " - "or use a supported system (NVIDIA/ROCm) while specifying gpus/devices." - ) + LOGGER.info("\t+ Using the version 1.4 since it's the last image supporting CPU") + self.image = "ghcr.io/huggingface/text-generation-inference:1.4" if is_rocm_system() and "rocm" not in self.image: LOGGER.warning("\t+ You are running on a ROCm AMD GPU system but using a non-ROCM image.") diff --git a/py_txi/utils.py b/py_txi/utils.py index 96f2f18..0dcbaeb 100644 --- a/py_txi/utils.py +++ b/py_txi/utils.py @@ -44,7 +44,10 @@ def color_text(text: str, color: str) -> str: def styled_logs(log: str) -> str: - dict_log = loads(log) + try: + dict_log = loads(log) + except Exception: + return log fields = dict_log.get("fields", {}) level = dict_log.get("level", "could not parse level") diff --git a/tests/test_txi.py b/tests/test_txi.py index a7ebba6..b94ef48 100644 --- a/tests/test_txi.py +++ b/tests/test_txi.py @@ -12,9 +12,8 @@ def test_cpu_tei(): embed.close() -# tested locally with gpu -def test_gpu_tgi(): - llm = TGI(config=TGIConfig(model_id="bigscience/bloom-560m", gpus="0")) +def test_cpu_tgi(): + llm = TGI(config=TGIConfig(model_id="bigscience/bloom-560m")) output = llm.generate("Hi, I'm a sanity test") assert isinstance(output, str) output = llm.generate(["Hi, I'm a sanity test", "I'm a second sentence"])