diff --git a/src/agentlab/llm/chat_api.py b/src/agentlab/llm/chat_api.py index 7392e666..096aae00 100644 --- a/src/agentlab/llm/chat_api.py +++ b/src/agentlab/llm/chat_api.py @@ -143,6 +143,13 @@ def make_model(self): max_new_tokens=self.max_new_tokens, n_retry_server=self.n_retry_server, ) + elif self.backend == "vllm": + return VLLMChatModel( + model_name=self.model_name, + temperature=self.temperature, + max_tokens=self.max_new_tokens, + n_retry_server=self.n_retry_server, + ) else: raise ValueError(f"Backend {self.backend} is not supported") @@ -423,3 +430,27 @@ def __init__( client = InferenceClient(model=model_url, token=token) self.llm = partial(client.text_generation, max_new_tokens=max_new_tokens) + + +class VLLMChatModel(ChatModel): + def __init__( + self, + model_name, + api_key=None, + temperature=0.5, + max_tokens=100, + n_retry_server=4, + min_retry_wait_time=60, + ): + super().__init__( + model_name=model_name, + api_key=api_key, + temperature=temperature, + max_tokens=max_tokens, + max_retry=n_retry_server, + min_retry_wait_time=min_retry_wait_time, + api_key_env_var="VLLM_API_KEY", + client_class=OpenAI, + client_args={"base_url": "http://0.0.0.0:8000/v1"}, + pricing_func=None, + )