Skip to content

Commit

Permalink
change logging
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Mar 6, 2024
1 parent bfe39c7 commit 647e89a
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 18 deletions.
16 changes: 8 additions & 8 deletions py_txi/docker_inference_server.py → py_txi/inference_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
basicConfig(level=INFO)

DOCKER = docker.from_env()
LOGGER = getLogger("docker-inference-server")
LOGGER = getLogger("Inference-Server")


@dataclass
class DockerInferenceServerConfig:
class InferenceServerConfig:
# Image to use for the container
image: str
# Shared memory size for the container
Expand All @@ -44,21 +44,21 @@ class DockerInferenceServerConfig:
metadata={"help": "Dictionary of environment variables to forward to the container."},
)

max_concurrent_requests: Optional[int] = None
timeout: int = 60
max_concurrent_requests: int = 128

def __post_init__(self) -> None:
if self.ports["80/tcp"][1] == 0:
LOGGER.info("\t+ Getting a free port for the server")
self.ports["80/tcp"] = (self.ports["80/tcp"][0], get_free_port())


class DockerInferenceServer(ABC):
NAME: str = "Docker-Inference-Server"
class InferenceServer(ABC):
NAME: str = "Inference-Server"
SUCCESS_SENTINEL: str = "Success"
FAILURE_SENTINEL: str = "Failure"

def __init__(self, config: DockerInferenceServerConfig) -> None:
def __init__(self, config: InferenceServerConfig) -> None:
self.config = config

try:
Expand Down Expand Up @@ -89,13 +89,13 @@ def __init__(self, config: DockerInferenceServerConfig) -> None:
LOGGER.info(f"\t+ Building {self.NAME} command")
self.command = []
for k, v in asdict(self.config).items():
if k in DockerInferenceServerConfig.__annotations__:
if k in InferenceServerConfig.__annotations__:
continue
elif v is not None:
if isinstance(v, bool):
self.command.append(f"--{k.replace('_', '-')}")
else:
self.command.append(f"--{k.replace('_', '-')}={v}")
self.command.append(f"--{k.replace('_', '-')}={str(v).lower()}")

address, port = self.config.ports["80/tcp"]
self.url = f"http://{address}:{port}"
Expand Down
8 changes: 4 additions & 4 deletions py_txi/text_embedding_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@

import numpy as np

from .docker_inference_server import DockerInferenceServer, DockerInferenceServerConfig
from .inference_server import InferenceServer, InferenceServerConfig
from .utils import is_nvidia_system

LOGGER = getLogger("TEI")
LOGGER = getLogger("Text-Embedding-Inference")


Pooling_Literal = Literal["cls", "mean"]
DType_Literal = Literal["float32", "float16"]


@dataclass
class TEIConfig(DockerInferenceServerConfig):
class TEIConfig(InferenceServerConfig):
# Docker options
image: str = "ghcr.io/huggingface/text-embeddings-inference:cpu-latest"
# Launcher options
Expand All @@ -37,7 +37,7 @@ def __post_init__(self) -> None:
)


class TEI(DockerInferenceServer):
class TEI(InferenceServer):
NAME: str = "Text-Embedding-Inference"
SUCCESS_SENTINEL: str = "Ready"
FAILURE_SENTINEL: str = "Error"
Expand Down
8 changes: 4 additions & 4 deletions py_txi/text_generation_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@
from logging import getLogger
from typing import Literal, Optional, Union

from .docker_inference_server import DockerInferenceServer, DockerInferenceServerConfig
from .inference_server import InferenceServer, InferenceServerConfig
from .utils import is_rocm_system

LOGGER = getLogger("TGI")
LOGGER = getLogger("Text-Generation-Inference")

Shareded_Literal = Literal["true", "false"]
DType_Literal = Literal["float32", "float16", "bfloat16"]
Quantize_Literal = Literal["bitsandbytes-nf4", "bitsandbytes-fp4", "gptq"]


@dataclass
class TGIConfig(DockerInferenceServerConfig):
class TGIConfig(InferenceServerConfig):
# Docker options
image: str = "ghcr.io/huggingface/text-generation-inference:latest"
# Launcher options
Expand All @@ -41,7 +41,7 @@ def __post_init__(self) -> None:
self.image += "-rocm"


class TGI(DockerInferenceServer):
class TGI(InferenceServer):
NAME: str = "Text-Generation-Inference"
SUCCESS_SENTINEL: str = "Connected"
FAILURE_SENTINEL: str = "Error"
Expand Down
3 changes: 1 addition & 2 deletions tests/test_txi.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import numpy as np

from py_txi.text_embedding_inference import TEI, TEIConfig
from py_txi.text_generation_inference import TGI, TGIConfig
from py_txi import TEI, TGI, TEIConfig, TGIConfig


def test_tei():
Expand Down

0 comments on commit 647e89a

Please sign in to comment.