Skip to content

Commit

Permalink
feat: add new RateLimitExceededError (#2277)
Browse files Browse the repository at this point in the history
Co-authored-by: Caren Thomas <[email protected]>
  • Loading branch information
carenthomas and Caren Thomas authored Dec 19, 2024
1 parent cef408a commit e5f230e
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 48 deletions.
22 changes: 17 additions & 5 deletions letta/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
REQ_HEARTBEAT_MESSAGE,
STRUCTURED_OUTPUT_MODELS,
)
from letta.errors import LLMError
from letta.errors import ContextWindowExceededError
from letta.helpers import ToolRulesSolver
from letta.interface import AgentInterface
from letta.llm_api.helpers import is_context_overflow_error
Expand Down Expand Up @@ -1094,6 +1094,7 @@ def inner_step(

# If we got a context alert, try trimming the messages length, then try again
if is_context_overflow_error(e):
printd(f"context window exceeded with limit {self.agent_state.llm_config.context_window}, running summarizer to trim messages")
# A separate API call to run a summarizer
self.summarize_messages_inplace()

Expand Down Expand Up @@ -1169,8 +1170,13 @@ def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True,

# If at this point there's nothing to summarize, throw an error
if len(candidate_messages_to_summarize) == 0:
raise LLMError(
f"Summarize error: tried to run summarize, but couldn't find enough messages to compress [len={len(self.messages)}, preserve_N={MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST}]"
raise ContextWindowExceededError(
"Not enough messages to compress for summarization",
details={
"num_candidate_messages": len(candidate_messages_to_summarize),
"num_total_messages": len(self.messages),
"preserve_N": MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST,
},
)

# Walk down the message buffer (front-to-back) until we hit the target token count
Expand Down Expand Up @@ -1204,8 +1210,13 @@ def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True,
message_sequence_to_summarize = self._messages[1:cutoff] # do NOT get rid of the system message
if len(message_sequence_to_summarize) <= 1:
# This prevents a potential infinite loop of summarizing the same message over and over
raise LLMError(
f"Summarize error: tried to run summarize, but couldn't find enough messages to compress [len={len(message_sequence_to_summarize)} <= 1]"
raise ContextWindowExceededError(
"Not enough messages to compress for summarization after determining cutoff",
details={
"num_candidate_messages": len(message_sequence_to_summarize),
"num_total_messages": len(self.messages),
"preserve_N": MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST,
},
)
else:
printd(f"Attempting to summarize {len(message_sequence_to_summarize)} messages [1:{cutoff}] of {len(self._messages)}")
Expand All @@ -1218,6 +1229,7 @@ def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True,
self.agent_state.llm_config.context_window = (
LLM_MAX_TOKENS[self.model] if (self.model is not None and self.model in LLM_MAX_TOKENS) else LLM_MAX_TOKENS["DEFAULT"]
)

summary = summarize_messages(agent_state=self.agent_state, message_sequence_to_summarize=message_sequence_to_summarize)
printd(f"Got summary: {summary}")

Expand Down
73 changes: 54 additions & 19 deletions letta/errors.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,63 @@
import json
from enum import Enum
from typing import TYPE_CHECKING, List, Optional, Union

# Avoid circular imports
if TYPE_CHECKING:
from letta.schemas.message import Message


class ErrorCode(Enum):
"""Enum for error codes used by client."""

INTERNAL_SERVER_ERROR = "INTERNAL_SERVER_ERROR"
CONTEXT_WINDOW_EXCEEDED = "CONTEXT_WINDOW_EXCEEDED"
RATE_LIMIT_EXCEEDED = "RATE_LIMIT_EXCEEDED"


class LettaError(Exception):
"""Base class for all Letta related errors."""

def __init__(self, message: str, code: Optional[ErrorCode] = None, details: dict = {}):
self.message = message
self.code = code
self.details = details
super().__init__(message)

def __str__(self) -> str:
if self.code:
return f"{self.code.value}: {self.message}"
return self.message

def __repr__(self) -> str:
return f"{self.__class__.__name__}(message='{self.message}', code='{self.code}', details={self.details})"


class LettaToolCreateError(LettaError):
"""Error raised when a tool cannot be created."""

default_error_message = "Error creating tool."

def __init__(self, message=None):
if message is None:
message = self.default_error_message
self.message = message
super().__init__(self.message)
super().__init__(message=message or self.default_error_message)


class LettaConfigurationError(LettaError):
"""Error raised when there are configuration-related issues."""

def __init__(self, message: str, missing_fields: Optional[List[str]] = None):
self.missing_fields = missing_fields or []
super().__init__(message)
super().__init__(message=message, details={"missing_fields": self.missing_fields})


class LettaAgentNotFoundError(LettaError):
"""Error raised when an agent is not found."""

def __init__(self, message: str):
self.message = message
super().__init__(self.message)
pass


class LettaUserNotFoundError(LettaError):
"""Error raised when a user is not found."""

def __init__(self, message: str):
self.message = message
super().__init__(self.message)
pass


class LLMError(LettaError):
Expand All @@ -54,24 +68,45 @@ class LLMJSONParsingError(LettaError):
"""Exception raised for errors in the JSON parsing process."""

def __init__(self, message="Error parsing JSON generated by LLM"):
self.message = message
super().__init__(self.message)
super().__init__(message=message)


class LocalLLMError(LettaError):
"""Generic catch-all error for local LLM problems"""

def __init__(self, message="Encountered an error while running local LLM"):
self.message = message
super().__init__(self.message)
super().__init__(message=message)


class LocalLLMConnectionError(LettaError):
"""Error for when local LLM cannot be reached with provided IP/port"""

def __init__(self, message="Could not connect to local LLM"):
self.message = message
super().__init__(self.message)
super().__init__(message=message)


class ContextWindowExceededError(LettaError):
"""Error raised when the context window is exceeded but further summarization fails."""

def __init__(self, message: str, details: dict = {}):
error_message = f"{message} ({details})"
super().__init__(
message=error_message,
code=ErrorCode.CONTEXT_WINDOW_EXCEEDED,
details=details,
)


class RateLimitExceededError(LettaError):
"""Error raised when the llm rate limiter throttles api requests."""

def __init__(self, message: str, max_retries: int):
error_message = f"{message} ({max_retries})"
super().__init__(
message=error_message,
code=ErrorCode.RATE_LIMIT_EXCEEDED,
details={"max_retries": max_retries},
)


class LettaMessageError(LettaError):
Expand Down
4 changes: 2 additions & 2 deletions letta/llm_api/llm_api_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import requests

from letta.constants import CLI_WARNING_PREFIX
from letta.errors import LettaConfigurationError
from letta.errors import LettaConfigurationError, RateLimitExceededError
from letta.llm_api.anthropic import anthropic_chat_completions_request
from letta.llm_api.azure_openai import azure_openai_chat_completions_request
from letta.llm_api.google_ai import (
Expand Down Expand Up @@ -80,7 +80,7 @@ def wrapper(*args, **kwargs):

# Check if max retries has been reached
if num_retries > max_retries:
raise Exception(f"Maximum number of retries ({max_retries}) exceeded.")
raise RateLimitExceededError("Maximum number of retries exceeded", max_retries=max_retries)

# Increment the delay
delay *= exponential_base * (1 + jitter * random.random())
Expand Down
45 changes: 23 additions & 22 deletions letta/server/rest_api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from fastapi import Header
from pydantic import BaseModel

from letta.errors import ContextWindowExceededError, RateLimitExceededError
from letta.schemas.usage import LettaUsageStatistics
from letta.server.rest_api.interface import StreamingServerInterface
from letta.server.server import SyncServer
Expand Down Expand Up @@ -61,34 +62,21 @@ async def sse_async_generator(
if not isinstance(usage, LettaUsageStatistics):
raise ValueError(f"Expected LettaUsageStatistics, got {type(usage)}")
yield sse_formatter({"usage": usage.model_dump()})
except Exception as e:
import traceback

traceback.print_exc()
warnings.warn(f"SSE stream generator failed: {e}")

# Log the error, since the exception handler upstack (in FastAPI) won't catch it, because this may be a 200 response
# Print the stack trace
if (os.getenv("SENTRY_DSN") is not None) and (os.getenv("SENTRY_DSN") != ""):
import sentry_sdk
except ContextWindowExceededError as e:
log_error_to_sentry(e)
yield sse_formatter({"error": f"Stream failed: {e}", "code": str(e.code.value) if e.code else None})

sentry_sdk.capture_exception(e)
except RateLimitExceededError as e:
log_error_to_sentry(e)
yield sse_formatter({"error": f"Stream failed: {e}", "code": str(e.code.value) if e.code else None})

except Exception as e:
log_error_to_sentry(e)
yield sse_formatter({"error": f"Stream failed (internal error occured)"})

except Exception as e:
import traceback

traceback.print_exc()
warnings.warn(f"SSE stream generator failed: {e}")

# Log the error, since the exception handler upstack (in FastAPI) won't catch it, because this may be a 200 response
# Print the stack trace
if (os.getenv("SENTRY_DSN") is not None) and (os.getenv("SENTRY_DSN") != ""):
import sentry_sdk

sentry_sdk.capture_exception(e)

log_error_to_sentry(e)
yield sse_formatter({"error": "Stream failed (decoder encountered an error)"})

finally:
Expand All @@ -113,3 +101,16 @@ def get_user_id(user_id: Optional[str] = Header(None, alias="user_id")) -> Optio

def get_current_interface() -> StreamingServerInterface:
return StreamingServerInterface

def log_error_to_sentry(e):
import traceback

traceback.print_exc()
warnings.warn(f"SSE stream generator failed: {e}")

# Log the error, since the exception handler upstack (in FastAPI) won't catch it, because this may be a 200 response
# Print the stack trace
if (os.getenv("SENTRY_DSN") is not None) and (os.getenv("SENTRY_DSN") != ""):
import sentry_sdk

sentry_sdk.capture_exception(e)

0 comments on commit e5f230e

Please sign in to comment.