Skip to content

Commit

Permalink
Fix out of order error
Browse files Browse the repository at this point in the history
  • Loading branch information
mattzh72 committed Dec 20, 2024
1 parent a80f74b commit 3556fd1
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 11 deletions.
13 changes: 8 additions & 5 deletions letta/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def update_memory_if_change(self, new_memory: Memory) -> bool:
# NOTE: don't do this since re-buildin the memory is handled at the start of the step
# rebuild memory - this records the last edited timestamp of the memory
# TODO: pass in update timestamp from block edit time
self.agent_manager.rebuild_system_prompt(agent_id=self.agent_state.id, actor=self.user)
self.agent_state = self.agent_manager.rebuild_system_prompt(agent_id=self.agent_state.id, actor=self.user)

return True
return False
Expand Down Expand Up @@ -565,7 +565,7 @@ def _handle_ai_response(

# rebuild memory
# TODO: @charles please check this
self.agent_manager.rebuild_system_prompt(agent_id=self.agent_state.id, actor=self.user)
self.agent_state = self.agent_manager.rebuild_system_prompt(agent_id=self.agent_state.id, actor=self.user)

# Update ToolRulesSolver state with last called function
self.tool_rules_solver.update_tool_usage(function_name)
Expand Down Expand Up @@ -597,6 +597,7 @@ def step(
messages=next_input_message,
**kwargs,
)

heartbeat_request = step_response.heartbeat_request
function_failed = step_response.function_failed
token_warning = step_response.in_context_memory_warning
Expand Down Expand Up @@ -748,7 +749,9 @@ def inner_step(
f"last response total_tokens ({current_total_tokens}) < {MESSAGE_SUMMARY_WARNING_FRAC * int(self.agent_state.llm_config.context_window)}"
)

self.agent_manager.append_to_in_context_messages(all_new_messages, agent_id=self.agent_state.id, actor=self.user)
self.agent_state = self.agent_manager.append_to_in_context_messages(
all_new_messages, agent_id=self.agent_state.id, actor=self.user
)

return AgentStepResponse(
messages=all_new_messages,
Expand Down Expand Up @@ -917,9 +920,9 @@ def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True,
printd(f"Packaged into message: {summary_message}")

prior_len = len(in_context_messages_openai)
self.agent_manager.trim_older_in_context_messages(cutoff, agent_id=self.agent_state.id, actor=self.user)
self.agent_state = self.agent_manager.trim_older_in_context_messages(cutoff, agent_id=self.agent_state.id, actor=self.user)
packed_summary_message = {"role": "user", "content": summary_message}
self.agent_manager.prepend_to_in_context_messages(
self.agent_state = self.agent_manager.prepend_to_in_context_messages(
messages=[
Message.dict_to_message(
agent_id=self.agent_state.id,
Expand Down
2 changes: 1 addition & 1 deletion letta/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ def _step(
)

# save agent after step
save_agent(letta_agent)
# save_agent(letta_agent)

except Exception as e:
logger.error(f"Error in server._step: {e}")
Expand Down
11 changes: 9 additions & 2 deletions letta/services/message_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,18 @@ def get_message_by_id(self, message_id: str, actor: PydanticUser) -> Optional[Py

@enforce_types
def get_messages_by_ids(self, message_ids: List[str], actor: PydanticUser) -> List[PydanticMessage]:
"""Fetch a message by ID."""
"""Fetch messages by ID and return them in the requested order."""
with self.session_maker() as session:
results = MessageModel.list(db_session=session, id=message_ids, organization_id=actor.organization_id)

return [msg.to_pydantic() for msg in results]
if len(results) != len(message_ids):
raise NoResultFound(
f"Expected {len(message_ids)} messages, but found {len(results)}. Missing ids={set(message_ids) - set([r.id for r in results])}"
)

# Sort results directly based on message_ids
result_dict = {msg.id: msg.to_pydantic() for msg in results}
return [result_dict[msg_id] for msg_id in message_ids]

@enforce_types
def create_message(self, pydantic_msg: PydanticMessage, actor: PydanticUser) -> PydanticMessage:
Expand Down
6 changes: 3 additions & 3 deletions tests/test_summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ def summarize_message_exists(messages: List[Message]) -> bool:

# check if the summarize message is inside the messages
assert isinstance(client, LocalClient), "Test only works with LocalClient"
agent_obj = client.server.load_agent(agent_id=agent_state.id, actor=client.user)
print("SUMMARY", summarize_message_exists(agent_obj._messages))
if summarize_message_exists(agent_obj._messages):
in_context_messages = client.server.agent_manager.get_in_context_messages(agent_id=agent_state.id, actor=client.user)
print("SUMMARY", summarize_message_exists(in_context_messages))
if summarize_message_exists(in_context_messages):
break

if message_count > MAX_ATTEMPTS:
Expand Down

0 comments on commit 3556fd1

Please sign in to comment.