diff --git a/letta/server/rest_api/utils.py b/letta/server/rest_api/utils.py index 64d46a5d3f..35047ff698 100644 --- a/letta/server/rest_api/utils.py +++ b/letta/server/rest_api/utils.py @@ -9,6 +9,7 @@ from pydantic import BaseModel from letta.errors import ContextWindowExceededError, RateLimitExceededError +from letta.schemas.letta_message import UsageMessage from letta.schemas.usage import LettaUsageStatistics from letta.server.rest_api.interface import StreamingServerInterface from letta.server.server import SyncServer @@ -59,9 +60,9 @@ async def sse_async_generator( try: usage = await usage_task # Double-check the type - if not isinstance(usage, LettaUsageStatistics): - raise ValueError(f"Expected LettaUsageStatistics, got {type(usage)}") - yield sse_formatter({"usage": usage.model_dump()}) + if not isinstance(usage, UsageMessage): + raise ValueError(f"Expected UsageMessage, got {type(usage)}") + yield sse_formatter(usage.model_dump()) except ContextWindowExceededError as e: log_error_to_sentry(e) diff --git a/tests/test_client_legacy.py b/tests/test_client_legacy.py index 0e047d9030..00b8a4e58b 100644 --- a/tests/test_client_legacy.py +++ b/tests/test_client_legacy.py @@ -260,10 +260,10 @@ def test_streaming_send_message(mock_e2b_api_key_none, client: RESTClient, agent send_message_ran = True if isinstance(chunk, UsageMessage): # Some rough metrics for a reasonable usage pattern - assert chunk.step_count == 1 - assert chunk.completion_tokens > 10 - assert chunk.prompt_tokens > 1000 - assert chunk.total_tokens > 1000 + assert chunk.usage.step_count == 1 + assert chunk.usage.completion_tokens > 10 + assert chunk.usage.prompt_tokens > 1000 + assert chunk.usage.total_tokens > 1000 elif chunk == OPENAI_SSE_DONE: assert not done, "Message stream already done" done = True