From b326ed140655d2149de6aa4d31a0c6da1f62323c Mon Sep 17 00:00:00 2001 From: Caren Thomas Date: Wed, 18 Dec 2024 17:50:54 -0800 Subject: [PATCH 1/4] remove incremental done steps --- letta/client/streaming.py | 6 ++-- letta/schemas/enums.py | 6 ---- letta/schemas/letta_response.py | 8 ++--- letta/server/rest_api/interface.py | 16 ++------- letta/server/rest_api/routers/v1/agents.py | 8 ++--- tests/test_client_legacy.py | 41 +++++++++------------- 6 files changed, 30 insertions(+), 55 deletions(-) diff --git a/letta/client/streaming.py b/letta/client/streaming.py index 4a258cdc7a..ef043ea928 100644 --- a/letta/client/streaming.py +++ b/letta/client/streaming.py @@ -6,7 +6,7 @@ from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING from letta.errors import LLMError -from letta.schemas.enums import MessageStreamStatus +from letta.llm_api.openai import OPENAI_SSE_DONE from letta.schemas.letta_message import ( ToolCallMessage, ToolReturnMessage, @@ -47,10 +47,10 @@ def _sse_post(url: str, data: dict, headers: dict) -> Generator[LettaStreamingRe # if sse.data == OPENAI_SSE_DONE: # print("finished") # break - if sse.data in [status.value for status in MessageStreamStatus]: + if sse.data == OPENAI_SSE_DONE: # break # print("sse.data::", sse.data) - yield MessageStreamStatus(sse.data) + yield sse.data else: chunk_data = json.loads(sse.data) if "reasoning" in chunk_data: diff --git a/letta/schemas/enums.py b/letta/schemas/enums.py index 6183033f54..04d69c5bb8 100644 --- a/letta/schemas/enums.py +++ b/letta/schemas/enums.py @@ -29,12 +29,6 @@ class JobStatus(str, Enum): pending = "pending" -class MessageStreamStatus(str, Enum): - done_generation = "[DONE_GEN]" - done_step = "[DONE_STEP]" - done = "[DONE]" - - class ToolRuleType(str, Enum): """ Type of tool rule. diff --git a/letta/schemas/letta_response.py b/letta/schemas/letta_response.py index c6a1e8be58..30d53643f6 100644 --- a/letta/schemas/letta_response.py +++ b/letta/schemas/letta_response.py @@ -1,17 +1,17 @@ import html import json import re -from typing import List, Union +from typing import List, Literal, Union from pydantic import BaseModel, Field -from letta.schemas.enums import MessageStreamStatus from letta.schemas.letta_message import LettaMessage, LettaMessageUnion from letta.schemas.usage import LettaUsageStatistics from letta.utils import json_dumps # TODO: consider moving into own file +StreamDoneStatus = Literal["[DONE]"] class LettaResponse(BaseModel): """ @@ -152,5 +152,5 @@ def format_json(json_str): return html_output -# The streaming response is either [DONE], [DONE_STEP], [DONE], an error, or a LettaMessage -LettaStreamingResponse = Union[LettaMessage, MessageStreamStatus, LettaUsageStatistics] +# The streaming response is either [DONE], an error, or a LettaMessage +LettaStreamingResponse = Union[LettaMessage, StreamDoneStatus, LettaUsageStatistics] diff --git a/letta/server/rest_api/interface.py b/letta/server/rest_api/interface.py index b3d344798b..96ee50d6a9 100644 --- a/letta/server/rest_api/interface.py +++ b/letta/server/rest_api/interface.py @@ -6,10 +6,10 @@ from datetime import datetime from typing import AsyncGenerator, Literal, Optional, Union +from letta.schemas.letta_response import StreamDoneStatus from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.interface import AgentInterface from letta.local_llm.constants import INNER_THOUGHTS_KWARG -from letta.schemas.enums import MessageStreamStatus from letta.schemas.letta_message import ( AssistantMessage, ToolCall, @@ -295,8 +295,6 @@ def __init__( # if multi_step = True, the stream ends when the agent yields # if multi_step = False, the stream ends when the step ends self.multi_step = multi_step - self.multi_step_indicator = MessageStreamStatus.done_step - self.multi_step_gen_indicator = MessageStreamStatus.done_generation # Support for AssistantMessage self.use_assistant_message = False # TODO: Remove this @@ -325,7 +323,7 @@ def _reset_inner_thoughts_json_reader(self): self.function_args_buffer = None self.function_id_buffer = None - async def _create_generator(self) -> AsyncGenerator[Union[LettaMessage, LegacyLettaMessage, MessageStreamStatus], None]: + async def _create_generator(self) -> AsyncGenerator[Union[LettaMessage, LegacyLettaMessage, StreamDoneStatus], None]: """An asynchronous generator that yields chunks as they become available.""" while self._active: try: @@ -350,8 +348,6 @@ def get_generator(self) -> AsyncGenerator: def _push_to_buffer( self, item: Union[ - # signal on SSE stream status [DONE_GEN], [DONE_STEP], [DONE] - MessageStreamStatus, # the non-streaming message types LettaMessage, LegacyLettaMessage, @@ -362,7 +358,7 @@ def _push_to_buffer( """Add an item to the deque""" assert self._active, "Generator is inactive" assert ( - isinstance(item, LettaMessage) or isinstance(item, LegacyLettaMessage) or isinstance(item, MessageStreamStatus) + isinstance(item, LettaMessage) or isinstance(item, LegacyLettaMessage) ), f"Wrong type: {type(item)}" self._chunks.append(item) @@ -381,9 +377,6 @@ def stream_end(self): """Clean up the stream by deactivating and clearing chunks.""" self.streaming_chat_completion_mode_function_name = None - if not self.streaming_chat_completion_mode and not self.nonstreaming_legacy_mode: - self._push_to_buffer(self.multi_step_gen_indicator) - # Wipe the inner thoughts buffers self._reset_inner_thoughts_json_reader() @@ -393,9 +386,6 @@ def step_complete(self): # end the stream self._active = False self._event.set() # Unblock the generator if it's waiting to allow it to complete - elif not self.streaming_chat_completion_mode and not self.nonstreaming_legacy_mode: - # signal that a new step has started in the stream - self._push_to_buffer(self.multi_step_indicator) # Wipe the inner thoughts buffers self._reset_inner_thoughts_json_reader() diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 69b97c764e..c75a8777cf 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -18,6 +18,7 @@ from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.log import get_logger +from letta.llm_api.openai import OPENAI_SSE_DONE from letta.orm.errors import NoResultFound from letta.schemas.agent import AgentState, CreateAgent, UpdateAgent from letta.schemas.block import ( # , BlockLabelUpdate, BlockLimitUpdate @@ -25,7 +26,6 @@ BlockUpdate, CreateBlock, ) -from letta.schemas.enums import MessageStreamStatus from letta.schemas.job import Job, JobStatus, JobUpdate from letta.schemas.letta_message import ( LegacyLettaMessage, @@ -732,14 +732,14 @@ async def send_message_to_agent( generated_stream = [] async for message in streaming_interface.get_generator(): assert ( - isinstance(message, LettaMessage) or isinstance(message, LegacyLettaMessage) or isinstance(message, MessageStreamStatus) + isinstance(message, LettaMessage) or isinstance(message, LegacyLettaMessage) or message == OPENAI_SSE_DONE ), type(message) generated_stream.append(message) - if message == MessageStreamStatus.done: + if message == OPENAI_SSE_DONE: break # Get rid of the stream status messages - filtered_stream = [d for d in generated_stream if not isinstance(d, MessageStreamStatus)] + filtered_stream = [d for d in generated_stream if d != OPENAI_SSE_DONE] usage = await task # By default the stream will be messages of type LettaMessage or LettaLegacyMessage diff --git a/tests/test_client_legacy.py b/tests/test_client_legacy.py index 16dc1cd6b4..f7d9e24200 100644 --- a/tests/test_client_legacy.py +++ b/tests/test_client_legacy.py @@ -12,10 +12,11 @@ from letta import create_client from letta.client.client import LocalClient, RESTClient from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, DEFAULT_PRESET +from letta.llm_api.openai import OPENAI_SSE_DONE from letta.orm import FileMetadata, Source from letta.schemas.agent import AgentState from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.enums import MessageRole, MessageStreamStatus +from letta.schemas.enums import MessageRole from letta.schemas.letta_message import ( AssistantMessage, ToolCallMessage, @@ -245,45 +246,35 @@ def test_streaming_send_message(mock_e2b_api_key_none, client: RESTClient, agent inner_thoughts_count = 0 # 2. Check that the agent runs `send_message` send_message_ran = False - # 3. Check that we get all the start/stop/end tokens we want - # This includes all of the MessageStreamStatus enums - done_gen = False - done_step = False + # 3. Check that we get the end token we want (StreamDoneStatus) done = False - # print(response) + print(response) assert response, "Sending message failed" for chunk in response: - assert isinstance(chunk, LettaStreamingResponse) - if isinstance(chunk, ReasoningMessage) and chunk.reasoning and chunk.reasoning != "": - inner_thoughts_exist = True - inner_thoughts_count += 1 - if isinstance(chunk, ToolCallMessage) and chunk.tool_call and chunk.tool_call.name == "send_message": - send_message_ran = True - if isinstance(chunk, MessageStreamStatus): - if chunk == MessageStreamStatus.done: - assert not done, "Message stream already done" - done = True - elif chunk == MessageStreamStatus.done_step: - assert not done_step, "Message stream already done step" - done_step = True - elif chunk == MessageStreamStatus.done_generation: - assert not done_gen, "Message stream already done generation" - done_gen = True - if isinstance(chunk, LettaUsageStatistics): + if isinstance(chunk, LettaMessage): + if isinstance(chunk, ReasoningMessage) and chunk.reasoning and chunk.reasoning != "": + inner_thoughts_exist = True + inner_thoughts_count += 1 + if isinstance(chunk, ToolReturnMessage) and chunk.tool_call and chunk.tool_call.name == "send_message": + send_message_ran = True + elif chunk == OPENAI_SSE_DONE: + assert not done, "Message stream already done" + done = True + elif isinstance(chunk, LettaUsageStatistics): # Some rough metrics for a reasonable usage pattern assert chunk.step_count == 1 assert chunk.completion_tokens > 10 assert chunk.prompt_tokens > 1000 assert chunk.total_tokens > 1000 + else: + assert isinstance(chunk, LettaStreamingResponse) # If stream tokens, we expect at least one inner thought assert inner_thoughts_count >= 1, "Expected more than one inner thought" assert inner_thoughts_exist, "No inner thoughts found" assert send_message_ran, "send_message function call not found" assert done, "Message stream not done" - assert done_step, "Message stream not done step" - assert done_gen, "Message stream not done generation" def test_humans_personas(client: Union[LocalClient, RESTClient], agent: AgentState): From c1309cfbbfe8e763e9064ecb4320300937e4f723 Mon Sep 17 00:00:00 2001 From: Caren Thomas Date: Fri, 20 Dec 2024 10:19:12 -0800 Subject: [PATCH 2/4] comment out steps instead of deleting --- letta/server/rest_api/interface.py | 14 +++++++++++++- tests/test_client_legacy.py | 4 ++-- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/letta/server/rest_api/interface.py b/letta/server/rest_api/interface.py index 96ee50d6a9..dfe4651ff8 100644 --- a/letta/server/rest_api/interface.py +++ b/letta/server/rest_api/interface.py @@ -6,6 +6,7 @@ from datetime import datetime from typing import AsyncGenerator, Literal, Optional, Union +from letta.llm_api.openai import OPENAI_SSE_DONE from letta.schemas.letta_response import StreamDoneStatus from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.interface import AgentInterface @@ -295,6 +296,8 @@ def __init__( # if multi_step = True, the stream ends when the agent yields # if multi_step = False, the stream ends when the step ends self.multi_step = multi_step + # self.multi_step_indicator = MessageStreamStatus.done_step + # self.multi_step_gen_indicator = MessageStreamStatus.done_generation # Support for AssistantMessage self.use_assistant_message = False # TODO: Remove this @@ -348,6 +351,9 @@ def get_generator(self) -> AsyncGenerator: def _push_to_buffer( self, item: Union[ + # signal on SSE stream status [DONE_GEN], [DONE_STEP], [DONE] + # MessageStreamStatus, + StreamDoneStatus, # the non-streaming message types LettaMessage, LegacyLettaMessage, @@ -358,7 +364,7 @@ def _push_to_buffer( """Add an item to the deque""" assert self._active, "Generator is inactive" assert ( - isinstance(item, LettaMessage) or isinstance(item, LegacyLettaMessage) + isinstance(item, LettaMessage) or isinstance(item, LegacyLettaMessage) or item == OPENAI_SSE_DONE ), f"Wrong type: {type(item)}" self._chunks.append(item) @@ -377,6 +383,9 @@ def stream_end(self): """Clean up the stream by deactivating and clearing chunks.""" self.streaming_chat_completion_mode_function_name = None + # if not self.streaming_chat_completion_mode and not self.nonstreaming_legacy_mode: + # self._push_to_buffer(self.multi_step_gen_indicator) + # Wipe the inner thoughts buffers self._reset_inner_thoughts_json_reader() @@ -386,6 +395,9 @@ def step_complete(self): # end the stream self._active = False self._event.set() # Unblock the generator if it's waiting to allow it to complete + # elif not self.streaming_chat_completion_mode and not self.nonstreaming_legacy_mode: + # signal that a new step has started in the stream + # self._push_to_buffer(self.multi_step_indicator) # Wipe the inner thoughts buffers self._reset_inner_thoughts_json_reader() diff --git a/tests/test_client_legacy.py b/tests/test_client_legacy.py index f7d9e24200..4bfe4f49e2 100644 --- a/tests/test_client_legacy.py +++ b/tests/test_client_legacy.py @@ -249,14 +249,14 @@ def test_streaming_send_message(mock_e2b_api_key_none, client: RESTClient, agent # 3. Check that we get the end token we want (StreamDoneStatus) done = False - print(response) + # print(response) assert response, "Sending message failed" for chunk in response: if isinstance(chunk, LettaMessage): if isinstance(chunk, ReasoningMessage) and chunk.reasoning and chunk.reasoning != "": inner_thoughts_exist = True inner_thoughts_count += 1 - if isinstance(chunk, ToolReturnMessage) and chunk.tool_call and chunk.tool_call.name == "send_message": + if isinstance(chunk, ToolCallMessage) and chunk.tool_call and chunk.tool_call.name == "send_message": send_message_ran = True elif chunk == OPENAI_SSE_DONE: assert not done, "Message stream already done" From 8e9dc859e031e0923a448625923c33bb35faab7d Mon Sep 17 00:00:00 2001 From: Caren Thomas Date: Fri, 20 Dec 2024 12:04:25 -0800 Subject: [PATCH 3/4] wrap usage stats with new UsageMessage --- letta/agent.py | 12 ++++++++---- letta/chat_only_agent.py | 4 ++-- letta/client/streaming.py | 4 ++-- letta/o1_agent.py | 10 ++++++++-- letta/schemas/letta_message.py | 14 ++++++++++++++ letta/schemas/letta_response.py | 13 ++++++------- letta/server/server.py | 16 ++++++++-------- tests/test_client_legacy.py | 14 +++++++------- 8 files changed, 55 insertions(+), 32 deletions(-) diff --git a/letta/agent.py b/letta/agent.py index 82958acda5..9e19a3a9fb 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -33,6 +33,7 @@ from letta.schemas.block import BlockUpdate from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import MessageRole +from letta.schemas.letta_message import UsageMessage from letta.schemas.memory import ContextWindowOverview, Memory from letta.schemas.message import Message, MessageUpdate from letta.schemas.openai.chat_completion_request import ( @@ -232,7 +233,7 @@ class BaseAgent(ABC): def step( self, messages: Union[Message, List[Message]], - ) -> LettaUsageStatistics: + ) -> UsageMessage: """ Top-level event message handler for the agent. """ @@ -917,7 +918,7 @@ def step( chaining: bool = True, max_chaining_steps: Optional[int] = None, **kwargs, - ) -> LettaUsageStatistics: + ) -> UsageMessage: """Run Agent.step in a loop, handling chaining via heartbeat requests and function failures""" next_input_message = messages if isinstance(messages, list) else [messages] counter = 0 @@ -991,8 +992,11 @@ def step( # Letta no-op / yield else: break - - return LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count) + return UsageMessage( + id="null", + date=get_utc_time(), + usage=LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count) + ) def inner_step( self, diff --git a/letta/chat_only_agent.py b/letta/chat_only_agent.py index e340673eba..d0f755eef4 100644 --- a/letta/chat_only_agent.py +++ b/letta/chat_only_agent.py @@ -6,10 +6,10 @@ from letta.prompts import gpt_system from letta.schemas.agent import AgentState, AgentType from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.letta_message import UsageMessage from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import BasicBlockMemory, Block from letta.schemas.message import Message -from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User from letta.utils import get_persona_text @@ -36,7 +36,7 @@ def step( chaining: bool = True, max_chaining_steps: Optional[int] = None, **kwargs, - ) -> LettaUsageStatistics: + ) -> UsageMessage: letta_statistics = super().step(messages=messages, chaining=chaining, max_chaining_steps=max_chaining_steps, **kwargs) if self.always_rethink_memory: diff --git a/letta/client/streaming.py b/letta/client/streaming.py index ef043ea928..bfd33a7f24 100644 --- a/letta/client/streaming.py +++ b/letta/client/streaming.py @@ -11,9 +11,9 @@ ToolCallMessage, ToolReturnMessage, ReasoningMessage, + UsageMessage, ) from letta.schemas.letta_response import LettaStreamingResponse -from letta.schemas.usage import LettaUsageStatistics def _sse_post(url: str, data: dict, headers: dict) -> Generator[LettaStreamingResponse, None, None]: @@ -60,7 +60,7 @@ def _sse_post(url: str, data: dict, headers: dict) -> Generator[LettaStreamingRe elif "tool_return" in chunk_data: yield ToolReturnMessage(**chunk_data) elif "usage" in chunk_data: - yield LettaUsageStatistics(**chunk_data["usage"]) + yield UsageMessage(**chunk_data) else: raise ValueError(f"Unknown message type in chunk_data: {chunk_data}") diff --git a/letta/o1_agent.py b/letta/o1_agent.py index 285ed966fa..e77cf50ce0 100644 --- a/letta/o1_agent.py +++ b/letta/o1_agent.py @@ -3,10 +3,12 @@ from letta.agent import Agent, save_agent from letta.interface import AgentInterface from letta.schemas.agent import AgentState +from letta.schemas.letta_message import UsageMessage from letta.schemas.message import Message from letta.schemas.openai.chat_completion_response import UsageStatistics from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User +from letta.utils import get_utc_time def send_thinking_message(self: "Agent", message: str) -> Optional[str]: @@ -56,7 +58,7 @@ def step( chaining: bool = True, max_chaining_steps: Optional[int] = None, **kwargs, - ) -> LettaUsageStatistics: + ) -> UsageMessage: """Run Agent.inner_step in a loop, terminate when final thinking message is sent or max_thinking_steps is reached""" # assert ms is not None, "MetadataStore is required" next_input_message = messages if isinstance(messages, list) else [messages] @@ -83,4 +85,8 @@ def step( break save_agent(self) - return LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count) + return UsageMessage( + id="null", + date=get_utc_time(), + usage=LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count) + ) diff --git a/letta/schemas/letta_message.py b/letta/schemas/letta_message.py index 45fcf36180..7afc9a646d 100644 --- a/letta/schemas/letta_message.py +++ b/letta/schemas/letta_message.py @@ -4,6 +4,8 @@ from pydantic import BaseModel, Field, field_serializer, field_validator +from letta.schemas.usage import LettaUsageStatistics + # Letta API style responses (intended to be easier to use vs getting true Message types) @@ -162,6 +164,18 @@ class ToolReturnMessage(LettaMessage): stderr: Optional[List[str]] = None +class UsageMessage(LettaMessage): + """ + A message representint the usage statistics for the agent interaction. + + Attributes: + usage (LettaUsageStatistics): Usage statistics for the agent interaction. + """ + + message_type: Literal["usage_message"] = "usage_message" + usage: LettaUsageStatistics + + # Legacy Letta API had an additional type "assistant_message" and the "function_call" was a formatted string diff --git a/letta/schemas/letta_response.py b/letta/schemas/letta_response.py index 30d53643f6..013117ff2b 100644 --- a/letta/schemas/letta_response.py +++ b/letta/schemas/letta_response.py @@ -5,8 +5,7 @@ from pydantic import BaseModel, Field -from letta.schemas.letta_message import LettaMessage, LettaMessageUnion -from letta.schemas.usage import LettaUsageStatistics +from letta.schemas.letta_message import LettaMessage, LettaMessageUnion, UsageMessage from letta.utils import json_dumps # TODO: consider moving into own file @@ -24,14 +23,14 @@ class LettaResponse(BaseModel): """ messages: List[LettaMessageUnion] = Field(..., description="The messages returned by the agent.") - usage: LettaUsageStatistics = Field(..., description="The usage statistics of the agent.") + usage: UsageMessage = Field(..., description="The usage statistics of the agent.") def __str__(self): return json_dumps( { "messages": [message.model_dump() for message in self.messages], - # Assume `Message` and `LettaMessage` have a `dict()` method - "usage": self.usage.model_dump(), # Assume `LettaUsageStatistics` has a `dict()` method + # Assume `Message` and `UsageMessage` have a `dict()` method + "usage": self.usage.model_dump(), # Assume `UsageMessage` has a `dict()` method }, indent=4, ) @@ -139,7 +138,7 @@ def format_json(json_str): html_output += "" # Formatting the usage statistics - usage_html = json.dumps(self.usage.model_dump(), indent=2) + usage_html = json.dumps(self.usage.usage.model_dump(), indent=2) html_output += f"""
@@ -153,4 +152,4 @@ def format_json(json_str): # The streaming response is either [DONE], an error, or a LettaMessage -LettaStreamingResponse = Union[LettaMessage, StreamDoneStatus, LettaUsageStatistics] +LettaStreamingResponse = Union[LettaMessage, StreamDoneStatus] diff --git a/letta/server/server.py b/letta/server/server.py index 617283345a..7abaccf5dd 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -47,7 +47,7 @@ # openai schemas from letta.schemas.enums import JobStatus from letta.schemas.job import Job, JobUpdate -from letta.schemas.letta_message import ToolReturnMessage, LettaMessage +from letta.schemas.letta_message import ToolReturnMessage, LettaMessage, UsageMessage from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import ( ArchivalMemorySummary, @@ -427,7 +427,7 @@ def _step( input_messages: Union[Message, List[Message]], interface: Union[AgentInterface, None] = None, # needed to getting responses # timestamp: Optional[datetime], - ) -> LettaUsageStatistics: + ) -> UsageMessage: """Send the input message through the agent""" # TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user # Input validation @@ -469,7 +469,7 @@ def _step( return usage_stats - def _command(self, user_id: str, agent_id: str, command: str) -> LettaUsageStatistics: + def _command(self, user_id: str, agent_id: str, command: str) -> UsageMessage: """Process a CLI command""" # TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user actor = self.user_manager.get_user_or_default(user_id=user_id) @@ -583,7 +583,7 @@ def _command(self, user_id: str, agent_id: str, command: str) -> LettaUsageStati usage = self._step(actor=actor, agent_id=agent_id, input_message=input_message) if not usage: - usage = LettaUsageStatistics() + usage = UsageMessage(id="null", date=get_utc_time(), usage=LettaUsageStatistics()) return usage @@ -593,7 +593,7 @@ def user_message( agent_id: str, message: Union[str, Message], timestamp: Optional[datetime] = None, - ) -> LettaUsageStatistics: + ) -> UsageMessage: """Process an incoming user message and feed it through the Letta agent""" try: actor = self.user_manager.get_user_by_id(user_id=user_id) @@ -645,7 +645,7 @@ def system_message( agent_id: str, message: Union[str, Message], timestamp: Optional[datetime] = None, - ) -> LettaUsageStatistics: + ) -> UsageMessage: """Process an incoming system message and feed it through the Letta agent""" try: actor = self.user_manager.get_user_by_id(user_id=user_id) @@ -712,7 +712,7 @@ def send_messages( wrap_user_message: bool = True, wrap_system_message: bool = True, interface: Union[AgentInterface, None] = None, # needed to getting responses - ) -> LettaUsageStatistics: + ) -> UsageMessage: """Send a list of messages to the agent If the messages are of type MessageCreate, we need to turn them into @@ -761,7 +761,7 @@ def send_messages( return self._step(actor=actor, agent_id=agent_id, input_messages=message_objects, interface=interface) # @LockingServer.agent_lock_decorator - def run_command(self, user_id: str, agent_id: str, command: str) -> LettaUsageStatistics: + def run_command(self, user_id: str, agent_id: str, command: str) -> UsageMessage: """Run a command on the agent""" # If the input begins with a command prefix, attempt to process it as a command if command.startswith("/"): diff --git a/tests/test_client_legacy.py b/tests/test_client_legacy.py index 4bfe4f49e2..0e047d9030 100644 --- a/tests/test_client_legacy.py +++ b/tests/test_client_legacy.py @@ -24,12 +24,12 @@ ReasoningMessage, LettaMessage, SystemMessage, + UsageMessage, UserMessage, ) from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse from letta.schemas.llm_config import LLMConfig from letta.schemas.message import MessageCreate -from letta.schemas.usage import LettaUsageStatistics from letta.services.organization_manager import OrganizationManager from letta.services.user_manager import UserManager from letta.settings import model_settings @@ -258,15 +258,15 @@ def test_streaming_send_message(mock_e2b_api_key_none, client: RESTClient, agent inner_thoughts_count += 1 if isinstance(chunk, ToolCallMessage) and chunk.tool_call and chunk.tool_call.name == "send_message": send_message_ran = True + if isinstance(chunk, UsageMessage): + # Some rough metrics for a reasonable usage pattern + assert chunk.step_count == 1 + assert chunk.completion_tokens > 10 + assert chunk.prompt_tokens > 1000 + assert chunk.total_tokens > 1000 elif chunk == OPENAI_SSE_DONE: assert not done, "Message stream already done" done = True - elif isinstance(chunk, LettaUsageStatistics): - # Some rough metrics for a reasonable usage pattern - assert chunk.step_count == 1 - assert chunk.completion_tokens > 10 - assert chunk.prompt_tokens > 1000 - assert chunk.total_tokens > 1000 else: assert isinstance(chunk, LettaStreamingResponse) From 285aad482eb12f4f7e4d8cb4dd0b47ac4f570466 Mon Sep 17 00:00:00 2001 From: Caren Thomas Date: Fri, 20 Dec 2024 14:12:10 -0800 Subject: [PATCH 4/4] fix callsites --- letta/server/rest_api/utils.py | 7 ++++--- tests/test_client_legacy.py | 8 ++++---- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/letta/server/rest_api/utils.py b/letta/server/rest_api/utils.py index 64d46a5d3f..35047ff698 100644 --- a/letta/server/rest_api/utils.py +++ b/letta/server/rest_api/utils.py @@ -9,6 +9,7 @@ from pydantic import BaseModel from letta.errors import ContextWindowExceededError, RateLimitExceededError +from letta.schemas.letta_message import UsageMessage from letta.schemas.usage import LettaUsageStatistics from letta.server.rest_api.interface import StreamingServerInterface from letta.server.server import SyncServer @@ -59,9 +60,9 @@ async def sse_async_generator( try: usage = await usage_task # Double-check the type - if not isinstance(usage, LettaUsageStatistics): - raise ValueError(f"Expected LettaUsageStatistics, got {type(usage)}") - yield sse_formatter({"usage": usage.model_dump()}) + if not isinstance(usage, UsageMessage): + raise ValueError(f"Expected UsageMessage, got {type(usage)}") + yield sse_formatter(usage.model_dump()) except ContextWindowExceededError as e: log_error_to_sentry(e) diff --git a/tests/test_client_legacy.py b/tests/test_client_legacy.py index 0e047d9030..00b8a4e58b 100644 --- a/tests/test_client_legacy.py +++ b/tests/test_client_legacy.py @@ -260,10 +260,10 @@ def test_streaming_send_message(mock_e2b_api_key_none, client: RESTClient, agent send_message_ran = True if isinstance(chunk, UsageMessage): # Some rough metrics for a reasonable usage pattern - assert chunk.step_count == 1 - assert chunk.completion_tokens > 10 - assert chunk.prompt_tokens > 1000 - assert chunk.total_tokens > 1000 + assert chunk.usage.step_count == 1 + assert chunk.usage.completion_tokens > 10 + assert chunk.usage.prompt_tokens > 1000 + assert chunk.usage.total_tokens > 1000 elif chunk == OPENAI_SSE_DONE: assert not done, "Message stream already done" done = True