Skip to content

Commit

Permalink
fix: redo the PR
Browse files Browse the repository at this point in the history
  • Loading branch information
cpacker committed Nov 22, 2024
1 parent b7b22a6 commit 181ea02
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 1 deletion.
15 changes: 15 additions & 0 deletions letta/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1600,6 +1600,18 @@ def remove_agent_memory_block(self, agent_id: str, block_label: str) -> Memory:
raise ValueError(f"Failed to remove agent memory block: {response.text}")
return Memory(**response.json())

def update_agent_memory_limit(self, agent_id: str, block_label: str, limit: int) -> Memory:

# @router.patch("/{agent_id}/memory/limit", response_model=Memory, operation_id="update_agent_memory_limit")
response = requests.patch(
f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/memory/limit",
headers=self.headers,
json={"label": block_label, "limit": limit},
)
if response.status_code != 200:
raise ValueError(f"Failed to update agent memory limit: {response.text}")
return Memory(**response.json())


class LocalClient(AbstractClient):
"""
Expand Down Expand Up @@ -2823,3 +2835,6 @@ def add_agent_memory_block(self, agent_id: str, create_block: BlockCreate) -> Me

def remove_agent_memory_block(self, agent_id: str, block_label: str) -> Memory:
return self.server.unlink_block_from_agent_memory(user_id=self.user_id, agent_id=agent_id, block_label=block_label)

def update_agent_memory_limit(self, agent_id: str, block_label: str, limit: int) -> Memory:
return self.server.update_agent_memory_limit(user_id=self.user_id, agent_id=agent_id, block_label=block_label, limit=limit)
7 changes: 7 additions & 0 deletions letta/schemas/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,13 @@ class Config:
extra = "ignore" # Ignores extra fields


class BlockLimitUpdate(BaseModel):
"""Update the limit of a block"""

label: str = Field(..., description="Label of the block.")
limit: int = Field(..., description="New limit of the block.")


class UpdatePersona(BlockUpdate):
"""Update a persona block"""

Expand Down
13 changes: 13 additions & 0 deletions letta/schemas/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,19 @@ def update_block_label(self, current_label: str, new_label: str):
# Then swap the block to the new label
self.memory[new_label] = self.memory.pop(current_label)

def update_block_limit(self, label: str, limit: int):
"""Update the limit of a block"""
if label not in self.memory:
raise ValueError(f"Block with label {label} does not exist")
if not isinstance(limit, int):
raise ValueError(f"Provided limit must be an integer")

# Check to make sure the new limit is greater than the current length of the block
if len(self.memory[label].value) > limit:
raise ValueError(f"New limit {limit} is less than the current length of the block {len(self.memory[label].value)}")

self.memory[label].limit = limit


# TODO: ideally this is refactored into ChatMemory and the subclasses are given more specific names.
class BasicBlockMemory(Memory):
Expand Down
24 changes: 23 additions & 1 deletion letta/server/rest_api/routers/v1/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgentState
from letta.schemas.block import Block, BlockCreate, BlockLabelUpdate
from letta.schemas.block import Block, BlockCreate, BlockLabelUpdate, BlockLimitUpdate
from letta.schemas.enums import MessageStreamStatus
from letta.schemas.letta_message import (
LegacyLettaMessage,
Expand Down Expand Up @@ -218,6 +218,7 @@ def update_agent_memory(
):
"""
Update the core memory of a specific agent.
This endpoint accepts new memory contents (labels as keys, and values as values) and updates the core memory of the agent identified by the user ID and agent ID.
This endpoint accepts new memory contents to update the core memory of the agent.
This endpoint only supports modifying existing blocks; it does not support deleting/unlinking or creating/linking blocks.
"""
Expand Down Expand Up @@ -287,6 +288,27 @@ def remove_agent_memory_block(
return updated_memory


@router.patch("/{agent_id}/memory/limit", response_model=Memory, operation_id="update_agent_memory_limit")
def update_agent_memory_limit(
agent_id: str,
update_label: BlockLimitUpdate = Body(...),
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Update the limit of a block in an agent's memory.
"""
actor = server.get_user_or_default(user_id=user_id)

memory = server.update_agent_memory_limit(
user_id=actor.id,
agent_id=agent_id,
block_label=update_label.label,
limit=update_label.limit,
)
return memory


@router.get("/{agent_id}/memory/recall", response_model=RecallMemorySummary, operation_id="get_agent_recall_memory_summary")
def get_agent_recall_memory_summary(
agent_id: str,
Expand Down
30 changes: 30 additions & 0 deletions letta/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1929,3 +1929,33 @@ def unlink_block_from_agent_memory(self, user_id: str, agent_id: str, block_labe
raise ValueError(f"Agent with id {agent_id} not found after linking block")
assert unlinked_block.label not in updated_agent.memory.list_block_labels()
return updated_agent.memory

def update_agent_memory_limit(self, user_id: str, agent_id: str, block_label: str, limit: int) -> Memory:
"""Update the limit of a block in an agent's memory"""

# Get the user
user = self.user_manager.get_user_by_id(user_id=user_id)

# Link a block to an agent's memory
letta_agent = self._get_or_load_agent(agent_id=agent_id)
letta_agent.memory.update_block_limit(label=block_label, limit=limit)
assert block_label in letta_agent.memory.list_block_labels()

# write out the update the database
self.block_manager.create_or_update_block(block=letta_agent.memory.get_block(block_label), actor=user)

# check that the block was updated
updated_block = self.block_manager.get_block_by_id(block_id=letta_agent.memory.get_block(block_label).id, actor=user)
assert updated_block and updated_block.limit == limit

# Recompile the agent memory
letta_agent.rebuild_memory(force=True, ms=self.ms)

# save agent
save_agent(letta_agent, self.ms)

updated_agent = self.ms.get_agent(agent_id=agent_id)
if updated_agent is None:
raise ValueError(f"Agent with id {agent_id} not found after linking block")
assert updated_agent.memory.get_block(label=block_label).limit == limit
return updated_agent.memory
31 changes: 31 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,3 +810,34 @@ def test_add_remove_agent_memory_block(client: Union[LocalClient, RESTClient], a

# finally:
# client.delete_agent(new_agent.id)


def test_update_agent_memory_limit(client: Union[LocalClient, RESTClient], agent: AgentState):
"""Test that we can update the limit of a block in an agent's memory"""

agent = client.create_agent(name=create_random_username())

try:
current_labels = agent.memory.list_block_labels()
example_label = current_labels[0]
example_new_limit = 1
current_block = agent.memory.get_block(label=example_label)
current_block_length = len(current_block.value)

assert example_new_limit != agent.memory.get_block(label=example_label).limit
assert example_new_limit < current_block_length

# We expect this to throw a value error
with pytest.raises(ValueError):
client.update_agent_memory_limit(agent_id=agent.id, block_label=example_label, limit=example_new_limit)

# Now try the same thing with a higher limit
example_new_limit = current_block_length + 10000
assert example_new_limit > current_block_length
client.update_agent_memory_limit(agent_id=agent.id, block_label=example_label, limit=example_new_limit)

updated_agent = client.get_agent(agent_id=agent.id)
assert example_new_limit == updated_agent.memory.get_block(label=example_label).limit

finally:
client.delete_agent(agent.id)
13 changes: 13 additions & 0 deletions tests/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,16 @@ def test_update_block_label(sample_memory: Memory):
sample_memory.update_block_label(current_label=test_old_label, new_label=test_new_label)
assert test_new_label in sample_memory.list_block_labels()
assert test_old_label not in sample_memory.list_block_labels()


def test_update_block_limit(sample_memory: Memory):
"""Test updating the limit of a block"""

test_new_limit = 1000
current_labels = sample_memory.list_block_labels()
test_old_label = current_labels[0]

assert sample_memory.get_block(label=test_old_label).limit != test_new_limit

sample_memory.update_block_limit(label=test_old_label, limit=test_new_limit)
assert sample_memory.get_block(label=test_old_label).limit == test_new_limit

0 comments on commit 181ea02

Please sign in to comment.