From 46d117d22420dc35bbcc6968f6027f060e4195af Mon Sep 17 00:00:00 2001 From: Chivier Humber Date: Tue, 12 Nov 2024 04:49:17 +0000 Subject: [PATCH] feat: support vllm & sllm inference --- nerif/agent/agent.py | 23 ++++--- nerif/agent/utils.py | 144 ++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 149 insertions(+), 18 deletions(-) diff --git a/nerif/agent/agent.py b/nerif/agent/agent.py index 2b3162a..146bada 100644 --- a/nerif/agent/agent.py +++ b/nerif/agent/agent.py @@ -10,6 +10,8 @@ get_litellm_embedding, get_litellm_response, get_ollama_response, + get_sllm_response, + get_vllm_response, ) @@ -78,18 +80,19 @@ def chat(self, message: str, append: bool = False, max_tokens: None | int = None if self.counter is not None: kwargs["counter"] = self.counter - if self.model.startswith("ollama"): - LOGGER.debug("requested with message: %s", self.messages) - LOGGER.debug("arguments of request: %s", kwargs) - result = get_ollama_response(self.messages, **kwargs) - elif self.model.startswith("openrouter"): - LOGGER.debug("requested with message: %s", self.messages) - LOGGER.debug("arguments of request: %s", kwargs) + LOGGER.debug("requested with message: %s", self.messages) + LOGGER.debug("arguments of request: %s", kwargs) + + if self.model in OPENAI_MODEL: result = get_litellm_response(self.messages, **kwargs) - elif self.model in OPENAI_MODEL: - LOGGER.debug("requested with message: %s", self.messages) - LOGGER.debug("arguments of request: %s", kwargs) + elif self.model.startswith("openrouter"): result = get_litellm_response(self.messages, **kwargs) + elif self.model.startswith("ollama"): + result = get_ollama_response(self.messages, **kwargs) + elif self.model.startswith("vllm"): + result = get_vllm_response(self.messages, **kwargs) + elif self.model.startswith("sllm"): + result = get_sllm_response(self.messages, **kwargs) else: raise ValueError(f"Model {self.model} not supported") diff --git a/nerif/agent/utils.py b/nerif/agent/utils.py index d788909..3498104 100644 --- a/nerif/agent/utils.py +++ b/nerif/agent/utils.py @@ -4,6 +4,7 @@ from typing import Any, List, Optional, Union import litellm +from openai import OpenAI from .token_counter import NerifTokenCounter @@ -105,7 +106,7 @@ def get_litellm_response( messages: List[Any], model: str = NERIF_DEFAULT_LLM_MODEL, temperature: float = 0, - max_tokens: int = 300, + max_tokens: int | None = None, stream: bool = False, api_key: Optional[str] = None, base_url: Optional[str] = None, @@ -128,7 +129,7 @@ def get_litellm_response( Returns: - list: A list of generated text responses if messages is a list, otherwise a single text response. """ - print(model) + if model in OPENAI_MODEL: if api_key is None or api_key == "": api_key = OPENAI_API_KEY @@ -150,6 +151,16 @@ def get_litellm_response( "model": model, "messages": messages, } + elif model.startswith("vllm"): + kargs = { + "model": model, + "messages": messages, + } + elif model.startswith("sllm"): + kargs = { + "model": model, + "messages": messages, + } else: raise ValueError(f"Model {model} not supported") @@ -171,20 +182,20 @@ def get_litellm_response( def get_ollama_response( - prompt: Union[str, List[str]], - url: str = "http://localhost:11434/v1/", + messages: List[Any], + url: str = OLLAMA_URL, model: str = "llama3.1", - max_tokens: int = 300, + max_tokens: int | None = None, temperature: float = 0, stream: bool = False, - api_key: Optional[str] = "ollama", + api_key: Optional[str] = None, counter: Optional[NerifTokenCounter] = None, ) -> Union[str, List[str]]: """ Get a text response from an Ollama model. Parameters: - - prompt (str or list): The input prompt(s) for the model. + - messages (str or list): The input messages for the model. - url (str): The URL of the Ollama API. Default is "http://localhost:11434/api/generate". - model (str): The name of the Ollama model. Default is "llama3.1". - max_tokens (int): The maximum number of tokens to generate in the response. Default is 300. @@ -198,9 +209,12 @@ def get_ollama_response( """ # todo: support batch ollama inference + if url is None or url == "": + # default ollama url + url = "http://localhost:11434/v1/" response = get_litellm_response( - prompt, + messages, model=model, temperature=temperature, max_tokens=max_tokens, @@ -214,3 +228,117 @@ def get_ollama_response( counter.count_from_response(response) return response + + +def get_vllm_response( + messages: List[Any], + url: str = VLLM_URL, + model: str = "llama3.1", + max_tokens: int | None = None, + temperature: float = 0, + stream: bool = False, + api_key: Optional[str] = None, + counter: Optional[NerifTokenCounter] = None, +) -> Union[str, List[str]]: + """ + Get a text response from a vLLM model. + + Parameters: + - messages (str or list): The input messages for the model. + - url (str): The URL of the vLLM API. Default is "http://localhost:8000/v1". + - model (str): The name of the vLLM model. Default is "llama3.1". + - max_tokens (int): The maximum number of tokens to generate in the response. Default is 300. + - temperature (float): The temperature setting for response generation. Default is 0. + - stream (bool): Whether to stream the response. Default is False. + - api_key (str): The API key for accessing the vLLM API. Default is None. + - batch_size (int): The number of predictions to make in a single request. Default is 1. + + Returns: + - str or list: The generated text response(s). + """ + + # todo: support batch ollama inference + if url is None or url == "": + # default vllm url + url = "http://localhost:8000/v1" + if api_key is None or api_key == "": + # default vllm api key from vllm document example + api_key = "token-abc123" + + model = "/".join(model.split("/")[1:]) + + client = OpenAI( + base_url=url, + api_key=api_key, + ) + + response = client.chat.completions.create( + model=model, + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + stream=stream, + ) + + if counter is not None: + counter.set_parser_based_on_model(model) + counter.count_from_response(response) + + return response + + +def get_sllm_response( + messages: List[Any], + url: str = SLLM_URL, + model: str = "llama3.1", + max_tokens: int | None = None, + temperature: float = 0, + stream: bool = False, + api_key: Optional[str] = None, + counter: Optional[NerifTokenCounter] = None, +) -> Union[str, List[str]]: + """ + Get a text response from an Ollama model. + + Parameters: + - messages (str or list): The input messages for the model. + - url (str): The URL of the SLLM API. Default is "http://localhost:8000/v1". + - model (str): The name of the SLLM model. Default is "llama3.1". + - max_tokens (int): The maximum number of tokens to generate in the response. Default is 300. + - temperature (float): The temperature setting for response generation. Default is 0. + - stream (bool): Whether to stream the response. Default is False. + - api_key (str): The API key for accessing the SLLM API. Default is None. + - batch_size (int): The number of predictions to make in a single request. Default is 1. + + Returns: + - str or list: The generated text response(s). + """ + + # todo: support batch ollama inference + if url is None or url == "": + # default vllm url + url = "http://localhost:8343/v1" + if api_key is None or api_key == "": + # default vllm api key from vllm document example + api_key = "token-abc123" + + model = "/".join(model.split("/")[1:]) + + client = OpenAI( + base_url=url, + api_key=api_key, + ) + + response = client.chat.completions.create( + model=model, + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + stream=stream, + ) + + if counter is not None: + counter.set_parser_based_on_model(model) + counter.count_from_response(response) + + return response