From db96c9d2803b51370baaf681ec49c5845ae9ee86 Mon Sep 17 00:00:00 2001 From: Caren Thomas Date: Fri, 20 Dec 2024 15:54:03 -0800 Subject: [PATCH] add message type literal to usage stats --- letta/client/streaming.py | 4 ++-- letta/schemas/usage.py | 3 ++- letta/server/rest_api/utils.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/letta/client/streaming.py b/letta/client/streaming.py index 4a258cdc7a..a364ada6f0 100644 --- a/letta/client/streaming.py +++ b/letta/client/streaming.py @@ -59,8 +59,8 @@ def _sse_post(url: str, data: dict, headers: dict) -> Generator[LettaStreamingRe yield ToolCallMessage(**chunk_data) elif "tool_return" in chunk_data: yield ToolReturnMessage(**chunk_data) - elif "usage" in chunk_data: - yield LettaUsageStatistics(**chunk_data["usage"]) + elif "step_count" in chunk_data: + yield LettaUsageStatistics(**chunk_data) else: raise ValueError(f"Unknown message type in chunk_data: {chunk_data}") diff --git a/letta/schemas/usage.py b/letta/schemas/usage.py index 804d63831d..53cda8b25a 100644 --- a/letta/schemas/usage.py +++ b/letta/schemas/usage.py @@ -1,3 +1,4 @@ +from typing import Literal from pydantic import BaseModel, Field @@ -11,7 +12,7 @@ class LettaUsageStatistics(BaseModel): total_tokens (int): The total number of tokens processed by the agent. step_count (int): The number of steps taken by the agent. """ - + message_type: Literal["usage_statistics"] = "usage_statistics" completion_tokens: int = Field(0, description="The number of tokens generated by the agent.") prompt_tokens: int = Field(0, description="The number of tokens in the prompt.") total_tokens: int = Field(0, description="The total number of tokens processed by the agent.") diff --git a/letta/server/rest_api/utils.py b/letta/server/rest_api/utils.py index 64d46a5d3f..86a8899043 100644 --- a/letta/server/rest_api/utils.py +++ b/letta/server/rest_api/utils.py @@ -61,7 +61,7 @@ async def sse_async_generator( # Double-check the type if not isinstance(usage, LettaUsageStatistics): raise ValueError(f"Expected LettaUsageStatistics, got {type(usage)}") - yield sse_formatter({"usage": usage.model_dump()}) + yield sse_formatter(usage.model_dump()) except ContextWindowExceededError as e: log_error_to_sentry(e)