diff --git a/alembic/versions/1c8880d671ee_make_an_blocks_agents_mapping_table.py b/alembic/versions/1c8880d671ee_make_an_blocks_agents_mapping_table.py new file mode 100644 index 0000000000..ffcb0b6715 --- /dev/null +++ b/alembic/versions/1c8880d671ee_make_an_blocks_agents_mapping_table.py @@ -0,0 +1,52 @@ +"""Make an blocks agents mapping table + +Revision ID: 1c8880d671ee +Revises: f81ceea2c08d +Create Date: 2024-11-22 15:42:47.209229 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "1c8880d671ee" +down_revision: Union[str, None] = "f81ceea2c08d" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_unique_constraint("unique_block_id_label", "block", ["id", "label"]) + + op.create_table( + "blocks_agents", + sa.Column("agent_id", sa.String(), nullable=False), + sa.Column("block_id", sa.String(), nullable=False), + sa.Column("block_label", sa.String(), nullable=False), + sa.Column("id", sa.String(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False), + sa.Column("_created_by_id", sa.String(), nullable=True), + sa.Column("_last_updated_by_id", sa.String(), nullable=True), + sa.ForeignKeyConstraint( + ["agent_id"], + ["agents.id"], + ), + sa.ForeignKeyConstraint(["block_id", "block_label"], ["block.id", "block.label"], name="fk_block_id_label"), + sa.PrimaryKeyConstraint("agent_id", "block_id", "block_label", "id"), + sa.UniqueConstraint("agent_id", "block_label", name="unique_label_per_agent"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint("unique_block_id_label", "block", type_="unique") + op.drop_table("blocks_agents") + # ### end Alembic commands ### diff --git a/letta/orm/__init__.py b/letta/orm/__init__.py index eeed7c2e86..cd682f9955 100644 --- a/letta/orm/__init__.py +++ b/letta/orm/__init__.py @@ -1,5 +1,6 @@ from letta.orm.base import Base from letta.orm.block import Block +from letta.orm.blocks_agents import BlocksAgents from letta.orm.file import FileMetadata from letta.orm.organization import Organization from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable diff --git a/letta/orm/block.py b/letta/orm/block.py index dddb5631cc..ab7e40802e 100644 --- a/letta/orm/block.py +++ b/letta/orm/block.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING, Optional, Type -from sqlalchemy import JSON, BigInteger, Integer +from sqlalchemy import JSON, BigInteger, Integer, UniqueConstraint from sqlalchemy.orm import Mapped, mapped_column, relationship from letta.constants import CORE_MEMORY_BLOCK_CHAR_LIMIT @@ -18,6 +18,8 @@ class Block(OrganizationMixin, SqlalchemyBase): __tablename__ = "block" __pydantic_model__ = PydanticBlock + # This may seem redundant, but is necessary for the BlocksAgents composite FK relationship + __table_args__ = (UniqueConstraint("id", "label", name="unique_block_id_label"),) template_name: Mapped[Optional[str]] = mapped_column( nullable=True, doc="the unique name that identifies a block in a human-readable way" diff --git a/letta/orm/blocks_agents.py b/letta/orm/blocks_agents.py new file mode 100644 index 0000000000..31f0fa9d34 --- /dev/null +++ b/letta/orm/blocks_agents.py @@ -0,0 +1,29 @@ +from sqlalchemy import ForeignKey, ForeignKeyConstraint, String, UniqueConstraint +from sqlalchemy.orm import Mapped, mapped_column + +from letta.orm.sqlalchemy_base import SqlalchemyBase +from letta.schemas.blocks_agents import BlocksAgents as PydanticBlocksAgents + + +class BlocksAgents(SqlalchemyBase): + """Agents must have one or many blocks to make up their core memory.""" + + __tablename__ = "blocks_agents" + __pydantic_model__ = PydanticBlocksAgents + __table_args__ = ( + UniqueConstraint( + "agent_id", + "block_label", + name="unique_label_per_agent", + ), + ForeignKeyConstraint( + ["block_id", "block_label"], + ["block.id", "block.label"], + name="fk_block_id_label", + ), + ) + + # unique agent + block label + agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id"), primary_key=True) + block_id: Mapped[str] = mapped_column(String, primary_key=True) + block_label: Mapped[str] = mapped_column(String, primary_key=True) diff --git a/letta/schemas/blocks_agents.py b/letta/schemas/blocks_agents.py new file mode 100644 index 0000000000..8b33925a8c --- /dev/null +++ b/letta/schemas/blocks_agents.py @@ -0,0 +1,32 @@ +from datetime import datetime +from typing import Optional + +from pydantic import Field + +from letta.schemas.letta_base import LettaBase + + +class BlocksAgentsBase(LettaBase): + __id_prefix__ = "blocks_agents" + + +class BlocksAgents(BlocksAgentsBase): + """ + Schema representing the relationship between blocks and agents. + + Parameters: + agent_id (str): The ID of the associated agent. + block_id (str): The ID of the associated block. + block_label (str): The label of the block. + created_at (datetime): The date this relationship was created. + updated_at (datetime): The date this relationship was last updated. + is_deleted (bool): Whether this block-agent relationship is deleted or not. + """ + + id: str = BlocksAgentsBase.generate_id_field() + agent_id: str = Field(..., description="The ID of the associated agent.") + block_id: str = Field(..., description="The ID of the associated block.") + block_label: str = Field(..., description="The label of the block.") + created_at: Optional[datetime] = Field(None, description="The creation date of the association.") + updated_at: Optional[datetime] = Field(None, description="The update date of the association.") + is_deleted: bool = Field(False, description="Whether this block-agent relationship is deleted or not.") diff --git a/letta/server/server.py b/letta/server/server.py index 99267176dd..be4b827ae8 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -77,6 +77,7 @@ from letta.schemas.user import User from letta.services.agents_tags_manager import AgentsTagsManager from letta.services.block_manager import BlockManager +from letta.services.blocks_agents_manager import BlocksAgentsManager from letta.services.organization_manager import OrganizationManager from letta.services.sandbox_config_manager import SandboxConfigManager from letta.services.source_manager import SourceManager @@ -248,6 +249,7 @@ def __init__( self.block_manager = BlockManager() self.source_manager = SourceManager() self.agents_tags_manager = AgentsTagsManager() + self.blocks_agents_manager = BlocksAgentsManager() self.sandbox_config_manager = SandboxConfigManager(tool_settings) # Make default user and org diff --git a/letta/services/block_manager.py b/letta/services/block_manager.py index 3855f5a61d..c559d05ac9 100644 --- a/letta/services/block_manager.py +++ b/letta/services/block_manager.py @@ -36,7 +36,7 @@ def create_or_update_block(self, block: Block, actor: PydanticUser) -> PydanticB return block.to_pydantic() @enforce_types - def update_block(self, block_id: str, block_update: BlockUpdate, actor: PydanticUser, limit: Optional[int] = None) -> PydanticBlock: + def update_block(self, block_id: str, block_update: BlockUpdate, actor: PydanticUser) -> PydanticBlock: """Update a block by its ID with the given BlockUpdate object.""" with self.session_maker() as session: block = BlockModel.read(db_session=session, identifier=block_id, actor=actor) diff --git a/letta/services/blocks_agents_manager.py b/letta/services/blocks_agents_manager.py new file mode 100644 index 0000000000..bbc5bfc042 --- /dev/null +++ b/letta/services/blocks_agents_manager.py @@ -0,0 +1,84 @@ +import warnings +from typing import List + +from letta.orm.blocks_agents import BlocksAgents as BlocksAgentsModel +from letta.orm.errors import NoResultFound +from letta.schemas.blocks_agents import BlocksAgents as PydanticBlocksAgents +from letta.utils import enforce_types + + +# TODO: DELETE THIS ASAP +# TODO: So we have a patch where we manually specify CRUD operations +# TODO: This is because Agent is NOT migrated to the ORM yet +# TODO: Once we migrate Agent to the ORM, we should deprecate any agents relationship table managers +class BlocksAgentsManager: + """Manager class to handle business logic related to Blocks and Agents.""" + + def __init__(self): + from letta.server.server import db_context + + self.session_maker = db_context + + @enforce_types + def add_block_to_agent(self, agent_id: str, block_id: str, block_label: str) -> PydanticBlocksAgents: + """Add a block to an agent. If the label already exists on that agent, this will error.""" + with self.session_maker() as session: + try: + # Check if the block-label combination already exists for this agent + blocks_agents_record = BlocksAgentsModel.read(db_session=session, agent_id=agent_id, block_label=block_label) + warnings.warn(f"Block label '{block_label}' already exists for agent '{agent_id}'.") + except NoResultFound: + blocks_agents_record = PydanticBlocksAgents(agent_id=agent_id, block_id=block_id, block_label=block_label) + blocks_agents_record = BlocksAgentsModel(**blocks_agents_record.model_dump(exclude_none=True)) + blocks_agents_record.create(session) + + return blocks_agents_record.to_pydantic() + + @enforce_types + def remove_block_with_label_from_agent(self, agent_id: str, block_label: str) -> PydanticBlocksAgents: + """Remove a block with a label from an agent.""" + with self.session_maker() as session: + try: + # Find and delete the block-label association for the agent + blocks_agents_record = BlocksAgentsModel.read(db_session=session, agent_id=agent_id, block_label=block_label) + blocks_agents_record.hard_delete(session) + return blocks_agents_record.to_pydantic() + except NoResultFound: + raise ValueError(f"Block label '{block_label}' not found for agent '{agent_id}'.") + + @enforce_types + def remove_block_with_id_from_agent(self, agent_id: str, block_id: str) -> PydanticBlocksAgents: + """Remove a block with a label from an agent.""" + with self.session_maker() as session: + try: + # Find and delete the block-label association for the agent + blocks_agents_record = BlocksAgentsModel.read(db_session=session, agent_id=agent_id, block_id=block_id) + blocks_agents_record.hard_delete(session) + return blocks_agents_record.to_pydantic() + except NoResultFound: + raise ValueError(f"Block id '{block_id}' not found for agent '{agent_id}'.") + + @enforce_types + def update_block_id_for_agent(self, agent_id: str, block_label: str, new_block_id: str) -> PydanticBlocksAgents: + """Update the block ID for a specific block label for an agent.""" + with self.session_maker() as session: + try: + blocks_agents_record = BlocksAgentsModel.read(db_session=session, agent_id=agent_id, block_label=block_label) + blocks_agents_record.block_id = new_block_id + return blocks_agents_record.to_pydantic() + except NoResultFound: + raise ValueError(f"Block label '{block_label}' not found for agent '{agent_id}'.") + + @enforce_types + def list_block_ids_for_agent(self, agent_id: str) -> List[str]: + """List all blocks associated with a specific agent.""" + with self.session_maker() as session: + blocks_agents_record = BlocksAgentsModel.list(db_session=session, agent_id=agent_id) + return [record.block_id for record in blocks_agents_record] + + @enforce_types + def list_agent_ids_with_block(self, block_id: str) -> List[str]: + """List all agents associated with a specific block.""" + with self.session_maker() as session: + blocks_agents_record = BlocksAgentsModel.list(db_session=session, block_id=block_id) + return [record.agent_id for record in blocks_agents_record] diff --git a/tests/test_managers.py b/tests/test_managers.py index 946d4edbd3..7ac9b3d23e 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -1,10 +1,13 @@ import pytest from sqlalchemy import delete +from sqlalchemy.exc import DBAPIError import letta.utils as utils from letta.functions.functions import derive_openai_json_schema, parse_source_code +from letta.metadata import AgentModel from letta.orm import ( Block, + BlocksAgents, FileMetadata, Organization, SandboxConfig, @@ -13,6 +16,7 @@ Tool, User, ) +from letta.orm.agents_tags import AgentsTags from letta.schemas.agent import CreateAgent from letta.schemas.block import Block as PydanticBlock from letta.schemas.block import BlockUpdate @@ -60,12 +64,15 @@ def clear_tables(server: SyncServer): """Fixture to clear the organization table before each test.""" with server.organization_manager.session_maker() as session: + session.execute(delete(BlocksAgents)) + session.execute(delete(AgentsTags)) session.execute(delete(SandboxEnvironmentVariable)) session.execute(delete(SandboxConfig)) session.execute(delete(Block)) session.execute(delete(FileMetadata)) session.execute(delete(Source)) session.execute(delete(Tool)) # Clear all records from the Tool table + session.execute(delete(AgentModel)) session.execute(delete(User)) # Clear all records from the user table session.execute(delete(Organization)) # Clear all records from the organization table session.commit() # Commit the deletion @@ -121,8 +128,6 @@ def sarah_agent(server: SyncServer, default_user, default_organization): ) yield agent_state - server.delete_agent(user_id=default_user.id, agent_id=agent_state.id) - @pytest.fixture def charles_agent(server: SyncServer, default_user, default_organization): @@ -141,8 +146,6 @@ def charles_agent(server: SyncServer, default_user, default_organization): ) yield agent_state - server.delete_agent(user_id=default_user.id, agent_id=agent_state.id) - @pytest.fixture def tool_fixture(server: SyncServer, default_user, default_organization): @@ -200,6 +203,36 @@ def sandbox_env_var_fixture(server: SyncServer, sandbox_config_fixture, default_ yield created_env_var +@pytest.fixture +def default_block(server: SyncServer, default_user): + """Fixture to create and return a default block.""" + block_manager = BlockManager() + block_data = PydanticBlock( + label="default_label", + value="Default Block Content", + description="A default test block", + limit=1000, + metadata_={"type": "test"}, + ) + block = block_manager.create_or_update_block(block_data, actor=default_user) + yield block + + +@pytest.fixture +def other_block(server: SyncServer, default_user): + """Fixture to create and return another block.""" + block_manager = BlockManager() + block_data = PydanticBlock( + label="other_label", + value="Other Block Content", + description="Another test block", + limit=500, + metadata_={"type": "test"}, + ) + block = block_manager.create_or_update_block(block_data, actor=default_user) + yield block + + @pytest.fixture(scope="module") def server(): config = LettaConfig.load() @@ -561,7 +594,7 @@ def test_update_block_limit(server: SyncServer, default_user): except Exception: pass - block_manager.update_block(block_id=block.id, block_update=update_data, actor=default_user, limit=limit) + block_manager.update_block(block_id=block.id, block_update=update_data, actor=default_user) # Retrieve the updated block updated_block = block_manager.get_blocks(actor=default_user, id=block.id)[0] # Assertions to verify the update @@ -1018,3 +1051,85 @@ def test_get_sandbox_env_var_by_key(server: SyncServer, sandbox_env_var_fixture, # Assertions to verify correct retrieval assert retrieved_env_var.id == sandbox_env_var_fixture.id assert retrieved_env_var.key == sandbox_env_var_fixture.key + + +# ====================================================================================================================== +# BlocksAgentsManager Tests +# ====================================================================================================================== +def test_add_block_to_agent(server, sarah_agent, default_user, default_block): + block_association = server.blocks_agents_manager.add_block_to_agent( + agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label + ) + + assert block_association.agent_id == sarah_agent.id + assert block_association.block_id == default_block.id + assert block_association.block_label == default_block.label + + +def test_add_block_to_agent_nonexistent_block(server, sarah_agent, default_user): + with pytest.raises(DBAPIError, match="violates foreign key constraint .*fk_block_id_label"): + server.blocks_agents_manager.add_block_to_agent( + agent_id=sarah_agent.id, block_id="nonexistent_block", block_label="nonexistent_label" + ) + + +def test_add_block_to_agent_duplicate_label(server, sarah_agent, default_user, default_block, other_block): + server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label) + + with pytest.warns(UserWarning, match=f"Block label '{default_block.label}' already exists for agent '{sarah_agent.id}'"): + server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=other_block.id, block_label=default_block.label) + + +def test_remove_block_with_label_from_agent(server, sarah_agent, default_user, default_block): + server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label) + + removed_block = server.blocks_agents_manager.remove_block_with_label_from_agent( + agent_id=sarah_agent.id, block_label=default_block.label + ) + + assert removed_block.block_label == default_block.label + assert removed_block.block_id == default_block.id + assert removed_block.agent_id == sarah_agent.id + + with pytest.raises(ValueError, match=f"Block label '{default_block.label}' not found for agent '{sarah_agent.id}'"): + server.blocks_agents_manager.remove_block_with_label_from_agent(agent_id=sarah_agent.id, block_label=default_block.label) + + +def test_update_block_id_for_agent(server, sarah_agent, default_user, default_block, other_block): + server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label) + + updated_block = server.blocks_agents_manager.update_block_id_for_agent( + agent_id=sarah_agent.id, block_label=default_block.label, new_block_id=other_block.id + ) + + assert updated_block.block_id == other_block.id + assert updated_block.block_label == default_block.label + assert updated_block.agent_id == sarah_agent.id + + +def test_list_block_ids_for_agent(server, sarah_agent, default_user, default_block, other_block): + server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label) + server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=other_block.id, block_label=other_block.label) + + retrieved_block_ids = server.blocks_agents_manager.list_block_ids_for_agent(agent_id=sarah_agent.id) + + assert set(retrieved_block_ids) == {default_block.id, other_block.id} + + +def test_list_agent_ids_with_block(server, sarah_agent, charles_agent, default_user, default_block): + server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label) + server.blocks_agents_manager.add_block_to_agent(agent_id=charles_agent.id, block_id=default_block.id, block_label=default_block.label) + + agent_ids = server.blocks_agents_manager.list_agent_ids_with_block(block_id=default_block.id) + + assert sarah_agent.id in agent_ids + assert charles_agent.id in agent_ids + assert len(agent_ids) == 2 + + +def test_add_block_to_agent_with_deleted_block(server, sarah_agent, default_user, default_block): + block_manager = BlockManager() + block_manager.delete_block(block_id=default_block.id, actor=default_user) + + with pytest.raises(DBAPIError, match='insert or update on table "blocks_agents" violates foreign key constraint'): + server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label)