Skip to content

Commit

Permalink
feat: support vllm & sllm inference
Browse files Browse the repository at this point in the history
  • Loading branch information
Chivier committed Nov 12, 2024
1 parent 1500b06 commit 46d117d
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 18 deletions.
23 changes: 13 additions & 10 deletions nerif/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
get_litellm_embedding,
get_litellm_response,
get_ollama_response,
get_sllm_response,
get_vllm_response,
)


Expand Down Expand Up @@ -78,18 +80,19 @@ def chat(self, message: str, append: bool = False, max_tokens: None | int = None
if self.counter is not None:
kwargs["counter"] = self.counter

if self.model.startswith("ollama"):
LOGGER.debug("requested with message: %s", self.messages)
LOGGER.debug("arguments of request: %s", kwargs)
result = get_ollama_response(self.messages, **kwargs)
elif self.model.startswith("openrouter"):
LOGGER.debug("requested with message: %s", self.messages)
LOGGER.debug("arguments of request: %s", kwargs)
LOGGER.debug("requested with message: %s", self.messages)
LOGGER.debug("arguments of request: %s", kwargs)

if self.model in OPENAI_MODEL:
result = get_litellm_response(self.messages, **kwargs)
elif self.model in OPENAI_MODEL:
LOGGER.debug("requested with message: %s", self.messages)
LOGGER.debug("arguments of request: %s", kwargs)
elif self.model.startswith("openrouter"):
result = get_litellm_response(self.messages, **kwargs)
elif self.model.startswith("ollama"):
result = get_ollama_response(self.messages, **kwargs)
elif self.model.startswith("vllm"):
result = get_vllm_response(self.messages, **kwargs)
elif self.model.startswith("sllm"):
result = get_sllm_response(self.messages, **kwargs)
else:
raise ValueError(f"Model {self.model} not supported")

Expand Down
144 changes: 136 additions & 8 deletions nerif/agent/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, List, Optional, Union

import litellm
from openai import OpenAI

from .token_counter import NerifTokenCounter

Expand Down Expand Up @@ -105,7 +106,7 @@ def get_litellm_response(
messages: List[Any],
model: str = NERIF_DEFAULT_LLM_MODEL,
temperature: float = 0,
max_tokens: int = 300,
max_tokens: int | None = None,
stream: bool = False,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
Expand All @@ -128,7 +129,7 @@ def get_litellm_response(
Returns:
- list: A list of generated text responses if messages is a list, otherwise a single text response.
"""
print(model)

if model in OPENAI_MODEL:
if api_key is None or api_key == "":
api_key = OPENAI_API_KEY
Expand All @@ -150,6 +151,16 @@ def get_litellm_response(
"model": model,
"messages": messages,
}
elif model.startswith("vllm"):
kargs = {
"model": model,
"messages": messages,
}
elif model.startswith("sllm"):
kargs = {
"model": model,
"messages": messages,
}
else:
raise ValueError(f"Model {model} not supported")

Expand All @@ -171,20 +182,20 @@ def get_litellm_response(


def get_ollama_response(
prompt: Union[str, List[str]],
url: str = "http://localhost:11434/v1/",
messages: List[Any],
url: str = OLLAMA_URL,
model: str = "llama3.1",
max_tokens: int = 300,
max_tokens: int | None = None,
temperature: float = 0,
stream: bool = False,
api_key: Optional[str] = "ollama",
api_key: Optional[str] = None,
counter: Optional[NerifTokenCounter] = None,
) -> Union[str, List[str]]:
"""
Get a text response from an Ollama model.
Parameters:
- prompt (str or list): The input prompt(s) for the model.
- messages (str or list): The input messages for the model.
- url (str): The URL of the Ollama API. Default is "http://localhost:11434/api/generate".
- model (str): The name of the Ollama model. Default is "llama3.1".
- max_tokens (int): The maximum number of tokens to generate in the response. Default is 300.
Expand All @@ -198,9 +209,12 @@ def get_ollama_response(
"""

# todo: support batch ollama inference
if url is None or url == "":
# default ollama url
url = "http://localhost:11434/v1/"

response = get_litellm_response(
prompt,
messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
Expand All @@ -214,3 +228,117 @@ def get_ollama_response(
counter.count_from_response(response)

return response


def get_vllm_response(
messages: List[Any],
url: str = VLLM_URL,
model: str = "llama3.1",
max_tokens: int | None = None,
temperature: float = 0,
stream: bool = False,
api_key: Optional[str] = None,
counter: Optional[NerifTokenCounter] = None,
) -> Union[str, List[str]]:
"""
Get a text response from a vLLM model.
Parameters:
- messages (str or list): The input messages for the model.
- url (str): The URL of the vLLM API. Default is "http://localhost:8000/v1".
- model (str): The name of the vLLM model. Default is "llama3.1".
- max_tokens (int): The maximum number of tokens to generate in the response. Default is 300.
- temperature (float): The temperature setting for response generation. Default is 0.
- stream (bool): Whether to stream the response. Default is False.
- api_key (str): The API key for accessing the vLLM API. Default is None.
- batch_size (int): The number of predictions to make in a single request. Default is 1.
Returns:
- str or list: The generated text response(s).
"""

# todo: support batch ollama inference
if url is None or url == "":
# default vllm url
url = "http://localhost:8000/v1"
if api_key is None or api_key == "":
# default vllm api key from vllm document example
api_key = "token-abc123"

model = "/".join(model.split("/")[1:])

client = OpenAI(
base_url=url,
api_key=api_key,
)

response = client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
stream=stream,
)

if counter is not None:
counter.set_parser_based_on_model(model)
counter.count_from_response(response)

return response


def get_sllm_response(
messages: List[Any],
url: str = SLLM_URL,
model: str = "llama3.1",
max_tokens: int | None = None,
temperature: float = 0,
stream: bool = False,
api_key: Optional[str] = None,
counter: Optional[NerifTokenCounter] = None,
) -> Union[str, List[str]]:
"""
Get a text response from an Ollama model.
Parameters:
- messages (str or list): The input messages for the model.
- url (str): The URL of the SLLM API. Default is "http://localhost:8000/v1".
- model (str): The name of the SLLM model. Default is "llama3.1".
- max_tokens (int): The maximum number of tokens to generate in the response. Default is 300.
- temperature (float): The temperature setting for response generation. Default is 0.
- stream (bool): Whether to stream the response. Default is False.
- api_key (str): The API key for accessing the SLLM API. Default is None.
- batch_size (int): The number of predictions to make in a single request. Default is 1.
Returns:
- str or list: The generated text response(s).
"""

# todo: support batch ollama inference
if url is None or url == "":
# default vllm url
url = "http://localhost:8343/v1"
if api_key is None or api_key == "":
# default vllm api key from vllm document example
api_key = "token-abc123"

model = "/".join(model.split("/")[1:])

client = OpenAI(
base_url=url,
api_key=api_key,
)

response = client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
stream=stream,
)

if counter is not None:
counter.set_parser_based_on_model(model)
counter.count_from_response(response)

return response

0 comments on commit 46d117d

Please sign in to comment.