Skip to content

Commit

Permalink
Implement Computer Use for MacoS
Browse files Browse the repository at this point in the history
  • Loading branch information
PallavAg committed Oct 23, 2024
0 parents commit 60ff2ea
Show file tree
Hide file tree
Showing 12 changed files with 1,174 additions and 0 deletions.
30 changes: 30 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# macOS-specific files/folders
.DS_Store
.AppleDouble
.LSOverride
Icon?
._*

# Common Python artifacts
__pycache__/
*.pyc
*.pyo
*.pyd
.Python
env/
venv/
*.egg-info/
.tox/
build/
dist/

# IDE/Editor files
.idea/
*.iml
.vscode/

# Notebook checkpoints
.ipynb_checkpoints

# macOS Trash folder
.Trash/
Empty file added computer_use_demo/__init__.py
Empty file.
231 changes: 231 additions & 0 deletions computer_use_demo/loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
"""
Agentic sampling loop that calls the Anthropic API and local implenmentation of anthropic-defined computer use tools.
"""

import platform
from collections.abc import Callable
from datetime import datetime
from enum import StrEnum
from typing import Any, cast

from anthropic import Anthropic, AnthropicBedrock, AnthropicVertex, APIResponse
from anthropic.types import (
ToolResultBlockParam,
)
from anthropic.types.beta import (
BetaContentBlock,
BetaContentBlockParam,
BetaImageBlockParam,
BetaMessage,
BetaMessageParam,
BetaTextBlockParam,
BetaToolResultBlockParam,
)

from .tools import BashTool, ComputerTool, EditTool, ToolCollection, ToolResult

BETA_FLAG = "computer-use-2024-10-22"


class APIProvider(StrEnum):
ANTHROPIC = "anthropic"
BEDROCK = "bedrock"
VERTEX = "vertex"


PROVIDER_TO_DEFAULT_MODEL_NAME: dict[APIProvider, str] = {
APIProvider.ANTHROPIC: "claude-3-5-sonnet-20241022",
APIProvider.BEDROCK: "anthropic.claude-3-5-sonnet-20241022-v2:0",
APIProvider.VERTEX: "claude-3-5-sonnet-v2@20241022",
}


# This system prompt is optimized for the Docker environment in this repository and
# specific tool combinations enabled.
# We encourage modifying this system prompt to ensure the model has context for the
# environment it is running in, and to provide any additional information that may be
# helpful for the task at hand.
SYSTEM_PROMPT = f"""<SYSTEM_CAPABILITY>
* You are utilizing a MacOS computer using {platform.machine()} architecture with internet access.
* You can use the bash tool to execute commands in the terminal.
* To open applications, you can use the `open` command in the bash tool. For example, `open -a Safari` to open the Safari browser.
* When using your bash tool with commands that are expected to output very large quantities of text, redirect the output into a temporary file and use `str_replace_editor` or `grep -n -B <lines before> -A <lines after> <query> <filename>` to inspect the output.
* When viewing a page, it can be helpful to zoom out so that you can see everything on the page. Alternatively, ensure you scroll down to see everything before deciding something isn't available.
* When using your computer function calls, they may take a while to run and send back to you. Where possible and feasible, try to chain multiple of these calls into one function call request.
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}.
</SYSTEM_CAPABILITY>
<IMPORTANT>
* When using Safari or other applications, if any startup wizards or prompts appear, **IGNORE THEM**. Do not interact with them. Instead, click on the address bar or the area where you can enter commands or URLs, and proceed with your task.
* If the item you are looking at is a PDF, and after taking a single screenshot of the PDF it seems you want to read the entire document, instead of trying to continue to read the PDF from your screenshots and navigation, determine the URL, use `curl` to download the PDF, install and use `pdftotext` (you may need to install it via `brew install poppler`) to convert it to a text file, and then read that text file directly with your `str_replace_editor` tool.
</IMPORTANT>"""


async def sampling_loop(
*,
model: str,
provider: APIProvider,
system_prompt_suffix: str,
messages: list[BetaMessageParam],
output_callback: Callable[[BetaContentBlock], None],
tool_output_callback: Callable[[ToolResult, str], None],
api_response_callback: Callable[[APIResponse[BetaMessage]], None],
api_key: str,
only_n_most_recent_images: int | None = None,
max_tokens: int = 4096,
):
"""
Agentic sampling loop for the assistant/tool interaction of computer use.
"""
tool_collection = ToolCollection(
ComputerTool(),
BashTool(),
EditTool(),
)
system = (
f"{SYSTEM_PROMPT}{' ' + system_prompt_suffix if system_prompt_suffix else ''}"
)

while True:
if only_n_most_recent_images:
_maybe_filter_to_n_most_recent_images(messages, only_n_most_recent_images)

if provider == APIProvider.ANTHROPIC:
client = Anthropic(api_key=api_key)
elif provider == APIProvider.VERTEX:
client = AnthropicVertex()
elif provider == APIProvider.BEDROCK:
client = AnthropicBedrock()

# Call the API
# we use raw_response to provide debug information to streamlit. Your
# implementation may be able call the SDK directly with:
# `response = client.messages.create(...)` instead.
raw_response = client.beta.messages.with_raw_response.create(
max_tokens=max_tokens,
messages=messages,
model=model,
system=system,
tools=tool_collection.to_params(),
betas=["computer-use-2024-10-22"],
)

api_response_callback(cast(APIResponse[BetaMessage], raw_response))

response = raw_response.parse()

messages.append(
{
"role": "assistant",
"content": cast(list[BetaContentBlockParam], response.content),
}
)

tool_result_content: list[BetaToolResultBlockParam] = []
for content_block in cast(list[BetaContentBlock], response.content):
output_callback(content_block)
if content_block.type == "tool_use":
result = await tool_collection.run(
name=content_block.name,
tool_input=cast(dict[str, Any], content_block.input),
)
tool_result_content.append(
_make_api_tool_result(result, content_block.id)
)
tool_output_callback(result, content_block.id)

if not tool_result_content:
return messages

messages.append({"content": tool_result_content, "role": "user"})


def _maybe_filter_to_n_most_recent_images(
messages: list[BetaMessageParam],
images_to_keep: int,
min_removal_threshold: int = 10,
):
"""
With the assumption that images are screenshots that are of diminishing value as
the conversation progresses, remove all but the final `images_to_keep` tool_result
images in place, with a chunk of min_removal_threshold to reduce the amount we
break the implicit prompt cache.
"""
if images_to_keep is None:
return messages

tool_result_blocks = cast(
list[ToolResultBlockParam],
[
item
for message in messages
for item in (
message["content"] if isinstance(message["content"], list) else []
)
if isinstance(item, dict) and item.get("type") == "tool_result"
],
)

total_images = sum(
1
for tool_result in tool_result_blocks
for content in tool_result.get("content", [])
if isinstance(content, dict) and content.get("type") == "image"
)

images_to_remove = total_images - images_to_keep
# for better cache behavior, we want to remove in chunks
images_to_remove -= images_to_remove % min_removal_threshold

for tool_result in tool_result_blocks:
if isinstance(tool_result.get("content"), list):
new_content = []
for content in tool_result.get("content", []):
if isinstance(content, dict) and content.get("type") == "image":
if images_to_remove > 0:
images_to_remove -= 1
continue
new_content.append(content)
tool_result["content"] = new_content


def _make_api_tool_result(
result: ToolResult, tool_use_id: str
) -> BetaToolResultBlockParam:
"""Convert an agent ToolResult to an API ToolResultBlockParam."""
tool_result_content: list[BetaTextBlockParam | BetaImageBlockParam] | str = []
is_error = False
if result.error:
is_error = True
tool_result_content = _maybe_prepend_system_tool_result(result, result.error)
else:
if result.output:
tool_result_content.append(
{
"type": "text",
"text": _maybe_prepend_system_tool_result(result, result.output),
}
)
if result.base64_image:
tool_result_content.append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": result.base64_image,
},
}
)
return {
"type": "tool_result",
"content": tool_result_content,
"tool_use_id": tool_use_id,
"is_error": is_error,
}


def _maybe_prepend_system_tool_result(result: ToolResult, result_text: str):
if result.system:
result_text = f"<system>{result.system}</system>\n{result_text}"
return result_text
14 changes: 14 additions & 0 deletions computer_use_demo/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from .base import CLIResult, ToolResult
from .bash import BashTool
from .collection import ToolCollection
from .computer import ComputerTool
from .edit import EditTool

__ALL__ = [
BashTool,
CLIResult,
ComputerTool,
EditTool,
ToolCollection,
ToolResult,
]
69 changes: 69 additions & 0 deletions computer_use_demo/tools/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass, fields, replace
from typing import Any

from anthropic.types.beta import BetaToolUnionParam


class BaseAnthropicTool(metaclass=ABCMeta):
"""Abstract base class for Anthropic-defined tools."""

@abstractmethod
def __call__(self, **kwargs) -> Any:
"""Executes the tool with the given arguments."""
...

@abstractmethod
def to_params(
self,
) -> BetaToolUnionParam:
raise NotImplementedError


@dataclass(kw_only=True, frozen=True)
class ToolResult:
"""Represents the result of a tool execution."""

output: str | None = None
error: str | None = None
base64_image: str | None = None
system: str | None = None

def __bool__(self):
return any(getattr(self, field.name) for field in fields(self))

def __add__(self, other: "ToolResult"):
def combine_fields(
field: str | None, other_field: str | None, concatenate: bool = True
):
if field and other_field:
if concatenate:
return field + other_field
raise ValueError("Cannot combine tool results")
return field or other_field

return ToolResult(
output=combine_fields(self.output, other.output),
error=combine_fields(self.error, other.error),
base64_image=combine_fields(self.base64_image, other.base64_image, False),
system=combine_fields(self.system, other.system),
)

def replace(self, **kwargs):
"""Returns a new ToolResult with the given fields replaced."""
return replace(self, **kwargs)


class CLIResult(ToolResult):
"""A ToolResult that can be rendered as a CLI output."""


class ToolFailure(ToolResult):
"""A ToolResult that represents a failure."""


class ToolError(Exception):
"""Raised when a tool encounters an error."""

def __init__(self, message):
self.message = message
Loading

0 comments on commit 60ff2ea

Please sign in to comment.