diff --git a/.github/workflows/docker-integration-tests.yaml b/.github/workflows/docker-integration-tests.yaml index 6357afec4b..a6683446ef 100644 --- a/.github/workflows/docker-integration-tests.yaml +++ b/.github/workflows/docker-integration-tests.yaml @@ -56,7 +56,7 @@ jobs: run: | pipx install poetry==1.8.2 poetry install -E dev -E postgres - poetry run pytest -s tests/test_client.py + poetry run pytest -s tests/test_client_legacy.py - name: Print docker logs if tests fail if: failure() diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1d658c660e..42ae9e6d3b 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -6,12 +6,12 @@ env: ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }} GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }} + E2B_API_KEY: ${{ secrets.E2B_API_KEY }} on: push: branches: [ main ] pull_request: - branches: [ main ] jobs: run-core-unit-tests: @@ -21,14 +21,15 @@ jobs: fail-fast: false matrix: test_suite: - - "test_local_client.py" - "test_client.py" + - "test_local_client.py" + - "test_client_legacy.py" - "test_server.py" - "test_managers.py" - - "test_tools.py" - "test_o1_agent.py" - "test_tool_rule_solver.py" - "test_agent_tool_graph.py" + - "test_tool_execution_sandbox.py" - "test_utils.py" services: qdrant: @@ -58,7 +59,7 @@ jobs: with: python-version: "3.12" poetry-version: "1.8.2" - install-args: "-E dev -E postgres -E milvus -E external-tools -E tests" + install-args: "-E dev -E postgres -E milvus -E external-tools -E tests -E cloud-tool-sandbox" - name: Migrate database env: LETTA_PG_PORT: 5432 @@ -111,7 +112,7 @@ jobs: with: python-version: "3.12" poetry-version: "1.8.2" - install-args: "-E dev -E postgres -E milvus -E external-tools -E tests" + install-args: "-E dev -E postgres -E milvus -E external-tools -E tests -E cloud-tool-sandbox" - name: Migrate database env: LETTA_PG_PORT: 5432 @@ -132,4 +133,4 @@ jobs: LETTA_SERVER_PASS: test_server_token PYTHONPATH: ${{ github.workspace }}:${{ env.PYTHONPATH }} run: | - poetry run pytest -s -vv -k "not test_utils.py and not integration_test_summarizer.py and not test_agent_tool_graph.py and not test_tool_rule_solver.py and not test_local_client.py and not test_o1_agent.py and not test_cli.py and not test_tools.py and not test_concurrent_connections.py and not test_quickstart and not test_model_letta_perfomance and not test_storage and not test_server and not test_openai_client and not test_providers and not test_client.py" tests + poetry run pytest -s -vv -k "not test_model_letta_perfomance.py and not test_utils.py and not test_client.py and not test_tool_execution_sandbox.py and not integration_test_summarizer.py and not test_agent_tool_graph.py and not test_tool_rule_solver.py and not test_local_client.py and not test_o1_agent.py and not test_cli.py and not test_concurrent_connections.py and not test_quickstart and not test_model_letta_performance and not test_storage and not test_server and not test_openai_client and not test_providers and not test_client_legacy.py" tests diff --git a/.gitignore b/.gitignore index f9330dd9c8..1fcffd8a4c 100644 --- a/.gitignore +++ b/.gitignore @@ -1018,8 +1018,3 @@ pgdata/ letta/.pytest_cache/ memgpy/pytest.ini **/**/pytest_cache - - -# local sandbox venvs -letta/services/tool_sandbox_env/* -tests/test_tool_sandbox/* diff --git a/alembic/versions/f81ceea2c08d_create_sandbox_config_and_sandbox_env_.py b/alembic/versions/f81ceea2c08d_create_sandbox_config_and_sandbox_env_.py new file mode 100644 index 0000000000..55332bfc15 --- /dev/null +++ b/alembic/versions/f81ceea2c08d_create_sandbox_config_and_sandbox_env_.py @@ -0,0 +1,73 @@ +"""Create sandbox config and sandbox env var tables + +Revision ID: f81ceea2c08d +Revises: c85a3d07c028 +Create Date: 2024-11-14 17:51:27.263561 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "f81ceea2c08d" +down_revision: Union[str, None] = "f7507eab4bb9" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "sandbox_configs", + sa.Column("id", sa.String(), nullable=False), + sa.Column("type", sa.Enum("E2B", "LOCAL", name="sandboxtype"), nullable=False), + sa.Column("config", sa.JSON(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False), + sa.Column("_created_by_id", sa.String(), nullable=True), + sa.Column("_last_updated_by_id", sa.String(), nullable=True), + sa.Column("organization_id", sa.String(), nullable=False), + sa.ForeignKeyConstraint( + ["organization_id"], + ["organizations.id"], + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("type", "organization_id", name="uix_type_organization"), + ) + op.create_table( + "sandbox_environment_variables", + sa.Column("id", sa.String(), nullable=False), + sa.Column("key", sa.String(), nullable=False), + sa.Column("value", sa.String(), nullable=False), + sa.Column("description", sa.String(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False), + sa.Column("_created_by_id", sa.String(), nullable=True), + sa.Column("_last_updated_by_id", sa.String(), nullable=True), + sa.Column("organization_id", sa.String(), nullable=False), + sa.Column("sandbox_config_id", sa.String(), nullable=False), + sa.ForeignKeyConstraint( + ["organization_id"], + ["organizations.id"], + ), + sa.ForeignKeyConstraint( + ["sandbox_config_id"], + ["sandbox_configs.id"], + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("key", "sandbox_config_id", name="uix_key_sandbox_config"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("sandbox_environment_variables") + op.drop_table("sandbox_configs") + # ### end Alembic commands ### diff --git a/examples/docs/tools.py b/examples/docs/tools.py index 382e4520e4..3682d24825 100644 --- a/examples/docs/tools.py +++ b/examples/docs/tools.py @@ -8,7 +8,7 @@ # define a function with a docstring -def roll_d20(self) -> str: +def roll_d20() -> str: """ Simulate the roll of a 20-sided die (d20). diff --git a/examples/tool_rule_usage.py b/examples/tool_rule_usage.py index c1079bc1bf..63ae5248dc 100644 --- a/examples/tool_rule_usage.py +++ b/examples/tool_rule_usage.py @@ -31,14 +31,14 @@ """Contrived tools for this test case""" -def first_secret_word(self: "Agent"): +def first_secret_word(): """ Call this to retrieve the first secret word, which you will need for the second_secret_word function. """ return "v0iq020i0g" -def second_secret_word(self: "Agent", prev_secret_word: str): +def second_secret_word(prev_secret_word: str): """ Call this to retrieve the second secret word, which you will need for the third_secret_word function. If you get the word wrong, this function will error. @@ -51,7 +51,7 @@ def second_secret_word(self: "Agent", prev_secret_word: str): return "4rwp2b4gxq" -def third_secret_word(self: "Agent", prev_secret_word: str): +def third_secret_word(prev_secret_word: str): """ Call this to retrieve the third secret word, which you will need for the fourth_secret_word function. If you get the word wrong, this function will error. @@ -64,7 +64,7 @@ def third_secret_word(self: "Agent", prev_secret_word: str): return "hj2hwibbqm" -def fourth_secret_word(self: "Agent", prev_secret_word: str): +def fourth_secret_word(prev_secret_word: str): """ Call this to retrieve the last secret word, which you will need to output in a send_message later. If you get the word wrong, this function will error. @@ -77,7 +77,7 @@ def fourth_secret_word(self: "Agent", prev_secret_word: str): return "banana" -def auto_error(self: "Agent"): +def auto_error(): """ If you call this function, it will throw an error automatically. """ diff --git a/letta/agent.py b/letta/agent.py index d80fff045d..a850a8da0c 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -9,6 +9,7 @@ from letta.agent_store.storage import StorageConnector from letta.constants import ( + BASE_TOOLS, CLI_WARNING_PREFIX, FIRST_MESSAGE_ATTEMPTS, FUNC_FAILED_HEARTBEAT_MESSAGE, @@ -49,6 +50,7 @@ from letta.schemas.usage import LettaUsageStatistics from letta.services.block_manager import BlockManager from letta.services.source_manager import SourceManager +from letta.services.tool_execution_sandbox import ToolExecutionSandbox from letta.services.user_manager import UserManager from letta.streaming_interface import StreamingRefreshCLIInterface from letta.system import ( @@ -725,9 +727,27 @@ def _handle_ai_response( if isinstance(function_args[name], dict): function_args[name] = spec[name](**function_args[name]) - function_args["self"] = self # need to attach self to arg since it's dynamically linked + # TODO: This needs to be rethought, how do we allow functions that modify agent state/db? + # TODO: There should probably be two types of tools: stateless/stateful + + if function_name in BASE_TOOLS: + function_args["self"] = self # need to attach self to arg since it's dynamically linked + function_response = function_to_call(**function_args) + else: + # execute tool in a sandbox + # TODO: allow agent_state to specify which sandbox to execute tools in + sandbox_run_result = ToolExecutionSandbox(function_name, function_args, self.agent_state.user_id).run( + agent_state=self.agent_state + ) + function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state + # update agent state + if self.agent_state != updated_agent_state and updated_agent_state is not None: + self.agent_state = updated_agent_state + self.memory = self.agent_state.memory # TODO: don't duplicate + + # rebuild memory + self.rebuild_memory() - function_response = function_to_call(**function_args) if function_name in ["conversation_search", "conversation_search_date", "archival_memory_search"]: # with certain functions we rely on the paging mechanism to handle overflow truncate = False @@ -747,6 +767,7 @@ def _handle_ai_response( error_msg_user = f"{error_msg}\n{traceback.format_exc()}" printd(error_msg_user) function_response = package_function_response(False, error_msg) + # TODO: truncate error message somehow messages.append( Message.dict_to_message( agent_id=self.agent_state.id, diff --git a/letta/client/client.py b/letta/client/client.py index 8a3d053850..d7b7320f74 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -39,6 +39,16 @@ from letta.schemas.openai.chat_completions import ToolCall from letta.schemas.organization import Organization from letta.schemas.passage import Passage +from letta.schemas.sandbox_config import ( + E2BSandboxConfig, + LocalSandboxConfig, + SandboxConfig, + SandboxConfigCreate, + SandboxConfigUpdate, + SandboxEnvironmentVariable, + SandboxEnvironmentVariableCreate, + SandboxEnvironmentVariableUpdate, +) from letta.schemas.source import Source, SourceCreate, SourceUpdate from letta.schemas.tool import Tool, ToolCreate, ToolUpdate from letta.schemas.tool_rule import BaseToolRule @@ -296,6 +306,112 @@ def list_orgs(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> def delete_org(self, org_id: str) -> Organization: raise NotImplementedError + def create_sandbox_config(self, config: Union[LocalSandboxConfig, E2BSandboxConfig]) -> SandboxConfig: + """ + Create a new sandbox configuration. + + Args: + config (Union[LocalSandboxConfig, E2BSandboxConfig]): The sandbox settings. + + Returns: + SandboxConfig: The created sandbox configuration. + """ + raise NotImplementedError + + def update_sandbox_config(self, sandbox_config_id: str, config: Union[LocalSandboxConfig, E2BSandboxConfig]) -> SandboxConfig: + """ + Update an existing sandbox configuration. + + Args: + sandbox_config_id (str): The ID of the sandbox configuration to update. + config (Union[LocalSandboxConfig, E2BSandboxConfig]): The updated sandbox settings. + + Returns: + SandboxConfig: The updated sandbox configuration. + """ + raise NotImplementedError + + def delete_sandbox_config(self, sandbox_config_id: str) -> None: + """ + Delete a sandbox configuration. + + Args: + sandbox_config_id (str): The ID of the sandbox configuration to delete. + """ + raise NotImplementedError + + def list_sandbox_configs(self, limit: int = 50, cursor: Optional[str] = None) -> List[SandboxConfig]: + """ + List all sandbox configurations. + + Args: + limit (int, optional): The maximum number of sandbox configurations to return. Defaults to 50. + cursor (Optional[str], optional): The pagination cursor for retrieving the next set of results. + + Returns: + List[SandboxConfig]: A list of sandbox configurations. + """ + raise NotImplementedError + + def create_sandbox_env_var( + self, sandbox_config_id: str, key: str, value: str, description: Optional[str] = None + ) -> SandboxEnvironmentVariable: + """ + Create a new environment variable for a sandbox configuration. + + Args: + sandbox_config_id (str): The ID of the sandbox configuration to associate the environment variable with. + key (str): The name of the environment variable. + value (str): The value of the environment variable. + description (Optional[str], optional): A description of the environment variable. Defaults to None. + + Returns: + SandboxEnvironmentVariable: The created environment variable. + """ + raise NotImplementedError + + def update_sandbox_env_var( + self, env_var_id: str, key: Optional[str] = None, value: Optional[str] = None, description: Optional[str] = None + ) -> SandboxEnvironmentVariable: + """ + Update an existing environment variable. + + Args: + env_var_id (str): The ID of the environment variable to update. + key (Optional[str], optional): The updated name of the environment variable. Defaults to None. + value (Optional[str], optional): The updated value of the environment variable. Defaults to None. + description (Optional[str], optional): The updated description of the environment variable. Defaults to None. + + Returns: + SandboxEnvironmentVariable: The updated environment variable. + """ + raise NotImplementedError + + def delete_sandbox_env_var(self, env_var_id: str) -> None: + """ + Delete an environment variable by its ID. + + Args: + env_var_id (str): The ID of the environment variable to delete. + """ + raise NotImplementedError + + def list_sandbox_env_vars( + self, sandbox_config_id: str, limit: int = 50, cursor: Optional[str] = None + ) -> List[SandboxEnvironmentVariable]: + """ + List all environment variables associated with a sandbox configuration. + + Args: + sandbox_config_id (str): The ID of the sandbox configuration to retrieve environment variables for. + limit (int, optional): The maximum number of environment variables to return. Defaults to 50. + cursor (Optional[str], optional): The pagination cursor for retrieving the next set of results. + + Returns: + List[SandboxEnvironmentVariable]: A list of environment variables. + """ + raise NotImplementedError + class RESTClient(AbstractClient): """ @@ -1565,6 +1681,114 @@ def delete_org(self, org_id: str) -> Organization: # Parse and return the deleted organization return Organization(**response.json()) + def create_sandbox_config(self, config: Union[LocalSandboxConfig, E2BSandboxConfig]) -> SandboxConfig: + """ + Create a new sandbox configuration. + """ + payload = { + "config": config.model_dump(), + } + response = requests.post(f"{self.base_url}/{self.api_prefix}/sandbox-config", headers=self.headers, json=payload) + if response.status_code != 200: + raise ValueError(f"Failed to create sandbox config: {response.text}") + return SandboxConfig(**response.json()) + + def update_sandbox_config(self, sandbox_config_id: str, config: Union[LocalSandboxConfig, E2BSandboxConfig]) -> SandboxConfig: + """ + Update an existing sandbox configuration. + """ + payload = { + "config": config.model_dump(), + } + response = requests.patch( + f"{self.base_url}/{self.api_prefix}/sandbox-config/{sandbox_config_id}", + headers=self.headers, + json=payload, + ) + if response.status_code != 200: + raise ValueError(f"Failed to update sandbox config with ID '{sandbox_config_id}': {response.text}") + return SandboxConfig(**response.json()) + + def delete_sandbox_config(self, sandbox_config_id: str) -> None: + """ + Delete a sandbox configuration. + """ + response = requests.delete(f"{self.base_url}/{self.api_prefix}/sandbox-config/{sandbox_config_id}", headers=self.headers) + if response.status_code == 404: + raise ValueError(f"Sandbox config with ID '{sandbox_config_id}' does not exist") + elif response.status_code != 204: + raise ValueError(f"Failed to delete sandbox config with ID '{sandbox_config_id}': {response.text}") + + def list_sandbox_configs(self, limit: int = 50, cursor: Optional[str] = None) -> List[SandboxConfig]: + """ + List all sandbox configurations. + """ + params = {"limit": limit, "cursor": cursor} + response = requests.get(f"{self.base_url}/{self.api_prefix}/sandbox-config", headers=self.headers, params=params) + if response.status_code != 200: + raise ValueError(f"Failed to list sandbox configs: {response.text}") + return [SandboxConfig(**config_data) for config_data in response.json()] + + def create_sandbox_env_var( + self, sandbox_config_id: str, key: str, value: str, description: Optional[str] = None + ) -> SandboxEnvironmentVariable: + """ + Create a new environment variable for a sandbox configuration. + """ + payload = {"key": key, "value": value, "description": description} + response = requests.post( + f"{self.base_url}/{self.api_prefix}/sandbox-config/{sandbox_config_id}/environment-variable", + headers=self.headers, + json=payload, + ) + if response.status_code != 200: + raise ValueError(f"Failed to create environment variable for sandbox config ID '{sandbox_config_id}': {response.text}") + return SandboxEnvironmentVariable(**response.json()) + + def update_sandbox_env_var( + self, env_var_id: str, key: Optional[str] = None, value: Optional[str] = None, description: Optional[str] = None + ) -> SandboxEnvironmentVariable: + """ + Update an existing environment variable. + """ + payload = {k: v for k, v in {"key": key, "value": value, "description": description}.items() if v is not None} + response = requests.patch( + f"{self.base_url}/{self.api_prefix}/sandbox-config/environment-variable/{env_var_id}", + headers=self.headers, + json=payload, + ) + if response.status_code != 200: + raise ValueError(f"Failed to update environment variable with ID '{env_var_id}': {response.text}") + return SandboxEnvironmentVariable(**response.json()) + + def delete_sandbox_env_var(self, env_var_id: str) -> None: + """ + Delete an environment variable by its ID. + """ + response = requests.delete( + f"{self.base_url}/{self.api_prefix}/sandbox-config/environment-variable/{env_var_id}", headers=self.headers + ) + if response.status_code == 404: + raise ValueError(f"Environment variable with ID '{env_var_id}' does not exist") + elif response.status_code != 204: + raise ValueError(f"Failed to delete environment variable with ID '{env_var_id}': {response.text}") + + def list_sandbox_env_vars( + self, sandbox_config_id: str, limit: int = 50, cursor: Optional[str] = None + ) -> List[SandboxEnvironmentVariable]: + """ + List all environment variables associated with a sandbox configuration. + """ + params = {"limit": limit, "cursor": cursor} + response = requests.get( + f"{self.base_url}/{self.api_prefix}/sandbox-config/{sandbox_config_id}/environment-variable", + headers=self.headers, + params=params, + ) + if response.status_code != 200: + raise ValueError(f"Failed to list environment variables for sandbox config ID '{sandbox_config_id}': {response.text}") + return [SandboxEnvironmentVariable(**var_data) for var_data in response.json()] + def update_agent_memory_label(self, agent_id: str, current_label: str, new_label: str) -> Memory: # @router.patch("/{agent_id}/memory/label", response_model=Memory, operation_id="update_agent_memory_label") @@ -2821,6 +3045,72 @@ def list_orgs(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> def delete_org(self, org_id: str) -> Organization: return self.server.organization_manager.delete_organization_by_id(org_id=org_id) + def create_sandbox_config(self, config: Union[LocalSandboxConfig, E2BSandboxConfig]) -> SandboxConfig: + """ + Create a new sandbox configuration. + """ + config_create = SandboxConfigCreate(config=config) + return self.server.sandbox_config_manager.create_or_update_sandbox_config(sandbox_config_create=config_create, actor=self.user) + + def update_sandbox_config(self, sandbox_config_id: str, config: Union[LocalSandboxConfig, E2BSandboxConfig]) -> SandboxConfig: + """ + Update an existing sandbox configuration. + """ + sandbox_update = SandboxConfigUpdate(config=config) + return self.server.sandbox_config_manager.update_sandbox_config( + sandbox_config_id=sandbox_config_id, sandbox_update=sandbox_update, actor=self.user + ) + + def delete_sandbox_config(self, sandbox_config_id: str) -> None: + """ + Delete a sandbox configuration. + """ + return self.server.sandbox_config_manager.delete_sandbox_config(sandbox_config_id=sandbox_config_id, actor=self.user) + + def list_sandbox_configs(self, limit: int = 50, cursor: Optional[str] = None) -> List[SandboxConfig]: + """ + List all sandbox configurations. + """ + return self.server.sandbox_config_manager.list_sandbox_configs(actor=self.user, limit=limit, cursor=cursor) + + def create_sandbox_env_var( + self, sandbox_config_id: str, key: str, value: str, description: Optional[str] = None + ) -> SandboxEnvironmentVariable: + """ + Create a new environment variable for a sandbox configuration. + """ + env_var_create = SandboxEnvironmentVariableCreate(key=key, value=value, description=description) + return self.server.sandbox_config_manager.create_sandbox_env_var( + env_var_create=env_var_create, sandbox_config_id=sandbox_config_id, actor=self.user + ) + + def update_sandbox_env_var( + self, env_var_id: str, key: Optional[str] = None, value: Optional[str] = None, description: Optional[str] = None + ) -> SandboxEnvironmentVariable: + """ + Update an existing environment variable. + """ + env_var_update = SandboxEnvironmentVariableUpdate(key=key, value=value, description=description) + return self.server.sandbox_config_manager.update_sandbox_env_var( + env_var_id=env_var_id, env_var_update=env_var_update, actor=self.user + ) + + def delete_sandbox_env_var(self, env_var_id: str) -> None: + """ + Delete an environment variable by its ID. + """ + return self.server.sandbox_config_manager.delete_sandbox_env_var(env_var_id=env_var_id, actor=self.user) + + def list_sandbox_env_vars( + self, sandbox_config_id: str, limit: int = 50, cursor: Optional[str] = None + ) -> List[SandboxEnvironmentVariable]: + """ + List all environment variables associated with a sandbox configuration. + """ + return self.server.sandbox_config_manager.list_sandbox_env_vars( + sandbox_config_id=sandbox_config_id, actor=self.user, limit=limit, cursor=cursor + ) + def update_agent_memory_label(self, agent_id: str, current_label: str, new_label: str) -> Memory: return self.server.update_agent_memory_label( user_id=self.user_id, agent_id=agent_id, current_block_label=current_label, new_block_label=new_label diff --git a/letta/constants.py b/letta/constants.py index 436b9dcc4a..0cafeb14b6 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -36,14 +36,8 @@ DEFAULT_HUMAN = "basic" DEFAULT_PRESET = "memgpt_chat" -# Tools -BASE_TOOLS = [ - "send_message", - "conversation_search", - "conversation_search_date", - "archival_memory_insert", - "archival_memory_search", -] +# Base tools that cannot be edited, as they access agent state directly +BASE_TOOLS = ["send_message", "conversation_search", "conversation_search_date", "archival_memory_insert", "archival_memory_search"] # The name of the tool used to send message to the user # May not be relevant in cases where the agent has multiple ways to message to user (send_imessage, send_discord_mesasge, ...) diff --git a/letta/functions/schema_generator.py b/letta/functions/schema_generator.py index 430720034b..92894d4c45 100644 --- a/letta/functions/schema_generator.py +++ b/letta/functions/schema_generator.py @@ -93,9 +93,14 @@ def generate_schema(function, name: Optional[str] = None, description: Optional[ for param in sig.parameters.values(): # Exclude 'self' parameter + # TODO: eventually remove this (only applies to BASE_TOOLS) if param.name == "self": continue + # exclude 'agent_state' parameter + if param.name == "agent_state": + continue + # Assert that the parameter has a type annotation if param.annotation == inspect.Parameter.empty: raise TypeError(f"Parameter '{param.name}' in function '{function.__name__}' lacks a type annotation") @@ -129,6 +134,7 @@ def generate_schema(function, name: Optional[str] = None, description: Optional[ # append the heartbeat # TODO: don't hard-code + # TODO: if terminal, don't include this if function.__name__ not in ["send_message", "pause_heartbeats"]: schema["parameters"]["properties"]["request_heartbeat"] = { "type": "boolean", diff --git a/letta/log.py b/letta/log.py index 8d947c509c..fbac3830b0 100644 --- a/letta/log.py +++ b/letta/log.py @@ -23,12 +23,10 @@ def _setup_logfile() -> "Path": # TODO: production logging should be much less invasive DEVELOPMENT_LOGGING = { "version": 1, - "disable_existing_loggers": True, + "disable_existing_loggers": False, # Allow capturing from all loggers "formatters": { "standard": {"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"}, - "no_datetime": { - "format": "%(name)s - %(levelname)s - %(message)s", - }, + "no_datetime": {"format": "%(name)s - %(levelname)s - %(message)s"}, }, "handlers": { "console": { @@ -46,14 +44,14 @@ def _setup_logfile() -> "Path": "formatter": "standard", }, }, + "root": { # Root logger handles all logs + "level": logging.DEBUG if settings.debug else logging.INFO, + "handlers": ["console", "file"], + }, "loggers": { "Letta": { "level": logging.DEBUG if settings.debug else logging.INFO, - "handlers": [ - "console", - "file", - ], - "propagate": False, + "propagate": True, # Let logs bubble up to root }, "uvicorn": { "level": "CRITICAL", diff --git a/letta/orm/__init__.py b/letta/orm/__init__.py index 1b1df14985..eeed7c2e86 100644 --- a/letta/orm/__init__.py +++ b/letta/orm/__init__.py @@ -2,6 +2,7 @@ from letta.orm.block import Block from letta.orm.file import FileMetadata from letta.orm.organization import Organization +from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable from letta.orm.source import Source from letta.orm.tool import Tool from letta.orm.user import User diff --git a/letta/orm/mixins.py b/letta/orm/mixins.py index d49b868b19..0d0b576f5b 100644 --- a/letta/orm/mixins.py +++ b/letta/orm/mixins.py @@ -37,3 +37,11 @@ class SourceMixin(Base): __abstract__ = True source_id: Mapped[str] = mapped_column(String, ForeignKey("sources.id")) + + +class SandboxConfigMixin(Base): + """Mixin for models that belong to a SandboxConfig.""" + + __abstract__ = True + + sandbox_config_id: Mapped[str] = mapped_column(String, ForeignKey("sandbox_configs.id")) diff --git a/letta/orm/organization.py b/letta/orm/organization.py index c4a059c555..0cd32f9891 100644 --- a/letta/orm/organization.py +++ b/letta/orm/organization.py @@ -2,12 +2,12 @@ from sqlalchemy.orm import Mapped, mapped_column, relationship -from letta.orm.file import FileMetadata from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.schemas.organization import Organization as PydanticOrganization if TYPE_CHECKING: + from letta.orm.file import FileMetadata from letta.orm.tool import Tool from letta.orm.user import User @@ -27,6 +27,13 @@ class Organization(SqlalchemyBase): sources: Mapped[List["Source"]] = relationship("Source", back_populates="organization", cascade="all, delete-orphan") agents_tags: Mapped[List["AgentsTags"]] = relationship("AgentsTags", back_populates="organization", cascade="all, delete-orphan") files: Mapped[List["FileMetadata"]] = relationship("FileMetadata", back_populates="organization", cascade="all, delete-orphan") + sandbox_configs: Mapped[List["SandboxConfig"]] = relationship( + "SandboxConfig", back_populates="organization", cascade="all, delete-orphan" + ) + sandbox_environment_variables: Mapped[List["SandboxEnvironmentVariable"]] = relationship( + "SandboxEnvironmentVariable", back_populates="organization", cascade="all, delete-orphan" + ) + # TODO: Map these relationships later when we actually make these models # below is just a suggestion # agents: Mapped[List["Agent"]] = relationship("Agent", back_populates="organization", cascade="all, delete-orphan") diff --git a/letta/orm/sandbox_config.py b/letta/orm/sandbox_config.py new file mode 100644 index 0000000000..aa8e07dc93 --- /dev/null +++ b/letta/orm/sandbox_config.py @@ -0,0 +1,56 @@ +from typing import TYPE_CHECKING, Dict, List, Optional + +from sqlalchemy import JSON +from sqlalchemy import Enum as SqlEnum +from sqlalchemy import String, UniqueConstraint +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from letta.orm.mixins import OrganizationMixin, SandboxConfigMixin +from letta.orm.sqlalchemy_base import SqlalchemyBase +from letta.schemas.sandbox_config import SandboxConfig as PydanticSandboxConfig +from letta.schemas.sandbox_config import ( + SandboxEnvironmentVariable as PydanticSandboxEnvironmentVariable, +) +from letta.schemas.sandbox_config import SandboxType + +if TYPE_CHECKING: + from letta.orm.organization import Organization + + +class SandboxConfig(SqlalchemyBase, OrganizationMixin): + """ORM model for sandbox configurations with JSON storage for arbitrary config data.""" + + __tablename__ = "sandbox_configs" + __pydantic_model__ = PydanticSandboxConfig + + # For now, we only allow one type of sandbox config per organization + __table_args__ = (UniqueConstraint("type", "organization_id", name="uix_type_organization"),) + + id: Mapped[str] = mapped_column(String, primary_key=True, nullable=False) + type: Mapped[SandboxType] = mapped_column(SqlEnum(SandboxType), nullable=False, doc="The type of sandbox.") + config: Mapped[Dict] = mapped_column(JSON, nullable=False, doc="The JSON configuration data.") + + # relationships + organization: Mapped["Organization"] = relationship("Organization", back_populates="sandbox_configs") + sandbox_environment_variables: Mapped[List["SandboxEnvironmentVariable"]] = relationship( + "SandboxEnvironmentVariable", back_populates="sandbox_config", cascade="all, delete-orphan" + ) + + +class SandboxEnvironmentVariable(SqlalchemyBase, OrganizationMixin, SandboxConfigMixin): + """ORM model for environment variables associated with sandboxes.""" + + __tablename__ = "sandbox_environment_variables" + __pydantic_model__ = PydanticSandboxEnvironmentVariable + + # We cannot have duplicate key names in the same sandbox, the env var would get overwritten + __table_args__ = (UniqueConstraint("key", "sandbox_config_id", name="uix_key_sandbox_config"),) + + id: Mapped[str] = mapped_column(String, primary_key=True, nullable=False) + key: Mapped[str] = mapped_column(String, nullable=False, doc="The name of the environment variable.") + value: Mapped[str] = mapped_column(String, nullable=False, doc="The value of the environment variable.") + description: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="An optional description of the environment variable.") + + # relationships + organization: Mapped["Organization"] = relationship("Organization", back_populates="sandbox_environment_variables") + sandbox_config: Mapped["SandboxConfig"] = relationship("SandboxConfig", back_populates="sandbox_environment_variables") diff --git a/letta/schemas/letta_base.py b/letta/schemas/letta_base.py index 9f1af6e99a..455234867e 100644 --- a/letta/schemas/letta_base.py +++ b/letta/schemas/letta_base.py @@ -1,4 +1,5 @@ import uuid +from datetime import datetime from logging import getLogger from typing import Optional from uuid import UUID @@ -80,3 +81,11 @@ def allow_bare_uuids(cls, v, values): logger.warning(f"Bare UUIDs are deprecated, please use the full prefixed id ({cls.__id_prefix__})!") return f"{cls.__id_prefix__}-{v}" return v + + +class OrmMetadataBase(LettaBase): + # metadata fields + created_by_id: Optional[str] = Field(None, description="The id of the user that made this object.") + last_updated_by_id: Optional[str] = Field(None, description="The id of the user that made this object.") + created_at: Optional[datetime] = Field(None, description="The timestamp when the object was created.") + updated_at: Optional[datetime] = Field(None, description="The timestamp when the object was last updated.") diff --git a/letta/schemas/memory.py b/letta/schemas/memory.py index 82aae73878..1833805568 100644 --- a/letta/schemas/memory.py +++ b/letta/schemas/memory.py @@ -5,7 +5,7 @@ # Forward referencing to avoid circular import with Agent -> Memory -> Agent if TYPE_CHECKING: - from letta.agent import Agent + pass from letta.schemas.block import Block from letta.schemas.message import Message @@ -229,7 +229,7 @@ def __init__(self, blocks: List[Block] = []): assert block.label is not None and block.label != "", "each existing chat block must have a name" self.link_block(block=block) - def core_memory_append(self: "Agent", label: str, content: str) -> Optional[str]: # type: ignore + def core_memory_append(agent_state: "AgentState", label: str, content: str) -> Optional[str]: # type: ignore """ Append to the contents of core memory. @@ -240,12 +240,12 @@ def core_memory_append(self: "Agent", label: str, content: str) -> Optional[str] Returns: Optional[str]: None is always returned as this function does not produce a response. """ - current_value = str(self.memory.get_block(label).value) + current_value = str(agent_state.memory.get_block(label).value) new_value = current_value + "\n" + str(content) - self.memory.update_block_value(label=label, value=new_value) + agent_state.memory.update_block_value(label=label, value=new_value) return None - def core_memory_replace(self: "Agent", label: str, old_content: str, new_content: str) -> Optional[str]: # type: ignore + def core_memory_replace(agent_state: "AgentState", label: str, old_content: str, new_content: str) -> Optional[str]: # type: ignore """ Replace the contents of core memory. To delete memories, use an empty string for new_content. @@ -257,11 +257,11 @@ def core_memory_replace(self: "Agent", label: str, old_content: str, new_content Returns: Optional[str]: None is always returned as this function does not produce a response. """ - current_value = str(self.memory.get_block(label).value) + current_value = str(agent_state.memory.get_block(label).value) if old_content not in current_value: raise ValueError(f"Old content '{old_content}' not found in memory block '{label}'") new_value = current_value.replace(str(old_content), str(new_content)) - self.memory.update_block_value(label=label, value=new_value) + agent_state.memory.update_block_value(label=label, value=new_value) return None diff --git a/letta/schemas/sandbox_config.py b/letta/schemas/sandbox_config.py new file mode 100644 index 0000000000..74340ebeb8 --- /dev/null +++ b/letta/schemas/sandbox_config.py @@ -0,0 +1,114 @@ +import hashlib +import json +from enum import Enum +from typing import Any, Dict, List, Optional, Union + +from pydantic import BaseModel, Field + +from letta.schemas.agent import AgentState +from letta.schemas.letta_base import LettaBase, OrmMetadataBase + + +# Sandbox Config +class SandboxType(str, Enum): + E2B = "e2b" + LOCAL = "local" + + +class SandboxRunResult(BaseModel): + func_return: Optional[Any] = Field(None, description="The function return object") + agent_state: Optional[AgentState] = Field(None, description="The agent state") + stdout: Optional[List[str]] = Field(None, description="Captured stdout (e.g. prints, logs) from the function invocation") + sandbox_config_fingerprint: str = Field(None, description="The fingerprint of the config for the sandbox") + + +class LocalSandboxConfig(BaseModel): + sandbox_dir: str = Field(..., description="Directory for the sandbox environment.") + + @property + def type(self) -> "SandboxType": + return SandboxType.LOCAL + + +class E2BSandboxConfig(BaseModel): + timeout: int = Field(5 * 60, description="Time limit for the sandbox (in seconds).") + template: Optional[str] = Field(None, description="The E2B template id (docker image).") + pip_requirements: Optional[List[str]] = Field(None, description="A list of pip packages to install on the E2B Sandbox") + + @property + def type(self) -> "SandboxType": + return SandboxType.E2B + + +class SandboxConfigBase(OrmMetadataBase): + __id_prefix__ = "sandbox" + + +class SandboxConfig(SandboxConfigBase): + id: str = SandboxConfigBase.generate_id_field() + type: SandboxType = Field(None, description="The type of sandbox.") + organization_id: Optional[str] = Field(None, description="The unique identifier of the organization associated with the sandbox.") + config: Dict = Field(default_factory=lambda: {}, description="The JSON sandbox settings data.") + + def get_e2b_config(self) -> E2BSandboxConfig: + return E2BSandboxConfig(**self.config) + + def get_local_config(self) -> LocalSandboxConfig: + return LocalSandboxConfig(**self.config) + + def fingerprint(self) -> str: + # Only take into account type, org_id, and the config items + # Canonicalize input data into JSON with sorted keys + hash_input = json.dumps( + { + "type": self.type.value, + "organization_id": self.organization_id, + "config": self.config, + }, + sort_keys=True, # Ensure stable ordering + separators=(",", ":"), # Minimize serialization differences + ) + + # Compute SHA-256 hash + hash_digest = hashlib.sha256(hash_input.encode("utf-8")).digest() + + # Convert the digest to an integer for compatibility with Python's hash requirements + return str(int.from_bytes(hash_digest, byteorder="big")) + + +class SandboxConfigCreate(LettaBase): + config: Union[LocalSandboxConfig, E2BSandboxConfig] = Field(..., description="The configuration for the sandbox.") + + +class SandboxConfigUpdate(LettaBase): + """Pydantic model for updating SandboxConfig fields.""" + + config: Union[LocalSandboxConfig, E2BSandboxConfig] = Field(None, description="The JSON configuration data for the sandbox.") + + +# Environment Variable +class SandboxEnvironmentVariableBase(OrmMetadataBase): + __id_prefix__ = "sandbox-env" + + +class SandboxEnvironmentVariable(SandboxEnvironmentVariableBase): + id: str = SandboxEnvironmentVariableBase.generate_id_field() + key: str = Field(..., description="The name of the environment variable.") + value: str = Field(..., description="The value of the environment variable.") + description: Optional[str] = Field(None, description="An optional description of the environment variable.") + sandbox_config_id: str = Field(..., description="The ID of the sandbox config this environment variable belongs to.") + organization_id: Optional[str] = Field(None, description="The ID of the organization this environment variable belongs to.") + + +class SandboxEnvironmentVariableCreate(LettaBase): + key: str = Field(..., description="The name of the environment variable.") + value: str = Field(..., description="The value of the environment variable.") + description: Optional[str] = Field(None, description="An optional description of the environment variable.") + + +class SandboxEnvironmentVariableUpdate(LettaBase): + """Pydantic model for updating SandboxEnvironmentVariable fields.""" + + key: Optional[str] = Field(None, description="The name of the environment variable.") + value: Optional[str] = Field(None, description="The value of the environment variable.") + description: Optional[str] = Field(None, description="An optional description of the environment variable.") diff --git a/letta/server/rest_api/routers/v1/__init__.py b/letta/server/rest_api/routers/v1/__init__.py index 6589c5f41d..764a78a3e2 100644 --- a/letta/server/rest_api/routers/v1/__init__.py +++ b/letta/server/rest_api/routers/v1/__init__.py @@ -3,15 +3,10 @@ from letta.server.rest_api.routers.v1.health import router as health_router from letta.server.rest_api.routers.v1.jobs import router as jobs_router from letta.server.rest_api.routers.v1.llms import router as llm_router +from letta.server.rest_api.routers.v1.sandbox_configs import ( + router as sandbox_configs_router, +) from letta.server.rest_api.routers.v1.sources import router as sources_router from letta.server.rest_api.routers.v1.tools import router as tools_router -ROUTERS = [ - tools_router, - sources_router, - agents_router, - llm_router, - blocks_router, - jobs_router, - health_router, -] +ROUTERS = [tools_router, sources_router, agents_router, llm_router, blocks_router, jobs_router, health_router, sandbox_configs_router] diff --git a/letta/server/rest_api/routers/v1/sandbox_configs.py b/letta/server/rest_api/routers/v1/sandbox_configs.py new file mode 100644 index 0000000000..80640929a0 --- /dev/null +++ b/letta/server/rest_api/routers/v1/sandbox_configs.py @@ -0,0 +1,108 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query + +from letta.schemas.sandbox_config import SandboxConfig as PydanticSandboxConfig +from letta.schemas.sandbox_config import SandboxConfigCreate, SandboxConfigUpdate +from letta.schemas.sandbox_config import SandboxEnvironmentVariable as PydanticEnvVar +from letta.schemas.sandbox_config import ( + SandboxEnvironmentVariableCreate, + SandboxEnvironmentVariableUpdate, +) +from letta.server.rest_api.utils import get_letta_server, get_user_id +from letta.server.server import SyncServer + +router = APIRouter(prefix="/sandbox-config", tags=["sandbox-config"]) + + +### Sandbox Config Routes + + +@router.post("/", response_model=PydanticSandboxConfig) +def create_sandbox_config( + config_create: SandboxConfigCreate, + server: SyncServer = Depends(get_letta_server), + user_id: str = Depends(get_user_id), +): + actor = server.get_user_or_default(user_id=user_id) + + return server.sandbox_config_manager.create_or_update_sandbox_config(config_create, actor) + + +@router.patch("/{sandbox_config_id}", response_model=PydanticSandboxConfig) +def update_sandbox_config( + sandbox_config_id: str, + config_update: SandboxConfigUpdate, + server: SyncServer = Depends(get_letta_server), + user_id: str = Depends(get_user_id), +): + actor = server.get_user_or_default(user_id=user_id) + return server.sandbox_config_manager.update_sandbox_config(sandbox_config_id, config_update, actor) + + +@router.delete("/{sandbox_config_id}", status_code=204) +def delete_sandbox_config( + sandbox_config_id: str, + server: SyncServer = Depends(get_letta_server), + user_id: str = Depends(get_user_id), +): + actor = server.get_user_or_default(user_id=user_id) + server.sandbox_config_manager.delete_sandbox_config(sandbox_config_id, actor) + + +@router.get("/", response_model=List[PydanticSandboxConfig]) +def list_sandbox_configs( + limit: int = Query(1000, description="Number of results to return"), + cursor: Optional[str] = Query(None, description="Pagination cursor to fetch the next set of results"), + server: SyncServer = Depends(get_letta_server), + user_id: str = Depends(get_user_id), +): + actor = server.get_user_or_default(user_id=user_id) + return server.sandbox_config_manager.list_sandbox_configs(actor, limit=limit, cursor=cursor) + + +### Sandbox Environment Variable Routes + + +@router.post("/{sandbox_config_id}/environment-variable", response_model=PydanticEnvVar) +def create_sandbox_env_var( + sandbox_config_id: str, + env_var_create: SandboxEnvironmentVariableCreate, + server: SyncServer = Depends(get_letta_server), + user_id: str = Depends(get_user_id), +): + actor = server.get_user_or_default(user_id=user_id) + return server.sandbox_config_manager.create_sandbox_env_var(env_var_create, sandbox_config_id, actor) + + +@router.patch("/environment-variable/{env_var_id}", response_model=PydanticEnvVar) +def update_sandbox_env_var( + env_var_id: str, + env_var_update: SandboxEnvironmentVariableUpdate, + server: SyncServer = Depends(get_letta_server), + user_id: str = Depends(get_user_id), +): + actor = server.get_user_or_default(user_id=user_id) + return server.sandbox_config_manager.update_sandbox_env_var(env_var_id, env_var_update, actor) + + +@router.delete("/environment-variable/{env_var_id}", status_code=204) +def delete_sandbox_env_var( + env_var_id: str, + server: SyncServer = Depends(get_letta_server), + user_id: str = Depends(get_user_id), +): + actor = server.get_user_or_default(user_id=user_id) + server.sandbox_config_manager.delete_sandbox_env_var(env_var_id, actor) + + +@router.get("/{sandbox_config_id}/environment-variable", response_model=List[PydanticEnvVar]) +def list_sandbox_env_vars( + sandbox_config_id: str, + limit: int = Query(1000, description="Number of results to return"), + cursor: Optional[str] = Query(None, description="Pagination cursor to fetch the next set of results"), + server: SyncServer = Depends(get_letta_server), + user_id: str = Depends(get_user_id), +): + actor = server.get_user_or_default(user_id=user_id) + return server.sandbox_config_manager.list_sandbox_env_vars(sandbox_config_id, actor, limit=limit, cursor=cursor) diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index 117ce38cab..d51c4661d9 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -2,7 +2,6 @@ from fastapi import APIRouter, Body, Depends, Header, HTTPException -from letta.orm.errors import NoResultFound from letta.schemas.tool import Tool, ToolCreate, ToolUpdate from letta.server.rest_api.utils import get_letta_server from letta.server.server import SyncServer @@ -49,11 +48,10 @@ def get_tool_id( Get a tool ID by name """ actor = server.get_user_or_default(user_id=user_id) - - try: - tool = server.tool_manager.get_tool_by_name(tool_name=tool_name, actor=actor) + tool = server.tool_manager.get_tool_by_name(tool_name=tool_name, actor=actor) + if tool: return tool.id - except NoResultFound: + else: raise HTTPException(status_code=404, detail=f"Tool with name {tool_name} and organization id {actor.organization_id} not found.") diff --git a/letta/server/rest_api/utils.py b/letta/server/rest_api/utils.py index 27eda76c63..4fd92b5aad 100644 --- a/letta/server/rest_api/utils.py +++ b/letta/server/rest_api/utils.py @@ -5,6 +5,7 @@ from enum import Enum from typing import AsyncGenerator, Optional, Union +from fastapi import Header from pydantic import BaseModel from letta.schemas.usage import LettaUsageStatistics @@ -84,5 +85,10 @@ def get_letta_server() -> SyncServer: return server +# Dependency to get user_id from headers +def get_user_id(user_id: Optional[str] = Header(None, alias="user_id")) -> Optional[str]: + return user_id + + def get_current_interface() -> StreamingServerInterface: return StreamingServerInterface diff --git a/letta/server/server.py b/letta/server/server.py index 2e6438a804..99267176dd 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -78,6 +78,7 @@ from letta.services.agents_tags_manager import AgentsTagsManager from letta.services.block_manager import BlockManager from letta.services.organization_manager import OrganizationManager +from letta.services.sandbox_config_manager import SandboxConfigManager from letta.services.source_manager import SourceManager from letta.services.tool_manager import ToolManager from letta.services.user_manager import UserManager @@ -247,6 +248,7 @@ def __init__( self.block_manager = BlockManager() self.source_manager = SourceManager() self.agents_tags_manager = AgentsTagsManager() + self.sandbox_config_manager = SandboxConfigManager(tool_settings) # Make default user and org if init_with_default_org_and_user: @@ -381,10 +383,11 @@ def _load_agent(self, agent_id: str, actor: User, interface: Union[AgentInterfac tool_objs = [] for name in agent_state.tools: # TODO: This should be a hard failure, but for migration reasons, we patch it for now - try: + tool_obj = self.tool_manager.get_tool_by_name(tool_name=name, actor=actor) + if tool_obj: tool_obj = self.tool_manager.get_tool_by_name(tool_name=name, actor=actor) tool_objs.append(tool_obj) - except NoResultFound: + else: warnings.warn(f"Tried to retrieve a tool with name {name} from the agent_state, but does not exist in tool db.") # set agent_state tools to only the names of the available tools @@ -837,10 +840,10 @@ def create_agent( tool_objs = [] if request.tools: for tool_name in request.tools: - try: - tool_obj = self.tool_manager.get_tool_by_name(tool_name=tool_name, actor=actor) + tool_obj = self.tool_manager.get_tool_by_name(tool_name=tool_name, actor=actor) + if tool_obj: tool_objs.append(tool_obj) - except NoResultFound: + else: warnings.warn(f"Attempted to add a nonexistent tool {tool_name} to agent {request.name}, skipping.") # reset the request.tools to only valid tools request.tools = [t.name for t in tool_objs] diff --git a/letta/services/sandbox_config_manager.py b/letta/services/sandbox_config_manager.py new file mode 100644 index 0000000000..48e53f9f5d --- /dev/null +++ b/letta/services/sandbox_config_manager.py @@ -0,0 +1,256 @@ +from pathlib import Path +from typing import Dict, List, Optional + +from letta.log import get_logger +from letta.orm.errors import NoResultFound +from letta.orm.sandbox_config import SandboxConfig as SandboxConfigModel +from letta.orm.sandbox_config import SandboxEnvironmentVariable as SandboxEnvVarModel +from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig +from letta.schemas.sandbox_config import SandboxConfig as PydanticSandboxConfig +from letta.schemas.sandbox_config import SandboxConfigCreate, SandboxConfigUpdate +from letta.schemas.sandbox_config import SandboxEnvironmentVariable as PydanticEnvVar +from letta.schemas.sandbox_config import ( + SandboxEnvironmentVariableCreate, + SandboxEnvironmentVariableUpdate, + SandboxType, +) +from letta.schemas.user import User as PydanticUser +from letta.utils import enforce_types, printd + +logger = get_logger(__name__) + + +class SandboxConfigManager: + """Manager class to handle business logic related to SandboxConfig and SandboxEnvironmentVariable.""" + + def __init__(self, settings): + from letta.server.server import db_context + + self.session_maker = db_context + self.e2b_template_id = settings.e2b_sandbox_template_id + + @enforce_types + def get_or_create_default_sandbox_config(self, sandbox_type: SandboxType, actor: PydanticUser) -> PydanticSandboxConfig: + sandbox_config = self.get_sandbox_config_by_type(sandbox_type, actor=actor) + if not sandbox_config: + logger.info(f"Creating new sandbox config of type {sandbox_type}, none found for organization {actor.organization_id}.") + + # TODO: Add more sandbox types later + if sandbox_type == SandboxType.E2B: + default_config = E2BSandboxConfig(template=self.e2b_template_id).model_dump(exclude_none=True) + else: + default_local_sandbox_path = str(Path(__file__).parent / "tool_sandbox_env") + default_config = LocalSandboxConfig(sandbox_dir=default_local_sandbox_path).model_dump(exclude_none=True) + + sandbox_config = self.create_or_update_sandbox_config(SandboxConfigCreate(config=default_config), actor=actor) + return sandbox_config + + @enforce_types + def create_or_update_sandbox_config(self, sandbox_config_create: SandboxConfigCreate, actor: PydanticUser) -> PydanticSandboxConfig: + """Create or update a sandbox configuration based on the PydanticSandboxConfig schema.""" + config = sandbox_config_create.config + sandbox_type = config.type + sandbox_config = PydanticSandboxConfig( + type=sandbox_type, config=config.model_dump(exclude_none=True), organization_id=actor.organization_id + ) + + # Attempt to retrieve the existing sandbox configuration by type within the organization + db_sandbox = self.get_sandbox_config_by_type(sandbox_config.type, actor=actor) + if db_sandbox: + # Prepare the update data, excluding fields that should not be reset + update_data = sandbox_config.model_dump(exclude_unset=True, exclude_none=True) + update_data = {key: value for key, value in update_data.items() if getattr(db_sandbox, key) != value} + + # If there are changes, update the sandbox configuration + if update_data: + db_sandbox = self.update_sandbox_config(db_sandbox.id, SandboxConfigUpdate(**update_data), actor) + else: + printd( + f"`create_or_update_sandbox_config` was called with user_id={actor.id}, organization_id={actor.organization_id}, " + f"type={sandbox_config.type}, but found existing configuration with nothing to update." + ) + + return db_sandbox + else: + # If the sandbox configuration doesn't exist, create a new one + with self.session_maker() as session: + db_sandbox = SandboxConfigModel(**sandbox_config.model_dump(exclude_none=True)) + db_sandbox.create(session, actor=actor) + return db_sandbox.to_pydantic() + + @enforce_types + def update_sandbox_config( + self, sandbox_config_id: str, sandbox_update: SandboxConfigUpdate, actor: PydanticUser + ) -> PydanticSandboxConfig: + """Update an existing sandbox configuration.""" + with self.session_maker() as session: + sandbox = SandboxConfigModel.read(db_session=session, identifier=sandbox_config_id, actor=actor) + # We need to check that the sandbox_update provided is the same type as the original sandbox + if sandbox.type != sandbox_update.config.type: + raise ValueError( + f"Mismatched type for sandbox config update: tried to update sandbox_config of type {sandbox.type} with config of type {sandbox_update.config.type}" + ) + + update_data = sandbox_update.model_dump(exclude_unset=True, exclude_none=True) + update_data = {key: value for key, value in update_data.items() if getattr(sandbox, key) != value} + + if update_data: + for key, value in update_data.items(): + setattr(sandbox, key, value) + sandbox.update(db_session=session, actor=actor) + else: + printd( + f"`update_sandbox_config` called with user_id={actor.id}, organization_id={actor.organization_id}, " + f"name={sandbox.type}, but nothing to update." + ) + return sandbox.to_pydantic() + + @enforce_types + def delete_sandbox_config(self, sandbox_config_id: str, actor: PydanticUser) -> PydanticSandboxConfig: + """Delete a sandbox configuration by its ID.""" + with self.session_maker() as session: + sandbox = SandboxConfigModel.read(db_session=session, identifier=sandbox_config_id, actor=actor) + sandbox.hard_delete(db_session=session, actor=actor) + return sandbox.to_pydantic() + + @enforce_types + def list_sandbox_configs( + self, actor: PydanticUser, cursor: Optional[str] = None, limit: Optional[int] = 50 + ) -> List[PydanticSandboxConfig]: + """List all sandbox configurations with optional pagination.""" + with self.session_maker() as session: + sandboxes = SandboxConfigModel.list( + db_session=session, + cursor=cursor, + limit=limit, + organization_id=actor.organization_id, + ) + return [sandbox.to_pydantic() for sandbox in sandboxes] + + @enforce_types + def get_sandbox_config_by_id(self, sandbox_config_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticSandboxConfig]: + """Retrieve a sandbox configuration by its ID.""" + with self.session_maker() as session: + try: + sandbox = SandboxConfigModel.read(db_session=session, identifier=sandbox_config_id, actor=actor) + return sandbox.to_pydantic() + except NoResultFound: + return None + + @enforce_types + def get_sandbox_config_by_type(self, type: SandboxType, actor: Optional[PydanticUser] = None) -> Optional[PydanticSandboxConfig]: + """Retrieve a sandbox config by its type.""" + with self.session_maker() as session: + try: + sandboxes = SandboxConfigModel.list( + db_session=session, + type=type, + organization_id=actor.organization_id, + limit=1, + ) + if sandboxes: + return sandboxes[0].to_pydantic() + return None + except NoResultFound: + return None + + @enforce_types + def create_sandbox_env_var( + self, env_var_create: SandboxEnvironmentVariableCreate, sandbox_config_id: str, actor: PydanticUser + ) -> PydanticEnvVar: + """Create a new sandbox environment variable.""" + env_var = PydanticEnvVar(**env_var_create.model_dump(), sandbox_config_id=sandbox_config_id, organization_id=actor.organization_id) + + db_env_var = self.get_sandbox_env_var_by_key_and_sandbox_config_id(env_var.key, env_var.sandbox_config_id, actor=actor) + if db_env_var: + update_data = env_var.model_dump(exclude_unset=True, exclude_none=True) + update_data = {key: value for key, value in update_data.items() if getattr(db_env_var, key) != value} + # If there are changes, update the environment variable + if update_data: + db_env_var = self.update_sandbox_env_var(db_env_var.id, SandboxEnvironmentVariableUpdate(**update_data), actor) + else: + printd( + f"`create_or_update_sandbox_env_var` was called with user_id={actor.id}, organization_id={actor.organization_id}, " + f"key={env_var.key}, but found existing variable with nothing to update." + ) + + return db_env_var + else: + with self.session_maker() as session: + env_var = SandboxEnvVarModel(**env_var.model_dump(exclude_none=True)) + env_var.create(session, actor=actor) + return env_var.to_pydantic() + + @enforce_types + def update_sandbox_env_var( + self, env_var_id: str, env_var_update: SandboxEnvironmentVariableUpdate, actor: PydanticUser + ) -> PydanticEnvVar: + """Update an existing sandbox environment variable.""" + with self.session_maker() as session: + env_var = SandboxEnvVarModel.read(db_session=session, identifier=env_var_id, actor=actor) + update_data = env_var_update.model_dump(exclude_unset=True, exclude_none=True) + update_data = {key: value for key, value in update_data.items() if getattr(env_var, key) != value} + + if update_data: + for key, value in update_data.items(): + setattr(env_var, key, value) + env_var.update(db_session=session, actor=actor) + else: + printd( + f"`update_sandbox_env_var` called with user_id={actor.id}, organization_id={actor.organization_id}, " + f"key={env_var.key}, but nothing to update." + ) + return env_var.to_pydantic() + + @enforce_types + def delete_sandbox_env_var(self, env_var_id: str, actor: PydanticUser) -> PydanticEnvVar: + """Delete a sandbox environment variable by its ID.""" + with self.session_maker() as session: + env_var = SandboxEnvVarModel.read(db_session=session, identifier=env_var_id, actor=actor) + env_var.hard_delete(db_session=session, actor=actor) + return env_var.to_pydantic() + + @enforce_types + def list_sandbox_env_vars( + self, sandbox_config_id: str, actor: PydanticUser, cursor: Optional[str] = None, limit: Optional[int] = 50 + ) -> List[PydanticEnvVar]: + """List all sandbox environment variables with optional pagination.""" + with self.session_maker() as session: + env_vars = SandboxEnvVarModel.list( + db_session=session, + cursor=cursor, + limit=limit, + organization_id=actor.organization_id, + sandbox_config_id=sandbox_config_id, + ) + return [env_var.to_pydantic() for env_var in env_vars] + + @enforce_types + def get_sandbox_env_vars_as_dict( + self, sandbox_config_id: str, actor: PydanticUser, cursor: Optional[str] = None, limit: Optional[int] = 50 + ) -> Dict[str, str]: + env_vars = self.list_sandbox_env_vars(sandbox_config_id, actor, cursor, limit) + result = {} + for env_var in env_vars: + result[env_var.key] = env_var.value + return result + + @enforce_types + def get_sandbox_env_var_by_key_and_sandbox_config_id( + self, key: str, sandbox_config_id: str, actor: Optional[PydanticUser] = None + ) -> Optional[PydanticEnvVar]: + """Retrieve a sandbox environment variable by its key and sandbox_config_id.""" + with self.session_maker() as session: + try: + env_var = SandboxEnvVarModel.list( + db_session=session, + key=key, + sandbox_config_id=sandbox_config_id, + organization_id=actor.organization_id, + limit=1, + ) + if env_var: + return env_var[0].to_pydantic() + return None + except NoResultFound: + return None diff --git a/letta/services/tool_execution_sandbox.py b/letta/services/tool_execution_sandbox.py new file mode 100644 index 0000000000..a58d6dabd2 --- /dev/null +++ b/letta/services/tool_execution_sandbox.py @@ -0,0 +1,334 @@ +import ast +import base64 +import io +import os +import pickle +import runpy +import sys +import tempfile +import uuid +from typing import Any, Optional + +from letta.log import get_logger +from letta.schemas.agent import AgentState +from letta.schemas.sandbox_config import SandboxConfig, SandboxRunResult, SandboxType +from letta.services.sandbox_config_manager import SandboxConfigManager +from letta.services.tool_manager import ToolManager +from letta.services.user_manager import UserManager +from letta.settings import tool_settings + +logger = get_logger(__name__) + + +class ToolExecutionSandbox: + METADATA_CONFIG_STATE_KEY = "config_state" + REQUIREMENT_TXT_NAME = "requirements.txt" + + # For generating long, random marker hashes + NAMESPACE = uuid.NAMESPACE_DNS + LOCAL_SANDBOX_RESULT_START_MARKER = str(uuid.uuid5(NAMESPACE, "local-sandbox-result-start-marker")) + LOCAL_SANDBOX_RESULT_END_MARKER = str(uuid.uuid5(NAMESPACE, "local-sandbox-result-end-marker")) + + def __init__(self, tool_name: str, args: dict, user_id: str, force_recreate=False): + self.tool_name = tool_name + self.args = args + + # Get the user + # This user corresponds to the agent_state's user_id field + # agent_state is the state of the agent that invoked this run + self.user = UserManager().get_user_by_id(user_id=user_id) + + # Get the tool + # TODO: So in theory, it's possible this retrieves a tool not provisioned to the agent + # TODO: That would probably imply that agent_state is incorrectly configured + self.tool = ToolManager().get_tool_by_name(tool_name=tool_name, actor=self.user) + if not self.tool: + raise ValueError( + f"Agent attempted to invoke tool {self.tool_name} that does not exist for organization {self.user.organization_id}" + ) + + self.sandbox_config_manager = SandboxConfigManager(tool_settings) + self.force_recreate = force_recreate + + def run(self, agent_state: Optional[AgentState] = None) -> Optional[SandboxRunResult]: + """ + Run the tool in a sandbox environment. + + Args: + agent_state (Optional[AgentState]): The state of the agent invoking the tool + + Returns: + Tuple[Any, Optional[AgentState]]: Tuple containing (tool_result, agent_state) + """ + if tool_settings.e2b_api_key: + logger.info(f"Using e2b sandbox to execute {self.tool_name}") + code = self.generate_execution_script(wrap_print_with_markers=False, agent_state=agent_state) + result = self.run_e2b_sandbox(code=code) + else: + logger.info(f"Using local sandbox to execute {self.tool_name}") + code = self.generate_execution_script(wrap_print_with_markers=True, agent_state=agent_state) + result = self.run_local_dir_sandbox(code=code) + + # Log out any stdout from the tool run + logger.info(f"Executed tool '{self.tool_name}', logging stdout from tool run: \n") + for log_line in result.stdout: + logger.info(f"{log_line}\n") + logger.info(f"Ending stdout log from tool run.") + + # Return result + return result + + # local sandbox specific functions + from contextlib import contextmanager + + @contextmanager + def temporary_env_vars(self, env_vars: dict): + original_env = os.environ.copy() # Backup original environment variables + os.environ.update(env_vars) # Update with the new variables + try: + yield + finally: + os.environ.clear() + os.environ.update(original_env) # Restore original environment variables + + def run_local_dir_sandbox(self, code: str) -> Optional[SandboxRunResult]: + sbx_config = self.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.LOCAL, actor=self.user) + local_configs = sbx_config.get_local_config() + + # Get environment variables for the sandbox + env_vars = self.sandbox_config_manager.get_sandbox_env_vars_as_dict(sandbox_config_id=sbx_config.id, actor=self.user, limit=100) + + # Safety checks + if not os.path.isdir(local_configs.sandbox_dir): + raise FileNotFoundError(f"Sandbox directory does not exist: {local_configs.sandbox_dir}") + + # Write the code to a temp file in the sandbox_dir + with tempfile.NamedTemporaryFile(mode="w", dir=local_configs.sandbox_dir, suffix=".py", delete=False) as temp_file: + temp_file.write(code) + temp_file.flush() + temp_file_path = temp_file.name + + try: + # Redirect stdout to capture script output + captured_stdout = io.StringIO() + old_stdout = sys.stdout + sys.stdout = captured_stdout + + # Execute the temp file + with self.temporary_env_vars(env_vars): + result = runpy.run_path(temp_file_path, init_globals=env_vars) + + # Fetch the result + func_result = result.get("result") + func_return, agent_state = self.parse_best_effort(func_result) + + # Restore stdout and collect captured output + sys.stdout = old_stdout + stdout_output = captured_stdout.getvalue() + + return SandboxRunResult( + func_return=func_return, + agent_state=agent_state, + stdout=[stdout_output], + sandbox_config_fingerprint=sbx_config.fingerprint(), + ) + except Exception as e: + raise RuntimeError(f"Executing tool {self.tool_name} has an unexpected error: {e}") + finally: + # Clean up the temp file and restore stdout + sys.stdout = old_stdout + os.remove(temp_file_path) + + def parse_out_function_results_markers(self, text: str): + marker_len = len(self.LOCAL_SANDBOX_RESULT_START_MARKER) + start_index = text.index(self.LOCAL_SANDBOX_RESULT_START_MARKER) + marker_len + end_index = text.index(self.LOCAL_SANDBOX_RESULT_END_MARKER) + return text[start_index:end_index], text[: start_index - marker_len] + text[end_index + +marker_len :] + + # e2b sandbox specific functions + + def run_e2b_sandbox(self, code: str) -> Optional[SandboxRunResult]: + sbx_config = self.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.E2B, actor=self.user) + sbx = self.get_running_e2b_sandbox_with_same_state(sbx_config) + if not sbx or self.force_recreate: + sbx = self.create_e2b_sandbox_with_metadata_hash(sandbox_config=sbx_config) + + # Since this sandbox was used, we extend its lifecycle by the timeout + sbx.set_timeout(sbx_config.get_e2b_config().timeout) + + # Get environment variables for the sandbox + # TODO: We set limit to 100 here, but maybe we want it uncapped? Realistically this should be fine. + env_vars = self.sandbox_config_manager.get_sandbox_env_vars_as_dict(sandbox_config_id=sbx_config.id, actor=self.user, limit=100) + execution = sbx.run_code(code, envs=env_vars) + if execution.error is not None: + raise Exception(f"Executing tool {self.tool_name} failed with {execution.error}. Generated code: \n\n{code}") + elif len(execution.results) == 0: + return None + else: + func_return, agent_state = self.parse_best_effort(execution.results[0].text) + return SandboxRunResult( + func_return=func_return, + agent_state=agent_state, + stdout=execution.logs.stdout, + sandbox_config_fingerprint=sbx_config.fingerprint(), + ) + + def get_running_e2b_sandbox_with_same_state(self, sandbox_config: SandboxConfig) -> Optional["Sandbox"]: + from e2b_code_interpreter import Sandbox + + # List running sandboxes and access metadata. + running_sandboxes = self.list_running_e2b_sandboxes() + + # Hash the config to check the state + state_hash = sandbox_config.fingerprint() + for sandbox in running_sandboxes: + if self.METADATA_CONFIG_STATE_KEY in sandbox.metadata and sandbox.metadata[self.METADATA_CONFIG_STATE_KEY] == state_hash: + return Sandbox.connect(sandbox.sandbox_id) + + return None + + def create_e2b_sandbox_with_metadata_hash(self, sandbox_config: SandboxConfig) -> "Sandbox": + from e2b_code_interpreter import Sandbox + + state_hash = sandbox_config.fingerprint() + e2b_config = sandbox_config.get_e2b_config() + if e2b_config.template: + sbx = Sandbox(sandbox_config.get_e2b_config().template, metadata={self.METADATA_CONFIG_STATE_KEY: state_hash}) + else: + # no template + sbx = Sandbox(metadata={self.METADATA_CONFIG_STATE_KEY: state_hash}, **e2b_config.model_dump(exclude={"pip_requirements"})) + + # install pip requirements + if e2b_config.pip_requirements: + for package in e2b_config.pip_requirements: + sbx.commands.run(f"pip install {package}") + return sbx + + def list_running_e2b_sandboxes(self): + from e2b_code_interpreter import Sandbox + + # List running sandboxes and access metadata. + return Sandbox.list() + + # general utility functions + + def parse_best_effort(self, text: str) -> Any: + result = pickle.loads(base64.b64decode(text)) + agent_state = None + if not result["agent_state"] is None: + agent_state = result["agent_state"] + return result["results"], agent_state + + def parse_function_arguments(self, source_code: str, tool_name: str): + """Get arguments of a function from its source code""" + tree = ast.parse(source_code) + args = [] + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef) and node.name == tool_name: + for arg in node.args.args: + args.append(arg.arg) + return args + + def generate_execution_script(self, agent_state: AgentState, wrap_print_with_markers: bool = False) -> str: + """ + Generate code to run inside of execution sandbox. + Passes into a serialized agent state into the code, to be accessed by the tool. + + Args: + agent_state (AgentState): The agent state + wrap_print_with_markers (bool): Whether to wrap print statements (?) + + Returns: + code (str): The generated code strong + """ + # dump JSON representation of agent state to re-load + code = "from typing import *\n" + code += "import pickle\n" + code += "import sys\n" + code += "import base64\n" + + # Load the agent state data into the program + if agent_state: + code += "import letta\n" + code += "from letta import * \n" + import pickle + + agent_state_pickle = pickle.dumps(agent_state) + code += f"agent_state = pickle.loads({agent_state_pickle})\n" + else: + # agent state is None + code += "agent_state = None\n" + + for param in self.args: + code += self.initialize_param(param, self.args[param]) + + if "agent_state" in self.parse_function_arguments(self.tool.source_code, self.tool.name): + inject_agent_state = True + else: + inject_agent_state = False + + code += "\n" + self.tool.source_code + "\n" + + # TODO: handle wrapped print + + code += ( + 'result = {"results": ' + self.invoke_function_call(inject_agent_state=inject_agent_state) + ', "agent_state": agent_state}\n' + ) + code += "result = base64.b64encode(pickle.dumps(result)).decode('utf-8')\n" + if wrap_print_with_markers: + code += f"sys.stdout.write('{self.LOCAL_SANDBOX_RESULT_START_MARKER}')\n" + code += f"sys.stdout.write(str(result))\n" + code += f"sys.stdout.write('{self.LOCAL_SANDBOX_RESULT_END_MARKER}')\n" + else: + code += "result\n" + + return code + + def initialize_param(self, name: str, raw_value: str) -> str: + params = self.tool.json_schema["parameters"]["properties"] + spec = params.get(name) + if spec is None: + # ignore extra params (like 'self') for now + return "" + + param_type = spec.get("type") + if param_type is None and spec.get("parameters"): + param_type = spec["parameters"].get("type") + + if param_type == "string": + value = '"' + raw_value + '"' + elif param_type == "integer" or param_type == "boolean": + value = raw_value + else: + raise TypeError(f"unsupported type: {param_type}") + + return name + " = " + str(value) + "\n" + + def invoke_function_call(self, inject_agent_state: bool) -> str: + """ + Generate the code string to call the function. + + Args: + inject_agent_state (bool): Whether to inject the agent's state as an input into the tool + + Returns: + str: Generated code string for calling the tool + """ + kwargs = [] + for name in self.args: + if name in self.tool.json_schema["parameters"]["properties"]: + kwargs.append(name) + + param_list = [f"{arg}={arg}" for arg in kwargs] + if inject_agent_state: + param_list.append("agent_state=agent_state") + params = ", ".join(param_list) + # if "agent_state" in kwargs: + # params += ", agent_state=agent_state" + # TODO: fix to figure out when to insert agent state or not + # params += "agent_state=agent_state" + + func_call_str = self.tool.name + "(" + params + ")" + return func_call_str + + # diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index 4a705e8df3..f506744539 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -37,11 +37,8 @@ def create_or_update_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser # Derive json_schema derived_json_schema = pydantic_tool.json_schema or derive_openai_json_schema(source_code=pydantic_tool.source_code) derived_name = pydantic_tool.name or derived_json_schema["name"] - - try: - # NOTE: We use the organization id here - # This is important, because even if it's a different user, adding the same tool to the org should not happen - tool = self.get_tool_by_name(tool_name=derived_name, actor=actor) + tool = self.get_tool_by_name(tool_name=derived_name, actor=actor) + if tool: # Put to dict and remove fields that should not be reset update_data = pydantic_tool.model_dump(exclude={"module"}, exclude_unset=True, exclude_none=True) # Remove redundant update fields @@ -54,7 +51,7 @@ def create_or_update_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser printd( f"`create_or_update_tool` was called with user_id={actor.id}, organization_id={actor.organization_id}, name={pydantic_tool.name}, but found existing tool with nothing to update." ) - except NoResultFound: + else: pydantic_tool.json_schema = derived_json_schema pydantic_tool.name = derived_name tool = self.create_tool(pydantic_tool, actor=actor) @@ -88,11 +85,14 @@ def get_tool_by_id(self, tool_id: str, actor: PydanticUser) -> PydanticTool: return tool.to_pydantic() @enforce_types - def get_tool_by_name(self, tool_name: str, actor: PydanticUser): + def get_tool_by_name(self, tool_name: str, actor: PydanticUser) -> Optional[PydanticTool]: """Retrieve a tool by its name and a user. We derive the organization from the user, and retrieve that tool.""" - with self.session_maker() as session: - tool = ToolModel.read(db_session=session, name=tool_name, actor=actor) - return tool.to_pydantic() + try: + with self.session_maker() as session: + tool = ToolModel.read(db_session=session, name=tool_name, actor=actor) + return tool.to_pydantic() + except NoResultFound: + return None @enforce_types def list_tools(self, actor: PydanticUser, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticTool]: diff --git a/letta/services/tool_sandbox_env/.gitkeep b/letta/services/tool_sandbox_env/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/letta/settings.py b/letta/settings.py index 2f7e82f99b..1b443a4e6f 100644 --- a/letta/settings.py +++ b/letta/settings.py @@ -10,6 +10,10 @@ class ToolSettings(BaseSettings): composio_api_key: Optional[str] = None + # Sandbox configurations + e2b_api_key: Optional[str] = None + e2b_sandbox_template_id: Optional[str] = "ngtrcfmr9wyzs9yjd8l2" # Updated manually + class ModelSettings(BaseSettings): diff --git a/poetry.lock b/poetry.lock index 5e88ba7b18..d2f0d6f4ad 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1248,6 +1248,41 @@ files = [ {file = "durationpy-0.9.tar.gz", hash = "sha256:fd3feb0a69a0057d582ef643c355c40d2fa1c942191f914d12203b1a01ac722a"}, ] +[[package]] +name = "e2b" +version = "1.0.3" +description = "E2B SDK that give agents cloud environments" +optional = true +python-versions = "<4.0,>=3.8" +files = [ + {file = "e2b-1.0.3-py3-none-any.whl", hash = "sha256:7e087a94e0b6bc86fd330815dfb3312f0cc365d4f080f5ff1a6335f3f65426de"}, + {file = "e2b-1.0.3.tar.gz", hash = "sha256:767233663eadf78462b02eb3fa75d0993800309a43e65eaff40686843d26ecf8"}, +] + +[package.dependencies] +attrs = ">=23.2.0" +httpcore = ">=1.0.5,<2.0.0" +httpx = ">=0.27.0,<0.28.0" +packaging = ">=24.1" +protobuf = ">=3.20.0,<6.0.0" +python-dateutil = ">=2.8.2" + +[[package]] +name = "e2b-code-interpreter" +version = "1.0.1" +description = "E2B Code Interpreter - Stateful code execution" +optional = true +python-versions = "<4.0,>=3.8" +files = [ + {file = "e2b_code_interpreter-1.0.1-py3-none-any.whl", hash = "sha256:e27c40174ba7daac4942388611a73e1ac58300227f0ba6c0555ee54507d4944c"}, + {file = "e2b_code_interpreter-1.0.1.tar.gz", hash = "sha256:b0c061e41315d21514affe78f80052be335b687204e669dd7ca852b59eeaaea2"}, +] + +[package.dependencies] +attrs = ">=21.3.0" +e2b = ">=1.0.0,<2.0.0" +httpx = ">=0.20.0,<0.28.0" + [[package]] name = "environs" version = "9.5.0" @@ -7547,6 +7582,7 @@ testing = ["coverage[toml]", "zope.event", "zope.testing"] [extras] all = ["autoflake", "black", "composio-core", "composio-langchain", "datasets", "docker", "fastapi", "isort", "langchain", "langchain-community", "llama-index-embeddings-ollama", "locust", "pexpect", "pg8000", "pgvector", "pre-commit", "psycopg2", "psycopg2-binary", "pyright", "pytest-asyncio", "pytest-order", "uvicorn", "websockets", "wikipedia"] autogen = ["pyautogen"] +cloud-tool-sandbox = ["e2b-code-interpreter"] dev = ["autoflake", "black", "datasets", "isort", "locust", "pexpect", "pre-commit", "pyright", "pytest-asyncio", "pytest-order"] external-tools = ["composio-core", "composio-langchain", "docker", "langchain", "langchain-community", "wikipedia"] milvus = ["pymilvus"] @@ -7559,4 +7595,4 @@ tests = ["wikipedia"] [metadata] lock-version = "2.0" python-versions = "<3.13,>=3.10" -content-hash = "608f7d3ec8dd286bf1bacaf0b05647d6020e5c0e7a97fc0e255325e3a72ecd17" +content-hash = "28cd26c6573ca0a07173262bc0e819e19b661157fa757efca0590262f9b9f35c" diff --git a/pyproject.toml b/pyproject.toml index 55f439eaa4..9a07491227 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,6 +72,7 @@ alembic = "^1.13.3" pyhumps = "^3.8.0" psycopg2 = "^2.9.10" psycopg2-binary = "^2.9.10" +e2b-code-interpreter = {version = "^1.0.1", optional = true} pathvalidate = "^3.2.1" langchain-community = {version = "^0.3.7", optional = true} langchain = {version = "^0.3.7", optional = true} @@ -85,7 +86,7 @@ server = ["websockets", "fastapi", "uvicorn"] autogen = ["pyautogen"] qdrant = ["qdrant-client"] ollama = ["llama-index-embeddings-ollama"] -#external-tools = ["crewai", "docker", "crewai-tools", "langchain", "wikipedia", "langchain-community", "composio-core", "composio-langchain"] +cloud-tool-sandbox = ["e2b-code-interpreter"] external-tools = ["docker", "langchain", "wikipedia", "langchain-community", "composio-core", "composio-langchain"] tests = ["wikipedia"] all = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2", "pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "datasets", "pyright", "pytest-order", "autoflake", "isort", "websockets", "fastapi", "uvicorn", "llama-index-embeddings-ollama", "docker", "langchain", "wikipedia", "langchain-community", "composio-core", "composio-langchain", "locust"] diff --git a/tests/pytest.ini b/tests/pytest.ini index daeb36ba14..7ffe833c77 100644 --- a/tests/pytest.ini +++ b/tests/pytest.ini @@ -4,3 +4,6 @@ testpaths = /tests asyncio_mode = auto filterwarnings = ignore::pytest.PytestRemovedIn9Warning +markers = + local_sandbox: mark test as part of local sandbox tests + e2b_sandbox: mark test as part of E2B sandbox tests diff --git a/tests/test_agent_tool_graph.py b/tests/test_agent_tool_graph.py index 049e997804..227fd76134 100644 --- a/tests/test_agent_tool_graph.py +++ b/tests/test_agent_tool_graph.py @@ -5,6 +5,7 @@ from letta import create_client from letta.schemas.letta_message import FunctionCallMessage from letta.schemas.tool_rule import InitToolRule, TerminalToolRule, ToolRule +from letta.settings import tool_settings from tests.helpers.endpoints_helper import ( assert_invoked_function_call, assert_invoked_send_message_with_keyword, @@ -18,17 +19,33 @@ agent_uuid = str(uuid.uuid5(namespace, "test_agent_tool_graph")) config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json" + +@pytest.fixture +def mock_e2b_api_key_none(): + # Store the original value of e2b_api_key + original_api_key = tool_settings.e2b_api_key + + # Set e2b_api_key to None + tool_settings.e2b_api_key = None + + # Yield control to the test + yield + + # Restore the original value of e2b_api_key + tool_settings.e2b_api_key = original_api_key + + """Contrived tools for this test case""" -def first_secret_word(self: "Agent"): +def first_secret_word(): """ Call this to retrieve the first secret word, which you will need for the second_secret_word function. """ return "v0iq020i0g" -def second_secret_word(self: "Agent", prev_secret_word: str): +def second_secret_word(prev_secret_word: str): """ Call this to retrieve the second secret word, which you will need for the third_secret_word function. If you get the word wrong, this function will error. @@ -41,7 +58,7 @@ def second_secret_word(self: "Agent", prev_secret_word: str): return "4rwp2b4gxq" -def third_secret_word(self: "Agent", prev_secret_word: str): +def third_secret_word(prev_secret_word: str): """ Call this to retrieve the third secret word, which you will need for the fourth_secret_word function. If you get the word wrong, this function will error. @@ -54,7 +71,7 @@ def third_secret_word(self: "Agent", prev_secret_word: str): return "hj2hwibbqm" -def fourth_secret_word(self: "Agent", prev_secret_word: str): +def fourth_secret_word(prev_secret_word: str): """ Call this to retrieve the last secret word, which you will need to output in a send_message later. If you get the word wrong, this function will error. @@ -67,7 +84,7 @@ def fourth_secret_word(self: "Agent", prev_secret_word: str): return "banana" -def auto_error(self: "Agent"): +def auto_error(): """ If you call this function, it will throw an error automatically. """ @@ -75,7 +92,7 @@ def auto_error(self: "Agent"): @pytest.mark.timeout(60) # Sets a 60-second timeout for the test since this could loop infinitely -def test_single_path_agent_tool_call_graph(): +def test_single_path_agent_tool_call_graph(mock_e2b_api_key_none): client = create_client() cleanup(client=client, agent_uuid=agent_uuid) diff --git a/tests/test_client.py b/tests/test_client.py index 5a84c5a381..b23fd85faf 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,5 +1,4 @@ import os -import re import threading import time import uuid @@ -9,77 +8,52 @@ from dotenv import load_dotenv from sqlalchemy import delete -from letta import create_client -from letta.agent import initialize_message_sequence -from letta.client.client import LocalClient, RESTClient -from letta.constants import DEFAULT_PRESET -from letta.orm import FileMetadata, Source +from letta import LocalClient, RESTClient, create_client +from letta.orm import SandboxConfig, SandboxEnvironmentVariable from letta.schemas.agent import AgentState from letta.schemas.block import BlockCreate from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.enums import MessageRole, MessageStreamStatus -from letta.schemas.letta_message import ( - AssistantMessage, - FunctionCallMessage, - FunctionReturn, - InternalMonologue, - LettaMessage, - SystemMessage, - UserMessage, -) -from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse from letta.schemas.llm_config import LLMConfig -from letta.schemas.message import Message -from letta.schemas.usage import LettaUsageStatistics -from letta.services.tool_manager import ToolManager -from letta.settings import model_settings -from letta.utils import create_random_username, get_utc_time -from tests.helpers.client_helper import upload_file_using_client - -# from tests.utils import create_config +from letta.schemas.sandbox_config import LocalSandboxConfig, SandboxType +from letta.settings import tool_settings +from letta.utils import create_random_username -test_agent_name = f"test_client_{str(uuid.uuid4())}" -# test_preset_name = "test_preset" -test_preset_name = DEFAULT_PRESET -test_agent_state = None -client = None - -test_agent_state_post_message = None +# Constants +SERVER_PORT = 8283 +SANDBOX_DIR = "/tmp/sandbox" +UPDATED_SANDBOX_DIR = "/tmp/updated_sandbox" +ENV_VAR_KEY = "TEST_VAR" +UPDATED_ENV_VAR_KEY = "UPDATED_VAR" +ENV_VAR_VALUE = "test_value" +UPDATED_ENV_VAR_VALUE = "updated_value" +ENV_VAR_DESCRIPTION = "A test environment variable" def run_server(): load_dotenv() - # _reset_config() - from letta.server.rest_api.app import start_server print("Starting server...") start_server(debug=True) -# Fixture to create clients with different configurations @pytest.fixture( - # params=[{"server": True}, {"server": False}], # whether to use REST API server - params=[{"server": True}], # whether to use REST API server + params=[{"server": True}, {"server": False}], # whether to use REST API server scope="module", ) def client(request): if request.param["server"]: - # get URL from enviornment - server_url = os.getenv("LETTA_SERVER_URL") - if server_url is None: - # run server in thread - server_url = "http://localhost:8283" + # Get URL from environment or start server + server_url = os.getenv("LETTA_SERVER_URL", f"http://localhost:{SERVER_PORT}") + if not os.getenv("LETTA_SERVER_URL"): print("Starting server thread") thread = threading.Thread(target=run_server, daemon=True) thread.start() time.sleep(5) print("Running client tests with server:", server_url) - # create user via admin client - client = create_client(base_url=server_url, token=None) # This yields control back to the test function + client = create_client(base_url=server_url, token=None) else: - # use local client (no server) client = create_client() client.set_default_llm_config(LLMConfig.default_config("gpt-4")) @@ -87,611 +61,91 @@ def client(request): yield client -@pytest.fixture(autouse=True) -def clear_tables(): - """Fixture to clear the organization table before each test.""" - from letta.server.server import db_context - - with db_context() as session: - session.execute(delete(FileMetadata)) - session.execute(delete(Source)) - session.commit() - - # Fixture for test agent @pytest.fixture(scope="module") def agent(client: Union[LocalClient, RESTClient]): - agent_state = client.create_agent(name=test_agent_name) + agent_state = client.create_agent(name=f"test_client_{str(uuid.uuid4())}") yield agent_state # delete agent client.delete_agent(agent_state.id) -def test_agent(client: Union[LocalClient, RESTClient], agent: AgentState): - - # test client.rename_agent - new_name = "RenamedTestAgent" - client.rename_agent(agent_id=agent.id, new_name=new_name) - renamed_agent = client.get_agent(agent_id=agent.id) - assert renamed_agent.name == new_name, "Agent renaming failed" - - # get agent id - agent_id = client.get_agent_id(agent_name=new_name) - assert agent_id == agent.id, "Agent ID retrieval failed" - - # test client.delete_agent and client.agent_exists - delete_agent = client.create_agent(name="DeleteTestAgent") - assert client.agent_exists(agent_id=delete_agent.id), "Agent creation failed" - client.delete_agent(agent_id=delete_agent.id) - assert client.agent_exists(agent_id=delete_agent.id) == False, "Agent deletion failed" - - -def test_memory(client: Union[LocalClient, RESTClient], agent: AgentState): - # _reset_config() - - memory_response = client.get_in_context_memory(agent_id=agent.id) - print("MEMORY", memory_response.compile()) - - updated_memory = {"human": "Updated human memory", "persona": "Updated persona memory"} - client.update_in_context_memory(agent_id=agent.id, section="human", value=updated_memory["human"]) - client.update_in_context_memory(agent_id=agent.id, section="persona", value=updated_memory["persona"]) - updated_memory_response = client.get_in_context_memory(agent_id=agent.id) - assert ( - updated_memory_response.get_block("human").value == updated_memory["human"] - and updated_memory_response.get_block("persona").value == updated_memory["persona"] - ), "Memory update failed" - - -def test_agent_interactions(client: Union[LocalClient, RESTClient], agent: AgentState): - # _reset_config() - - message = "Hello, agent!" - print("Sending message", message) - response = client.user_message(agent_id=agent.id, message=message, include_full_message=True) - # Check the types coming back - assert all([isinstance(m, Message) for m in response.messages]), "All messages should be Message" - - print("Response", response) - assert isinstance(response.usage, LettaUsageStatistics) - assert response.usage.step_count == 1 - assert response.usage.total_tokens > 0 - assert response.usage.completion_tokens > 0 - assert isinstance(response.messages[0], Message) - print(response.messages) - - # test that it also works with LettaMessage - message = "Hello again, agent!" - print("Sending message", message) - response = client.user_message(agent_id=agent.id, message=message, include_full_message=False) - assert all([isinstance(m, LettaMessage) for m in response.messages]), "All messages should be LettaMessages" - - # We should also check that the types were cast properly - print("RESPONSE MESSAGES, client type:", type(client)) - print(response.messages) - for letta_message in response.messages: - assert type(letta_message) in [ - SystemMessage, - UserMessage, - InternalMonologue, - FunctionCallMessage, - FunctionReturn, - AssistantMessage, - ], f"Unexpected message type: {type(letta_message)}" - - # TODO: add streaming tests - - -def test_archival_memory(client: Union[LocalClient, RESTClient], agent: AgentState): - # _reset_config() - - memory_content = "Archival memory content" - insert_response = client.insert_archival_memory(agent_id=agent.id, memory=memory_content)[0] - print("Inserted memory", insert_response.text, insert_response.id) - assert insert_response, "Inserting archival memory failed" - - archival_memory_response = client.get_archival_memory(agent_id=agent.id, limit=1) - archival_memories = [memory.text for memory in archival_memory_response] - assert memory_content in archival_memories, f"Retrieving archival memory failed: {archival_memories}" - - memory_id_to_delete = archival_memory_response[0].id - client.delete_archival_memory(agent_id=agent.id, memory_id=memory_id_to_delete) - - # add archival memory - memory_str = "I love chats" - passage = client.insert_archival_memory(agent.id, memory=memory_str)[0] - - # list archival memory - passages = client.get_archival_memory(agent.id) - assert passage.text in [p.text for p in passages], f"Missing passage {passage.text} in {passages}" - - # get archival memory summary - archival_summary = client.get_archival_memory_summary(agent.id) - assert archival_summary.size == 1, f"Archival memory summary size is {archival_summary.size}" - - # delete archival memory - client.delete_archival_memory(agent.id, passage.id) - - # TODO: check deletion - client.get_archival_memory(agent.id) - - -def test_core_memory(client: Union[LocalClient, RESTClient], agent: AgentState): - response = client.send_message(agent_id=agent.id, message="Update your core memory to remember that my name is Timber!", role="user") - print("Response", response) - - memory = client.get_in_context_memory(agent_id=agent.id) - assert "Timber" in memory.get_block("human").value, f"Updating core memory failed: {memory.get_block('human').value}" - - -def test_messages(client: Union[LocalClient, RESTClient], agent: AgentState): - # _reset_config() - - send_message_response = client.send_message(agent_id=agent.id, message="Test message", role="user") - assert send_message_response, "Sending message failed" - - messages_response = client.get_messages(agent_id=agent.id, limit=1) - assert len(messages_response) > 0, "Retrieving messages failed" - - -def test_streaming_send_message(client: Union[LocalClient, RESTClient], agent: AgentState): - if isinstance(client, LocalClient): - pytest.skip("Skipping test_streaming_send_message because LocalClient does not support streaming") - assert isinstance(client, RESTClient), client - - # First, try streaming just steps - - # Next, try streaming both steps and tokens - response = client.send_message( - agent_id=agent.id, - message="This is a test. Repeat after me: 'banana'", - role="user", - stream_steps=True, - stream_tokens=True, - ) - - # Some manual checks to run - # 1. Check that there were inner thoughts - inner_thoughts_exist = False - # 2. Check that the agent runs `send_message` - send_message_ran = False - # 3. Check that we get all the start/stop/end tokens we want - # This includes all of the MessageStreamStatus enums - done_gen = False - done_step = False - done = False - - # print(response) - assert response, "Sending message failed" - for chunk in response: - assert isinstance(chunk, LettaStreamingResponse) - if isinstance(chunk, InternalMonologue) and chunk.internal_monologue and chunk.internal_monologue != "": - inner_thoughts_exist = True - if isinstance(chunk, FunctionCallMessage) and chunk.function_call and chunk.function_call.name == "send_message": - send_message_ran = True - if isinstance(chunk, MessageStreamStatus): - if chunk == MessageStreamStatus.done: - assert not done, "Message stream already done" - done = True - elif chunk == MessageStreamStatus.done_step: - assert not done_step, "Message stream already done step" - done_step = True - elif chunk == MessageStreamStatus.done_generation: - assert not done_gen, "Message stream already done generation" - done_gen = True - if isinstance(chunk, LettaUsageStatistics): - # Some rough metrics for a reasonable usage pattern - assert chunk.step_count == 1 - assert chunk.completion_tokens > 10 - assert chunk.prompt_tokens > 1000 - assert chunk.total_tokens > 1000 - - assert inner_thoughts_exist, "No inner thoughts found" - assert send_message_ran, "send_message function call not found" - assert done, "Message stream not done" - assert done_step, "Message stream not done step" - assert done_gen, "Message stream not done generation" - - -def test_humans_personas(client: Union[LocalClient, RESTClient], agent: AgentState): - # _reset_config() - - humans_response = client.list_humans() - print("HUMANS", humans_response) - - personas_response = client.list_personas() - print("PERSONAS", personas_response) - - persona_name = "TestPersona" - persona_id = client.get_persona_id(persona_name) - if persona_id: - client.delete_persona(persona_id) - persona = client.create_persona(name=persona_name, text="Persona text") - assert persona.template_name == persona_name - assert persona.value == "Persona text", "Creating persona failed" - - human_name = "TestHuman" - human_id = client.get_human_id(human_name) - if human_id: - client.delete_human(human_id) - human = client.create_human(name=human_name, text="Human text") - assert human.template_name == human_name - assert human.value == "Human text", "Creating human failed" - - -def test_list_tools_pagination(client: Union[LocalClient, RESTClient]): - tools = client.list_tools() - visited_ids = {t.id: False for t in tools} - - cursor = None - # Choose 3 for uneven buckets (only 7 default tools) - num_tools = 3 - # Construct a complete pagination test to see if we can return all the tools eventually - for _ in range(0, len(tools), num_tools): - curr_tools = client.list_tools(cursor, num_tools) - assert len(curr_tools) <= num_tools - - for curr_tool in curr_tools: - assert curr_tool.id in visited_ids - visited_ids[curr_tool.id] = True - - cursor = curr_tools[-1].id - - # Assert that everything has been visited - assert all(visited_ids.values()) - - -def test_list_tools(client: Union[LocalClient, RESTClient]): - tools = client.add_base_tools() - tool_names = [t.name for t in tools] - expected = ToolManager.BASE_TOOL_NAMES - assert sorted(tool_names) == sorted(expected) - - -def test_list_files_pagination(client: Union[LocalClient, RESTClient], agent: AgentState): - # clear sources - for source in client.list_sources(): - client.delete_source(source.id) - - # clear jobs - for job in client.list_jobs(): - client.delete_job(job.id) - - # create a source - source = client.create_source(name="test_source") - - # load files into sources - file_a = "tests/data/memgpt_paper.pdf" - file_b = "tests/data/test.txt" - upload_file_using_client(client, source, file_a) - upload_file_using_client(client, source, file_b) - - # Get the first file - files_a = client.list_files_from_source(source.id, limit=1) - assert len(files_a) == 1 - assert files_a[0].source_id == source.id - - # Use the cursor from response_a to get the remaining file - files_b = client.list_files_from_source(source.id, limit=1, cursor=files_a[-1].id) - assert len(files_b) == 1 - assert files_b[0].source_id == source.id - - # Check files are different to ensure the cursor works - assert files_a[0].file_name != files_b[0].file_name - - # Use the cursor from response_b to list files, should be empty - files = client.list_files_from_source(source.id, limit=1, cursor=files_b[-1].id) - assert len(files) == 0 # Should be empty - - -def test_delete_file_from_source(client: Union[LocalClient, RESTClient], agent: AgentState): - # clear sources - for source in client.list_sources(): - client.delete_source(source.id) - - # clear jobs - for job in client.list_jobs(): - client.delete_job(job.id) - - # create a source - source = client.create_source(name="test_source") - - # load files into sources - file_a = "tests/data/test.txt" - upload_file_using_client(client, source, file_a) - - # Get the first file - files_a = client.list_files_from_source(source.id, limit=1) - assert len(files_a) == 1 - assert files_a[0].source_id == source.id - - # Delete the file - client.delete_file_from_source(source.id, files_a[0].id) - - # Check that no files are attached to the source - empty_files = client.list_files_from_source(source.id, limit=1) - assert len(empty_files) == 0 - - -def test_load_file(client: Union[LocalClient, RESTClient], agent: AgentState): - # _reset_config() - - # clear sources - for source in client.list_sources(): - client.delete_source(source.id) - - # clear jobs - for job in client.list_jobs(): - client.delete_job(job.id) - - # create a source - source = client.create_source(name="test_source") - - # load a file into a source (non-blocking job) - filename = "tests/data/memgpt_paper.pdf" - upload_file_using_client(client, source, filename) - - # Get the files - files = client.list_files_from_source(source.id) - assert len(files) == 1 # Should be condensed to one document - - # Get the memgpt paper - file = files[0] - # Assert the filename matches the pattern - pattern = re.compile(r"^memgpt_paper_[a-f0-9]{32}\.pdf$") - assert pattern.match(file.file_name), f"Filename '{file.file_name}' does not match expected pattern." - - assert file.source_id == source.id - - -def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState): - # _reset_config() - - # clear sources - for source in client.list_sources(): - client.delete_source(source.id) - - # clear jobs - for job in client.list_jobs(): - client.delete_job(job.id) - - # list sources - sources = client.list_sources() - print("listed sources", sources) - assert len(sources) == 0 - - # create a source - source = client.create_source(name="test_source") - - # list sources - sources = client.list_sources() - print("listed sources", sources) - assert len(sources) == 1 - - # TODO: add back? - assert sources[0].metadata_["num_passages"] == 0 - assert sources[0].metadata_["num_documents"] == 0 - - # update the source - original_id = source.id - original_name = source.name - new_name = original_name + "_new" - client.update_source(source_id=source.id, name=new_name) - - # get the source name (check that it's been updated) - source = client.get_source(source_id=source.id) - assert source.name == new_name - assert source.id == original_id - - # get the source id (make sure that it's the same) - assert str(original_id) == client.get_source_id(source_name=new_name) - - # check agent archival memory size - archival_memories = client.get_archival_memory(agent_id=agent.id) - print(archival_memories) - assert len(archival_memories) == 0 - - # load a file into a source (non-blocking job) - filename = "tests/data/memgpt_paper.pdf" - upload_job = upload_file_using_client(client, source, filename) - job = client.get_job(upload_job.id) - created_passages = job.metadata_["num_passages"] - - # TODO: add test for blocking job - - # TODO: make sure things run in the right order - archival_memories = client.get_archival_memory(agent_id=agent.id) - assert len(archival_memories) == 0 - - # attach a source - client.attach_source_to_agent(source_id=source.id, agent_id=agent.id) - - # list attached sources - attached_sources = client.list_attached_sources(agent_id=agent.id) - print("attached sources", attached_sources) - assert source.id in [s.id for s in attached_sources], f"Attached sources: {attached_sources}" - - # list archival memory - archival_memories = client.get_archival_memory(agent_id=agent.id) - # print(archival_memories) - assert len(archival_memories) == created_passages, f"Mismatched length {len(archival_memories)} vs. {created_passages}" - - # check number of passages - sources = client.list_sources() - # TODO: add back? - # assert sources.sources[0].metadata_["num_passages"] > 0 - # assert sources.sources[0].metadata_["num_documents"] == 0 # TODO: fix this once document store added - print(sources) - - # detach the source - assert len(client.get_archival_memory(agent_id=agent.id)) > 0, "No archival memory" - deleted_source = client.detach_source(source_id=source.id, agent_id=agent.id) - assert deleted_source.id == source.id - archival_memories = client.get_archival_memory(agent_id=agent.id) - assert len(archival_memories) == 0, f"Failed to detach source: {len(archival_memories)}" - assert source.id not in [s.id for s in client.list_attached_sources(agent.id)] - - # delete the source - client.delete_source(source.id) - - -def test_message_update(client: Union[LocalClient, RESTClient], agent: AgentState): - """Test that we can update the details of a message""" - - # create a message - message_response = client.send_message(agent_id=agent.id, message="Test message", role="user", include_full_message=True) - print("Messages=", message_response) - assert isinstance(message_response, LettaResponse) - assert isinstance(message_response.messages[-1], Message) - message = message_response.messages[-1] - - new_text = "This exact string would never show up in the message???" - new_message = client.update_message(message_id=message.id, text=new_text, agent_id=agent.id) - assert new_message.text == new_text - - -def test_organization(client: RESTClient): - if isinstance(client, LocalClient): - pytest.skip("Skipping test_organization because LocalClient does not support organizations") - - # create an organization - org_name = "test-org" - org = client.create_org(org_name) - - # assert the id appears - orgs = client.list_orgs() - assert org.id in [o.id for o in orgs] - - org = client.delete_org(org.id) - assert org.name == org_name - - # assert the id is gone - orgs = client.list_orgs() - assert not (org.id in [o.id for o in orgs]) - - -def test_list_llm_models(client: RESTClient): - """Test that if the user's env has the right api keys set, at least one model appears in the model list""" - - def has_model_endpoint_type(models: List["LLMConfig"], target_type: str) -> bool: - return any(model.model_endpoint_type == target_type for model in models) - - models = client.list_llm_configs() - if model_settings.groq_api_key: - assert has_model_endpoint_type(models, "groq") - if model_settings.azure_api_key: - assert has_model_endpoint_type(models, "azure") - if model_settings.openai_api_key: - assert has_model_endpoint_type(models, "openai") - if model_settings.gemini_api_key: - assert has_model_endpoint_type(models, "google_ai") - if model_settings.anthropic_api_key: - assert has_model_endpoint_type(models, "anthropic") +@pytest.fixture(autouse=True) +def clear_tables(): + """Clear the sandbox tables before each test.""" + from letta.server.server import db_context + with db_context() as session: + session.execute(delete(SandboxEnvironmentVariable)) + session.execute(delete(SandboxConfig)) + session.commit() -def test_shared_blocks(client: Union[LocalClient, RESTClient], agent: AgentState): - # _reset_config() - # create a block - block = client.create_block(label="human", value="username: sarah") +@pytest.fixture +def mock_e2b_api_key_none(): + # Store the original value of e2b_api_key + original_api_key = tool_settings.e2b_api_key - # create agents with shared block - from letta.schemas.memory import BasicBlockMemory + # Set e2b_api_key to None + tool_settings.e2b_api_key = None - persona1_block = client.create_block(label="persona", value="you are agent 1") - persona2_block = client.create_block(label="persona", value="you are agent 2") + # Yield control to the test + yield - # create agnets - agent_state1 = client.create_agent(name="agent1", memory=BasicBlockMemory(blocks=[block, persona1_block])) - agent_state2 = client.create_agent(name="agent2", memory=BasicBlockMemory(blocks=[block, persona2_block])) + # Restore the original value of e2b_api_key + tool_settings.e2b_api_key = original_api_key - # update memory - response = client.user_message(agent_id=agent_state1.id, message="my name is actually charles") - # check agent 2 memory - assert "charles" in client.get_block(block.id).value.lower(), f"Shared block update failed {client.get_block(block.id).value}" +def test_sandbox_config_and_env_var_basic(client: Union[LocalClient, RESTClient]): + """ + Test sandbox config and environment variable functions for both LocalClient and RESTClient. + """ - response = client.user_message(agent_id=agent_state2.id, message="whats my name?") - assert ( - "charles" in client.get_core_memory(agent_state2.id).get_block("human").value.lower() - ), f"Shared block update failed {client.get_core_memory(agent_state2.id).get_block('human').value}" - # assert "charles" in response.messages[1].text.lower(), f"Shared block update failed {response.messages[0].text}" + # 1. Create a sandbox config + local_config = LocalSandboxConfig(sandbox_dir=SANDBOX_DIR) + sandbox_config = client.create_sandbox_config(config=local_config) - # cleanup - client.delete_agent(agent_state1.id) - client.delete_agent(agent_state2.id) + # Assert the created sandbox config + assert sandbox_config.id is not None + assert sandbox_config.type == SandboxType.LOCAL + # 2. Update the sandbox config + updated_config = LocalSandboxConfig(sandbox_dir=UPDATED_SANDBOX_DIR) + sandbox_config = client.update_sandbox_config(sandbox_config_id=sandbox_config.id, config=updated_config) + assert sandbox_config.config["sandbox_dir"] == UPDATED_SANDBOX_DIR -@pytest.fixture -def cleanup_agents(): - created_agents = [] - yield created_agents - # Cleanup will run even if test fails - for agent_id in created_agents: - try: - client.delete_agent(agent_id) - except Exception as e: - print(f"Failed to delete agent {agent_id}: {e}") - - -def test_initial_message_sequence(client: Union[LocalClient, RESTClient], agent: AgentState, cleanup_agents: List[str]): - """Test that we can set an initial message sequence - - If we pass in None, we should get a "default" message sequence - If we pass in a non-empty list, we should get that sequence - If we pass in an empty list, we should get an empty sequence - """ + # 3. List all sandbox configs + sandbox_configs = client.list_sandbox_configs(limit=10) + assert isinstance(sandbox_configs, List) + assert len(sandbox_configs) == 1 + assert sandbox_configs[0].id == sandbox_config.id - # The reference initial message sequence: - reference_init_messages = initialize_message_sequence( - model=agent.llm_config.model, - system=agent.system, - memory=agent.memory, - archival_memory=None, - recall_memory=None, - memory_edit_timestamp=get_utc_time(), - include_initial_boot_message=True, + # 4. Create an environment variable + env_var = client.create_sandbox_env_var( + sandbox_config_id=sandbox_config.id, key=ENV_VAR_KEY, value=ENV_VAR_VALUE, description=ENV_VAR_DESCRIPTION ) - - # system, login message, send_message test, send_message receipt - assert len(reference_init_messages) > 0 - assert len(reference_init_messages) == 4, f"Expected 4 messages, got {len(reference_init_messages)}" - - # Test with default sequence - default_agent_state = client.create_agent(name="test-default-message-sequence", initial_message_sequence=None) - cleanup_agents.append(default_agent_state.id) - assert default_agent_state.message_ids is not None - assert len(default_agent_state.message_ids) > 0 - assert len(default_agent_state.message_ids) == len( - reference_init_messages - ), f"Expected {len(reference_init_messages)} messages, got {len(default_agent_state.message_ids)}" - - # Test with empty sequence - empty_agent_state = client.create_agent(name="test-empty-message-sequence", initial_message_sequence=[]) - cleanup_agents.append(empty_agent_state.id) - assert empty_agent_state.message_ids is not None - assert len(empty_agent_state.message_ids) == 1, f"Expected 0 messages, got {len(empty_agent_state.message_ids)}" - - # Test with custom sequence - custom_sequence = [ - Message( - role=MessageRole.user, - text="Hello, how are you?", - user_id=agent.user_id, - agent_id=agent.id, - model=agent.llm_config.model, - name=None, - tool_calls=None, - tool_call_id=None, - ), - ] - custom_agent_state = client.create_agent(name="test-custom-message-sequence", initial_message_sequence=custom_sequence) - cleanup_agents.append(custom_agent_state.id) - assert custom_agent_state.message_ids is not None - assert ( - len(custom_agent_state.message_ids) == len(custom_sequence) + 1 - ), f"Expected {len(custom_sequence) + 1} messages, got {len(custom_agent_state.message_ids)}" - assert custom_agent_state.message_ids[1:] == [msg.id for msg in custom_sequence] + assert env_var.id is not None + assert env_var.key == ENV_VAR_KEY + assert env_var.value == ENV_VAR_VALUE + assert env_var.description == ENV_VAR_DESCRIPTION + + # 5. Update the environment variable + updated_env_var = client.update_sandbox_env_var(env_var_id=env_var.id, key=UPDATED_ENV_VAR_KEY, value=UPDATED_ENV_VAR_VALUE) + assert updated_env_var.key == UPDATED_ENV_VAR_KEY + assert updated_env_var.value == UPDATED_ENV_VAR_VALUE + + # 6. List environment variables + env_vars = client.list_sandbox_env_vars(sandbox_config_id=sandbox_config.id) + assert isinstance(env_vars, List) + assert len(env_vars) == 1 + assert env_vars[0].key == UPDATED_ENV_VAR_KEY + + # 7. Delete the environment variable + client.delete_sandbox_env_var(env_var_id=env_var.id) + + # 8. Delete the sandbox config + client.delete_sandbox_config(sandbox_config_id=sandbox_config.id) def test_add_and_manage_tags_for_agent(client: Union[LocalClient, RESTClient], agent: AgentState): diff --git a/tests/test_client_legacy.py b/tests/test_client_legacy.py new file mode 100644 index 0000000000..56bbf9a632 --- /dev/null +++ b/tests/test_client_legacy.py @@ -0,0 +1,728 @@ +import os +import re +import threading +import time +import uuid +from typing import List, Union + +import pytest +from dotenv import load_dotenv +from sqlalchemy import delete + +from letta import create_client +from letta.agent import initialize_message_sequence +from letta.client.client import LocalClient, RESTClient +from letta.constants import DEFAULT_PRESET +from letta.orm import FileMetadata, Source +from letta.schemas.agent import AgentState +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.enums import MessageRole, MessageStreamStatus +from letta.schemas.letta_message import ( + AssistantMessage, + FunctionCallMessage, + FunctionReturn, + InternalMonologue, + LettaMessage, + SystemMessage, + UserMessage, +) +from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse +from letta.schemas.llm_config import LLMConfig +from letta.schemas.message import Message +from letta.schemas.usage import LettaUsageStatistics +from letta.services.tool_manager import ToolManager +from letta.settings import model_settings +from letta.utils import get_utc_time +from tests.helpers.client_helper import upload_file_using_client + +# from tests.utils import create_config + +test_agent_name = f"test_client_{str(uuid.uuid4())}" +# test_preset_name = "test_preset" +test_preset_name = DEFAULT_PRESET +test_agent_state = None +client = None + +test_agent_state_post_message = None + + +def run_server(): + load_dotenv() + + # _reset_config() + + from letta.server.rest_api.app import start_server + + print("Starting server...") + start_server(debug=True) + + +# Fixture to create clients with different configurations +@pytest.fixture( + # params=[{"server": True}, {"server": False}], # whether to use REST API server + params=[{"server": True}], # whether to use REST API server + scope="module", +) +def client(request): + if request.param["server"]: + # get URL from enviornment + server_url = os.getenv("LETTA_SERVER_URL") + if server_url is None: + # run server in thread + server_url = "http://localhost:8283" + print("Starting server thread") + thread = threading.Thread(target=run_server, daemon=True) + thread.start() + time.sleep(5) + print("Running client tests with server:", server_url) + # create user via admin client + client = create_client(base_url=server_url, token=None) # This yields control back to the test function + else: + # use local client (no server) + client = create_client() + + client.set_default_llm_config(LLMConfig.default_config("gpt-4")) + client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) + yield client + + +@pytest.fixture(autouse=True) +def clear_tables(): + """Fixture to clear the organization table before each test.""" + from letta.server.server import db_context + + with db_context() as session: + session.execute(delete(FileMetadata)) + session.execute(delete(Source)) + session.commit() + + +# Fixture for test agent +@pytest.fixture(scope="module") +def agent(client: Union[LocalClient, RESTClient]): + agent_state = client.create_agent(name=test_agent_name) + yield agent_state + + # delete agent + client.delete_agent(agent_state.id) + + +def test_agent(client: Union[LocalClient, RESTClient], agent: AgentState): + + # test client.rename_agent + new_name = "RenamedTestAgent" + client.rename_agent(agent_id=agent.id, new_name=new_name) + renamed_agent = client.get_agent(agent_id=agent.id) + assert renamed_agent.name == new_name, "Agent renaming failed" + + # get agent id + agent_id = client.get_agent_id(agent_name=new_name) + assert agent_id == agent.id, "Agent ID retrieval failed" + + # test client.delete_agent and client.agent_exists + delete_agent = client.create_agent(name="DeleteTestAgent") + assert client.agent_exists(agent_id=delete_agent.id), "Agent creation failed" + client.delete_agent(agent_id=delete_agent.id) + assert client.agent_exists(agent_id=delete_agent.id) == False, "Agent deletion failed" + + +def test_memory(client: Union[LocalClient, RESTClient], agent: AgentState): + # _reset_config() + + memory_response = client.get_in_context_memory(agent_id=agent.id) + print("MEMORY", memory_response.compile()) + + updated_memory = {"human": "Updated human memory", "persona": "Updated persona memory"} + client.update_in_context_memory(agent_id=agent.id, section="human", value=updated_memory["human"]) + client.update_in_context_memory(agent_id=agent.id, section="persona", value=updated_memory["persona"]) + updated_memory_response = client.get_in_context_memory(agent_id=agent.id) + assert ( + updated_memory_response.get_block("human").value == updated_memory["human"] + and updated_memory_response.get_block("persona").value == updated_memory["persona"] + ), "Memory update failed" + + +def test_agent_interactions(client: Union[LocalClient, RESTClient], agent: AgentState): + # _reset_config() + + message = "Hello, agent!" + print("Sending message", message) + response = client.user_message(agent_id=agent.id, message=message, include_full_message=True) + # Check the types coming back + assert all([isinstance(m, Message) for m in response.messages]), "All messages should be Message" + + print("Response", response) + assert isinstance(response.usage, LettaUsageStatistics) + assert response.usage.step_count == 1 + assert response.usage.total_tokens > 0 + assert response.usage.completion_tokens > 0 + assert isinstance(response.messages[0], Message) + print(response.messages) + + # test that it also works with LettaMessage + message = "Hello again, agent!" + print("Sending message", message) + response = client.user_message(agent_id=agent.id, message=message, include_full_message=False) + assert all([isinstance(m, LettaMessage) for m in response.messages]), "All messages should be LettaMessages" + + # We should also check that the types were cast properly + print("RESPONSE MESSAGES, client type:", type(client)) + print(response.messages) + for letta_message in response.messages: + assert type(letta_message) in [ + SystemMessage, + UserMessage, + InternalMonologue, + FunctionCallMessage, + FunctionReturn, + AssistantMessage, + ], f"Unexpected message type: {type(letta_message)}" + + # TODO: add streaming tests + + +def test_archival_memory(client: Union[LocalClient, RESTClient], agent: AgentState): + # _reset_config() + + memory_content = "Archival memory content" + insert_response = client.insert_archival_memory(agent_id=agent.id, memory=memory_content)[0] + print("Inserted memory", insert_response.text, insert_response.id) + assert insert_response, "Inserting archival memory failed" + + archival_memory_response = client.get_archival_memory(agent_id=agent.id, limit=1) + archival_memories = [memory.text for memory in archival_memory_response] + assert memory_content in archival_memories, f"Retrieving archival memory failed: {archival_memories}" + + memory_id_to_delete = archival_memory_response[0].id + client.delete_archival_memory(agent_id=agent.id, memory_id=memory_id_to_delete) + + # add archival memory + memory_str = "I love chats" + passage = client.insert_archival_memory(agent.id, memory=memory_str)[0] + + # list archival memory + passages = client.get_archival_memory(agent.id) + assert passage.text in [p.text for p in passages], f"Missing passage {passage.text} in {passages}" + + # get archival memory summary + archival_summary = client.get_archival_memory_summary(agent.id) + assert archival_summary.size == 1, f"Archival memory summary size is {archival_summary.size}" + + # delete archival memory + client.delete_archival_memory(agent.id, passage.id) + + # TODO: check deletion + client.get_archival_memory(agent.id) + + +def test_core_memory(client: Union[LocalClient, RESTClient], agent: AgentState): + response = client.send_message(agent_id=agent.id, message="Update your core memory to remember that my name is Timber!", role="user") + print("Response", response) + + memory = client.get_in_context_memory(agent_id=agent.id) + assert "Timber" in memory.get_block("human").value, f"Updating core memory failed: {memory.get_block('human').value}" + + +def test_messages(client: Union[LocalClient, RESTClient], agent: AgentState): + # _reset_config() + + send_message_response = client.send_message(agent_id=agent.id, message="Test message", role="user") + assert send_message_response, "Sending message failed" + + messages_response = client.get_messages(agent_id=agent.id, limit=1) + assert len(messages_response) > 0, "Retrieving messages failed" + + +def test_streaming_send_message(client: Union[LocalClient, RESTClient], agent: AgentState): + if isinstance(client, LocalClient): + pytest.skip("Skipping test_streaming_send_message because LocalClient does not support streaming") + assert isinstance(client, RESTClient), client + + # First, try streaming just steps + + # Next, try streaming both steps and tokens + response = client.send_message( + agent_id=agent.id, + message="This is a test. Repeat after me: 'banana'", + role="user", + stream_steps=True, + stream_tokens=True, + ) + + # Some manual checks to run + # 1. Check that there were inner thoughts + inner_thoughts_exist = False + # 2. Check that the agent runs `send_message` + send_message_ran = False + # 3. Check that we get all the start/stop/end tokens we want + # This includes all of the MessageStreamStatus enums + done_gen = False + done_step = False + done = False + + # print(response) + assert response, "Sending message failed" + for chunk in response: + assert isinstance(chunk, LettaStreamingResponse) + if isinstance(chunk, InternalMonologue) and chunk.internal_monologue and chunk.internal_monologue != "": + inner_thoughts_exist = True + if isinstance(chunk, FunctionCallMessage) and chunk.function_call and chunk.function_call.name == "send_message": + send_message_ran = True + if isinstance(chunk, MessageStreamStatus): + if chunk == MessageStreamStatus.done: + assert not done, "Message stream already done" + done = True + elif chunk == MessageStreamStatus.done_step: + assert not done_step, "Message stream already done step" + done_step = True + elif chunk == MessageStreamStatus.done_generation: + assert not done_gen, "Message stream already done generation" + done_gen = True + if isinstance(chunk, LettaUsageStatistics): + # Some rough metrics for a reasonable usage pattern + assert chunk.step_count == 1 + assert chunk.completion_tokens > 10 + assert chunk.prompt_tokens > 1000 + assert chunk.total_tokens > 1000 + + assert inner_thoughts_exist, "No inner thoughts found" + assert send_message_ran, "send_message function call not found" + assert done, "Message stream not done" + assert done_step, "Message stream not done step" + assert done_gen, "Message stream not done generation" + + +def test_humans_personas(client: Union[LocalClient, RESTClient], agent: AgentState): + # _reset_config() + + humans_response = client.list_humans() + print("HUMANS", humans_response) + + personas_response = client.list_personas() + print("PERSONAS", personas_response) + + persona_name = "TestPersona" + persona_id = client.get_persona_id(persona_name) + if persona_id: + client.delete_persona(persona_id) + persona = client.create_persona(name=persona_name, text="Persona text") + assert persona.template_name == persona_name + assert persona.value == "Persona text", "Creating persona failed" + + human_name = "TestHuman" + human_id = client.get_human_id(human_name) + if human_id: + client.delete_human(human_id) + human = client.create_human(name=human_name, text="Human text") + assert human.template_name == human_name + assert human.value == "Human text", "Creating human failed" + + +def test_list_tools_pagination(client: Union[LocalClient, RESTClient]): + tools = client.list_tools() + visited_ids = {t.id: False for t in tools} + + cursor = None + # Choose 3 for uneven buckets (only 7 default tools) + num_tools = 3 + # Construct a complete pagination test to see if we can return all the tools eventually + for _ in range(0, len(tools), num_tools): + curr_tools = client.list_tools(cursor, num_tools) + assert len(curr_tools) <= num_tools + + for curr_tool in curr_tools: + assert curr_tool.id in visited_ids + visited_ids[curr_tool.id] = True + + cursor = curr_tools[-1].id + + # Assert that everything has been visited + assert all(visited_ids.values()) + + +def test_list_tools(client: Union[LocalClient, RESTClient]): + tools = client.add_base_tools() + tool_names = [t.name for t in tools] + expected = ToolManager.BASE_TOOL_NAMES + assert sorted(tool_names) == sorted(expected) + + +def test_list_files_pagination(client: Union[LocalClient, RESTClient], agent: AgentState): + # clear sources + for source in client.list_sources(): + client.delete_source(source.id) + + # clear jobs + for job in client.list_jobs(): + client.delete_job(job.id) + + # create a source + source = client.create_source(name="test_source") + + # load files into sources + file_a = "tests/data/memgpt_paper.pdf" + file_b = "tests/data/test.txt" + upload_file_using_client(client, source, file_a) + upload_file_using_client(client, source, file_b) + + # Get the first file + files_a = client.list_files_from_source(source.id, limit=1) + assert len(files_a) == 1 + assert files_a[0].source_id == source.id + + # Use the cursor from response_a to get the remaining file + files_b = client.list_files_from_source(source.id, limit=1, cursor=files_a[-1].id) + assert len(files_b) == 1 + assert files_b[0].source_id == source.id + + # Check files are different to ensure the cursor works + assert files_a[0].file_name != files_b[0].file_name + + # Use the cursor from response_b to list files, should be empty + files = client.list_files_from_source(source.id, limit=1, cursor=files_b[-1].id) + assert len(files) == 0 # Should be empty + + +def test_delete_file_from_source(client: Union[LocalClient, RESTClient], agent: AgentState): + # clear sources + for source in client.list_sources(): + client.delete_source(source.id) + + # clear jobs + for job in client.list_jobs(): + client.delete_job(job.id) + + # create a source + source = client.create_source(name="test_source") + + # load files into sources + file_a = "tests/data/test.txt" + upload_file_using_client(client, source, file_a) + + # Get the first file + files_a = client.list_files_from_source(source.id, limit=1) + assert len(files_a) == 1 + assert files_a[0].source_id == source.id + + # Delete the file + client.delete_file_from_source(source.id, files_a[0].id) + + # Check that no files are attached to the source + empty_files = client.list_files_from_source(source.id, limit=1) + assert len(empty_files) == 0 + + +def test_load_file(client: Union[LocalClient, RESTClient], agent: AgentState): + # _reset_config() + + # clear sources + for source in client.list_sources(): + client.delete_source(source.id) + + # clear jobs + for job in client.list_jobs(): + client.delete_job(job.id) + + # create a source + source = client.create_source(name="test_source") + + # load a file into a source (non-blocking job) + filename = "tests/data/memgpt_paper.pdf" + upload_file_using_client(client, source, filename) + + # Get the files + files = client.list_files_from_source(source.id) + assert len(files) == 1 # Should be condensed to one document + + # Get the memgpt paper + file = files[0] + # Assert the filename matches the pattern + pattern = re.compile(r"^memgpt_paper_[a-f0-9]{32}\.pdf$") + assert pattern.match(file.file_name), f"Filename '{file.file_name}' does not match expected pattern." + + assert file.source_id == source.id + + +def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState): + # _reset_config() + + # clear sources + for source in client.list_sources(): + client.delete_source(source.id) + + # clear jobs + for job in client.list_jobs(): + client.delete_job(job.id) + + # list sources + sources = client.list_sources() + print("listed sources", sources) + assert len(sources) == 0 + + # create a source + source = client.create_source(name="test_source") + + # list sources + sources = client.list_sources() + print("listed sources", sources) + assert len(sources) == 1 + + # TODO: add back? + assert sources[0].metadata_["num_passages"] == 0 + assert sources[0].metadata_["num_documents"] == 0 + + # update the source + original_id = source.id + original_name = source.name + new_name = original_name + "_new" + client.update_source(source_id=source.id, name=new_name) + + # get the source name (check that it's been updated) + source = client.get_source(source_id=source.id) + assert source.name == new_name + assert source.id == original_id + + # get the source id (make sure that it's the same) + assert str(original_id) == client.get_source_id(source_name=new_name) + + # check agent archival memory size + archival_memories = client.get_archival_memory(agent_id=agent.id) + print(archival_memories) + assert len(archival_memories) == 0 + + # load a file into a source (non-blocking job) + filename = "tests/data/memgpt_paper.pdf" + upload_job = upload_file_using_client(client, source, filename) + job = client.get_job(upload_job.id) + created_passages = job.metadata_["num_passages"] + + # TODO: add test for blocking job + + # TODO: make sure things run in the right order + archival_memories = client.get_archival_memory(agent_id=agent.id) + assert len(archival_memories) == 0 + + # attach a source + client.attach_source_to_agent(source_id=source.id, agent_id=agent.id) + + # list attached sources + attached_sources = client.list_attached_sources(agent_id=agent.id) + print("attached sources", attached_sources) + assert source.id in [s.id for s in attached_sources], f"Attached sources: {attached_sources}" + + # list archival memory + archival_memories = client.get_archival_memory(agent_id=agent.id) + # print(archival_memories) + assert len(archival_memories) == created_passages, f"Mismatched length {len(archival_memories)} vs. {created_passages}" + + # check number of passages + sources = client.list_sources() + # TODO: add back? + # assert sources.sources[0].metadata_["num_passages"] > 0 + # assert sources.sources[0].metadata_["num_documents"] == 0 # TODO: fix this once document store added + print(sources) + + # detach the source + assert len(client.get_archival_memory(agent_id=agent.id)) > 0, "No archival memory" + deleted_source = client.detach_source(source_id=source.id, agent_id=agent.id) + assert deleted_source.id == source.id + archival_memories = client.get_archival_memory(agent_id=agent.id) + assert len(archival_memories) == 0, f"Failed to detach source: {len(archival_memories)}" + assert source.id not in [s.id for s in client.list_attached_sources(agent.id)] + + # delete the source + client.delete_source(source.id) + + +def test_message_update(client: Union[LocalClient, RESTClient], agent: AgentState): + """Test that we can update the details of a message""" + + # create a message + message_response = client.send_message(agent_id=agent.id, message="Test message", role="user", include_full_message=True) + print("Messages=", message_response) + assert isinstance(message_response, LettaResponse) + assert isinstance(message_response.messages[-1], Message) + message = message_response.messages[-1] + + new_text = "This exact string would never show up in the message???" + new_message = client.update_message(message_id=message.id, text=new_text, agent_id=agent.id) + assert new_message.text == new_text + + +def test_organization(client: RESTClient): + if isinstance(client, LocalClient): + pytest.skip("Skipping test_organization because LocalClient does not support organizations") + + # create an organization + org_name = "test-org" + org = client.create_org(org_name) + + # assert the id appears + orgs = client.list_orgs() + assert org.id in [o.id for o in orgs] + + org = client.delete_org(org.id) + assert org.name == org_name + + # assert the id is gone + orgs = client.list_orgs() + assert not (org.id in [o.id for o in orgs]) + + +def test_list_llm_models(client: RESTClient): + """Test that if the user's env has the right api keys set, at least one model appears in the model list""" + + def has_model_endpoint_type(models: List["LLMConfig"], target_type: str) -> bool: + return any(model.model_endpoint_type == target_type for model in models) + + models = client.list_llm_configs() + if model_settings.groq_api_key: + assert has_model_endpoint_type(models, "groq") + if model_settings.azure_api_key: + assert has_model_endpoint_type(models, "azure") + if model_settings.openai_api_key: + assert has_model_endpoint_type(models, "openai") + if model_settings.gemini_api_key: + assert has_model_endpoint_type(models, "google_ai") + if model_settings.anthropic_api_key: + assert has_model_endpoint_type(models, "anthropic") + + +def test_shared_blocks(client: Union[LocalClient, RESTClient], agent: AgentState): + # _reset_config() + + # create a block + block = client.create_block(label="human", value="username: sarah") + + # create agents with shared block + from letta.schemas.memory import BasicBlockMemory + + persona1_block = client.create_block(label="persona", value="you are agent 1") + persona2_block = client.create_block(label="persona", value="you are agent 2") + + # create agnets + agent_state1 = client.create_agent(name="agent1", memory=BasicBlockMemory(blocks=[block, persona1_block])) + agent_state2 = client.create_agent(name="agent2", memory=BasicBlockMemory(blocks=[block, persona2_block])) + + # update memory + response = client.user_message(agent_id=agent_state1.id, message="my name is actually charles") + + # check agent 2 memory + assert "charles" in client.get_block(block.id).value.lower(), f"Shared block update failed {client.get_block(block.id).value}" + + response = client.user_message(agent_id=agent_state2.id, message="whats my name?") + assert ( + "charles" in client.get_core_memory(agent_state2.id).get_block("human").value.lower() + ), f"Shared block update failed {client.get_core_memory(agent_state2.id).get_block('human').value}" + # assert "charles" in response.messages[1].text.lower(), f"Shared block update failed {response.messages[0].text}" + + # cleanup + client.delete_agent(agent_state1.id) + client.delete_agent(agent_state2.id) + + +@pytest.fixture +def cleanup_agents(): + created_agents = [] + yield created_agents + # Cleanup will run even if test fails + for agent_id in created_agents: + try: + client.delete_agent(agent_id) + except Exception as e: + print(f"Failed to delete agent {agent_id}: {e}") + + +def test_initial_message_sequence(client: Union[LocalClient, RESTClient], agent: AgentState, cleanup_agents: List[str]): + """Test that we can set an initial message sequence + + If we pass in None, we should get a "default" message sequence + If we pass in a non-empty list, we should get that sequence + If we pass in an empty list, we should get an empty sequence + """ + + # The reference initial message sequence: + reference_init_messages = initialize_message_sequence( + model=agent.llm_config.model, + system=agent.system, + memory=agent.memory, + archival_memory=None, + recall_memory=None, + memory_edit_timestamp=get_utc_time(), + include_initial_boot_message=True, + ) + + # system, login message, send_message test, send_message receipt + assert len(reference_init_messages) > 0 + assert len(reference_init_messages) == 4, f"Expected 4 messages, got {len(reference_init_messages)}" + + # Test with default sequence + default_agent_state = client.create_agent(name="test-default-message-sequence", initial_message_sequence=None) + cleanup_agents.append(default_agent_state.id) + assert default_agent_state.message_ids is not None + assert len(default_agent_state.message_ids) > 0 + assert len(default_agent_state.message_ids) == len( + reference_init_messages + ), f"Expected {len(reference_init_messages)} messages, got {len(default_agent_state.message_ids)}" + + # Test with empty sequence + empty_agent_state = client.create_agent(name="test-empty-message-sequence", initial_message_sequence=[]) + cleanup_agents.append(empty_agent_state.id) + assert empty_agent_state.message_ids is not None + assert len(empty_agent_state.message_ids) == 1, f"Expected 0 messages, got {len(empty_agent_state.message_ids)}" + + # Test with custom sequence + custom_sequence = [ + Message( + role=MessageRole.user, + text="Hello, how are you?", + user_id=agent.user_id, + agent_id=agent.id, + model=agent.llm_config.model, + name=None, + tool_calls=None, + tool_call_id=None, + ), + ] + custom_agent_state = client.create_agent(name="test-custom-message-sequence", initial_message_sequence=custom_sequence) + cleanup_agents.append(custom_agent_state.id) + assert custom_agent_state.message_ids is not None + assert ( + len(custom_agent_state.message_ids) == len(custom_sequence) + 1 + ), f"Expected {len(custom_sequence) + 1} messages, got {len(custom_agent_state.message_ids)}" + assert custom_agent_state.message_ids[1:] == [msg.id for msg in custom_sequence] + + +def test_add_and_manage_tags_for_agent(client: Union[LocalClient, RESTClient], agent: AgentState): + """ + Comprehensive happy path test for adding, retrieving, and managing tags on an agent. + """ + + # Step 1: Add multiple tags to the agent + tags_to_add = ["test_tag_1", "test_tag_2", "test_tag_3"] + client.update_agent(agent_id=agent.id, tags=tags_to_add) + + # Step 2: Retrieve tags for the agent and verify they match the added tags + retrieved_tags = client.get_agent(agent_id=agent.id).tags + assert set(retrieved_tags) == set(tags_to_add), f"Expected tags {tags_to_add}, but got {retrieved_tags}" + + # Step 3: Retrieve agents by each tag to ensure the agent is associated correctly + for tag in tags_to_add: + agents_with_tag = client.list_agents(tags=[tag]) + assert agent.id in [a.id for a in agents_with_tag], f"Expected agent {agent.id} to be associated with tag '{tag}'" + + # Step 4: Delete a specific tag from the agent and verify its removal + tag_to_delete = tags_to_add.pop() + client.update_agent(agent_id=agent.id, tags=tags_to_add) + + # Verify the tag is removed from the agent's tags + remaining_tags = client.get_agent(agent_id=agent.id).tags + assert tag_to_delete not in remaining_tags, f"Tag '{tag_to_delete}' was not removed as expected" + assert set(remaining_tags) == set(tags_to_add), f"Expected remaining tags to be {tags_to_add[1:]}, but got {remaining_tags}" + + # Step 5: Delete all remaining tags from the agent + client.update_agent(agent_id=agent.id, tags=[]) + + # Verify all tags are removed + final_tags = client.get_agent(agent_id=agent.id).tags + assert len(final_tags) == 0, f"Expected no tags, but found {final_tags}" diff --git a/tests/test_managers.py b/tests/test_managers.py index 9d566e84a5..946d4edbd3 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -3,7 +3,16 @@ import letta.utils as utils from letta.functions.functions import derive_openai_json_schema, parse_source_code -from letta.orm import Block, FileMetadata, Organization, Source, Tool, User +from letta.orm import ( + Block, + FileMetadata, + Organization, + SandboxConfig, + SandboxEnvironmentVariable, + Source, + Tool, + User, +) from letta.schemas.agent import CreateAgent from letta.schemas.block import Block as PydanticBlock from letta.schemas.block import BlockUpdate @@ -12,12 +21,22 @@ from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import ChatMemory from letta.schemas.organization import Organization as PydanticOrganization +from letta.schemas.sandbox_config import ( + E2BSandboxConfig, + LocalSandboxConfig, + SandboxConfigCreate, + SandboxConfigUpdate, + SandboxEnvironmentVariableCreate, + SandboxEnvironmentVariableUpdate, + SandboxType, +) from letta.schemas.source import Source as PydanticSource from letta.schemas.source import SourceUpdate from letta.schemas.tool import Tool as PydanticTool from letta.schemas.tool import ToolUpdate from letta.services.block_manager import BlockManager from letta.services.organization_manager import OrganizationManager +from letta.settings import tool_settings utils.DEBUG = True from letta.config import LettaConfig @@ -41,6 +60,8 @@ def clear_tables(server: SyncServer): """Fixture to clear the organization table before each test.""" with server.organization_manager.session_maker() as session: + session.execute(delete(SandboxEnvironmentVariable)) + session.execute(delete(SandboxConfig)) session.execute(delete(Block)) session.execute(delete(FileMetadata)) session.execute(delete(Source)) @@ -157,6 +178,28 @@ def print_tool(message: str): yield {"tool": tool} +@pytest.fixture +def sandbox_config_fixture(server: SyncServer, default_user): + sandbox_config_create = SandboxConfigCreate( + config=E2BSandboxConfig(), + ) + created_config = server.sandbox_config_manager.create_or_update_sandbox_config(sandbox_config_create, actor=default_user) + yield created_config + + +@pytest.fixture +def sandbox_env_var_fixture(server: SyncServer, sandbox_config_fixture, default_user): + env_var_create = SandboxEnvironmentVariableCreate( + key="SAMPLE_VAR", + value="sample_value", + description="A sample environment variable for testing.", + ) + created_env_var = server.sandbox_config_manager.create_sandbox_env_var( + env_var_create, sandbox_config_id=sandbox_config_fixture.id, actor=default_user + ) + yield created_env_var + + @pytest.fixture(scope="module") def server(): config = LettaConfig.load() @@ -829,3 +872,149 @@ def test_get_agents_by_tag(server: SyncServer, sarah_agent, charles_agent, defau assert sarah_agent.id not in agent_ids assert charles_agent.id in agent_ids assert len(agent_ids) == 1 + + +# ====================================================================================================================== +# SandboxConfigManager Tests - Sandbox Configs +# ====================================================================================================================== +def test_create_or_update_sandbox_config(server: SyncServer, default_user): + sandbox_config_create = SandboxConfigCreate( + config=E2BSandboxConfig(), + ) + created_config = server.sandbox_config_manager.create_or_update_sandbox_config(sandbox_config_create, actor=default_user) + + # Assertions + assert created_config.type == SandboxType.E2B + assert created_config.get_e2b_config() == sandbox_config_create.config + assert created_config.organization_id == default_user.organization_id + + +def test_default_e2b_settings_sandbox_config(server: SyncServer, default_user): + created_config = server.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.E2B, actor=default_user) + e2b_config = created_config.get_e2b_config() + + # Assertions + assert e2b_config.timeout == 5 * 60 + assert e2b_config.template + assert e2b_config.template == tool_settings.e2b_sandbox_template_id + + +def test_update_existing_sandbox_config(server: SyncServer, sandbox_config_fixture, default_user): + update_data = SandboxConfigUpdate(config=E2BSandboxConfig(template="template_2", timeout=120)) + updated_config = server.sandbox_config_manager.update_sandbox_config(sandbox_config_fixture.id, update_data, actor=default_user) + + # Assertions + assert updated_config.config["template"] == "template_2" + assert updated_config.config["timeout"] == 120 + + +def test_delete_sandbox_config(server: SyncServer, sandbox_config_fixture, default_user): + deleted_config = server.sandbox_config_manager.delete_sandbox_config(sandbox_config_fixture.id, actor=default_user) + + # Assertions to verify deletion + assert deleted_config.id == sandbox_config_fixture.id + + # Verify it no longer exists + config_list = server.sandbox_config_manager.list_sandbox_configs(actor=default_user) + assert sandbox_config_fixture.id not in [config.id for config in config_list] + + +def test_get_sandbox_config_by_type(server: SyncServer, sandbox_config_fixture, default_user): + retrieved_config = server.sandbox_config_manager.get_sandbox_config_by_type(sandbox_config_fixture.type, actor=default_user) + + # Assertions to verify correct retrieval + assert retrieved_config.id == sandbox_config_fixture.id + assert retrieved_config.type == sandbox_config_fixture.type + + +def test_list_sandbox_configs(server: SyncServer, default_user): + # Creating multiple sandbox configs + config_a = SandboxConfigCreate( + config=E2BSandboxConfig(), + ) + config_b = SandboxConfigCreate( + config=LocalSandboxConfig(sandbox_dir=""), + ) + server.sandbox_config_manager.create_or_update_sandbox_config(config_a, actor=default_user) + server.sandbox_config_manager.create_or_update_sandbox_config(config_b, actor=default_user) + + # List configs without pagination + configs = server.sandbox_config_manager.list_sandbox_configs(actor=default_user) + assert len(configs) >= 2 + + # List configs with pagination + paginated_configs = server.sandbox_config_manager.list_sandbox_configs(actor=default_user, limit=1) + assert len(paginated_configs) == 1 + + next_page = server.sandbox_config_manager.list_sandbox_configs(actor=default_user, cursor=paginated_configs[-1].id, limit=1) + assert len(next_page) == 1 + assert next_page[0].id != paginated_configs[0].id + + +# ====================================================================================================================== +# SandboxConfigManager Tests - Environment Variables +# ====================================================================================================================== +def test_create_sandbox_env_var(server: SyncServer, sandbox_config_fixture, default_user): + env_var_create = SandboxEnvironmentVariableCreate(key="TEST_VAR", value="test_value", description="A test environment variable.") + created_env_var = server.sandbox_config_manager.create_sandbox_env_var( + env_var_create, sandbox_config_id=sandbox_config_fixture.id, actor=default_user + ) + + # Assertions + assert created_env_var.key == env_var_create.key + assert created_env_var.value == env_var_create.value + assert created_env_var.organization_id == default_user.organization_id + + +def test_update_sandbox_env_var(server: SyncServer, sandbox_env_var_fixture, default_user): + update_data = SandboxEnvironmentVariableUpdate(value="updated_value") + updated_env_var = server.sandbox_config_manager.update_sandbox_env_var(sandbox_env_var_fixture.id, update_data, actor=default_user) + + # Assertions + assert updated_env_var.value == "updated_value" + assert updated_env_var.id == sandbox_env_var_fixture.id + + +def test_delete_sandbox_env_var(server: SyncServer, sandbox_config_fixture, sandbox_env_var_fixture, default_user): + deleted_env_var = server.sandbox_config_manager.delete_sandbox_env_var(sandbox_env_var_fixture.id, actor=default_user) + + # Assertions to verify deletion + assert deleted_env_var.id == sandbox_env_var_fixture.id + + # Verify it no longer exists + env_vars = server.sandbox_config_manager.list_sandbox_env_vars(sandbox_config_id=sandbox_config_fixture.id, actor=default_user) + assert sandbox_env_var_fixture.id not in [env_var.id for env_var in env_vars] + + +def test_list_sandbox_env_vars(server: SyncServer, sandbox_config_fixture, default_user): + # Creating multiple environment variables + env_var_create_a = SandboxEnvironmentVariableCreate(key="VAR1", value="value1") + env_var_create_b = SandboxEnvironmentVariableCreate(key="VAR2", value="value2") + server.sandbox_config_manager.create_sandbox_env_var(env_var_create_a, sandbox_config_id=sandbox_config_fixture.id, actor=default_user) + server.sandbox_config_manager.create_sandbox_env_var(env_var_create_b, sandbox_config_id=sandbox_config_fixture.id, actor=default_user) + + # List env vars without pagination + env_vars = server.sandbox_config_manager.list_sandbox_env_vars(sandbox_config_id=sandbox_config_fixture.id, actor=default_user) + assert len(env_vars) >= 2 + + # List env vars with pagination + paginated_env_vars = server.sandbox_config_manager.list_sandbox_env_vars( + sandbox_config_id=sandbox_config_fixture.id, actor=default_user, limit=1 + ) + assert len(paginated_env_vars) == 1 + + next_page = server.sandbox_config_manager.list_sandbox_env_vars( + sandbox_config_id=sandbox_config_fixture.id, actor=default_user, cursor=paginated_env_vars[-1].id, limit=1 + ) + assert len(next_page) == 1 + assert next_page[0].id != paginated_env_vars[0].id + + +def test_get_sandbox_env_var_by_key(server: SyncServer, sandbox_env_var_fixture, default_user): + retrieved_env_var = server.sandbox_config_manager.get_sandbox_env_var_by_key_and_sandbox_config_id( + sandbox_env_var_fixture.key, sandbox_env_var_fixture.sandbox_config_id, actor=default_user + ) + + # Assertions to verify correct retrieval + assert retrieved_env_var.id == sandbox_env_var_fixture.id + assert retrieved_env_var.key == sandbox_env_var_fixture.key diff --git a/tests/test_summarize.py b/tests/test_summarize.py index dd58311ebf..31a8592912 100644 --- a/tests/test_summarize.py +++ b/tests/test_summarize.py @@ -99,6 +99,7 @@ def test_auto_summarize(): def summarize_message_exists(messages: List[Message]) -> bool: for message in messages: if message.text and "have been hidden from view due to conversation memory constraints" in message.text: + print(f"Summarize message found after {message_count} messages: \n {message.text}") return True return False @@ -113,12 +114,12 @@ def summarize_message_exists(messages: List[Message]) -> bool: ) message_count += 1 + print(f"Message {message_count}: \n\n{response.messages}") + # check if the summarize message is inside the messages assert isinstance(client, LocalClient), "Test only works with LocalClient" agent_obj = client.server._get_or_load_agent(agent_id=agent_state.id) if summarize_message_exists(agent_obj._messages): - # We found a summarize message - print(f"Summarize message found after {message_count} messages") break if message_count > MAX_ATTEMPTS: diff --git a/tests/test_tool_execution_sandbox.py b/tests/test_tool_execution_sandbox.py new file mode 100644 index 0000000000..f1d82b61a3 --- /dev/null +++ b/tests/test_tool_execution_sandbox.py @@ -0,0 +1,425 @@ +import secrets +import string +import uuid +from pathlib import Path +from unittest.mock import patch + +import pytest +from composio import Action +from sqlalchemy import delete + +from letta import create_client +from letta.functions.functions import parse_source_code +from letta.functions.schema_generator import generate_schema +from letta.orm import SandboxConfig, SandboxEnvironmentVariable +from letta.schemas.agent import AgentState +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.llm_config import LLMConfig +from letta.schemas.memory import ChatMemory +from letta.schemas.organization import Organization +from letta.schemas.sandbox_config import ( + E2BSandboxConfig, + LocalSandboxConfig, + SandboxConfigCreate, + SandboxConfigUpdate, + SandboxEnvironmentVariableCreate, +) +from letta.schemas.tool import Tool, ToolCreate +from letta.schemas.user import User +from letta.services.organization_manager import OrganizationManager +from letta.services.sandbox_config_manager import SandboxConfigManager +from letta.services.tool_execution_sandbox import ToolExecutionSandbox +from letta.services.tool_manager import ToolManager +from letta.services.user_manager import UserManager +from letta.settings import tool_settings + +# Constants +namespace = uuid.NAMESPACE_DNS +org_name = str(uuid.uuid5(namespace, "test-tool-execution-sandbox-org")) +user_name = str(uuid.uuid5(namespace, "test-tool-execution-sandbox-user")) + + +# Fixtures +@pytest.fixture(autouse=True) +def clear_tables(): + """Fixture to clear the organization table before each test.""" + from letta.server.server import db_context + + with db_context() as session: + session.execute(delete(SandboxEnvironmentVariable)) + session.execute(delete(SandboxConfig)) + session.commit() # Commit the deletion + + # Kill all sandboxes + from e2b_code_interpreter import Sandbox + + for sandbox in Sandbox.list(): + Sandbox.connect(sandbox.sandbox_id).kill() + + +@pytest.fixture +def mock_e2b_api_key_none(): + # Store the original value of e2b_api_key + original_api_key = tool_settings.e2b_api_key + + # Set e2b_api_key to None + tool_settings.e2b_api_key = None + + # Yield control to the test + yield + + # Restore the original value of e2b_api_key + tool_settings.e2b_api_key = original_api_key + + +@pytest.fixture +def check_e2b_key_is_set(): + original_api_key = tool_settings.e2b_api_key + assert original_api_key is not None, "Missing e2b key! Cannot execute these tests." + yield + + +@pytest.fixture +def check_composio_key_set(): + original_api_key = tool_settings.composio_api_key + assert original_api_key is not None, "Missing composio key! Cannot execute this test." + yield + + +@pytest.fixture +def test_organization(): + """Fixture to create and return the default organization.""" + org = OrganizationManager().create_organization(Organization(name=org_name)) + yield org + + +@pytest.fixture +def test_user(test_organization): + """Fixture to create and return the default user within the default organization.""" + user = UserManager().create_user(User(name=user_name, organization_id=test_organization.id)) + yield user + + +@pytest.fixture +def add_integers_tool(test_user): + def add(x: int, y: int) -> int: + """ + Simple function that adds two integers. + + Parameters: + x (int): The first integer to add. + y (int): The second integer to add. + + Returns: + int: The result of adding x and y. + """ + return x + y + + tool = create_tool_from_func(add) + tool = ToolManager().create_or_update_tool(tool, test_user) + yield tool + + +@pytest.fixture +def cowsay_tool(test_user): + # This defines a tool for a package we definitely do NOT have in letta + # If this test passes, that means the tool was correctly executed in a separate Python environment + def cowsay() -> str: + """ + Simple function that uses the cowsay package to print out the secret word env variable. + + Returns: + str: The cowsay ASCII art. + """ + import os + + import cowsay + + cowsay.cow(os.getenv("secret_word")) + + tool = create_tool_from_func(cowsay) + tool = ToolManager().create_or_update_tool(tool, test_user) + yield tool + + +@pytest.fixture +def get_env_tool(test_user): + def get_env() -> str: + """ + Simple function that returns the secret word env variable. + + Returns: + str: The secret word + """ + import os + + secret_word = os.getenv("secret_word") + print(secret_word) + return secret_word + + tool = create_tool_from_func(get_env) + tool = ToolManager().create_or_update_tool(tool, test_user) + yield tool + + +@pytest.fixture +def list_tool(test_user): + def create_list(): + """Simple function that returns a list""" + + return [1] * 5 + + tool = create_tool_from_func(create_list) + tool = ToolManager().create_or_update_tool(tool, test_user) + yield tool + + +@pytest.fixture +def composio_github_star_tool(test_user): + tool_manager = ToolManager() + tool_create = ToolCreate.from_composio(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER) + tool = tool_manager.create_or_update_tool(pydantic_tool=Tool(**tool_create.model_dump()), actor=test_user) + yield tool + + +@pytest.fixture +def clear_core_memory(test_user): + def clear_memory(agent_state: AgentState): + """Clear the core memory""" + agent_state.memory.get_block("human").value = "" + agent_state.memory.get_block("persona").value = "" + + tool = create_tool_from_func(clear_memory) + tool = ToolManager().create_or_update_tool(tool, test_user) + yield tool + + +# Utility functions +def create_tool_from_func(func: callable): + return Tool( + name=func.__name__, + description="", + source_type="python", + tags=[], + source_code=parse_source_code(func), + json_schema=generate_schema(func, None), + ) + + +# Local sandbox tests +@pytest.mark.local_sandbox +def test_local_sandbox_default(mock_e2b_api_key_none, add_integers_tool, test_user): + args = {"x": 10, "y": 5} + + # Mock and assert correct pathway was invoked + with patch.object(ToolExecutionSandbox, "run_local_dir_sandbox") as mock_run_local_dir_sandbox: + sandbox = ToolExecutionSandbox(add_integers_tool.name, args, user_id=test_user.id) + sandbox.run() + mock_run_local_dir_sandbox.assert_called_once() + + # Run again to get actual response + sandbox = ToolExecutionSandbox(add_integers_tool.name, args, user_id=test_user.id) + result = sandbox.run() + assert result.func_return == args["x"] + args["y"] + + +@pytest.mark.local_sandbox +def test_local_sandbox_stateful_tool(mock_e2b_api_key_none, clear_core_memory, test_user): + args = {} + + client = create_client() + agent_state = client.create_agent( + memory=ChatMemory(persona="This is the persona", human="This is the human"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + llm_config=LLMConfig.default_config(model_name="gpt-4"), + ) + + # Run again to get actual response + sandbox = ToolExecutionSandbox(clear_core_memory.name, args, user_id=test_user.id) + result = sandbox.run(agent_state=agent_state) + assert result.agent_state.memory.get_block("human").value == "" + assert result.agent_state.memory.get_block("persona").value == "" + assert result.func_return is None + + +@pytest.mark.local_sandbox +def test_local_sandbox_with_list_rv(mock_e2b_api_key_none, list_tool, test_user): + sandbox = ToolExecutionSandbox(list_tool.name, {}, user_id=test_user.id) + result = sandbox.run() + assert len(result.func_return) == 5 + + +@pytest.mark.local_sandbox +def test_local_sandbox_env(mock_e2b_api_key_none, get_env_tool, test_user): + manager = SandboxConfigManager(tool_settings) + + # Make a custom local sandbox config + sandbox_dir = str(Path(__file__).parent / "test_tool_sandbox") + config_create = SandboxConfigCreate(config=LocalSandboxConfig(sandbox_dir=sandbox_dir).model_dump()) + config = manager.create_or_update_sandbox_config(config_create, test_user) + + # Make a environment variable with a long random string + key = "secret_word" + long_random_string = "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(20)) + manager.create_sandbox_env_var( + SandboxEnvironmentVariableCreate(key=key, value=long_random_string), sandbox_config_id=config.id, actor=test_user + ) + + # Create tool and args + args = {} + + # Run the custom sandbox + sandbox = ToolExecutionSandbox(get_env_tool.name, args, user_id=test_user.id) + result = sandbox.run() + + assert long_random_string in result.func_return + + +# E2B sandbox tests + + +@pytest.mark.e2b_sandbox +def test_e2b_sandbox_default(check_e2b_key_is_set, add_integers_tool, test_user): + args = {"x": 10, "y": 5} + + # Mock and assert correct pathway was invoked + with patch.object(ToolExecutionSandbox, "run_e2b_sandbox") as mock_run_local_dir_sandbox: + sandbox = ToolExecutionSandbox(add_integers_tool.name, args, user_id=test_user.id) + sandbox.run() + mock_run_local_dir_sandbox.assert_called_once() + + # Run again to get actual response + sandbox = ToolExecutionSandbox(add_integers_tool.name, args, user_id=test_user.id) + result = sandbox.run() + assert int(result.func_return) == args["x"] + args["y"] + + +@pytest.mark.e2b_sandbox +def test_e2b_sandbox_pip_installs(check_e2b_key_is_set, cowsay_tool, test_user): + manager = SandboxConfigManager(tool_settings) + config_create = SandboxConfigCreate(config=E2BSandboxConfig(pip_requirements=["cowsay"]).model_dump()) + config = manager.create_or_update_sandbox_config(config_create, test_user) + + # Add an environment variable + key = "secret_word" + long_random_string = "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(20)) + manager.create_sandbox_env_var( + SandboxEnvironmentVariableCreate(key=key, value=long_random_string), sandbox_config_id=config.id, actor=test_user + ) + + sandbox = ToolExecutionSandbox(cowsay_tool.name, {}, user_id=test_user.id) + result = sandbox.run() + assert long_random_string in result.stdout[0] + + +@pytest.mark.e2b_sandbox +def test_e2b_sandbox_reuses_same_sandbox(check_e2b_key_is_set, list_tool, test_user): + sandbox = ToolExecutionSandbox(list_tool.name, {}, user_id=test_user.id) + + # Run the function once + result = sandbox.run() + old_config_fingerprint = result.sandbox_config_fingerprint + + # Run it again to ensure that there is still only one running sandbox + result = sandbox.run() + new_config_fingerprint = result.sandbox_config_fingerprint + + assert old_config_fingerprint == new_config_fingerprint + + +@pytest.mark.e2b_sandbox +def test_e2b_sandbox_stateful_tool(check_e2b_key_is_set, clear_core_memory, test_user): + sandbox = ToolExecutionSandbox(clear_core_memory.name, {}, user_id=test_user.id) + + # create an agent + client = create_client() + agent_state = client.create_agent( + memory=ChatMemory(persona="This is the persona", human="This is the human"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + llm_config=LLMConfig.default_config(model_name="gpt-4"), + ) + + # run the sandbox + result = sandbox.run(agent_state=agent_state) + assert result.agent_state.memory.get_block("human").value == "" + assert result.agent_state.memory.get_block("persona").value == "" + assert result.func_return is None + + +@pytest.mark.e2b_sandbox +def test_e2b_sandbox_inject_env_var_existing_sandbox(check_e2b_key_is_set, get_env_tool, test_user): + manager = SandboxConfigManager(tool_settings) + config_create = SandboxConfigCreate(config=E2BSandboxConfig().model_dump()) + config = manager.create_or_update_sandbox_config(config_create, test_user) + + # Run the custom sandbox once, assert nothing returns because missing env variable + sandbox = ToolExecutionSandbox(get_env_tool.name, {}, user_id=test_user.id, force_recreate=True) + result = sandbox.run() + # response should be None + assert result.func_return is None + + # Add an environment variable + key = "secret_word" + long_random_string = "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(20)) + manager.create_sandbox_env_var( + SandboxEnvironmentVariableCreate(key=key, value=long_random_string), sandbox_config_id=config.id, actor=test_user + ) + + # Assert that the environment variable gets injected correctly, even when the sandbox is NOT refreshed + sandbox = ToolExecutionSandbox(get_env_tool.name, {}, user_id=test_user.id) + result = sandbox.run() + assert long_random_string in result.func_return + + +@pytest.mark.e2b_sandbox +def test_e2b_sandbox_config_change_force_recreates_sandbox(check_e2b_key_is_set, list_tool, test_user): + manager = SandboxConfigManager(tool_settings) + old_timeout = 5 * 60 + new_timeout = 10 * 60 + + # Make the config + config_create = SandboxConfigCreate(config=E2BSandboxConfig(timeout=old_timeout)) + config = manager.create_or_update_sandbox_config(config_create, test_user) + + # Run the custom sandbox once, assert a failure gets returned because missing environment variable + sandbox = ToolExecutionSandbox(list_tool.name, {}, user_id=test_user.id) + result = sandbox.run() + assert len(result.func_return) == 5 + old_config_fingerprint = result.sandbox_config_fingerprint + + # Change the config + config_update = SandboxConfigUpdate(config=E2BSandboxConfig(timeout=new_timeout)) + config = manager.update_sandbox_config(config.id, config_update, test_user) + + # Run again + result = ToolExecutionSandbox(list_tool.name, {}, user_id=test_user.id).run() + new_config_fingerprint = result.sandbox_config_fingerprint + assert config.fingerprint() == new_config_fingerprint + + # Assert the fingerprints are different + assert old_config_fingerprint != new_config_fingerprint + + +@pytest.mark.e2b_sandbox +def test_e2b_sandbox_with_list_rv(check_e2b_key_is_set, list_tool, test_user): + sandbox = ToolExecutionSandbox(list_tool.name, {}, user_id=test_user.id) + result = sandbox.run() + assert len(result.func_return) == 5 + + +# TODO: Add tests for composio +# def test_e2b_e2e_composio_star_github(check_e2b_key_is_set, check_composio_key_set, composio_github_star_tool, test_user): +# # Add the composio key +# manager = SandboxConfigManager(tool_settings) +# config = manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.E2B, actor=test_user) +# +# manager.create_sandbox_env_var( +# SandboxEnvironmentVariableCreate(key="COMPOSIO_API_KEY", value=tool_settings.composio_api_key), +# sandbox_config_id=config.id, +# actor=test_user, +# ) +# +# result = ToolExecutionSandbox(composio_github_star_tool.name, {}, user_id=test_user.id).run() +# import ipdb +# +# ipdb.set_trace() diff --git a/tests/test_tool_sandbox/.gitkeep b/tests/test_tool_sandbox/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_tools.py b/tests/test_tools.py deleted file mode 100644 index 3515eca4a6..0000000000 --- a/tests/test_tools.py +++ /dev/null @@ -1,150 +0,0 @@ -import uuid -from typing import Union - -import pytest -from dotenv import load_dotenv - -from letta import create_client -from letta.agent import Agent -from letta.client.client import LocalClient, RESTClient -from letta.constants import DEFAULT_PRESET -from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.llm_config import LLMConfig -from letta.schemas.memory import ChatMemory -from letta.services.tool_manager import ToolManager - -test_agent_name = f"test_client_{str(uuid.uuid4())}" -# test_preset_name = "test_preset" -test_preset_name = DEFAULT_PRESET -test_agent_state = None -client = None - -test_agent_state_post_message = None -test_user_id = uuid.uuid4() - - -def run_server(): - load_dotenv() - - # _reset_config() - - from letta.server.rest_api.app import start_server - - print("Starting server...") - start_server(debug=True) - - -@pytest.fixture(scope="module") -def client(): - client = create_client() - client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini")) - client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) - - yield client - - -# Fixture for test agent -@pytest.fixture(scope="module") -def agent(client): - agent_state = client.create_agent(name=test_agent_name) - print("AGENT ID", agent_state.id) - yield agent_state - - # delete agent - client.delete_agent(agent_state.id) - - -def test_create_tool(client: Union[LocalClient, RESTClient]): - """Test creation of a simple tool""" - - def print_tool(message: str): - """ - Example tool that prints a message - - Args: - message (str): The message to print. - - Returns: - str: The message that was printed. - - """ - print(message) - return message - - tools = client.list_tools() - tool_names = [t.name for t in tools] - for tool in ToolManager.BASE_TOOL_NAMES: - assert tool in tool_names - - tool = client.create_tool(print_tool, name="my_name", tags=["extras"]) - - tools = client.list_tools() - assert tool in tools, f"Expected {tool.name} in {[t.name for t in tools]}" - print(f"Updated tools {[t.name for t in tools]}") - - # check tool id - tool = client.get_tool(tool.id) - assert tool is not None, "Expected tool to be created" - assert tool.id == tool.id, f"Expected {tool.id} to be {tool.id}" - - # create agent with tool - assert tool.name is not None, "Expected tool name to be set" - agent_state = client.create_agent(tools=[tool.name]) - - # Send message without error - client.user_message(agent_id=agent_state.id, message="hi") - - -def test_create_agent_tool(client): - """Test creation of a agent tool""" - - def core_memory_clear(self: "Agent"): - """ - Clear the core memory of the agent - - Args: - agent (Agent): The agent to delete from memory. - - Returns: - str: The agent that was deleted. - - """ - self.memory.update_block_value(label="human", value="") - self.memory.update_block_value(label="persona", value="") - print("UPDATED MEMORY", self.memory.memory) - return None - - # TODO: test attaching and using function on agent - tool = client.create_tool(core_memory_clear, tags=["extras"]) - print(f"Created tool", tool.name) - - # create agent with tool - memory = ChatMemory(human="I am a human", persona="You must clear your memory if the human instructs you") - agent = client.create_agent(name=test_agent_name, tools=[tool.name], memory=memory) - assert str(tool.created_by_id) == str(agent.user_id), f"Expected {tool.created_by_id} to be {agent.user_id}" - - # initial memory - initial_memory = client.get_in_context_memory(agent.id) - print("initial memory", initial_memory.compile()) - human = initial_memory.get_block("human") - persona = initial_memory.get_block("persona") - print("Initial memory:", human, persona) - assert len(human.value) > 0, "Expected human memory to be non-empty" - assert len(persona.value) > 0, "Expected persona memory to be non-empty" - - # test agent tool - response = client.send_message(role="user", agent_id=agent.id, message="clear your memory with the core_memory_clear tool") - print(response) - - # updated memory - print("Query agent memory") - updated_memory = client.get_in_context_memory(agent.id) - human = updated_memory.get_block("human") - persona = updated_memory.get_block("persona") - print("Updated memory:", human, persona) - assert len(human.value) == 0, "Expected human memory to be empty" - assert len(persona.value) == 0, "Expected persona memory to be empty" - - -def test_custom_import_tool(client): - pass