From d44fb9d4418573a77001a10dfe1206218eff53f3 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 6 Mar 2024 01:11:11 +0100 Subject: [PATCH] added semaphores and event loop management --- example.py | 10 +++++--- py_txi/docker_inference_server.py | 12 +++++++++ py_txi/text_embedding_inference.py | 10 +++++--- py_txi/text_generation_inference.py | 39 +++++++---------------------- 4 files changed, 33 insertions(+), 38 deletions(-) diff --git a/example.py b/example.py index 20b6c43..0073fc3 100644 --- a/example.py +++ b/example.py @@ -2,11 +2,13 @@ from py_txi.text_generation_inference import TGI, TGIConfig embed = TEI(config=TEIConfig(pooling="cls")) -output = embed.encode(["Hi, I'm an embedding model", "I'm fine, how are you?"]) -print("Embed:", output) +output = embed.encode(["Hi, I'm an embedding model", "I'm fine, how are you?"] * 100) +print(len(output)) +print("Embed:", output[0]) embed.close() llm = TGI(config=TGIConfig(sharded="false")) -output = llm.generate(["Hi, I'm a language model", "I'm fine, how are you?"]) -print("LLM:", output) +output = llm.generate(["Hi, I'm a language model", "I'm fine, how are you?"] * 50) +print(len(output)) +print("LLM:", output[0]) llm.close() diff --git a/py_txi/docker_inference_server.py b/py_txi/docker_inference_server.py index a233b4e..9587c12 100644 --- a/py_txi/docker_inference_server.py +++ b/py_txi/docker_inference_server.py @@ -45,6 +45,7 @@ class DockerInferenceServerConfig: ) timeout: int = 60 + max_concurrent_requests: int = 128 def __post_init__(self) -> None: if self.ports["80/tcp"][1] == 0: @@ -125,6 +126,13 @@ def __init__(self, config: DockerInferenceServerConfig) -> None: else: LOGGER.info(f"\t {log}") + try: + asyncio.set_event_loop(asyncio.get_event_loop()) + except RuntimeError: + asyncio.set_event_loop(asyncio.new_event_loop()) + + self.semaphore = asyncio.Semaphore(self.config.max_concurrent_requests) + LOGGER.info(f"\t+ Waiting for {self.NAME} server to be ready") start_time = time.time() while time.time() - start_time < self.config.timeout: @@ -153,6 +161,10 @@ def close(self) -> None: LOGGER.info("\t+ Docker container stopped") del self.container + if hasattr(self, "semaphore"): + self.semaphore + del self.semaphore + if hasattr(self, "client"): del self.client diff --git a/py_txi/text_embedding_inference.py b/py_txi/text_embedding_inference.py index 86ba5c0..075cc09 100644 --- a/py_txi/text_embedding_inference.py +++ b/py_txi/text_embedding_inference.py @@ -15,7 +15,7 @@ DType_Literal = Literal["float32", "float16"] -@dataclass(order=False) +@dataclass class TEIConfig(DockerInferenceServerConfig): # Docker options image: str = "ghcr.io/huggingface/text-embeddings-inference:cpu-latest" @@ -24,7 +24,8 @@ class TEIConfig(DockerInferenceServerConfig): revision: str = "main" dtype: Optional[DType_Literal] = None pooling: Optional[Pooling_Literal] = None - tokenization_workers: Optional[int] = None + # Concurrency options + max_concurrent_requests: int = 512 def __post_init__(self) -> None: super().__post_init__() @@ -45,8 +46,9 @@ def __init__(self, config: TEIConfig) -> None: super().__init__(config) async def single_client_call(self, text: str, **kwargs) -> np.ndarray: - output = await self.client.feature_extraction(text=text, **kwargs) - return output + async with self.semaphore: + output = await self.client.feature_extraction(text=text, **kwargs) + return output async def batch_client_call(self, text: List[str], **kwargs) -> List[np.ndarray]: output = await asyncio.gather(*[self.single_client_call(t, **kwargs) for t in text]) diff --git a/py_txi/text_generation_inference.py b/py_txi/text_generation_inference.py index f297bd0..b3990e1 100644 --- a/py_txi/text_generation_inference.py +++ b/py_txi/text_generation_inference.py @@ -20,37 +20,15 @@ class TGIConfig(DockerInferenceServerConfig): # Launcher options model_id: str = "gpt2" revision: str = "main" + num_shard: Optional[int] = None dtype: Optional[DType_Literal] = None - quantize: Optional[Quantize_Literal] = None + enable_cuda_graphs: Optional[bool] = None sharded: Optional[Shareded_Literal] = None - num_shard: Optional[int] = None - trust_remote_code: Optional[bool] = None + quantize: Optional[Quantize_Literal] = None disable_custom_kernels: Optional[bool] = None - # Inference options - max_best_of: Optional[int] = None - max_concurrent_requests: Optional[int] = None - max_stop_sequences: Optional[int] = None - max_top_n_tokens: Optional[int] = None - max_input_length: Optional[int] = None - max_total_tokens: Optional[int] = None - waiting_served_ratio: Optional[float] = None - max_batch_prefill_tokens: Optional[int] = None - max_batch_total_tokens: Optional[int] = None - max_waiting_tokens: Optional[int] = None - max_batch_size: Optional[int] = None - enable_cuda_graphs: Optional[bool] = None - huggingface_hub_cache: Optional[str] = None - weights_cache_override: Optional[str] = None - cuda_memory_fraction: Optional[float] = None - rope_scaling: Optional[str] = None - rope_factor: Optional[str] = None - json_output: Optional[bool] = None - otlp_endpoint: Optional[str] = None - cors_allow_origin: Optional[list] = None - watermark_gamma: Optional[str] = None - watermark_delta: Optional[str] = None - tokenizer_config_path: Optional[str] = None - disable_grammar_support: Optional[bool] = None + trust_remote_code: Optional[bool] = None + # Concurrency options + max_concurrent_requests: int = 128 def __post_init__(self) -> None: super().__post_init__() @@ -72,8 +50,9 @@ def __init__(self, config: TGIConfig) -> None: super().__init__(config) async def single_client_call(self, prompt: str, **kwargs) -> str: - output = await self.client.text_generation(prompt=prompt, **kwargs) - return output + async with self.semaphore: + output = await self.client.text_generation(prompt=prompt, **kwargs) + return output async def batch_client_call(self, prompt: list, **kwargs) -> list: output = await asyncio.gather(*[self.single_client_call(prompt=p, **kwargs) for p in prompt])