Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: create new usage message type #2294

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions letta/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from letta.schemas.block import BlockUpdate
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message import UsageMessage
from letta.schemas.memory import ContextWindowOverview, Memory
from letta.schemas.message import Message, MessageUpdate
from letta.schemas.openai.chat_completion_request import (
Expand Down Expand Up @@ -232,7 +233,7 @@ class BaseAgent(ABC):
def step(
self,
messages: Union[Message, List[Message]],
) -> LettaUsageStatistics:
) -> UsageMessage:
"""
Top-level event message handler for the agent.
"""
Expand Down Expand Up @@ -917,7 +918,7 @@ def step(
chaining: bool = True,
max_chaining_steps: Optional[int] = None,
**kwargs,
) -> LettaUsageStatistics:
) -> UsageMessage:
"""Run Agent.step in a loop, handling chaining via heartbeat requests and function failures"""
next_input_message = messages if isinstance(messages, list) else [messages]
counter = 0
Expand Down Expand Up @@ -991,8 +992,11 @@ def step(
# Letta no-op / yield
else:
break

return LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count)
return UsageMessage(
id="null",
date=get_utc_time(),
usage=LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count)
)

def inner_step(
self,
Expand Down
4 changes: 2 additions & 2 deletions letta/chat_only_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from letta.prompts import gpt_system
from letta.schemas.agent import AgentState, AgentType
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.letta_message import UsageMessage
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import BasicBlockMemory, Block
from letta.schemas.message import Message
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
from letta.utils import get_persona_text

Expand All @@ -36,7 +36,7 @@ def step(
chaining: bool = True,
max_chaining_steps: Optional[int] = None,
**kwargs,
) -> LettaUsageStatistics:
) -> UsageMessage:
letta_statistics = super().step(messages=messages, chaining=chaining, max_chaining_steps=max_chaining_steps, **kwargs)

if self.always_rethink_memory:
Expand Down
10 changes: 5 additions & 5 deletions letta/client/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@

from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING
from letta.errors import LLMError
from letta.schemas.enums import MessageStreamStatus
from letta.llm_api.openai import OPENAI_SSE_DONE
from letta.schemas.letta_message import (
ToolCallMessage,
ToolReturnMessage,
ReasoningMessage,
UsageMessage,
)
from letta.schemas.letta_response import LettaStreamingResponse
from letta.schemas.usage import LettaUsageStatistics


def _sse_post(url: str, data: dict, headers: dict) -> Generator[LettaStreamingResponse, None, None]:
Expand Down Expand Up @@ -47,10 +47,10 @@ def _sse_post(url: str, data: dict, headers: dict) -> Generator[LettaStreamingRe
# if sse.data == OPENAI_SSE_DONE:
# print("finished")
# break
if sse.data in [status.value for status in MessageStreamStatus]:
if sse.data == OPENAI_SSE_DONE:
# break
# print("sse.data::", sse.data)
yield MessageStreamStatus(sse.data)
yield sse.data
else:
chunk_data = json.loads(sse.data)
if "reasoning" in chunk_data:
Expand All @@ -60,7 +60,7 @@ def _sse_post(url: str, data: dict, headers: dict) -> Generator[LettaStreamingRe
elif "tool_return" in chunk_data:
yield ToolReturnMessage(**chunk_data)
elif "usage" in chunk_data:
yield LettaUsageStatistics(**chunk_data["usage"])
yield UsageMessage(**chunk_data)
else:
raise ValueError(f"Unknown message type in chunk_data: {chunk_data}")

Expand Down
10 changes: 8 additions & 2 deletions letta/o1_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from letta.agent import Agent, save_agent
from letta.interface import AgentInterface
from letta.schemas.agent import AgentState
from letta.schemas.letta_message import UsageMessage
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import UsageStatistics
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
from letta.utils import get_utc_time


def send_thinking_message(self: "Agent", message: str) -> Optional[str]:
Expand Down Expand Up @@ -56,7 +58,7 @@ def step(
chaining: bool = True,
max_chaining_steps: Optional[int] = None,
**kwargs,
) -> LettaUsageStatistics:
) -> UsageMessage:
"""Run Agent.inner_step in a loop, terminate when final thinking message is sent or max_thinking_steps is reached"""
# assert ms is not None, "MetadataStore is required"
next_input_message = messages if isinstance(messages, list) else [messages]
Expand All @@ -83,4 +85,8 @@ def step(
break
save_agent(self)

return LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count)
return UsageMessage(
id="null",
date=get_utc_time(),
usage=LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count)
)
6 changes: 0 additions & 6 deletions letta/schemas/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,6 @@ class JobStatus(str, Enum):
pending = "pending"


class MessageStreamStatus(str, Enum):
done_generation = "[DONE_GEN]"
done_step = "[DONE_STEP]"
done = "[DONE]"


class ToolRuleType(str, Enum):
"""
Type of tool rule.
Expand Down
14 changes: 14 additions & 0 deletions letta/schemas/letta_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from pydantic import BaseModel, Field, field_serializer, field_validator

from letta.schemas.usage import LettaUsageStatistics

# Letta API style responses (intended to be easier to use vs getting true Message types)


Expand Down Expand Up @@ -162,6 +164,18 @@ class ToolReturnMessage(LettaMessage):
stderr: Optional[List[str]] = None


class UsageMessage(LettaMessage):
"""
A message representint the usage statistics for the agent interaction.

Attributes:
usage (LettaUsageStatistics): Usage statistics for the agent interaction.
"""

message_type: Literal["usage_message"] = "usage_message"
usage: LettaUsageStatistics


# Legacy Letta API had an additional type "assistant_message" and the "function_call" was a formatted string


Expand Down
19 changes: 9 additions & 10 deletions letta/schemas/letta_response.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
import html
import json
import re
from typing import List, Union
from typing import List, Literal, Union

from pydantic import BaseModel, Field

from letta.schemas.enums import MessageStreamStatus
from letta.schemas.letta_message import LettaMessage, LettaMessageUnion
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.letta_message import LettaMessage, LettaMessageUnion, UsageMessage
from letta.utils import json_dumps

# TODO: consider moving into own file

StreamDoneStatus = Literal["[DONE]"]

class LettaResponse(BaseModel):
"""
Expand All @@ -24,14 +23,14 @@ class LettaResponse(BaseModel):
"""

messages: List[LettaMessageUnion] = Field(..., description="The messages returned by the agent.")
usage: LettaUsageStatistics = Field(..., description="The usage statistics of the agent.")
usage: UsageMessage = Field(..., description="The usage statistics of the agent.")

def __str__(self):
return json_dumps(
{
"messages": [message.model_dump() for message in self.messages],
# Assume `Message` and `LettaMessage` have a `dict()` method
"usage": self.usage.model_dump(), # Assume `LettaUsageStatistics` has a `dict()` method
# Assume `Message` and `UsageMessage` have a `dict()` method
"usage": self.usage.model_dump(), # Assume `UsageMessage` has a `dict()` method
},
indent=4,
)
Expand Down Expand Up @@ -139,7 +138,7 @@ def format_json(json_str):
html_output += "</div>"

# Formatting the usage statistics
usage_html = json.dumps(self.usage.model_dump(), indent=2)
usage_html = json.dumps(self.usage.usage.model_dump(), indent=2)
html_output += f"""
<div class="usage-container">
<div class="usage-stats">
Expand All @@ -152,5 +151,5 @@ def format_json(json_str):
return html_output


# The streaming response is either [DONE], [DONE_STEP], [DONE], an error, or a LettaMessage
LettaStreamingResponse = Union[LettaMessage, MessageStreamStatus, LettaUsageStatistics]
# The streaming response is either [DONE], an error, or a LettaMessage
LettaStreamingResponse = Union[LettaMessage, StreamDoneStatus]
22 changes: 12 additions & 10 deletions letta/server/rest_api/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
from datetime import datetime
from typing import AsyncGenerator, Literal, Optional, Union

from letta.llm_api.openai import OPENAI_SSE_DONE
from letta.schemas.letta_response import StreamDoneStatus
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.interface import AgentInterface
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
from letta.schemas.enums import MessageStreamStatus
from letta.schemas.letta_message import (
AssistantMessage,
ToolCall,
Expand Down Expand Up @@ -295,8 +296,8 @@ def __init__(
# if multi_step = True, the stream ends when the agent yields
# if multi_step = False, the stream ends when the step ends
self.multi_step = multi_step
self.multi_step_indicator = MessageStreamStatus.done_step
self.multi_step_gen_indicator = MessageStreamStatus.done_generation
# self.multi_step_indicator = MessageStreamStatus.done_step
# self.multi_step_gen_indicator = MessageStreamStatus.done_generation

# Support for AssistantMessage
self.use_assistant_message = False # TODO: Remove this
Expand Down Expand Up @@ -325,7 +326,7 @@ def _reset_inner_thoughts_json_reader(self):
self.function_args_buffer = None
self.function_id_buffer = None

async def _create_generator(self) -> AsyncGenerator[Union[LettaMessage, LegacyLettaMessage, MessageStreamStatus], None]:
async def _create_generator(self) -> AsyncGenerator[Union[LettaMessage, LegacyLettaMessage, StreamDoneStatus], None]:
"""An asynchronous generator that yields chunks as they become available."""
while self._active:
try:
Expand All @@ -351,7 +352,8 @@ def _push_to_buffer(
self,
item: Union[
# signal on SSE stream status [DONE_GEN], [DONE_STEP], [DONE]
MessageStreamStatus,
# MessageStreamStatus,
StreamDoneStatus,
# the non-streaming message types
LettaMessage,
LegacyLettaMessage,
Expand All @@ -362,7 +364,7 @@ def _push_to_buffer(
"""Add an item to the deque"""
assert self._active, "Generator is inactive"
assert (
isinstance(item, LettaMessage) or isinstance(item, LegacyLettaMessage) or isinstance(item, MessageStreamStatus)
isinstance(item, LettaMessage) or isinstance(item, LegacyLettaMessage) or item == OPENAI_SSE_DONE
), f"Wrong type: {type(item)}"

self._chunks.append(item)
Expand All @@ -381,8 +383,8 @@ def stream_end(self):
"""Clean up the stream by deactivating and clearing chunks."""
self.streaming_chat_completion_mode_function_name = None

if not self.streaming_chat_completion_mode and not self.nonstreaming_legacy_mode:
self._push_to_buffer(self.multi_step_gen_indicator)
# if not self.streaming_chat_completion_mode and not self.nonstreaming_legacy_mode:
# self._push_to_buffer(self.multi_step_gen_indicator)

# Wipe the inner thoughts buffers
self._reset_inner_thoughts_json_reader()
Expand All @@ -393,9 +395,9 @@ def step_complete(self):
# end the stream
self._active = False
self._event.set() # Unblock the generator if it's waiting to allow it to complete
elif not self.streaming_chat_completion_mode and not self.nonstreaming_legacy_mode:
# elif not self.streaming_chat_completion_mode and not self.nonstreaming_legacy_mode:
# signal that a new step has started in the stream
self._push_to_buffer(self.multi_step_indicator)
# self._push_to_buffer(self.multi_step_indicator)

# Wipe the inner thoughts buffers
self._reset_inner_thoughts_json_reader()
Expand Down
8 changes: 4 additions & 4 deletions letta/server/rest_api/routers/v1/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@

from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.log import get_logger
from letta.llm_api.openai import OPENAI_SSE_DONE
from letta.orm.errors import NoResultFound
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgent
from letta.schemas.block import ( # , BlockLabelUpdate, BlockLimitUpdate
Block,
BlockUpdate,
CreateBlock,
)
from letta.schemas.enums import MessageStreamStatus
from letta.schemas.job import Job, JobStatus, JobUpdate
from letta.schemas.letta_message import (
LegacyLettaMessage,
Expand Down Expand Up @@ -732,14 +732,14 @@ async def send_message_to_agent(
generated_stream = []
async for message in streaming_interface.get_generator():
assert (
isinstance(message, LettaMessage) or isinstance(message, LegacyLettaMessage) or isinstance(message, MessageStreamStatus)
isinstance(message, LettaMessage) or isinstance(message, LegacyLettaMessage) or message == OPENAI_SSE_DONE
), type(message)
generated_stream.append(message)
if message == MessageStreamStatus.done:
if message == OPENAI_SSE_DONE:
break

# Get rid of the stream status messages
filtered_stream = [d for d in generated_stream if not isinstance(d, MessageStreamStatus)]
filtered_stream = [d for d in generated_stream if d != OPENAI_SSE_DONE]
usage = await task

# By default the stream will be messages of type LettaMessage or LettaLegacyMessage
Expand Down
7 changes: 4 additions & 3 deletions letta/server/rest_api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pydantic import BaseModel

from letta.errors import ContextWindowExceededError, RateLimitExceededError
from letta.schemas.letta_message import UsageMessage
from letta.schemas.usage import LettaUsageStatistics
from letta.server.rest_api.interface import StreamingServerInterface
from letta.server.server import SyncServer
Expand Down Expand Up @@ -59,9 +60,9 @@ async def sse_async_generator(
try:
usage = await usage_task
# Double-check the type
if not isinstance(usage, LettaUsageStatistics):
raise ValueError(f"Expected LettaUsageStatistics, got {type(usage)}")
yield sse_formatter({"usage": usage.model_dump()})
if not isinstance(usage, UsageMessage):
raise ValueError(f"Expected UsageMessage, got {type(usage)}")
yield sse_formatter(usage.model_dump())

except ContextWindowExceededError as e:
log_error_to_sentry(e)
Expand Down
Loading
Loading