Skip to content

Commit

Permalink
fix: patch use_assistant_message flag on the server (#704)
Browse files Browse the repository at this point in the history
Co-authored-by: Sarah Wooders <[email protected]>
  • Loading branch information
cpacker and sarahwooders authored Jan 20, 2025
1 parent 5419a94 commit c170c25
Show file tree
Hide file tree
Showing 6 changed files with 166 additions and 63 deletions.
4 changes: 3 additions & 1 deletion letta/client/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING
from letta.errors import LLMError
from letta.schemas.enums import MessageStreamStatus
from letta.schemas.letta_message import ReasoningMessage, ToolCallMessage, ToolReturnMessage
from letta.schemas.letta_message import AssistantMessage, ReasoningMessage, ToolCallMessage, ToolReturnMessage
from letta.schemas.letta_response import LettaStreamingResponse
from letta.schemas.usage import LettaUsageStatistics

Expand Down Expand Up @@ -50,6 +50,8 @@ def _sse_post(url: str, data: dict, headers: dict) -> Generator[LettaStreamingRe
chunk_data = json.loads(sse.data)
if "reasoning" in chunk_data:
yield ReasoningMessage(**chunk_data)
elif "assistant_message" in chunk_data:
yield AssistantMessage(**chunk_data)
elif "tool_call" in chunk_data:
yield ToolCallMessage(**chunk_data)
elif "tool_return" in chunk_data:
Expand Down
8 changes: 4 additions & 4 deletions letta/server/rest_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,8 @@ def start_server(
server_logger.addHandler(stream_handler)

if (os.getenv("LOCAL_HTTPS") == "true") or "--localhttps" in sys.argv:
print(f"▶ Server running at: https://{host or 'localhost'}:{port or REST_DEFAULT_PORT}\n")
print(f"▶ View using ADE at: https://app.letta.com/development-servers/local/dashboard")
print(f"▶ Server running at: https://{host or 'localhost'}:{port or REST_DEFAULT_PORT}")
print(f"▶ View using ADE at: https://app.letta.com/development-servers/local/dashboard\n")
uvicorn.run(
app,
host=host or "localhost",
Expand All @@ -300,8 +300,8 @@ def start_server(
ssl_certfile="certs/localhost.pem",
)
else:
print(f"▶ Server running at: http://{host or 'localhost'}:{port or REST_DEFAULT_PORT}\n")
print(f"▶ View using ADE at: https://app.letta.com/development-servers/local/dashboard")
print(f"▶ Server running at: http://{host or 'localhost'}:{port or REST_DEFAULT_PORT}")
print(f"▶ View using ADE at: https://app.letta.com/development-servers/local/dashboard\n")
uvicorn.run(
app,
host=host or "localhost",
Expand Down
181 changes: 135 additions & 46 deletions letta/server/rest_api/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def __init__(
self,
multi_step=True,
# Related to if we want to try and pass back the AssistantMessage as a special case function
use_assistant_message=False,
assistant_message_tool_name=DEFAULT_MESSAGE_TOOL,
assistant_message_tool_kwarg=DEFAULT_MESSAGE_TOOL_KWARG,
# Related to if we expect inner_thoughts to be in the kwargs
Expand Down Expand Up @@ -295,9 +296,10 @@ def __init__(
# self.multi_step_gen_indicator = MessageStreamStatus.done_generation

# Support for AssistantMessage
self.use_assistant_message = False # TODO: Remove this
self.use_assistant_message = use_assistant_message # TODO: Remove this (actually? @charles)
self.assistant_message_tool_name = assistant_message_tool_name
self.assistant_message_tool_kwarg = assistant_message_tool_kwarg
self.prev_assistant_message_id = None # Used to skip tool call response receipts for `send_message`

# Support for inner_thoughts_in_kwargs
self.inner_thoughts_in_kwargs = inner_thoughts_in_kwargs
Expand All @@ -308,6 +310,8 @@ def __init__(
self.function_name_buffer = None
self.function_args_buffer = None
self.function_id_buffer = None
# A buffer used to store the last flushed function name
self.last_flushed_function_name = None

# extra prints
self.debug = False
Expand Down Expand Up @@ -434,7 +438,8 @@ def _process_chunk_to_letta_style(

# TODO(charles) merge into logic for internal_monologue
# special case for trapping `send_message`
if self.use_assistant_message and tool_call.function:
# if self.use_assistant_message and tool_call.function:
if not self.inner_thoughts_in_kwargs and self.use_assistant_message and tool_call.function:
if self.inner_thoughts_in_kwargs:
raise NotImplementedError("inner_thoughts_in_kwargs with use_assistant_message not yet supported")

Expand Down Expand Up @@ -535,15 +540,28 @@ def _process_chunk_to_letta_style(
# however the frontend may expect name first, then args, so to be
# safe we'll output name first in a separate chunk
if self.function_name_buffer:
processed_chunk = ToolCallMessage(
id=message_id,
date=message_date,
tool_call=ToolCallDelta(
name=self.function_name_buffer,
arguments=None,
tool_call_id=self.function_id_buffer,
),
)

# use_assisitant_message means that we should also not release main_json raw, and instead should only release the contents of "message": "..."
if self.use_assistant_message and self.function_name_buffer == self.assistant_message_tool_name:
processed_chunk = None

# Store the ID of the tool call so allow skipping the corresponding response
if self.function_id_buffer:
self.prev_assistant_message_id = self.function_id_buffer

else:
processed_chunk = ToolCallMessage(
id=message_id,
date=message_date,
tool_call=ToolCallDelta(
name=self.function_name_buffer,
arguments=None,
tool_call_id=self.function_id_buffer,
),
)

# Record what the last function name we flushed was
self.last_flushed_function_name = self.function_name_buffer
# Clear the buffer
self.function_name_buffer = None
self.function_id_buffer = None
Expand All @@ -559,35 +577,94 @@ def _process_chunk_to_letta_style(
# If there was nothing in the name buffer, we can proceed to
# output the arguments chunk as a ToolCallMessage
else:
# There may be a buffer from a previous chunk, for example
# if the previous chunk had arguments but we needed to flush name
if self.function_args_buffer:
# In this case, we should release the buffer + new data at once
combined_chunk = self.function_args_buffer + updates_main_json
processed_chunk = ToolCallMessage(
id=message_id,
date=message_date,
tool_call=ToolCallDelta(
name=None,
arguments=combined_chunk,
tool_call_id=self.function_id_buffer,
),
)
# clear buffer
self.function_args_buffer = None
self.function_id_buffer = None

# use_assisitant_message means that we should also not release main_json raw, and instead should only release the contents of "message": "..."
if self.use_assistant_message and (
self.last_flushed_function_name is not None
and self.last_flushed_function_name == self.assistant_message_tool_name
):
# do an additional parse on the updates_main_json
if self.function_args_buffer:

updates_main_json = self.function_args_buffer + updates_main_json
self.function_args_buffer = None

# Pretty gross hardcoding that assumes that if we're toggling into the keywords, we have the full prefix
match_str = '{"' + self.assistant_message_tool_kwarg + '":"'
if updates_main_json == match_str:
updates_main_json = None

else:
# Some hardcoding to strip off the trailing "}"
if updates_main_json in ["}", '"}']:
updates_main_json = None
if updates_main_json and len(updates_main_json) > 0 and updates_main_json[-1:] == '"':
updates_main_json = updates_main_json[:-1]

if not updates_main_json:
# early exit to turn into content mode
return None

# There may be a buffer from a previous chunk, for example
# if the previous chunk had arguments but we needed to flush name
if self.function_args_buffer:
# In this case, we should release the buffer + new data at once
combined_chunk = self.function_args_buffer + updates_main_json
processed_chunk = AssistantMessage(
id=message_id,
date=message_date,
assistant_message=combined_chunk,
)
# Store the ID of the tool call so allow skipping the corresponding response
if self.function_id_buffer:
self.prev_assistant_message_id = self.function_id_buffer
# clear buffer
self.function_args_buffer = None
self.function_id_buffer = None

else:
# If there's no buffer to clear, just output a new chunk with new data
processed_chunk = AssistantMessage(
id=message_id,
date=message_date,
assistant_message=updates_main_json,
)
# Store the ID of the tool call so allow skipping the corresponding response
if self.function_id_buffer:
self.prev_assistant_message_id = self.function_id_buffer
# clear buffers
self.function_id_buffer = None
else:
# If there's no buffer to clear, just output a new chunk with new data
processed_chunk = ToolCallMessage(
id=message_id,
date=message_date,
tool_call=ToolCallDelta(
name=None,
arguments=updates_main_json,
tool_call_id=self.function_id_buffer,
),
)
self.function_id_buffer = None

# There may be a buffer from a previous chunk, for example
# if the previous chunk had arguments but we needed to flush name
if self.function_args_buffer:
# In this case, we should release the buffer + new data at once
combined_chunk = self.function_args_buffer + updates_main_json
processed_chunk = ToolCallMessage(
id=message_id,
date=message_date,
tool_call=ToolCallDelta(
name=None,
arguments=combined_chunk,
tool_call_id=self.function_id_buffer,
),
)
# clear buffer
self.function_args_buffer = None
self.function_id_buffer = None
else:
# If there's no buffer to clear, just output a new chunk with new data
processed_chunk = ToolCallMessage(
id=message_id,
date=message_date,
tool_call=ToolCallDelta(
name=None,
arguments=updates_main_json,
tool_call_id=self.function_id_buffer,
),
)
self.function_id_buffer = None

# # If there's something in the main_json buffer, we should add if to the arguments and release it together
# tool_call_delta = {}
Expand Down Expand Up @@ -906,6 +983,8 @@ def function_message(self, msg: str, msg_obj: Optional[Message] = None):
date=msg_obj.created_at,
assistant_message=func_args[self.assistant_message_tool_kwarg],
)
# Store the ID of the tool call so allow skipping the corresponding response
self.prev_assistant_message_id = function_call.id
else:
processed_chunk = ToolCallMessage(
id=msg_obj.id,
Expand Down Expand Up @@ -938,13 +1017,23 @@ def function_message(self, msg: str, msg_obj: Optional[Message] = None):
msg = msg.replace("Success: ", "")
# new_message = {"function_return": msg, "status": "success"}
assert msg_obj.tool_call_id is not None
new_message = ToolReturnMessage(
id=msg_obj.id,
date=msg_obj.created_at,
tool_return=msg,
status="success",
tool_call_id=msg_obj.tool_call_id,
)

print(f"YYY printing the function call - {msg_obj.tool_call_id} == {self.prev_assistant_message_id} ???")

# Skip this is use_assistant_message is on
if self.use_assistant_message and msg_obj.tool_call_id == self.prev_assistant_message_id:
# Wipe the cache
self.prev_assistant_message_id = None
# Skip this tool call receipt
return
else:
new_message = ToolReturnMessage(
id=msg_obj.id,
date=msg_obj.created_at,
tool_return=msg,
status="success",
tool_call_id=msg_obj.tool_call_id,
)

elif msg.startswith("Error: "):
msg = msg.replace("Error: ", "", 1)
Expand Down
23 changes: 16 additions & 7 deletions letta/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1215,7 +1215,6 @@ async def send_message_to_agent(
stream_tokens: bool,
# related to whether or not we return `LettaMessage`s or `Message`s
chat_completion_mode: bool = False,
timestamp: Optional[datetime] = None,
# Support for AssistantMessage
use_assistant_message: bool = True,
assistant_message_tool_name: str = constants.DEFAULT_MESSAGE_TOOL,
Expand Down Expand Up @@ -1249,7 +1248,16 @@ async def send_message_to_agent(
stream_tokens = False

# Create a new interface per request
letta_agent.interface = StreamingServerInterface(use_assistant_message)
letta_agent.interface = StreamingServerInterface(
# multi_step=True, # would we ever want to disable this?
use_assistant_message=use_assistant_message,
assistant_message_tool_name=assistant_message_tool_name,
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
inner_thoughts_in_kwargs=(
llm_config.put_inner_thoughts_in_kwargs if llm_config.put_inner_thoughts_in_kwargs is not None else False
),
# inner_thoughts_kwarg=INNER_THOUGHTS_KWARG,
)
streaming_interface = letta_agent.interface
if not isinstance(streaming_interface, StreamingServerInterface):
raise ValueError(f"Agent has wrong type of interface: {type(streaming_interface)}")
Expand All @@ -1263,13 +1271,14 @@ async def send_message_to_agent(
# streaming_interface.function_call_legacy_mode = stream

# Allow AssistantMessage is desired by client
streaming_interface.assistant_message_tool_name = assistant_message_tool_name
streaming_interface.assistant_message_tool_kwarg = assistant_message_tool_kwarg
# streaming_interface.use_assistant_message = use_assistant_message
# streaming_interface.assistant_message_tool_name = assistant_message_tool_name
# streaming_interface.assistant_message_tool_kwarg = assistant_message_tool_kwarg

# Related to JSON buffer reader
streaming_interface.inner_thoughts_in_kwargs = (
llm_config.put_inner_thoughts_in_kwargs if llm_config.put_inner_thoughts_in_kwargs is not None else False
)
# streaming_interface.inner_thoughts_in_kwargs = (
# llm_config.put_inner_thoughts_in_kwargs if llm_config.put_inner_thoughts_in_kwargs is not None else False
# )

# Offload the synchronous message_func to a separate thread
streaming_interface.stream_start()
Expand Down
6 changes: 4 additions & 2 deletions letta/streaming_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Tuple

from letta.constants import DEFAULT_MESSAGE_TOOL_KWARG

Expand Down Expand Up @@ -48,7 +48,7 @@ def __init__(self, inner_thoughts_key="inner_thoughts", wait_for_first_key=False
self.hold_main_json = wait_for_first_key
self.main_json_held_buffer = ""

def process_fragment(self, fragment):
def process_fragment(self, fragment: str) -> Tuple[str, str]:
updates_main_json = ""
updates_inner_thoughts = ""
i = 0
Expand Down Expand Up @@ -263,8 +263,10 @@ def process_json_chunk(self, chunk: str) -> Optional[str]:
self.key_buffer = ""
self.accumulating = True
return None

if chunk.strip() == "}":
self.in_message = False
self.message_started = False
return None

return None
Loading

0 comments on commit c170c25

Please sign in to comment.