-
Notifications
You must be signed in to change notification settings - Fork 43
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 60ff2ea
Showing
12 changed files
with
1,174 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# macOS-specific files/folders | ||
.DS_Store | ||
.AppleDouble | ||
.LSOverride | ||
Icon? | ||
._* | ||
|
||
# Common Python artifacts | ||
__pycache__/ | ||
*.pyc | ||
*.pyo | ||
*.pyd | ||
.Python | ||
env/ | ||
venv/ | ||
*.egg-info/ | ||
.tox/ | ||
build/ | ||
dist/ | ||
|
||
# IDE/Editor files | ||
.idea/ | ||
*.iml | ||
.vscode/ | ||
|
||
# Notebook checkpoints | ||
.ipynb_checkpoints | ||
|
||
# macOS Trash folder | ||
.Trash/ |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,231 @@ | ||
""" | ||
Agentic sampling loop that calls the Anthropic API and local implenmentation of anthropic-defined computer use tools. | ||
""" | ||
|
||
import platform | ||
from collections.abc import Callable | ||
from datetime import datetime | ||
from enum import StrEnum | ||
from typing import Any, cast | ||
|
||
from anthropic import Anthropic, AnthropicBedrock, AnthropicVertex, APIResponse | ||
from anthropic.types import ( | ||
ToolResultBlockParam, | ||
) | ||
from anthropic.types.beta import ( | ||
BetaContentBlock, | ||
BetaContentBlockParam, | ||
BetaImageBlockParam, | ||
BetaMessage, | ||
BetaMessageParam, | ||
BetaTextBlockParam, | ||
BetaToolResultBlockParam, | ||
) | ||
|
||
from .tools import BashTool, ComputerTool, EditTool, ToolCollection, ToolResult | ||
|
||
BETA_FLAG = "computer-use-2024-10-22" | ||
|
||
|
||
class APIProvider(StrEnum): | ||
ANTHROPIC = "anthropic" | ||
BEDROCK = "bedrock" | ||
VERTEX = "vertex" | ||
|
||
|
||
PROVIDER_TO_DEFAULT_MODEL_NAME: dict[APIProvider, str] = { | ||
APIProvider.ANTHROPIC: "claude-3-5-sonnet-20241022", | ||
APIProvider.BEDROCK: "anthropic.claude-3-5-sonnet-20241022-v2:0", | ||
APIProvider.VERTEX: "claude-3-5-sonnet-v2@20241022", | ||
} | ||
|
||
|
||
# This system prompt is optimized for the Docker environment in this repository and | ||
# specific tool combinations enabled. | ||
# We encourage modifying this system prompt to ensure the model has context for the | ||
# environment it is running in, and to provide any additional information that may be | ||
# helpful for the task at hand. | ||
SYSTEM_PROMPT = f"""<SYSTEM_CAPABILITY> | ||
* You are utilizing a MacOS computer using {platform.machine()} architecture with internet access. | ||
* You can use the bash tool to execute commands in the terminal. | ||
* To open applications, you can use the `open` command in the bash tool. For example, `open -a Safari` to open the Safari browser. | ||
* When using your bash tool with commands that are expected to output very large quantities of text, redirect the output into a temporary file and use `str_replace_editor` or `grep -n -B <lines before> -A <lines after> <query> <filename>` to inspect the output. | ||
* When viewing a page, it can be helpful to zoom out so that you can see everything on the page. Alternatively, ensure you scroll down to see everything before deciding something isn't available. | ||
* When using your computer function calls, they may take a while to run and send back to you. Where possible and feasible, try to chain multiple of these calls into one function call request. | ||
* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}. | ||
</SYSTEM_CAPABILITY> | ||
<IMPORTANT> | ||
* When using Safari or other applications, if any startup wizards or prompts appear, **IGNORE THEM**. Do not interact with them. Instead, click on the address bar or the area where you can enter commands or URLs, and proceed with your task. | ||
* If the item you are looking at is a PDF, and after taking a single screenshot of the PDF it seems you want to read the entire document, instead of trying to continue to read the PDF from your screenshots and navigation, determine the URL, use `curl` to download the PDF, install and use `pdftotext` (you may need to install it via `brew install poppler`) to convert it to a text file, and then read that text file directly with your `str_replace_editor` tool. | ||
</IMPORTANT>""" | ||
|
||
|
||
async def sampling_loop( | ||
*, | ||
model: str, | ||
provider: APIProvider, | ||
system_prompt_suffix: str, | ||
messages: list[BetaMessageParam], | ||
output_callback: Callable[[BetaContentBlock], None], | ||
tool_output_callback: Callable[[ToolResult, str], None], | ||
api_response_callback: Callable[[APIResponse[BetaMessage]], None], | ||
api_key: str, | ||
only_n_most_recent_images: int | None = None, | ||
max_tokens: int = 4096, | ||
): | ||
""" | ||
Agentic sampling loop for the assistant/tool interaction of computer use. | ||
""" | ||
tool_collection = ToolCollection( | ||
ComputerTool(), | ||
BashTool(), | ||
EditTool(), | ||
) | ||
system = ( | ||
f"{SYSTEM_PROMPT}{' ' + system_prompt_suffix if system_prompt_suffix else ''}" | ||
) | ||
|
||
while True: | ||
if only_n_most_recent_images: | ||
_maybe_filter_to_n_most_recent_images(messages, only_n_most_recent_images) | ||
|
||
if provider == APIProvider.ANTHROPIC: | ||
client = Anthropic(api_key=api_key) | ||
elif provider == APIProvider.VERTEX: | ||
client = AnthropicVertex() | ||
elif provider == APIProvider.BEDROCK: | ||
client = AnthropicBedrock() | ||
|
||
# Call the API | ||
# we use raw_response to provide debug information to streamlit. Your | ||
# implementation may be able call the SDK directly with: | ||
# `response = client.messages.create(...)` instead. | ||
raw_response = client.beta.messages.with_raw_response.create( | ||
max_tokens=max_tokens, | ||
messages=messages, | ||
model=model, | ||
system=system, | ||
tools=tool_collection.to_params(), | ||
betas=["computer-use-2024-10-22"], | ||
) | ||
|
||
api_response_callback(cast(APIResponse[BetaMessage], raw_response)) | ||
|
||
response = raw_response.parse() | ||
|
||
messages.append( | ||
{ | ||
"role": "assistant", | ||
"content": cast(list[BetaContentBlockParam], response.content), | ||
} | ||
) | ||
|
||
tool_result_content: list[BetaToolResultBlockParam] = [] | ||
for content_block in cast(list[BetaContentBlock], response.content): | ||
output_callback(content_block) | ||
if content_block.type == "tool_use": | ||
result = await tool_collection.run( | ||
name=content_block.name, | ||
tool_input=cast(dict[str, Any], content_block.input), | ||
) | ||
tool_result_content.append( | ||
_make_api_tool_result(result, content_block.id) | ||
) | ||
tool_output_callback(result, content_block.id) | ||
|
||
if not tool_result_content: | ||
return messages | ||
|
||
messages.append({"content": tool_result_content, "role": "user"}) | ||
|
||
|
||
def _maybe_filter_to_n_most_recent_images( | ||
messages: list[BetaMessageParam], | ||
images_to_keep: int, | ||
min_removal_threshold: int = 10, | ||
): | ||
""" | ||
With the assumption that images are screenshots that are of diminishing value as | ||
the conversation progresses, remove all but the final `images_to_keep` tool_result | ||
images in place, with a chunk of min_removal_threshold to reduce the amount we | ||
break the implicit prompt cache. | ||
""" | ||
if images_to_keep is None: | ||
return messages | ||
|
||
tool_result_blocks = cast( | ||
list[ToolResultBlockParam], | ||
[ | ||
item | ||
for message in messages | ||
for item in ( | ||
message["content"] if isinstance(message["content"], list) else [] | ||
) | ||
if isinstance(item, dict) and item.get("type") == "tool_result" | ||
], | ||
) | ||
|
||
total_images = sum( | ||
1 | ||
for tool_result in tool_result_blocks | ||
for content in tool_result.get("content", []) | ||
if isinstance(content, dict) and content.get("type") == "image" | ||
) | ||
|
||
images_to_remove = total_images - images_to_keep | ||
# for better cache behavior, we want to remove in chunks | ||
images_to_remove -= images_to_remove % min_removal_threshold | ||
|
||
for tool_result in tool_result_blocks: | ||
if isinstance(tool_result.get("content"), list): | ||
new_content = [] | ||
for content in tool_result.get("content", []): | ||
if isinstance(content, dict) and content.get("type") == "image": | ||
if images_to_remove > 0: | ||
images_to_remove -= 1 | ||
continue | ||
new_content.append(content) | ||
tool_result["content"] = new_content | ||
|
||
|
||
def _make_api_tool_result( | ||
result: ToolResult, tool_use_id: str | ||
) -> BetaToolResultBlockParam: | ||
"""Convert an agent ToolResult to an API ToolResultBlockParam.""" | ||
tool_result_content: list[BetaTextBlockParam | BetaImageBlockParam] | str = [] | ||
is_error = False | ||
if result.error: | ||
is_error = True | ||
tool_result_content = _maybe_prepend_system_tool_result(result, result.error) | ||
else: | ||
if result.output: | ||
tool_result_content.append( | ||
{ | ||
"type": "text", | ||
"text": _maybe_prepend_system_tool_result(result, result.output), | ||
} | ||
) | ||
if result.base64_image: | ||
tool_result_content.append( | ||
{ | ||
"type": "image", | ||
"source": { | ||
"type": "base64", | ||
"media_type": "image/png", | ||
"data": result.base64_image, | ||
}, | ||
} | ||
) | ||
return { | ||
"type": "tool_result", | ||
"content": tool_result_content, | ||
"tool_use_id": tool_use_id, | ||
"is_error": is_error, | ||
} | ||
|
||
|
||
def _maybe_prepend_system_tool_result(result: ToolResult, result_text: str): | ||
if result.system: | ||
result_text = f"<system>{result.system}</system>\n{result_text}" | ||
return result_text |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from .base import CLIResult, ToolResult | ||
from .bash import BashTool | ||
from .collection import ToolCollection | ||
from .computer import ComputerTool | ||
from .edit import EditTool | ||
|
||
__ALL__ = [ | ||
BashTool, | ||
CLIResult, | ||
ComputerTool, | ||
EditTool, | ||
ToolCollection, | ||
ToolResult, | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
from abc import ABCMeta, abstractmethod | ||
from dataclasses import dataclass, fields, replace | ||
from typing import Any | ||
|
||
from anthropic.types.beta import BetaToolUnionParam | ||
|
||
|
||
class BaseAnthropicTool(metaclass=ABCMeta): | ||
"""Abstract base class for Anthropic-defined tools.""" | ||
|
||
@abstractmethod | ||
def __call__(self, **kwargs) -> Any: | ||
"""Executes the tool with the given arguments.""" | ||
... | ||
|
||
@abstractmethod | ||
def to_params( | ||
self, | ||
) -> BetaToolUnionParam: | ||
raise NotImplementedError | ||
|
||
|
||
@dataclass(kw_only=True, frozen=True) | ||
class ToolResult: | ||
"""Represents the result of a tool execution.""" | ||
|
||
output: str | None = None | ||
error: str | None = None | ||
base64_image: str | None = None | ||
system: str | None = None | ||
|
||
def __bool__(self): | ||
return any(getattr(self, field.name) for field in fields(self)) | ||
|
||
def __add__(self, other: "ToolResult"): | ||
def combine_fields( | ||
field: str | None, other_field: str | None, concatenate: bool = True | ||
): | ||
if field and other_field: | ||
if concatenate: | ||
return field + other_field | ||
raise ValueError("Cannot combine tool results") | ||
return field or other_field | ||
|
||
return ToolResult( | ||
output=combine_fields(self.output, other.output), | ||
error=combine_fields(self.error, other.error), | ||
base64_image=combine_fields(self.base64_image, other.base64_image, False), | ||
system=combine_fields(self.system, other.system), | ||
) | ||
|
||
def replace(self, **kwargs): | ||
"""Returns a new ToolResult with the given fields replaced.""" | ||
return replace(self, **kwargs) | ||
|
||
|
||
class CLIResult(ToolResult): | ||
"""A ToolResult that can be rendered as a CLI output.""" | ||
|
||
|
||
class ToolFailure(ToolResult): | ||
"""A ToolResult that represents a failure.""" | ||
|
||
|
||
class ToolError(Exception): | ||
"""Raised when a tool encounters an error.""" | ||
|
||
def __init__(self, message): | ||
self.message = message |
Oops, something went wrong.