Skip to content

Commit

Permalink
feat: support togetherAI via /completions (#2045)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpacker authored Nov 18, 2024
1 parent 004cd69 commit 05045de
Show file tree
Hide file tree
Showing 14 changed files with 364 additions and 6 deletions.
104 changes: 104 additions & 0 deletions .github/workflows/test_together.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
name: Together Llama 3.1 70b Capabilities Test

env:
TOGETHER_API_KEY: ${{ secrets.TOGETHER_API_KEY }}

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]

jobs:
test:
runs-on: ubuntu-latest
timeout-minutes: 15
steps:
- name: Checkout
uses: actions/checkout@v4

- name: "Setup Python, Poetry and Dependencies"
uses: packetcoders/action-setup-cache-python-poetry@main
with:
python-version: "3.12"
poetry-version: "1.8.2"
install-args: "-E dev -E external-tools"

- name: Test first message contains expected function call and inner monologue
id: test_first_message
env:
TOGETHER_API_KEY: ${{ secrets.TOGETHER_API_KEY }}
run: |
poetry run pytest -s -vv tests/test_model_letta_perfomance.py::test_together_llama_3_70b_returns_valid_first_message
echo "TEST_FIRST_MESSAGE_EXIT_CODE=$?" >> $GITHUB_ENV
continue-on-error: true

- name: Test model sends message with keyword
id: test_keyword_message
env:
TOGETHER_API_KEY: ${{ secrets.TOGETHER_API_KEY }}
run: |
poetry run pytest -s -vv tests/test_model_letta_perfomance.py::test_together_llama_3_70b_returns_keyword
echo "TEST_KEYWORD_MESSAGE_EXIT_CODE=$?" >> $GITHUB_ENV
continue-on-error: true

- name: Test model uses external tool correctly
id: test_external_tool
env:
TOGETHER_API_KEY: ${{ secrets.TOGETHER_API_KEY }}
run: |
poetry run pytest -s -vv tests/test_model_letta_perfomance.py::test_together_llama_3_70b_uses_external_tool
echo "TEST_EXTERNAL_TOOL_EXIT_CODE=$?" >> $GITHUB_ENV
continue-on-error: true

- name: Test model recalls chat memory
id: test_chat_memory
env:
TOGETHER_API_KEY: ${{ secrets.TOGETHER_API_KEY }}
run: |
poetry run pytest -s -vv tests/test_model_letta_perfomance.py::test_together_llama_3_70b_recall_chat_memory
echo "TEST_CHAT_MEMORY_EXIT_CODE=$?" >> $GITHUB_ENV
continue-on-error: true

- name: Test model uses 'archival_memory_search' to find secret
id: test_archival_memory
env:
TOGETHER_API_KEY: ${{ secrets.TOGETHER_API_KEY }}
run: |
poetry run pytest -s -vv tests/test_model_letta_perfomance.py::test_together_llama_3_70b_archival_memory_retrieval
echo "TEST_ARCHIVAL_MEMORY_EXIT_CODE=$?" >> $GITHUB_ENV
continue-on-error: true

- name: Test model can edit core memories
id: test_core_memory
env:
TOGETHER_API_KEY: ${{ secrets.TOGETHER_API_KEY }}
run: |
poetry run pytest -s -vv tests/test_model_letta_perfomance.py::test_together_llama_3_70b_edit_core_memory
echo "TEST_CORE_MEMORY_EXIT_CODE=$?" >> $GITHUB_ENV
continue-on-error: true

- name: Summarize test results
if: always()
run: |
echo "Test Results Summary:"
# If the exit code is empty, treat it as a failure (❌)
echo "Test first message: $([[ -z $TEST_FIRST_MESSAGE_EXIT_CODE || $TEST_FIRST_MESSAGE_EXIT_CODE -ne 0 ]] && echo ❌ || echo ✅)"
echo "Test model sends message with keyword: $([[ -z $TEST_KEYWORD_MESSAGE_EXIT_CODE || $TEST_KEYWORD_MESSAGE_EXIT_CODE -ne 0 ]] && echo ❌ || echo ✅)"
echo "Test model uses external tool: $([[ -z $TEST_EXTERNAL_TOOL_EXIT_CODE || $TEST_EXTERNAL_TOOL_EXIT_CODE -ne 0 ]] && echo ❌ || echo ✅)"
echo "Test model recalls chat memory: $([[ -z $TEST_CHAT_MEMORY_EXIT_CODE || $TEST_CHAT_MEMORY_EXIT_CODE -ne 0 ]] && echo ❌ || echo ✅)"
echo "Test model uses 'archival_memory_search' to find secret: $([[ -z $TEST_ARCHIVAL_MEMORY_EXIT_CODE || $TEST_ARCHIVAL_MEMORY_EXIT_CODE -ne 0 ]] && echo ❌ || echo ✅)"
echo "Test model can edit core memories: $([[ -z $TEST_CORE_MEMORY_EXIT_CODE || $TEST_CORE_MEMORY_EXIT_CODE -ne 0 ]] && echo ❌ || echo ✅)"
# Check if any test failed (either non-zero or unset exit code)
if [[ -z $TEST_FIRST_MESSAGE_EXIT_CODE || $TEST_FIRST_MESSAGE_EXIT_CODE -ne 0 || \
-z $TEST_KEYWORD_MESSAGE_EXIT_CODE || $TEST_KEYWORD_MESSAGE_EXIT_CODE -ne 0 || \
-z $TEST_EXTERNAL_TOOL_EXIT_CODE || $TEST_EXTERNAL_TOOL_EXIT_CODE -ne 0 || \
-z $TEST_CHAT_MEMORY_EXIT_CODE || $TEST_CHAT_MEMORY_EXIT_CODE -ne 0 || \
-z $TEST_ARCHIVAL_MEMORY_EXIT_CODE || $TEST_ARCHIVAL_MEMORY_EXIT_CODE -ne 0 || \
-z $TEST_CORE_MEMORY_EXIT_CODE || $TEST_CORE_MEMORY_EXIT_CODE -ne 0 ]]; then
echo "Some tests failed."
exit 78
fi
continue-on-error: true
2 changes: 1 addition & 1 deletion letta/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
TOOL_CALL_ID_MAX_LEN = 29

# minimum context window size
MIN_CONTEXT_WINDOW = 4000
MIN_CONTEXT_WINDOW = 4096

# embeddings
MAX_EMBEDDING_DIM = 4096 # maximum supported embeding size - do NOT change or else DBs will need to be reset
Expand Down
29 changes: 29 additions & 0 deletions letta/llm_api/llm_api_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
cast_message_to_subtype,
)
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
from letta.settings import ModelSettings
from letta.streaming_interface import (
AgentChunkStreamingInterface,
AgentRefreshStreamingInterface,
Expand Down Expand Up @@ -126,6 +127,7 @@ def create(
from letta.settings import model_settings

model_settings = model_settings
assert isinstance(model_settings, ModelSettings)

printd(f"Using model {llm_config.model_endpoint_type}, endpoint: {llm_config.model_endpoint}")

Expand Down Expand Up @@ -326,6 +328,33 @@ def create(

return response

elif llm_config.model_endpoint_type == "together":
"""TogetherAI endpoint that goes via /completions instead of /chat/completions"""

if stream:
raise NotImplementedError(f"Streaming not yet implemented for TogetherAI (via the /completions endpoint).")

if model_settings.together_api_key is None and llm_config.model_endpoint == "https://api.together.ai/v1/completions":
raise ValueError(f"TogetherAI key is missing from letta config file")

return get_chat_completion(
model=llm_config.model,
messages=messages,
functions=functions,
functions_python=functions_python,
function_call=function_call,
context_window=llm_config.context_window,
endpoint=llm_config.model_endpoint,
endpoint_type="vllm", # NOTE: use the vLLM path through /completions
wrapper=llm_config.model_wrapper,
user=str(user_id),
# hint
first_message=first_message,
# auth-related
auth_type="bearer_token", # NOTE: Together expects bearer token auth
auth_key=model_settings.together_api_key,
)

# local model
else:
if stream:
Expand Down
1 change: 0 additions & 1 deletion letta/llm_api/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,6 @@ def openai_chat_completions_request(
tool["function"] = convert_to_structured_output(tool["function"])

response_json = make_post_request(url, headers, data)

return ChatCompletionResponse(**response_json)


Expand Down
1 change: 1 addition & 0 deletions letta/local_llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def num_tokens_from_messages(messages: List[dict], model: str = "gpt-4") -> int:
https://community.openai.com/t/how-to-calculate-the-tokens-when-using-function-call/266573/11
"""
try:
# Attempt to search for the encoding based on the model string
encoding = tiktoken.encoding_for_model(model)
except KeyError:
# print("Warning: model not found. Using cl100k_base encoding.")
Expand Down
144 changes: 141 additions & 3 deletions letta/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from pydantic import BaseModel, Field, model_validator

from letta.constants import LLM_MAX_TOKENS
from letta.constants import LLM_MAX_TOKENS, MIN_CONTEXT_WINDOW
from letta.llm_api.azure_openai import (
get_azure_chat_completions_endpoint,
get_azure_embeddings_endpoint,
Expand Down Expand Up @@ -67,10 +67,15 @@ def list_llm_models(self) -> List[LLMConfig]:
extra_params = {"supported_parameters": "tools"} if "openrouter.ai" in self.base_url else None
response = openai_get_model_list(self.base_url, api_key=self.api_key, extra_params=extra_params)

assert "data" in response, f"OpenAI model query response missing 'data' field: {response}"
# TogetherAI's response is missing the 'data' field
# assert "data" in response, f"OpenAI model query response missing 'data' field: {response}"
if "data" in response:
data = response["data"]
else:
data = response

configs = []
for model in response["data"]:
for model in data:
assert "id" in model, f"OpenAI model missing 'id' field: {model}"
model_name = model["id"]

Expand All @@ -82,6 +87,32 @@ def list_llm_models(self) -> List[LLMConfig]:

if not context_window_size:
continue

# TogetherAI includes the type, which we can use to filter out embedding models
if self.base_url == "https://api.together.ai/v1":
if "type" in model and model["type"] != "chat":
continue

# for TogetherAI, we need to skip the models that don't support JSON mode / function calling
# requests.exceptions.HTTPError: HTTP error occurred: 400 Client Error: Bad Request for url: https://api.together.ai/v1/chat/completions | Status code: 400, Message: {
# "error": {
# "message": "mistralai/Mixtral-8x7B-v0.1 is not supported for JSON mode/function calling",
# "type": "invalid_request_error",
# "param": null,
# "code": "constraints_model"
# }
# }
if "config" not in model:
continue
if "chat_template" not in model["config"]:
continue
if model["config"]["chat_template"] is None:
continue
if "tools" not in model["config"]["chat_template"]:
continue
# if "config" in data and "chat_template" in data["config"] and "tools" not in data["config"]["chat_template"]:
# continue

configs.append(
LLMConfig(model=model_name, model_endpoint_type="openai", model_endpoint=self.base_url, context_window=context_window_size)
)
Expand Down Expand Up @@ -325,6 +356,113 @@ def get_model_context_window_size(self, model_name: str):
raise NotImplementedError


class TogetherProvider(OpenAIProvider):
"""TogetherAI provider that uses the /completions API
TogetherAI can also be used via the /chat/completions API
by settings OPENAI_API_KEY and OPENAI_API_BASE to the TogetherAI API key
and API URL, however /completions is preferred because their /chat/completions
function calling support is limited.
"""

name: str = "together"
base_url: str = "https://api.together.ai/v1"
api_key: str = Field(..., description="API key for the TogetherAI API.")
default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.")

def list_llm_models(self) -> List[LLMConfig]:
from letta.llm_api.openai import openai_get_model_list

response = openai_get_model_list(self.base_url, api_key=self.api_key)

# TogetherAI's response is missing the 'data' field
# assert "data" in response, f"OpenAI model query response missing 'data' field: {response}"
if "data" in response:
data = response["data"]
else:
data = response

configs = []
for model in data:
assert "id" in model, f"TogetherAI model missing 'id' field: {model}"
model_name = model["id"]

if "context_length" in model:
# Context length is returned in OpenRouter as "context_length"
context_window_size = model["context_length"]
else:
context_window_size = self.get_model_context_window_size(model_name)

# We need the context length for embeddings too
if not context_window_size:
continue

# Skip models that are too small for Letta
if context_window_size <= MIN_CONTEXT_WINDOW:
continue

# TogetherAI includes the type, which we can use to filter for embedding models
if "type" in model and model["type"] not in ["chat", "language"]:
continue

configs.append(
LLMConfig(
model=model_name,
model_endpoint_type="together",
model_endpoint=self.base_url,
model_wrapper=self.default_prompt_formatter,
context_window=context_window_size,
)
)

return configs

def list_embedding_models(self) -> List[EmbeddingConfig]:
# TODO renable once we figure out how to pass API keys through properly
return []

# from letta.llm_api.openai import openai_get_model_list

# response = openai_get_model_list(self.base_url, api_key=self.api_key)

# # TogetherAI's response is missing the 'data' field
# # assert "data" in response, f"OpenAI model query response missing 'data' field: {response}"
# if "data" in response:
# data = response["data"]
# else:
# data = response

# configs = []
# for model in data:
# assert "id" in model, f"TogetherAI model missing 'id' field: {model}"
# model_name = model["id"]

# if "context_length" in model:
# # Context length is returned in OpenRouter as "context_length"
# context_window_size = model["context_length"]
# else:
# context_window_size = self.get_model_context_window_size(model_name)

# if not context_window_size:
# continue

# # TogetherAI includes the type, which we can use to filter out embedding models
# if "type" in model and model["type"] not in ["embedding"]:
# continue

# configs.append(
# EmbeddingConfig(
# embedding_model=model_name,
# embedding_endpoint_type="openai",
# embedding_endpoint=self.base_url,
# embedding_dim=context_window_size,
# embedding_chunk_size=300, # TODO: change?
# )
# )

# return configs


class GoogleAIProvider(Provider):
# gemini
api_key: str = Field(..., description="API key for the Google AI API.")
Expand Down
1 change: 1 addition & 0 deletions letta/schemas/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class LLMConfig(BaseModel):
"vllm",
"hugging-face",
"mistral",
"together", # completions endpoint
] = Field(..., description="The endpoint type for the model.")
model_endpoint: Optional[str] = Field(None, description="The endpoint for the model.")
model_wrapper: Optional[str] = Field(None, description="The wrapper for the model.")
Expand Down
1 change: 1 addition & 0 deletions letta/schemas/openai/chat_completion_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class Choice(BaseModel):
index: int
message: Message
logprobs: Optional[Dict[str, Union[List[MessageContentLogProb], None]]] = None
seed: Optional[int] = None # found in TogetherAI


class UsageStatistics(BaseModel):
Expand Down
14 changes: 13 additions & 1 deletion letta/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
OllamaProvider,
OpenAIProvider,
Provider,
TogetherProvider,
VLLMChatCompletionsProvider,
VLLMCompletionsProvider,
)
Expand Down Expand Up @@ -303,7 +304,18 @@ def __init__(
)
)
if model_settings.groq_api_key:
self._enabled_providers.append(GroqProvider(api_key=model_settings.groq_api_key))
self._enabled_providers.append(
GroqProvider(
api_key=model_settings.groq_api_key,
)
)
if model_settings.together_api_key:
self._enabled_providers.append(
TogetherProvider(
api_key=model_settings.together_api_key,
default_prompt_formatter=model_settings.default_prompt_formatter,
)
)
if model_settings.vllm_api_base:
# vLLM exposes both a /chat/completions and a /completions endpoint
self._enabled_providers.append(
Expand Down
Loading

0 comments on commit 05045de

Please sign in to comment.