Compare commits

...

12 Commits

13 changed files with 1975 additions and 4 deletions

View File

@ -19,6 +19,8 @@ Metric = Callable[[Any, Any], float]
Getter = Callable[[gym.Env, Dict[str, Any]], Any]
MAX_RETRIES = 5 # Maximum retries for environment setup
class DesktopEnv(gym.Env):
"""
@ -104,7 +106,7 @@ class DesktopEnv(gym.Env):
# mode: human or machine
self.instruction = None
assert action_space in ["computer_13", "pyautogui"]
assert action_space in ["computer_13", "pyautogui", "claude_computer_use"]
self.action_space = action_space # todo: refactor it to the ActType
# episodic stuffs, like counters, will be updated or reset
@ -307,7 +309,7 @@ class DesktopEnv(gym.Env):
reward = 0 # todo: Define reward calculation for each example
done = False # todo: Define episode termination condition for each example
info = {}
logger.info(f"Step {self._step_no} in trajectory {self._traj_no} with action: {action}")
# handle the special actions
if action in ['WAIT', 'FAIL', 'DONE'] or (type(action) == dict and action['action_type'] in ['WAIT', 'FAIL', 'DONE']):
if action == 'WAIT':
@ -322,12 +324,15 @@ class DesktopEnv(gym.Env):
if self.action_space == "computer_13":
# the set of all possible actions defined in the action representation
self.controller.execute_action(action)
elif self.action_space == "pyautogui":
elif self.action_space == "pyautogui" or self.action_space == "claude_computer_use":
if action in ['WAIT', 'FAIL', 'DONE']:
self.controller.execute_action(action)
else:
# the set of all possible python commands insides `pyautogui`
self.controller.execute_python_command(action)
if type(action) == str:
self.controller.execute_python_command(action)
elif type(action) == dict:
self.controller.execute_python_command(action['command'])
time.sleep(pause)
observation = self._get_obs()

View File

@ -0,0 +1,18 @@
# Anthropic Agent Integration
> Notice: As Anthropic API only supports images long edge is less than 1568 pixels and image is less than ~1,600 tokens, we resize the screenshot to 1280x720.
## Setup
To run with the Anthropic API, you need to set up your environment with the necessary API keys and configurations. Follow these steps:
1. **Install Dependencies**: Ensure you have the required Python packages installed. You can do this by running:
```bash
pip install anthropic
```
2. **Set Environment Variables**: You need to set the environment variable with your API key. You can do this in .env:
For aws bedrock:
```.env
AWS_ACCESS_KEY_ID=your_access_key_id
AWS_SECRET_ACCESS_KEY=your_secret_access_key
```
For anthropic, you need set APIProvider to `anthropic` and set the API key:
```.env
ANTHROPIC_API_KEY=your_anthropic_api_key
```

View File

@ -0,0 +1,23 @@
"""
Anthropic agent implementation
"""
from .main import AnthropicAgent
from .tools import (
BashTool,
CLIResult,
ComputerTool,
EditTool,
ToolCollection,
ToolResult
)
__all__ = [
'AnthropicAgent',
'BashTool',
'CLIResult',
'ComputerTool',
'EditTool',
'ToolCollection',
'ToolResult'
]

442
mm_agents/anthropic/main.py Normal file
View File

@ -0,0 +1,442 @@
import base64
import os
import time
from typing import Any, cast, Optional, Dict
from PIL import Image
import io
from anthropic import (
Anthropic,
AnthropicBedrock,
AnthropicVertex,
APIError,
APIResponseValidationError,
APIStatusError,
)
from anthropic.types.beta import (
BetaMessageParam,
BetaTextBlockParam,
)
from .utils import COMPUTER_USE_BETA_FLAG, PROMPT_CACHING_BETA_FLAG,SYSTEM_PROMPT, SYSTEM_PROMPT_WINDOWS, APIProvider, PROVIDER_TO_DEFAULT_MODEL_NAME
from .utils import _response_to_params, _inject_prompt_caching, _maybe_filter_to_n_most_recent_images
import logging
logger = logging.getLogger("desktopenv.agent")
class AnthropicAgent:
def __init__(self,
platform: str = "Ubuntu",
model: str = "claude-3-5-sonnet-20241022",
provider: APIProvider = APIProvider.BEDROCK,
max_tokens: int = 4096,
api_key: str = os.environ.get("ANTHROPIC_API_KEY", None),
system_prompt_suffix: str = "",
only_n_most_recent_images: Optional[int] = 10,
action_space: str = "claude_computer_use",
screen_size: tuple[int, int] = (1920, 1080),
*args, **kwargs
):
self.platform = platform
self.action_space = action_space
self.logger = logger
self.class_name = self.__class__.__name__
self.model_name = model
self.provider = provider
self.max_tokens = max_tokens
self.api_key = api_key
self.system_prompt_suffix = system_prompt_suffix
self.only_n_most_recent_images = only_n_most_recent_images
self.messages: list[BetaMessageParam] = []
self.screen_size = screen_size
self.resize_factor = (
screen_size[0] / 1280, # Assuming 1280 is the base width
screen_size[1] / 720 # Assuming 720 is the base height
)
def add_tool_result(self, tool_call_id: str, result: str, screenshot: bytes = None):
"""Add tool result to message history"""
tool_result_content = [
{
"type": "tool_result",
"tool_use_id": tool_call_id,
"content": [{"type": "text", "text": result}]
}
]
# Add screenshot if provided
if screenshot is not None:
screenshot_base64 = base64.b64encode(screenshot).decode('utf-8')
tool_result_content[0]["content"].append({
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": screenshot_base64
}
})
self.messages.append({
"role": "user",
"content": tool_result_content
})
def parse_actions_from_tool_call(self, tool_call: Dict) -> str:
result = ""
function_args = (
tool_call["input"]
)
action = function_args.get("action")
if not action:
action = tool_call.function.name
action_conversion = {
"left click": "click",
"right click": "right_click"
}
action = action_conversion.get(action, action)
text = function_args.get("text")
coordinate = function_args.get("coordinate")
scroll_direction = function_args.get("scroll_direction")
scroll_amount = function_args.get("scroll_amount")
duration = function_args.get("duration")
# resize coordinates if resize_factor is set
if coordinate and self.resize_factor:
coordinate = (
int(coordinate[0] * self.resize_factor[0]),
int(coordinate[1] * self.resize_factor[1])
)
# Handle mouse move and drag actions
if action in ("mouse_move", "left_click_drag"):
if coordinate is None:
raise ValueError(f"coordinate is required for {action}")
if text is not None:
raise ValueError(f"text is not accepted for {action}")
if not isinstance(coordinate, (list, tuple)) or len(coordinate) != 2:
raise ValueError(f"{coordinate} must be a tuple of length 2")
if not all(isinstance(i, int) for i in coordinate):
raise ValueError(f"{coordinate} must be a tuple of ints")
x, y = coordinate[0], coordinate[1]
if action == "mouse_move":
result += (
f"pyautogui.moveTo({x}, {y}, duration={duration or 0.5})\n"
)
expected_outcome = f"Mouse moved to ({x},{y})."
elif action == "left_click_drag":
result += (
f"pyautogui.dragTo({x}, {y}, duration={duration or 0.5})\n"
)
expected_outcome = f"Cursor dragged to ({x},{y})."
# Handle keyboard actions
elif action in ("key", "type"):
if text is None:
raise ValueError(f"text is required for {action}")
if coordinate is not None:
raise ValueError(f"coordinate is not accepted for {action}")
if not isinstance(text, str):
raise ValueError(f"{text} must be a string")
if action == "key":
key_conversion = {
"page_down": "pagedown",
"page_up": "pageup",
"super_l": "win",
"super": "command",
"escape": "esc"
}
keys = text.split('+')
for key in keys:
key = key.strip().lower()
key = key_conversion.get(key, key)
result += (f"pyautogui.keyDown('{key}')\n")
for key in reversed(keys):
key = key.strip().lower()
key = key_conversion.get(key, key)
result += (f"pyautogui.keyUp('{key}')\n")
expected_outcome = f"Key {key} pressed."
elif action == "type":
result += (
f"pyautogui.typewrite(\"\"\"{text}\"\"\", interval=0.01)\n"
)
expected_outcome = f"Text {text} written."
# Handle scroll actions
elif action == "scroll":
if coordinate is None:
if scroll_direction in ("up", "down"):
result += (
f"pyautogui.scroll({scroll_amount if scroll_direction == 'up' else -scroll_amount})\n"
)
elif scroll_direction in ("left", "right"):
result += (
f"pyautogui.hscroll({scroll_amount if scroll_direction == 'right' else -scroll_amount})\n"
)
else:
if scroll_direction in ("up", "down"):
x, y = coordinate[0], coordinate[1]
result += (
f"pyautogui.scroll({scroll_amount if scroll_direction == 'up' else -scroll_amount}, {x}, {y})\n"
)
elif scroll_direction in ("left", "right"):
x, y = coordinate[0], coordinate[1]
result += (
f"pyautogui.hscroll({scroll_amount if scroll_direction == 'right' else -scroll_amount}, {x}, {y})\n"
)
expected_outcome = "Scroll action finished"
# Handle click actions
elif action in ("left_click", "right_click", "double_click", "middle_click", "left_press"):
if coordinate is not None:
x, y = coordinate
if action == "left_click":
result += (f"pyautogui.click({x}, {y})\n")
elif action == "right_click":
result += (f"pyautogui.rightClick({x}, {y})\n")
elif action == "double_click":
result += (f"pyautogui.doubleClick({x}, {y})\n")
elif action == "middle_click":
result += (f"pyautogui.middleClick({x}, {y})\n")
elif action == "left_press":
result += (f"pyautogui.mouseDown({x}, {y})\n")
result += ("time.sleep(1)\n")
result += (f"pyautogui.mouseUp({x}, {y})\n")
else:
if action == "left_click":
result += ("pyautogui.click()\n")
elif action == "right_click":
result += ("pyautogui.rightClick()\n")
elif action == "double_click":
result += ("pyautogui.doubleClick()\n")
elif action == "middle_click":
result += ("pyautogui.middleClick()\n")
elif action == "left_press":
result += ("pyautogui.mouseDown()\n")
result += ("time.sleep(1)\n")
result += ("pyautogui.mouseUp()\n")
expected_outcome = "Click action finished"
elif action == "wait":
result += "pyautogui.sleep(0.5)\n"
expected_outcome = "Wait for 0.5 seconds"
elif action == "fail":
result += "FAIL"
expected_outcome = "Finished"
elif action == "done":
result += "DONE"
expected_outcome = "Finished"
elif action == "call_user":
result += "CALL_USER"
expected_outcome = "Call user"
elif action == "screenshot":
result += "pyautogui.sleep(0.1)\n"
expected_outcome = "Screenshot taken"
else:
raise ValueError(f"Invalid action: {action}")
return result
def predict(self, task_instruction: str, obs: Dict = None, system: Any = None):
system = BetaTextBlockParam(
type="text",
text=f"{SYSTEM_PROMPT_WINDOWS if self.platform == 'Windows' else SYSTEM_PROMPT}{' ' + self.system_prompt_suffix if self.system_prompt_suffix else ''}"
)
# resize screenshot if resize_factor is set
if obs and "screenshot" in obs:
# Convert bytes to PIL Image
screenshot_bytes = obs["screenshot"]
screenshot_image = Image.open(io.BytesIO(screenshot_bytes))
# Calculate new size based on resize factor
new_width, new_height = 1280, 720
# Resize the image
resized_image = screenshot_image.resize((new_width, new_height), Image.Resampling.LANCZOS)
# Convert back to bytes
output_buffer = io.BytesIO()
resized_image.save(output_buffer, format='PNG')
obs["screenshot"] = output_buffer.getvalue()
if not self.messages:
init_screenshot = obs
init_screenshot_base64 = base64.b64encode(init_screenshot["screenshot"]).decode('utf-8')
self.messages.append({
"role": "user",
"content": [
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": init_screenshot_base64,
},
},
{"type": "text", "text": task_instruction},
]
})
if self.messages and "tool_use" in [content_block["type"] for content_block in self.messages[-1]["content"]]:
self.add_tool_result(
self.messages[-1]["content"][-1]["id"],
f"Success",
screenshot=obs.get("screenshot") if obs else None
)
enable_prompt_caching = False
betas = ["computer-use-2025-01-24"]
if self.model_name == "claude-3-7-sonnet-20250219" or self.model_name == "claude-4-opus-20250514" or self.model_name == "claude-4-sonnet-20250514":
betas = ["computer-use-2025-01-24"]
elif self.model_name == "claude-3-5-sonnet-20241022":
betas = [COMPUTER_USE_BETA_FLAG]
image_truncation_threshold = 10
if self.provider == APIProvider.ANTHROPIC:
client = Anthropic(api_key=self.api_key, max_retries=4)
enable_prompt_caching = True
elif self.provider == APIProvider.VERTEX:
client = AnthropicVertex()
elif self.provider == APIProvider.BEDROCK:
client = AnthropicBedrock(
# Authenticate by either providing the keys below or use the default AWS credential providers, such as
# using ~/.aws/credentials or the "AWS_SECRET_ACCESS_KEY" and "AWS_ACCESS_KEY_ID" environment variables.
aws_access_key=os.getenv('AWS_ACCESS_KEY_ID'),
aws_secret_key=os.getenv('AWS_SECRET_ACCESS_KEY'),
# aws_region changes the aws region to which the request is made. By default, we read AWS_REGION,
# and if that's not present, we default to us-east-1. Note that we do not read ~/.aws/config for the region.
aws_region=os.getenv('AWS_DEFAULT_REGION'),
)
if enable_prompt_caching:
betas.append(PROMPT_CACHING_BETA_FLAG)
_inject_prompt_caching(self.messages)
image_truncation_threshold = 50
system["cache_control"] = {"type": "ephemeral"}
if self.only_n_most_recent_images:
_maybe_filter_to_n_most_recent_images(
self.messages,
self.only_n_most_recent_images,
min_removal_threshold=image_truncation_threshold,
)
try:
if self.model_name == "claude-3-5-sonnet-20241022":
tools = [
{'name': 'computer', 'type': 'computer_20241022', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1},
# {'type': 'bash_20241022', 'name': 'bash'},
# {'name': 'str_replace_editor', 'type': 'text_editor_20241022'}
] if self.platform == 'Ubuntu' else [
{'name': 'computer', 'type': 'computer_20241022', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1},
]
elif self.model_name == "claude-3-7-sonnet-20250219" or self.model_name == "claude-4-opus-20250514" or self.model_name == "claude-4-sonnet-20250514":
tools = [
{'name': 'computer', 'type': 'computer_20250124', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1},
# {'type': 'bash_20250124', 'name': 'bash'},
# {'name': 'str_replace_editor', 'type': 'text_editor_20250124'}
] if self.platform == 'Ubuntu' else [
{'name': 'computer', 'type': 'computer_20250124', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1},
]
extra_body = {
"thinking": {"type": "enabled", "budget_tokens": 1024}
}
response = None
if self.model_name == "claude-3-7-sonnet-20250219" or self.model_name == "claude-4-opus-20250514" or self.model_name == "claude-4-sonnet-20250514":
response = client.beta.messages.create(
max_tokens=self.max_tokens,
messages=self.messages,
model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name],
system=[system],
tools=tools,
betas=betas,
extra_body=extra_body
)
elif self.model_name == "claude-3-5-sonnet-20241022":
response = client.beta.messages.create(
max_tokens=self.max_tokens,
messages=self.messages,
model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name],
system=[system],
tools=tools,
betas=betas,
)
except (APIError, APIStatusError, APIResponseValidationError) as e:
self.logger.exception(f"Anthropic API error: {str(e)}")
try:
self.logger.warning("Retrying with backup API key...")
backup_client = Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY_BACKUP"), max_retries=4)
if self.model_name == "claude-3-7-sonnet-20250219" or self.model_name == "claude-4-opus-20250514" or self.model_name == "claude-4-sonnet-20250514":
response = backup_client.beta.messages.create(
max_tokens=self.max_tokens,
messages=self.messages,
model=PROVIDER_TO_DEFAULT_MODEL_NAME[APIProvider.ANTHROPIC, self.model_name],
system=[system],
tools=tools,
betas=betas,
extra_body=extra_body
)
elif self.model_name == "claude-3-5-sonnet-20241022":
response = backup_client.beta.messages.create(
max_tokens=self.max_tokens,
messages=self.messages,
model=PROVIDER_TO_DEFAULT_MODEL_NAME[APIProvider.ANTHROPIC, self.model_name],
system=[system],
tools=tools,
betas=betas,
)
self.logger.info("Successfully used backup API key")
except Exception as backup_e:
self.logger.exception(f"Backup API call also failed: {str(backup_e)}")
return None, None
except Exception as e:
self.logger.exception(f"Error in Anthropic API: {str(e)}")
return None, None
response_params = _response_to_params(response)
logger.info(f"Received response params: {response_params}")
# Store response in message history
self.messages.append({
"role": "assistant",
"content": response_params
})
actions: list[Any] = []
reasonings: list[str] = []
for content_block in response_params:
if content_block["type"] == "tool_use":
actions.append({
"name": content_block["name"],
"input": cast(dict[str, Any], content_block["input"]),
"id": content_block["id"],
"action_type": content_block.get("type"),
"command": self.parse_actions_from_tool_call(content_block)
})
elif content_block["type"] == "text":
reasonings.append(content_block["text"])
if isinstance(reasonings, list) and len(reasonings) > 0:
reasonings = reasonings[0]
else:
reasonings = ""
logger.info(f"Received actions: {actions}")
logger.info(f"Received reasonings: {reasonings}")
if len(actions) == 0:
actions = ["DONE"]
return reasonings, actions
def reset(self, *args, **kwargs):
"""
Reset the agent's state.
"""
self.messages = []
self.logger.info(f"{self.class_name} reset.")

View File

@ -0,0 +1,14 @@
from .base import CLIResult, ToolResult
from .bash import BashTool
from .collection import ToolCollection
from .computer import ComputerTool
from .edit import EditTool
__ALL__ = [
BashTool,
CLIResult,
ComputerTool,
EditTool,
ToolCollection,
ToolResult,
]

View File

@ -0,0 +1,69 @@
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass, fields, replace
from typing import Any, Optional
from anthropic.types.beta import BetaToolUnionParam
class BaseAnthropicTool(metaclass=ABCMeta):
"""Abstract base class for Anthropic-defined tools."""
@abstractmethod
def __call__(self, **kwargs) -> Any:
"""Executes the tool with the given arguments."""
...
@abstractmethod
def to_params(
self,
) -> BetaToolUnionParam:
raise NotImplementedError
@dataclass(frozen=True) #kw_only=True,
class ToolResult:
"""Represents the result of a tool execution."""
output: Optional[str] = None
error: Optional[str] = None
base64_image: Optional[str] = None
system: Optional[str] = None
def __bool__(self):
return any(getattr(self, field.name) for field in fields(self))
def __add__(self, other: "ToolResult"):
def combine_fields(
field: Optional[str], other_field: Optional[str], concatenate: bool = True
):
if field and other_field:
if concatenate:
return field + other_field
raise ValueError("Cannot combine tool results")
return field or other_field
return ToolResult(
output=combine_fields(self.output, other.output),
error=combine_fields(self.error, other.error),
base64_image=combine_fields(self.base64_image, other.base64_image, False),
system=combine_fields(self.system, other.system),
)
def replace(self, **kwargs):
"""Returns a new ToolResult with the given fields replaced."""
return replace(self, **kwargs)
class CLIResult(ToolResult):
"""A ToolResult that can be rendered as a CLI output."""
class ToolFailure(ToolResult):
"""A ToolResult that represents a failure."""
class ToolError(Exception):
"""Raised when a tool encounters an error."""
def __init__(self, message):
self.message = message

View File

@ -0,0 +1,144 @@
import asyncio
import os
from typing import ClassVar, Literal, Optional
from anthropic.types.beta import BetaToolBash20241022Param
from .base import BaseAnthropicTool, CLIResult, ToolError, ToolResult
class _BashSession:
"""A session of a bash shell."""
_started: bool
_process: asyncio.subprocess.Process
command: str = "/bin/bash"
_output_delay: float = 0.2 # seconds
_timeout: float = 120.0 # seconds
_sentinel: str = "<<exit>>"
def __init__(self):
self._started = False
self._timed_out = False
async def start(self):
if self._started:
return
self._process = await asyncio.create_subprocess_shell(
self.command,
preexec_fn=os.setsid,
shell=True,
bufsize=0,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
self._started = True
def stop(self):
"""Terminate the bash shell."""
if not self._started:
raise ToolError("Session has not started.")
if self._process.returncode is not None:
return
self._process.terminate()
async def run(self, command: str):
"""Execute a command in the bash shell."""
if not self._started:
raise ToolError("Session has not started.")
if self._process.returncode is not None:
return ToolResult(
system="tool must be restarted",
error=f"bash has exited with returncode {self._process.returncode}",
)
if self._timed_out:
raise ToolError(
f"timed out: bash has not returned in {self._timeout} seconds and must be restarted",
)
# we know these are not None because we created the process with PIPEs
assert self._process.stdin
assert self._process.stdout
assert self._process.stderr
# send command to the process
self._process.stdin.write(
command.encode() + f"; echo '{self._sentinel}'\n".encode()
)
await self._process.stdin.drain()
# read output from the process, until the sentinel is found
try:
async with asyncio.timeout(self._timeout):
while True:
await asyncio.sleep(self._output_delay)
# if we read directly from stdout/stderr, it will wait forever for
# EOF. use the StreamReader buffer directly instead.
output = self._process.stdout._buffer.decode() # pyright: ignore[reportAttributeAccessIssue]
if self._sentinel in output:
# strip the sentinel and break
output = output[: output.index(self._sentinel)]
break
except asyncio.TimeoutError:
self._timed_out = True
raise ToolError(
f"timed out: bash has not returned in {self._timeout} seconds and must be restarted",
) from None
if output.endswith("\n"):
output = output[:-1]
error = self._process.stderr._buffer.decode() # pyright: ignore[reportAttributeAccessIssue]
if error.endswith("\n"):
error = error[:-1]
# clear the buffers so that the next output can be read correctly
self._process.stdout._buffer.clear() # pyright: ignore[reportAttributeAccessIssue]
self._process.stderr._buffer.clear() # pyright: ignore[reportAttributeAccessIssue]
return CLIResult(output=output, error=error)
class BashTool(BaseAnthropicTool):
"""
A tool that allows the agent to run bash commands.
The tool parameters are defined by Anthropic and are not editable.
"""
_session: Optional[_BashSession]
name: ClassVar[Literal["bash"]] = "bash"
api_type: ClassVar[Literal["bash_20241022"]] = "bash_20241022"
def __init__(self):
self._session = None
super().__init__()
async def __call__(
self, command: Optional[str] = None, restart: bool = False, **kwargs
):
if restart:
if self._session:
self._session.stop()
self._session = _BashSession()
await self._session.start()
return ToolResult(system="tool has been restarted.")
if self._session is None:
self._session = _BashSession()
await self._session.start()
if command is not None:
return await self._session.run(command)
raise ToolError("no command provided.")
def to_params(self) -> BetaToolBash20241022Param:
return {
"type": self.api_type,
"name": self.name,
}

View File

@ -0,0 +1,34 @@
"""Collection classes for managing multiple tools."""
from typing import Any
from anthropic.types.beta import BetaToolUnionParam
from .base import (
BaseAnthropicTool,
ToolError,
ToolFailure,
ToolResult,
)
class ToolCollection:
"""A collection of anthropic-defined tools."""
def __init__(self, *tools: BaseAnthropicTool):
self.tools = tools
self.tool_map = {tool.to_params()["name"]: tool for tool in tools}
def to_params(
self,
) -> list[BetaToolUnionParam]:
return [tool.to_params() for tool in self.tools]
async def run(self, *, name: str, tool_input: dict[str, Any]) -> ToolResult:
tool = self.tool_map.get(name)
if not tool:
return ToolFailure(error=f"Tool {name} is invalid")
try:
return await tool(**tool_input)
except ToolError as e:
return ToolFailure(error=e.message)

View File

@ -0,0 +1,260 @@
import asyncio
import base64
import os
import shlex
import shutil
from enum import Enum
from pathlib import Path
from typing import Literal, TypedDict, Optional, Tuple
from uuid import uuid4
from anthropic.types.beta import BetaToolComputerUse20241022Param
from .base import BaseAnthropicTool, ToolError, ToolResult
from .run import run
OUTPUT_DIR = "/tmp/outputs"
TYPING_DELAY_MS = 12
TYPING_GROUP_SIZE = 50
Action = Literal[
"key",
"type",
"mouse_move",
"left_click",
"left_click_drag",
"right_click",
"middle_click",
"double_click",
"screenshot",
"cursor_position",
]
class Resolution(TypedDict):
width: int
height: int
# sizes above XGA/WXGA are not recommended (see README.md)
# scale down to one of these targets if ComputerTool._scaling_enabled is set
MAX_SCALING_TARGETS: dict[str, Resolution] = {
"XGA": Resolution(width=1024, height=768), # 4:3
"WXGA": Resolution(width=1280, height=800), # 16:10
"FWXGA": Resolution(width=1366, height=768), # ~16:9
}
class ScalingSource(Enum):
COMPUTER = "computer"
API = "api"
class ComputerToolOptions(TypedDict):
display_height_px: int
display_width_px: int
display_number: Optional[int]
def chunks(s: str, chunk_size: int) -> list[str]:
return [s[i : i + chunk_size] for i in range(0, len(s), chunk_size)]
class ComputerTool(BaseAnthropicTool):
"""
A tool that allows the agent to interact with the screen, keyboard, and mouse of the current computer.
The tool parameters are defined by Anthropic and are not editable.
"""
name: Literal["computer"] = "computer"
api_type: Literal["computer_20241022"] = "computer_20241022"
width: int
height: int
display_num: Optional[int]
_screenshot_delay = 2.0
_scaling_enabled = True
@property
def options(self) -> ComputerToolOptions:
width, height = self.scale_coordinates(
ScalingSource.COMPUTER, self.width, self.height
)
return {
"display_width_px": width,
"display_height_px": height,
"display_number": self.display_num,
}
def to_params(self) -> BetaToolComputerUse20241022Param:
return {"name": self.name, "type": self.api_type, **self.options}
def __init__(self):
super().__init__()
self.width = int(os.getenv("WIDTH") or 0)
self.height = int(os.getenv("HEIGHT") or 0)
assert self.width and self.height, "WIDTH, HEIGHT must be set"
if (display_num := os.getenv("DISPLAY_NUM")) is not None:
self.display_num = int(display_num)
self._display_prefix = f"DISPLAY=:{self.display_num} "
else:
self.display_num = None
self._display_prefix = ""
self.xdotool = f"{self._display_prefix}xdotool"
async def __call__(
self,
*,
action: Action,
text: Optional[str] = None,
coordinate: Optional[Tuple[int, int]] = None,
**kwargs,
):
if action in ("mouse_move", "left_click_drag"):
if coordinate is None:
raise ToolError(f"coordinate is required for {action}")
if text is not None:
raise ToolError(f"text is not accepted for {action}")
if not isinstance(coordinate, list) or len(coordinate) != 2:
raise ToolError(f"{coordinate} must be a tuple of length 2")
if not all(isinstance(i, int) and i >= 0 for i in coordinate):
raise ToolError(f"{coordinate} must be a tuple of non-negative ints")
x, y = self.scale_coordinates(
ScalingSource.API, coordinate[0], coordinate[1]
)
if action == "mouse_move":
return await self.shell(f"{self.xdotool} mousemove --sync {x} {y}")
elif action == "left_click_drag":
return await self.shell(
f"{self.xdotool} mousedown 1 mousemove --sync {x} {y} mouseup 1"
)
if action in ("key", "type"):
if text is None:
raise ToolError(f"text is required for {action}")
if coordinate is not None:
raise ToolError(f"coordinate is not accepted for {action}")
if not isinstance(text, str):
raise ToolError(output=f"{text} must be a string")
if action == "key":
return await self.shell(f"{self.xdotool} key -- {text}")
elif action == "type":
results: list[ToolResult] = []
for chunk in chunks(text, TYPING_GROUP_SIZE):
cmd = f"{self.xdotool} type --delay {TYPING_DELAY_MS} -- {shlex.quote(chunk)}"
results.append(await self.shell(cmd, take_screenshot=False))
screenshot_base64 = (await self.screenshot()).base64_image
return ToolResult(
output="".join(result.output or "" for result in results),
error="".join(result.error or "" for result in results),
base64_image=screenshot_base64,
)
if action in (
"left_click",
"right_click",
"double_click",
"middle_click",
"screenshot",
"cursor_position",
):
if text is not None:
raise ToolError(f"text is not accepted for {action}")
if coordinate is not None:
raise ToolError(f"coordinate is not accepted for {action}")
if action == "screenshot":
return await self.screenshot()
elif action == "cursor_position":
result = await self.shell(
f"{self.xdotool} getmouselocation --shell",
take_screenshot=False,
)
output = result.output or ""
x, y = self.scale_coordinates(
ScalingSource.COMPUTER,
int(output.split("X=")[1].split("\n")[0]),
int(output.split("Y=")[1].split("\n")[0]),
)
return result.replace(output=f"X={x},Y={y}")
else:
click_arg = {
"left_click": "1",
"right_click": "3",
"middle_click": "2",
"double_click": "--repeat 2 --delay 500 1",
}[action]
return await self.shell(f"{self.xdotool} click {click_arg}")
raise ToolError(f"Invalid action: {action}")
async def screenshot(self):
"""Take a screenshot of the current screen and return the base64 encoded image."""
output_dir = Path(OUTPUT_DIR)
output_dir.mkdir(parents=True, exist_ok=True)
path = output_dir / f"screenshot_{uuid4().hex}.png"
# Try gnome-screenshot first
if shutil.which("gnome-screenshot"):
screenshot_cmd = f"{self._display_prefix}gnome-screenshot -f {path} -p"
else:
# Fall back to scrot if gnome-screenshot isn't available
screenshot_cmd = f"{self._display_prefix}scrot -p {path}"
result = await self.shell(screenshot_cmd, take_screenshot=False)
if self._scaling_enabled:
x, y = self.scale_coordinates(
ScalingSource.COMPUTER, self.width, self.height
)
await self.shell(
f"convert {path} -resize {x}x{y}! {path}", take_screenshot=False
)
if path.exists():
return result.replace(
base64_image=base64.b64encode(path.read_bytes()).decode()
)
raise ToolError(f"Failed to take screenshot: {result.error}")
async def shell(self, command: str, take_screenshot=True) -> ToolResult:
"""Run a shell command and return the output, error, and optionally a screenshot."""
_, stdout, stderr = await run(command)
base64_image = None
if take_screenshot:
# delay to let things settle before taking a screenshot
await asyncio.sleep(self._screenshot_delay)
base64_image = (await self.screenshot()).base64_image
return ToolResult(output=stdout, error=stderr, base64_image=base64_image)
def scale_coordinates(self, source: ScalingSource, x: int, y: int):
"""Scale coordinates to a target maximum resolution."""
if not self._scaling_enabled:
return x, y
ratio = self.width / self.height
target_dimension = None
for dimension in MAX_SCALING_TARGETS.values():
# allow some error in the aspect ratio - not ratios are exactly 16:9
if abs(dimension["width"] / dimension["height"] - ratio) < 0.02:
if dimension["width"] < self.width:
target_dimension = dimension
break
if target_dimension is None:
return x, y
# should be less than 1
x_scaling_factor = target_dimension["width"] / self.width
y_scaling_factor = target_dimension["height"] / self.height
if source == ScalingSource.API:
if x > self.width or y > self.height:
raise ToolError(f"Coordinates {x}, {y} are out of bounds")
# scale up
return round(x / x_scaling_factor), round(y / y_scaling_factor)
# scale down
return round(x * x_scaling_factor), round(y * y_scaling_factor)

View File

@ -0,0 +1,290 @@
from collections import defaultdict
from pathlib import Path
from typing import Literal, get_args, Optional, List
from anthropic.types.beta import BetaToolTextEditor20241022Param
from .base import BaseAnthropicTool, CLIResult, ToolError, ToolResult
from .run import maybe_truncate, run
Command = Literal[
"view",
"create",
"str_replace",
"insert",
"undo_edit",
]
SNIPPET_LINES: int = 4
class EditTool(BaseAnthropicTool):
"""
An filesystem editor tool that allows the agent to view, create, and edit files.
The tool parameters are defined by Anthropic and are not editable.
"""
api_type: Literal["text_editor_20241022"] = "text_editor_20241022"
name: Literal["str_replace_editor"] = "str_replace_editor"
_file_history: dict[Path, list[str]]
def __init__(self):
self._file_history = defaultdict(list)
super().__init__()
def to_params(self) -> BetaToolTextEditor20241022Param:
return {
"name": self.name,
"type": self.api_type,
}
async def __call__(
self,
*,
command: Command,
path: str,
file_text: Optional[str] = None,
view_range: Optional[list[int]] = None,
old_str: Optional[str] = None,
new_str: Optional[str] = None,
insert_line: Optional[int] = None,
**kwargs,
):
_path = Path(path)
self.validate_path(command, _path)
if command == "view":
return await self.view(_path, view_range)
elif command == "create":
if file_text is None:
raise ToolError("Parameter `file_text` is required for command: create")
self.write_file(_path, file_text)
self._file_history[_path].append(file_text)
return ToolResult(output=f"File created successfully at: {_path}")
elif command == "str_replace":
if old_str is None:
raise ToolError(
"Parameter `old_str` is required for command: str_replace"
)
return self.str_replace(_path, old_str, new_str)
elif command == "insert":
if insert_line is None:
raise ToolError(
"Parameter `insert_line` is required for command: insert"
)
if new_str is None:
raise ToolError("Parameter `new_str` is required for command: insert")
return self.insert(_path, insert_line, new_str)
elif command == "undo_edit":
return self.undo_edit(_path)
raise ToolError(
f'Unrecognized command {command}. The allowed commands for the {self.name} tool are: {", ".join(get_args(Command))}'
)
def validate_path(self, command: str, path: Path):
"""
Check that the path/command combination is valid.
"""
# Check if its an absolute path
if not path.is_absolute():
suggested_path = Path("") / path
raise ToolError(
f"The path {path} is not an absolute path, it should start with `/`. Maybe you meant {suggested_path}?"
)
# Check if path exists
if not path.exists() and command != "create":
raise ToolError(
f"The path {path} does not exist. Please provide a valid path."
)
if path.exists() and command == "create":
raise ToolError(
f"File already exists at: {path}. Cannot overwrite files using command `create`."
)
# Check if the path points to a directory
if path.is_dir():
if command != "view":
raise ToolError(
f"The path {path} is a directory and only the `view` command can be used on directories"
)
async def view(self, path: Path, view_range: Optional[List[int]] = None):
"""Implement the view command"""
if path.is_dir():
if view_range:
raise ToolError(
"The `view_range` parameter is not allowed when `path` points to a directory."
)
_, stdout, stderr = await run(
rf"find {path} -maxdepth 2 -not -path '*/\.*'"
)
if not stderr:
stdout = f"Here's the files and directories up to 2 levels deep in {path}, excluding hidden items:\n{stdout}\n"
return CLIResult(output=stdout, error=stderr)
file_content = self.read_file(path)
init_line = 1
if view_range:
if len(view_range) != 2 or not all(isinstance(i, int) for i in view_range):
raise ToolError(
"Invalid `view_range`. It should be a list of two integers."
)
file_lines = file_content.split("\n")
n_lines_file = len(file_lines)
init_line, final_line = view_range
if init_line < 1 or init_line > n_lines_file:
raise ToolError(
f"Invalid `view_range`: {view_range}. Its first element `{init_line}` should be within the range of lines of the file: {[1, n_lines_file]}"
)
if final_line > n_lines_file:
raise ToolError(
f"Invalid `view_range`: {view_range}. Its second element `{final_line}` should be smaller than the number of lines in the file: `{n_lines_file}`"
)
if final_line != -1 and final_line < init_line:
raise ToolError(
f"Invalid `view_range`: {view_range}. Its second element `{final_line}` should be larger or equal than its first `{init_line}`"
)
if final_line == -1:
file_content = "\n".join(file_lines[init_line - 1 :])
else:
file_content = "\n".join(file_lines[init_line - 1 : final_line])
return CLIResult(
output=self._make_output(file_content, str(path), init_line=init_line)
)
def str_replace(self, path: Path, old_str: str, new_str: Optional[str]):
"""Implement the str_replace command, which replaces old_str with new_str in the file content"""
# Read the file content
file_content = self.read_file(path).expandtabs()
old_str = old_str.expandtabs()
new_str = new_str.expandtabs() if new_str is not None else ""
# Check if old_str is unique in the file
occurrences = file_content.count(old_str)
if occurrences == 0:
raise ToolError(
f"No replacement was performed, old_str `{old_str}` did not appear verbatim in {path}."
)
elif occurrences > 1:
file_content_lines = file_content.split("\n")
lines = [
idx + 1
for idx, line in enumerate(file_content_lines)
if old_str in line
]
raise ToolError(
f"No replacement was performed. Multiple occurrences of old_str `{old_str}` in lines {lines}. Please ensure it is unique"
)
# Replace old_str with new_str
new_file_content = file_content.replace(old_str, new_str)
# Write the new content to the file
self.write_file(path, new_file_content)
# Save the content to history
self._file_history[path].append(file_content)
# Create a snippet of the edited section
replacement_line = file_content.split(old_str)[0].count("\n")
start_line = max(0, replacement_line - SNIPPET_LINES)
end_line = replacement_line + SNIPPET_LINES + new_str.count("\n")
snippet = "\n".join(new_file_content.split("\n")[start_line : end_line + 1])
# Prepare the success message
success_msg = f"The file {path} has been edited. "
success_msg += self._make_output(
snippet, f"a snippet of {path}", start_line + 1
)
success_msg += "Review the changes and make sure they are as expected. Edit the file again if necessary."
return CLIResult(output=success_msg)
def insert(self, path: Path, insert_line: int, new_str: str):
"""Implement the insert command, which inserts new_str at the specified line in the file content."""
file_text = self.read_file(path).expandtabs()
new_str = new_str.expandtabs()
file_text_lines = file_text.split("\n")
n_lines_file = len(file_text_lines)
if insert_line < 0 or insert_line > n_lines_file:
raise ToolError(
f"Invalid `insert_line` parameter: {insert_line}. It should be within the range of lines of the file: {[0, n_lines_file]}"
)
new_str_lines = new_str.split("\n")
new_file_text_lines = (
file_text_lines[:insert_line]
+ new_str_lines
+ file_text_lines[insert_line:]
)
snippet_lines = (
file_text_lines[max(0, insert_line - SNIPPET_LINES) : insert_line]
+ new_str_lines
+ file_text_lines[insert_line : insert_line + SNIPPET_LINES]
)
new_file_text = "\n".join(new_file_text_lines)
snippet = "\n".join(snippet_lines)
self.write_file(path, new_file_text)
self._file_history[path].append(file_text)
success_msg = f"The file {path} has been edited. "
success_msg += self._make_output(
snippet,
"a snippet of the edited file",
max(1, insert_line - SNIPPET_LINES + 1),
)
success_msg += "Review the changes and make sure they are as expected (correct indentation, no duplicate lines, etc). Edit the file again if necessary."
return CLIResult(output=success_msg)
def undo_edit(self, path: Path):
"""Implement the undo_edit command."""
if not self._file_history[path]:
raise ToolError(f"No edit history found for {path}.")
old_text = self._file_history[path].pop()
self.write_file(path, old_text)
return CLIResult(
output=f"Last edit to {path} undone successfully. {self._make_output(old_text, str(path))}"
)
def read_file(self, path: Path):
"""Read the content of a file from a given path; raise a ToolError if an error occurs."""
try:
return path.read_text()
except Exception as e:
raise ToolError(f"Ran into {e} while trying to read {path}") from None
def write_file(self, path: Path, file: str):
"""Write the content of a file to a given path; raise a ToolError if an error occurs."""
try:
path.write_text(file)
except Exception as e:
raise ToolError(f"Ran into {e} while trying to write to {path}") from None
def _make_output(
self,
file_content: str,
file_descriptor: str,
init_line: int = 1,
expand_tabs: bool = True,
):
"""Generate output for the CLI based on the content of a file."""
file_content = maybe_truncate(file_content)
if expand_tabs:
file_content = file_content.expandtabs()
file_content = "\n".join(
[
f"{i + init_line:6}\t{line}"
for i, line in enumerate(file_content.split("\n"))
]
)
return (
f"Here's the result of running `cat -n` on {file_descriptor}:\n"
+ file_content
+ "\n"
)

View File

@ -0,0 +1,42 @@
"""Utility to run shell commands asynchronously with a timeout."""
import asyncio
from typing import Optional
TRUNCATED_MESSAGE: str = "<response clipped><NOTE>To save on context only part of this file has been shown to you. You should retry this tool after you have searched inside the file with `grep -n` in order to find the line numbers of what you are looking for.</NOTE>"
MAX_RESPONSE_LEN: int = 16000
def maybe_truncate(content: str, truncate_after: Optional[int] = MAX_RESPONSE_LEN):
"""Truncate content and append a notice if content exceeds the specified length."""
return (
content
if not truncate_after or len(content) <= truncate_after
else content[:truncate_after] + TRUNCATED_MESSAGE
)
async def run(
cmd: str,
timeout: Optional[float] = 120.0, # seconds
truncate_after: Optional[int] = MAX_RESPONSE_LEN,
):
"""Run a shell command asynchronously with a timeout."""
process = await asyncio.create_subprocess_shell(
cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
try:
stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout)
return (
process.returncode or 0,
maybe_truncate(stdout.decode(), truncate_after=truncate_after),
maybe_truncate(stderr.decode(), truncate_after=truncate_after),
)
except asyncio.TimeoutError as exc:
try:
process.kill()
except ProcessLookupError:
pass
raise TimeoutError(
f"Command '{cmd}' timed out after {timeout} seconds"
) from exc

View File

@ -0,0 +1,246 @@
"""
Utility functions for the Anthropic API.
"""
from typing import List, Union, cast
from enum import Enum
from anthropic import (
Anthropic,
AnthropicBedrock,
AnthropicVertex,
APIError,
APIResponseValidationError,
APIStatusError,
)
from anthropic.types.beta import (
BetaCacheControlEphemeralParam,
BetaContentBlockParam,
BetaImageBlockParam,
BetaMessage,
BetaMessageParam,
BetaTextBlock,
BetaTextBlockParam,
BetaToolResultBlockParam,
BetaToolUseBlockParam,
)
from datetime import datetime
from .tools import ToolResult
COMPUTER_USE_BETA_FLAG = "computer-use-2024-10-22"
PROMPT_CACHING_BETA_FLAG = "prompt-caching-2024-07-31"
class APIProvider(Enum):
ANTHROPIC = "anthropic"
BEDROCK = "bedrock"
VERTEX = "vertex"
PROVIDER_TO_DEFAULT_MODEL_NAME: dict[(APIProvider, str), str] = {
(APIProvider.ANTHROPIC, "claude-3-5-sonnet-20241022"): "claude-3-5-sonnet-20241022",
(APIProvider.BEDROCK, "claude-3-5-sonnet-20241022"): "us.anthropic.claude-3-5-sonnet-20241022-v2:0",
(APIProvider.VERTEX, "claude-3-5-sonnet-20241022"): "claude-3-5-sonnet-v1@20241022",
(APIProvider.ANTHROPIC, "claude-3-7-sonnet-20250219"): "claude-3-7-sonnet-20250219",
(APIProvider.BEDROCK, "claude-3-7-sonnet-20250219"): "us.anthropic.claude-3-7-sonnet-20250219-v1:0",
(APIProvider.VERTEX, "claude-3-7-sonnet-20250219"): "claude-3-7-sonnet-v1@20250219",
(APIProvider.ANTHROPIC, "claude-4-opus-20250514"): "claude-4-opus-20250514",
(APIProvider.BEDROCK, "claude-4-opus-20250514"): "us.anthropic.claude-opus-4-20250514-v1:0",
(APIProvider.VERTEX, "claude-4-opus-20250514"): "claude-4-opus-v1@20250514",
(APIProvider.ANTHROPIC, "claude-4-sonnet-20250514"): "claude-4-sonnet-20250514",
(APIProvider.BEDROCK, "claude-4-sonnet-20250514"): "us.anthropic.claude-sonnet-4-20250514-v1:0",
(APIProvider.VERTEX, "claude-4-sonnet-20250514"): "claude-sonnet-4-v1@20250514",
}
# This system prompt is optimized for the Docker environment in this repository and
# specific tool combinations enabled.
# We encourage modifying this system prompt to ensure the model has context for the
# environment it is running in, and to provide any additional information that may be
# helpful for the task at hand.
SYSTEM_PROMPT = f"""<SYSTEM_CAPABILITY>
* You are utilising an Ubuntu virtual machine using x86_64 architecture with internet access.
* You can feel free to install Ubuntu applications with your bash tool. Use curl instead of wget.
* To open browser, please just click on the Chrome icon. Note, Chrome is what is installed on your system.
* Using bash tool you can start GUI applications, but you need to set export DISPLAY=:1 and use a subshell. For example "(DISPLAY=:1 xterm &)". GUI apps run with bash tool will appear within your desktop environment, but they may take some time to appear. Take a screenshot to confirm it did.
* When using your bash tool with commands that are expected to output very large quantities of text, redirect into a tmp file and use str_replace_editor or `grep -n -B <lines before> -A <lines after> <query> <filename>` to confirm output.
* When viewing a page it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available.
* DO NOT ask users for clarification during task execution. DO NOT stop to request more information from users. Always take action using available tools.
* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request.
* The current date is {datetime.today().strftime('%A, %B %d, %Y')}.
* Home directory of this Ubuntu system is '/home/user'.
</SYSTEM_CAPABILITY>
<IMPORTANT>
* If the item you are looking at is a pdf, if after taking a single screenshot of the pdf it seems that you want to read the entire document instead of trying to continue to read the pdf from your screenshots + navigation, determine the URL, use curl to download the pdf, install and use pdftotext to convert it to a text file, and then read that text file directly with your StrReplaceEditTool.
</IMPORTANT>"""
SYSTEM_PROMPT_WINDOWS = f"""<SYSTEM_CAPABILITY>
* You are utilising a Windows virtual machine using x86_64 architecture with internet access.
* To open browser, please just click on the Chrome icon. Note, Chrome is what is installed on your system.
* When viewing a page it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available.
* The current date is {datetime.today().strftime('%A, %B %d, %Y')}.
* Home directory of this Windows system is 'C:\\Users\\user'.
* When you want to open some applications on Windows, please use Double Click on it instead of clicking once.
</SYSTEM_CAPABILITY>"""
def _make_api_tool_result(
result: ToolResult, tool_use_id: str
) -> BetaToolResultBlockParam:
"""Convert an agent ToolResult to an API ToolResultBlockParam."""
tool_result_content: Union[List[Union[BetaTextBlockParam,
BetaImageBlockParam]], str] = []
is_error = False
if not result or (result.get('error') is not None and result.get('error') != ""):
is_error = True
error_message = str(result.get('error', 'Unknown error occurred')) if result else 'No result received'
tool_result_content = [{
"type": "text",
"text": _maybe_prepend_system_tool_result(result, error_message)
}]
else:
if result.get('output'):
tool_result_content.append({
"type": "text",
"text": _maybe_prepend_system_tool_result(
result,
str(result.get('output', '')
if result else '')
),
})
if result.get('base64_image'):
tool_result_content.append({
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": result.get('base64_image', ''),
},
})
if not tool_result_content:
tool_result_content.append({
"type": "text",
"text": "Action completed successfully"
})
return {
"type": "tool_result",
"content": tool_result_content,
"tool_use_id": tool_use_id,
"is_error": is_error,
}
def _maybe_prepend_system_tool_result(result: ToolResult, result_text: str):
if not result:
return result_text
if result.get('system', False):
result_text = f"<system>{result.get('system','')}</system>\n{result_text}"
return result_text
def _inject_prompt_caching(
messages: list[BetaMessageParam],
):
"""
Set cache breakpoints for the 3 most recent turns
one cache breakpoint is left for tools/system prompt, to be shared across sessions
"""
breakpoints_remaining = 3
for message in reversed(messages):
if message["role"] == "user" and isinstance(
content := message["content"], list
):
if breakpoints_remaining:
breakpoints_remaining -= 1
# Use type ignore to bypass TypedDict check until SDK types are updated
content[-1]["cache_control"] = BetaCacheControlEphemeralParam( # type: ignore
{"type": "ephemeral"}
)
else:
content[-1].pop("cache_control", None)
# we'll only every have one extra turn per loop
break
def _maybe_filter_to_n_most_recent_images(
messages: list[BetaMessageParam],
images_to_keep: int,
min_removal_threshold: int,
):
"""
With the assumption that images are screenshots that are of diminishing value as
the conversation progresses, remove all but the final `images_to_keep` tool_result
images in place, with a chunk of min_removal_threshold to reduce the amount we
break the implicit prompt cache.
"""
if images_to_keep is None:
return messages
tool_result_blocks = cast(
list[BetaToolResultBlockParam],
[
item
for message in messages
for item in (
message["content"] if isinstance(message["content"], list) else []
)
if isinstance(item, dict) and item.get("type") == "tool_result"
],
)
total_images = sum(
1
for tool_result in tool_result_blocks
for content in tool_result.get("content", [])
if isinstance(content, dict) and content.get("type") == "image"
)
images_to_remove = total_images - images_to_keep
# for better cache behavior, we want to remove in chunks
images_to_remove -= images_to_remove % min_removal_threshold
for tool_result in tool_result_blocks:
if isinstance(tool_result.get("content"), list):
new_content = []
for content in tool_result.get("content", []):
if isinstance(content, dict) and content.get("type") == "image":
if images_to_remove > 0:
images_to_remove -= 1
continue
new_content.append(content)
tool_result["content"] = new_content
def _response_to_params(
response: BetaMessage,
) -> list[BetaContentBlockParam]:
res: list[BetaContentBlockParam] = []
if response.content:
for block in response.content:
if isinstance(block, BetaTextBlock):
if block.text:
res.append(BetaTextBlockParam(type="text", text=block.text))
elif getattr(block, "type", None) == "thinking":
# Handle thinking blocks - include signature field
thinking_block = {
"type": "thinking",
"thinking": getattr(block, "thinking", None),
}
if hasattr(block, "signature"):
thinking_block["signature"] = getattr(block, "signature", None)
res.append(cast(BetaContentBlockParam, thinking_block))
else:
# Handle tool use blocks normally
res.append(cast(BetaToolUseBlockParam, block.model_dump()))
return res
else:
return []

384
run_multienv_claude.py Normal file
View File

@ -0,0 +1,384 @@
"""Script to run end-to-end evaluation on the benchmark.
Utils and basic architecture credit to https://github.com/web-arena-x/webarena/blob/main/run.py.
"""
import argparse
import datetime
import json
import logging
import os
import sys
from typing import List, Dict
import math
from tqdm import tqdm
from multiprocessing import Process, Manager
import lib_run_single
from desktop_env.desktop_env import DesktopEnv
from mm_agents.anthropic import AnthropicAgent as PromptAgent
# import fake_run_single as lib_run_single
# from test_env import DesktopEnv
# .env
from dotenv import load_dotenv
load_dotenv()
# Logger Configs {{{ #
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
file_handler = logging.FileHandler(
os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8"
)
debug_handler = logging.FileHandler(
os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8"
)
stdout_handler = logging.StreamHandler(sys.stdout)
sdebug_handler = logging.FileHandler(
os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8"
)
file_handler.setLevel(logging.INFO)
debug_handler.setLevel(logging.DEBUG)
stdout_handler.setLevel(logging.INFO)
sdebug_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter(
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
)
file_handler.setFormatter(formatter)
debug_handler.setFormatter(formatter)
stdout_handler.setFormatter(formatter)
sdebug_handler.setFormatter(formatter)
stdout_handler.addFilter(logging.Filter("desktopenv"))
sdebug_handler.addFilter(logging.Filter("desktopenv"))
logger.addHandler(file_handler)
logger.addHandler(debug_handler)
logger.addHandler(stdout_handler)
logger.addHandler(sdebug_handler)
# }}} Logger Configs #
logger = logging.getLogger("desktopenv.experiment")
def config() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Run end-to-end evaluation on the benchmark"
)
# environment config
parser.add_argument("--path_to_vm", type=str, default=None)
parser.add_argument(
"--headless", action="store_true", help="Run in headless machine"
)
parser.add_argument(
"--action_space", type=str, default="claude_computer_use", help="Action type"
)
parser.add_argument(
"--observation_type",
choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"],
default="a11y_tree",
help="Observation type",
)
parser.add_argument("--screen_width", type=int, default=1920)
parser.add_argument("--screen_height", type=int, default=1080)
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
parser.add_argument("--max_steps", type=int, default=15)
# agent config
parser.add_argument("--max_trajectory_length", type=int, default=3)
parser.add_argument(
"--test_config_base_dir", type=str, default="evaluation_examples"
)
# lm config
parser.add_argument("--model", type=str, default="claude-4-sonnet-20250514")
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--top_p", type=float, default=0.9)
parser.add_argument("--max_tokens", type=int, default=1500)
parser.add_argument("--stop_token", type=str, default=None)
# example config
parser.add_argument("--domain", type=str, default="all")
parser.add_argument(
"--test_all_meta_path", type=str, default="evaluation_examples/test_all.json"
)
# logging related
parser.add_argument("--result_dir", type=str, default="./results")
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel")
# aws config
parser.add_argument(
"--region", type=str, default="us-east-1", help="AWS region for the VM"
)
args = parser.parse_args()
return args
def distribute_tasks(test_all_meta: dict, num_envs: int) -> List[Dict]:
"""Distribute tasks evenly across environments."""
# Flatten the tasks into a single list
all_tasks = []
for domain, examples in test_all_meta.items():
for example_id in examples:
all_tasks.append((domain, example_id))
# Calculate tasks per environment
tasks_per_env = math.ceil(len(all_tasks) / num_envs)
# Distribute tasks
distributed_tasks = []
for i in range(num_envs):
env_tasks = {}
start_idx = i * tasks_per_env
end_idx = min((i + 1) * tasks_per_env, len(all_tasks))
for domain, example_id in all_tasks[start_idx:end_idx]:
if domain not in env_tasks:
env_tasks[domain] = []
env_tasks[domain].append(example_id)
distributed_tasks.append(env_tasks)
return distributed_tasks
def run_env_tasks(env_idx: int, env: DesktopEnv, agent: PromptAgent, env_tasks: dict, args: argparse.Namespace, shared_scores: list):
"""Run tasks for a single environment."""
logger.info(f"Executing tasks in environment {env_idx + 1}/{args.num_envs}")
for domain in tqdm(env_tasks, desc=f"Env{env_idx+1}-Domain"):
for example_id in tqdm(env_tasks[domain], desc="Example", leave=False):
config_file = os.path.join(
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
)
with open(config_file, "r", encoding="utf-8") as f:
example = json.load(f)
logger.info(f"[Env {env_idx+1}][Domain]: {domain}")
logger.info(f"[Env {env_idx+1}][Example ID]: {example_id}")
logger.info(f"[Env {env_idx+1}][Instruction]: {example['instruction']}")
example_result_dir = os.path.join(
args.result_dir,
args.action_space,
args.observation_type,
args.model,
domain,
example_id,
)
os.makedirs(example_result_dir, exist_ok=True)
try:
lib_run_single.run_single_example(
agent,
env,
example,
args.max_steps,
example["instruction"],
args,
example_result_dir,
shared_scores,
)
except Exception as e:
import traceback
# logger traceback
logger.error(f"Exception in Env{env_idx+1} {domain}/{example_id}: {e}")
logger.error(traceback.format_exc())
env.controller.end_recording(
os.path.join(example_result_dir, "recording.mp4")
)
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write(
json.dumps(
{"Error": f"Time limit exceeded in {domain}/{example_id}"}
)
)
f.write("\n")
env.close()
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
logger.info("Args: %s", args)
distributed_tasks = distribute_tasks(test_all_meta, args.num_envs)
# First, set up all environments
logger.info("Setting up all environments...")
envs = []
agents = []
for env_idx in range(args.num_envs):
logger.info(f"Setting up environment {env_idx + 1}/{args.num_envs}")
agent = PromptAgent(
model=args.model,
max_tokens=args.max_tokens,
top_p=args.top_p,
temperature=args.temperature,
action_space=args.action_space,
observation_type=args.observation_type,
max_trajectory_length=args.max_trajectory_length,
screen_size=(args.screen_width, args.screen_height),
)
agents.append(agent)
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
REGION = "us-east-1"
env = DesktopEnv(
path_to_vm=args.path_to_vm,
action_space=agent.action_space,
provider_name="aws",
region="us-east-1",
snapshot_name=IMAGE_ID_MAP[REGION],
screen_size=(args.screen_width, args.screen_height),
headless=args.headless,
os_type="Ubuntu",
require_a11y_tree=args.observation_type
in ["a11y_tree", "screenshot_a11y_tree", "som"],
)
envs.append(env)
logger.info("All environments are ready. Starting parallel task execution...")
# Create a shared list for scores across processes
with Manager() as manager:
shared_scores = manager.list()
# Create and start processes for each environment
processes = []
for env_idx, (env, agent, env_tasks) in enumerate(zip(envs, agents, distributed_tasks)):
p = Process(
target=run_env_tasks,
args=(env_idx, env, agent, env_tasks, args, shared_scores)
)
processes.append(p)
p.start()
# Wait for all processes to complete
for p in processes:
p.join()
# Convert shared list to regular list
scores = list(shared_scores)
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
def get_unfinished(
action_space, use_model, observation_type, result_dir, total_file_json
):
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
if not os.path.exists(target_dir):
return total_file_json
finished = {}
for domain in os.listdir(target_dir):
finished[domain] = []
domain_path = os.path.join(target_dir, domain)
if os.path.isdir(domain_path):
for example_id in os.listdir(domain_path):
if example_id == "onboard":
continue
example_path = os.path.join(domain_path, example_id)
if os.path.isdir(example_path):
if "result.txt" not in os.listdir(example_path):
# empty all files under example_id
for file in os.listdir(example_path):
os.remove(os.path.join(example_path, file))
else:
finished[domain].append(example_id)
if not finished:
return total_file_json
for domain, examples in finished.items():
if domain in total_file_json:
total_file_json[domain] = [
x for x in total_file_json[domain] if x not in examples
]
return total_file_json
def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
if not os.path.exists(target_dir):
print("New experiment, no result yet.")
return None
all_result = []
for domain in os.listdir(target_dir):
domain_path = os.path.join(target_dir, domain)
if os.path.isdir(domain_path):
for example_id in os.listdir(domain_path):
example_path = os.path.join(domain_path, example_id)
if os.path.isdir(example_path):
if "result.txt" in os.listdir(example_path):
# empty all files under example_id
try:
all_result.append(
float(
open(
os.path.join(example_path, "result.txt"), "r"
).read()
)
)
except:
all_result.append(0.0)
if not all_result:
print("New experiment, no result yet.")
return None
else:
print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%")
return all_result
if __name__ == "__main__":
####### The complete version of the list of examples #######
os.environ["TOKENIZERS_PARALLELISM"] = "false"
args = config()
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
test_all_meta = json.load(f)
if args.domain != "all":
test_all_meta = {args.domain: test_all_meta[args.domain]}
test_file_list = get_unfinished(
args.action_space,
args.model,
args.observation_type,
args.result_dir,
test_all_meta,
)
left_info = ""
for domain in test_file_list:
left_info += f"{domain}: {len(test_file_list[domain])}\n"
logger.info(f"Left tasks:\n{left_info}")
get_result(
args.action_space,
args.model,
args.observation_type,
args.result_dir,
test_all_meta,
)
test(args, test_file_list)
# path_to_vm can be a list["xxx","xxx"]