Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
khangich committed Nov 18, 2024
1 parent 82058d2 commit f4a964c
Showing 1 changed file with 53 additions and 99 deletions.
152 changes: 53 additions & 99 deletions letta/llm_api/llm_api_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,63 +101,55 @@ def wrapper(*args, **kwargs):

@retry_with_exponential_backoff
def create(
# agent_state: AgentState,
llm_config: LLMConfig,
messages: List[Message],
user_id: Optional[str] = None, # option UUID to associate request with
user_id: Optional[str] = None,
functions: Optional[list] = None,
functions_python: Optional[dict] = None,
function_call: str = "auto",
# hint
first_message: bool = False,
# use tool naming?
# if false, will use deprecated 'functions' style
use_tool_naming: bool = True,
# streaming?
stream: bool = False,
stream_interface: Optional[Union[AgentRefreshStreamingInterface, AgentChunkStreamingInterface]] = None,
max_tokens: Optional[int] = None,
model_settings: Optional[dict] = None, # TODO: eventually pass from server
model_settings: Optional[dict] = None,
) -> ChatCompletionResponse:
"""Return response to chat completion with backoff"""
"""Return response to chat completion with backoff."""
from letta.utils import printd

if not model_settings:
from letta.settings import model_settings

model_settings = model_settings

model_settings = model_settings
printd(f"Using model {llm_config.model_endpoint_type}, endpoint: {llm_config.model_endpoint}")

if function_call and not functions:
printd("unsetting function_call because functions is None")
printd("Unsetting function_call because functions is None")
function_call = None

# openai
if llm_config.model_endpoint_type == "openai":
def handle_openai():
if model_settings.openai_api_key is None and llm_config.model_endpoint == "https://api.openai.com/v1":
# only is a problem if we are *not* using an openai proxy
raise ValueError(f"OpenAI key is missing from letta config file")

data = build_openai_chat_completions_request(llm_config, messages, user_id, functions, function_call, use_tool_naming, max_tokens)
if stream: # Client requested token streaming
data.stream = True
assert isinstance(stream_interface, AgentChunkStreamingInterface) or isinstance(
stream_interface, AgentRefreshStreamingInterface
), type(stream_interface)
raise ValueError("OpenAI key is missing from letta config file")

data = build_openai_chat_completions_request(
llm_config, messages, user_id, functions, function_call, use_tool_naming, max_tokens
)
data.stream = stream

if stream:
assert isinstance(stream_interface, (AgentChunkStreamingInterface, AgentRefreshStreamingInterface))
response = openai_chat_completions_process_stream(
url=llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions
url=llm_config.model_endpoint,
api_key=model_settings.openai_api_key,
chat_completion_request=data,
stream_interface=stream_interface,
)
else: # Client did not request token streaming (expect a blocking backend response)
data.stream = False
else:
if isinstance(stream_interface, AgentChunkStreamingInterface):
stream_interface.stream_start()
try:
response = openai_chat_completions_request(
url=llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions
url=llm_config.model_endpoint,
api_key=model_settings.openai_api_key,
chat_completion_request=data,
)
Expand All @@ -170,22 +162,13 @@ def create(

return response

# azure
elif llm_config.model_endpoint_type == "azure":
def handle_azure():
if stream:
raise NotImplementedError(f"Streaming not yet implemented for {llm_config.model_endpoint_type}")
raise NotImplementedError("Streaming not yet implemented for Azure")

if model_settings.azure_api_key is None:
raise ValueError(f"Azure API key is missing. Did you set AZURE_API_KEY in your env?")
if not all([model_settings.azure_api_key, model_settings.azure_base_url, model_settings.azure_api_version]):
raise ValueError("Azure API key, base URL, or version is missing. Check your environment variables.")

if model_settings.azure_base_url is None:
raise ValueError(f"Azure base url is missing. Did you set AZURE_BASE_URL in your env?")

if model_settings.azure_api_version is None:
raise ValueError(f"Azure API version is missing. Did you set AZURE_API_VERSION in your env?")

# Set the llm config model_endpoint from model_settings
# For Azure, this model_endpoint is required to be configured via env variable, so users don't need to provide it in the LLM config
llm_config.model_endpoint = model_settings.azure_base_url
chat_completion_request = build_openai_chat_completions_request(
llm_config, messages, user_id, functions, function_call, use_tool_naming, max_tokens
Expand All @@ -203,34 +186,28 @@ def create(

return response

elif llm_config.model_endpoint_type == "google_ai":
def handle_google_ai():
if stream:
raise NotImplementedError(f"Streaming not yet implemented for {llm_config.model_endpoint_type}")
raise NotImplementedError("Streaming not yet implemented for Google AI")
if not use_tool_naming:
raise NotImplementedError("Only tool calling supported on Google AI API requests")

if functions is not None:
tools = [{"type": "function", "function": f} for f in functions]
tools = [Tool(**t) for t in tools]
tools = convert_tools_to_google_ai_format(tools, inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs)
else:
tools = None
tools = convert_tools_to_google_ai_format(
[{"type": "function", "function": f} for f in functions] if functions else None,
inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs
)

return google_ai_chat_completions_request(
base_url=llm_config.model_endpoint,
model=llm_config.model,
api_key=model_settings.gemini_api_key,
# see structure of payload here: https://ai.google.dev/docs/function_calling
data=dict(
contents=[m.to_google_ai_dict() for m in messages],
tools=tools,
),
data=dict(contents=[m.to_google_ai_dict() for m in messages], tools=tools),
inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs,
)

elif llm_config.model_endpoint_type == "anthropic":
def handle_anthropic():
if stream:
raise NotImplementedError(f"Streaming not yet implemented for {llm_config.model_endpoint_type}")
raise NotImplementedError("Streaming not yet implemented for Anthropic")
if not use_tool_naming:
raise NotImplementedError("Only tool calling supported on Anthropic API requests")

Expand All @@ -241,56 +218,25 @@ def create(
model=llm_config.model,
messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages],
tools=[{"type": "function", "function": f} for f in functions] if functions else None,
# tool_choice=function_call,
# user=str(user_id),
# NOTE: max_tokens is required for Anthropic API
max_tokens=1024, # TODO make dynamic
max_tokens=1024,
),
)

# elif llm_config.model_endpoint_type == "cohere":
# if stream:
# raise NotImplementedError(f"Streaming not yet implemented for {llm_config.model_endpoint_type}")
# if not use_tool_naming:
# raise NotImplementedError("Only tool calling supported on Cohere API requests")
#
# if functions is not None:
# tools = [{"type": "function", "function": f} for f in functions]
# tools = [Tool(**t) for t in tools]
# else:
# tools = None
#
# return cohere_chat_completions_request(
# # url=llm_config.model_endpoint,
# url="https://api.cohere.ai/v1", # TODO
# api_key=os.getenv("COHERE_API_KEY"), # TODO remove
# chat_completion_request=ChatCompletionRequest(
# model="command-r-plus", # TODO
# messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages],
# tools=tools,
# tool_choice=function_call,
# # user=str(user_id),
# # NOTE: max_tokens is required for Anthropic API
# # max_tokens=1024, # TODO make dynamic
# ),
# )

elif llm_config.model_endpoint_type == "groq":
def handle_groq():
if stream:
raise NotImplementedError(f"Streaming not yet implemented for Groq.")
raise NotImplementedError("Streaming not yet implemented for Groq")

if model_settings.groq_api_key is None and llm_config.model_endpoint == "https://api.groq.com/openai/v1/chat/completions":
raise ValueError(f"Groq key is missing from letta config file")
raise ValueError("Groq key is missing from letta config file")

# force to true for groq, since they don't support 'content' is non-null
if llm_config.put_inner_thoughts_in_kwargs:
functions = add_inner_thoughts_to_functions(
functions=functions,
inner_thoughts_key=INNER_THOUGHTS_KWARG,
inner_thoughts_description=INNER_THOUGHTS_KWARG_DESCRIPTION,
)

tools = [{"type": "function", "function": f} for f in functions] if functions is not None else None
tools = [{"type": "function", "function": f} for f in functions] if functions else None
data = ChatCompletionRequest(
model=llm_config.model,
messages=[m.to_openai_dict(put_inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs) for m in messages],
Expand All @@ -299,19 +245,15 @@ def create(
user=str(user_id),
)

# https://console.groq.com/docs/openai
# "The following fields are currently not supported and will result in a 400 error (yikes) if they are supplied:"
assert data.top_logprobs is None
assert data.logit_bias is None
assert data.logprobs == False
assert data.n == 1
# They mention that none of the messages can have names, but it seems to not error out (for now)

data.stream = False
if isinstance(stream_interface, AgentChunkStreamingInterface):
stream_interface.stream_start()
try:
# groq uses the openai chat completions API, so this component should be reusable
response = openai_chat_completions_request(
url=llm_config.model_endpoint,
api_key=model_settings.groq_api_key,
Expand All @@ -326,10 +268,9 @@ def create(

return response

# local model
else:
def handle_local():
if stream:
raise NotImplementedError(f"Streaming not yet implemented for {llm_config.model_endpoint_type}")
raise NotImplementedError("Streaming not yet implemented for local models")
return get_chat_completion(
model=llm_config.model,
messages=messages,
Expand All @@ -341,9 +282,22 @@ def create(
endpoint_type=llm_config.model_endpoint_type,
wrapper=llm_config.model_wrapper,
user=str(user_id),
# hint
first_message=first_message,
# auth-related
auth_type=model_settings.openllm_auth_type,
auth_key=model_settings.openllm_api_key,
)

handlers = {
"openai": handle_openai,
"azure": handle_azure,
"google_ai": handle_google_ai,
"anthropic": handle_anthropic,
"groq": handle_groq,
"local": handle_local,
}

handler = handlers.get(llm_config.model_endpoint_type)
if handler:
return handler()
else:
raise NotImplementedError(f"Model endpoint type '{llm_config.model_endpoint_type}' is not supported.")

0 comments on commit f4a964c

Please sign in to comment.