Compare commits
12 Commits
main
...
feat/claud
| Author | SHA1 | Date |
|---|---|---|
|
|
f7c33ed33b | |
|
|
0c426390e8 | |
|
|
996fc9d54e | |
|
|
1493246fa6 | |
|
|
7d4bcbd9d7 | |
|
|
14fe4d9476 | |
|
|
e75ac625fc | |
|
|
03385db30e | |
|
|
88b080388d | |
|
|
6549648c0d | |
|
|
15674876a0 | |
|
|
3455c8fb73 |
|
|
@ -19,6 +19,8 @@ Metric = Callable[[Any, Any], float]
|
|||
Getter = Callable[[gym.Env, Dict[str, Any]], Any]
|
||||
|
||||
MAX_RETRIES = 5 # Maximum retries for environment setup
|
||||
|
||||
|
||||
|
||||
class DesktopEnv(gym.Env):
|
||||
"""
|
||||
|
|
@ -104,7 +106,7 @@ class DesktopEnv(gym.Env):
|
|||
|
||||
# mode: human or machine
|
||||
self.instruction = None
|
||||
assert action_space in ["computer_13", "pyautogui"]
|
||||
assert action_space in ["computer_13", "pyautogui", "claude_computer_use"]
|
||||
self.action_space = action_space # todo: refactor it to the ActType
|
||||
|
||||
# episodic stuffs, like counters, will be updated or reset
|
||||
|
|
@ -307,7 +309,7 @@ class DesktopEnv(gym.Env):
|
|||
reward = 0 # todo: Define reward calculation for each example
|
||||
done = False # todo: Define episode termination condition for each example
|
||||
info = {}
|
||||
|
||||
logger.info(f"Step {self._step_no} in trajectory {self._traj_no} with action: {action}")
|
||||
# handle the special actions
|
||||
if action in ['WAIT', 'FAIL', 'DONE'] or (type(action) == dict and action['action_type'] in ['WAIT', 'FAIL', 'DONE']):
|
||||
if action == 'WAIT':
|
||||
|
|
@ -322,12 +324,15 @@ class DesktopEnv(gym.Env):
|
|||
if self.action_space == "computer_13":
|
||||
# the set of all possible actions defined in the action representation
|
||||
self.controller.execute_action(action)
|
||||
elif self.action_space == "pyautogui":
|
||||
elif self.action_space == "pyautogui" or self.action_space == "claude_computer_use":
|
||||
if action in ['WAIT', 'FAIL', 'DONE']:
|
||||
self.controller.execute_action(action)
|
||||
else:
|
||||
# the set of all possible python commands insides `pyautogui`
|
||||
self.controller.execute_python_command(action)
|
||||
if type(action) == str:
|
||||
self.controller.execute_python_command(action)
|
||||
elif type(action) == dict:
|
||||
self.controller.execute_python_command(action['command'])
|
||||
|
||||
time.sleep(pause)
|
||||
observation = self._get_obs()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,18 @@
|
|||
# Anthropic Agent Integration
|
||||
> Notice: As Anthropic API only supports image’s long edge is less than 1568 pixels and image is less than ~1,600 tokens, we resize the screenshot to 1280x720.
|
||||
## Setup
|
||||
To run with the Anthropic API, you need to set up your environment with the necessary API keys and configurations. Follow these steps:
|
||||
1. **Install Dependencies**: Ensure you have the required Python packages installed. You can do this by running:
|
||||
```bash
|
||||
pip install anthropic
|
||||
```
|
||||
2. **Set Environment Variables**: You need to set the environment variable with your API key. You can do this in .env:
|
||||
For aws bedrock:
|
||||
```.env
|
||||
AWS_ACCESS_KEY_ID=your_access_key_id
|
||||
AWS_SECRET_ACCESS_KEY=your_secret_access_key
|
||||
```
|
||||
For anthropic, you need set APIProvider to `anthropic` and set the API key:
|
||||
```.env
|
||||
ANTHROPIC_API_KEY=your_anthropic_api_key
|
||||
```
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
"""
|
||||
Anthropic agent implementation
|
||||
"""
|
||||
|
||||
from .main import AnthropicAgent
|
||||
from .tools import (
|
||||
BashTool,
|
||||
CLIResult,
|
||||
ComputerTool,
|
||||
EditTool,
|
||||
ToolCollection,
|
||||
ToolResult
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'AnthropicAgent',
|
||||
'BashTool',
|
||||
'CLIResult',
|
||||
'ComputerTool',
|
||||
'EditTool',
|
||||
'ToolCollection',
|
||||
'ToolResult'
|
||||
]
|
||||
|
|
@ -0,0 +1,442 @@
|
|||
import base64
|
||||
import os
|
||||
import time
|
||||
from typing import Any, cast, Optional, Dict
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
from anthropic import (
|
||||
Anthropic,
|
||||
AnthropicBedrock,
|
||||
AnthropicVertex,
|
||||
APIError,
|
||||
APIResponseValidationError,
|
||||
APIStatusError,
|
||||
)
|
||||
from anthropic.types.beta import (
|
||||
BetaMessageParam,
|
||||
BetaTextBlockParam,
|
||||
)
|
||||
from .utils import COMPUTER_USE_BETA_FLAG, PROMPT_CACHING_BETA_FLAG,SYSTEM_PROMPT, SYSTEM_PROMPT_WINDOWS, APIProvider, PROVIDER_TO_DEFAULT_MODEL_NAME
|
||||
from .utils import _response_to_params, _inject_prompt_caching, _maybe_filter_to_n_most_recent_images
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger("desktopenv.agent")
|
||||
|
||||
class AnthropicAgent:
|
||||
def __init__(self,
|
||||
platform: str = "Ubuntu",
|
||||
model: str = "claude-3-5-sonnet-20241022",
|
||||
provider: APIProvider = APIProvider.BEDROCK,
|
||||
max_tokens: int = 4096,
|
||||
api_key: str = os.environ.get("ANTHROPIC_API_KEY", None),
|
||||
system_prompt_suffix: str = "",
|
||||
only_n_most_recent_images: Optional[int] = 10,
|
||||
action_space: str = "claude_computer_use",
|
||||
screen_size: tuple[int, int] = (1920, 1080),
|
||||
*args, **kwargs
|
||||
):
|
||||
self.platform = platform
|
||||
self.action_space = action_space
|
||||
self.logger = logger
|
||||
self.class_name = self.__class__.__name__
|
||||
self.model_name = model
|
||||
self.provider = provider
|
||||
self.max_tokens = max_tokens
|
||||
self.api_key = api_key
|
||||
self.system_prompt_suffix = system_prompt_suffix
|
||||
self.only_n_most_recent_images = only_n_most_recent_images
|
||||
self.messages: list[BetaMessageParam] = []
|
||||
self.screen_size = screen_size
|
||||
self.resize_factor = (
|
||||
screen_size[0] / 1280, # Assuming 1280 is the base width
|
||||
screen_size[1] / 720 # Assuming 720 is the base height
|
||||
)
|
||||
|
||||
def add_tool_result(self, tool_call_id: str, result: str, screenshot: bytes = None):
|
||||
"""Add tool result to message history"""
|
||||
tool_result_content = [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_call_id,
|
||||
"content": [{"type": "text", "text": result}]
|
||||
}
|
||||
]
|
||||
|
||||
# Add screenshot if provided
|
||||
if screenshot is not None:
|
||||
screenshot_base64 = base64.b64encode(screenshot).decode('utf-8')
|
||||
tool_result_content[0]["content"].append({
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": screenshot_base64
|
||||
}
|
||||
})
|
||||
|
||||
self.messages.append({
|
||||
"role": "user",
|
||||
"content": tool_result_content
|
||||
})
|
||||
|
||||
def parse_actions_from_tool_call(self, tool_call: Dict) -> str:
|
||||
result = ""
|
||||
function_args = (
|
||||
tool_call["input"]
|
||||
)
|
||||
|
||||
action = function_args.get("action")
|
||||
if not action:
|
||||
action = tool_call.function.name
|
||||
action_conversion = {
|
||||
"left click": "click",
|
||||
"right click": "right_click"
|
||||
}
|
||||
action = action_conversion.get(action, action)
|
||||
|
||||
text = function_args.get("text")
|
||||
coordinate = function_args.get("coordinate")
|
||||
scroll_direction = function_args.get("scroll_direction")
|
||||
scroll_amount = function_args.get("scroll_amount")
|
||||
duration = function_args.get("duration")
|
||||
|
||||
# resize coordinates if resize_factor is set
|
||||
if coordinate and self.resize_factor:
|
||||
coordinate = (
|
||||
int(coordinate[0] * self.resize_factor[0]),
|
||||
int(coordinate[1] * self.resize_factor[1])
|
||||
)
|
||||
|
||||
# Handle mouse move and drag actions
|
||||
if action in ("mouse_move", "left_click_drag"):
|
||||
if coordinate is None:
|
||||
raise ValueError(f"coordinate is required for {action}")
|
||||
if text is not None:
|
||||
raise ValueError(f"text is not accepted for {action}")
|
||||
if not isinstance(coordinate, (list, tuple)) or len(coordinate) != 2:
|
||||
raise ValueError(f"{coordinate} must be a tuple of length 2")
|
||||
if not all(isinstance(i, int) for i in coordinate):
|
||||
raise ValueError(f"{coordinate} must be a tuple of ints")
|
||||
|
||||
x, y = coordinate[0], coordinate[1]
|
||||
if action == "mouse_move":
|
||||
result += (
|
||||
f"pyautogui.moveTo({x}, {y}, duration={duration or 0.5})\n"
|
||||
)
|
||||
expected_outcome = f"Mouse moved to ({x},{y})."
|
||||
elif action == "left_click_drag":
|
||||
result += (
|
||||
f"pyautogui.dragTo({x}, {y}, duration={duration or 0.5})\n"
|
||||
)
|
||||
expected_outcome = f"Cursor dragged to ({x},{y})."
|
||||
|
||||
# Handle keyboard actions
|
||||
elif action in ("key", "type"):
|
||||
if text is None:
|
||||
raise ValueError(f"text is required for {action}")
|
||||
if coordinate is not None:
|
||||
raise ValueError(f"coordinate is not accepted for {action}")
|
||||
if not isinstance(text, str):
|
||||
raise ValueError(f"{text} must be a string")
|
||||
|
||||
if action == "key":
|
||||
key_conversion = {
|
||||
"page_down": "pagedown",
|
||||
"page_up": "pageup",
|
||||
"super_l": "win",
|
||||
"super": "command",
|
||||
"escape": "esc"
|
||||
}
|
||||
keys = text.split('+')
|
||||
for key in keys:
|
||||
key = key.strip().lower()
|
||||
key = key_conversion.get(key, key)
|
||||
result += (f"pyautogui.keyDown('{key}')\n")
|
||||
for key in reversed(keys):
|
||||
key = key.strip().lower()
|
||||
key = key_conversion.get(key, key)
|
||||
result += (f"pyautogui.keyUp('{key}')\n")
|
||||
expected_outcome = f"Key {key} pressed."
|
||||
elif action == "type":
|
||||
result += (
|
||||
f"pyautogui.typewrite(\"\"\"{text}\"\"\", interval=0.01)\n"
|
||||
)
|
||||
expected_outcome = f"Text {text} written."
|
||||
|
||||
# Handle scroll actions
|
||||
elif action == "scroll":
|
||||
if coordinate is None:
|
||||
if scroll_direction in ("up", "down"):
|
||||
result += (
|
||||
f"pyautogui.scroll({scroll_amount if scroll_direction == 'up' else -scroll_amount})\n"
|
||||
)
|
||||
elif scroll_direction in ("left", "right"):
|
||||
result += (
|
||||
f"pyautogui.hscroll({scroll_amount if scroll_direction == 'right' else -scroll_amount})\n"
|
||||
)
|
||||
else:
|
||||
if scroll_direction in ("up", "down"):
|
||||
x, y = coordinate[0], coordinate[1]
|
||||
result += (
|
||||
f"pyautogui.scroll({scroll_amount if scroll_direction == 'up' else -scroll_amount}, {x}, {y})\n"
|
||||
)
|
||||
elif scroll_direction in ("left", "right"):
|
||||
x, y = coordinate[0], coordinate[1]
|
||||
result += (
|
||||
f"pyautogui.hscroll({scroll_amount if scroll_direction == 'right' else -scroll_amount}, {x}, {y})\n"
|
||||
)
|
||||
expected_outcome = "Scroll action finished"
|
||||
|
||||
# Handle click actions
|
||||
elif action in ("left_click", "right_click", "double_click", "middle_click", "left_press"):
|
||||
if coordinate is not None:
|
||||
x, y = coordinate
|
||||
if action == "left_click":
|
||||
result += (f"pyautogui.click({x}, {y})\n")
|
||||
elif action == "right_click":
|
||||
result += (f"pyautogui.rightClick({x}, {y})\n")
|
||||
elif action == "double_click":
|
||||
result += (f"pyautogui.doubleClick({x}, {y})\n")
|
||||
elif action == "middle_click":
|
||||
result += (f"pyautogui.middleClick({x}, {y})\n")
|
||||
elif action == "left_press":
|
||||
result += (f"pyautogui.mouseDown({x}, {y})\n")
|
||||
result += ("time.sleep(1)\n")
|
||||
result += (f"pyautogui.mouseUp({x}, {y})\n")
|
||||
else:
|
||||
if action == "left_click":
|
||||
result += ("pyautogui.click()\n")
|
||||
elif action == "right_click":
|
||||
result += ("pyautogui.rightClick()\n")
|
||||
elif action == "double_click":
|
||||
result += ("pyautogui.doubleClick()\n")
|
||||
elif action == "middle_click":
|
||||
result += ("pyautogui.middleClick()\n")
|
||||
elif action == "left_press":
|
||||
result += ("pyautogui.mouseDown()\n")
|
||||
result += ("time.sleep(1)\n")
|
||||
result += ("pyautogui.mouseUp()\n")
|
||||
expected_outcome = "Click action finished"
|
||||
|
||||
elif action == "wait":
|
||||
result += "pyautogui.sleep(0.5)\n"
|
||||
expected_outcome = "Wait for 0.5 seconds"
|
||||
elif action == "fail":
|
||||
result += "FAIL"
|
||||
expected_outcome = "Finished"
|
||||
elif action == "done":
|
||||
result += "DONE"
|
||||
expected_outcome = "Finished"
|
||||
elif action == "call_user":
|
||||
result += "CALL_USER"
|
||||
expected_outcome = "Call user"
|
||||
elif action == "screenshot":
|
||||
result += "pyautogui.sleep(0.1)\n"
|
||||
expected_outcome = "Screenshot taken"
|
||||
else:
|
||||
raise ValueError(f"Invalid action: {action}")
|
||||
|
||||
return result
|
||||
|
||||
def predict(self, task_instruction: str, obs: Dict = None, system: Any = None):
|
||||
system = BetaTextBlockParam(
|
||||
type="text",
|
||||
text=f"{SYSTEM_PROMPT_WINDOWS if self.platform == 'Windows' else SYSTEM_PROMPT}{' ' + self.system_prompt_suffix if self.system_prompt_suffix else ''}"
|
||||
)
|
||||
|
||||
# resize screenshot if resize_factor is set
|
||||
if obs and "screenshot" in obs:
|
||||
# Convert bytes to PIL Image
|
||||
screenshot_bytes = obs["screenshot"]
|
||||
screenshot_image = Image.open(io.BytesIO(screenshot_bytes))
|
||||
|
||||
# Calculate new size based on resize factor
|
||||
new_width, new_height = 1280, 720
|
||||
|
||||
# Resize the image
|
||||
resized_image = screenshot_image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||||
|
||||
# Convert back to bytes
|
||||
output_buffer = io.BytesIO()
|
||||
resized_image.save(output_buffer, format='PNG')
|
||||
obs["screenshot"] = output_buffer.getvalue()
|
||||
|
||||
|
||||
if not self.messages:
|
||||
|
||||
init_screenshot = obs
|
||||
init_screenshot_base64 = base64.b64encode(init_screenshot["screenshot"]).decode('utf-8')
|
||||
self.messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": init_screenshot_base64,
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": task_instruction},
|
||||
]
|
||||
})
|
||||
|
||||
if self.messages and "tool_use" in [content_block["type"] for content_block in self.messages[-1]["content"]]:
|
||||
self.add_tool_result(
|
||||
self.messages[-1]["content"][-1]["id"],
|
||||
f"Success",
|
||||
screenshot=obs.get("screenshot") if obs else None
|
||||
)
|
||||
|
||||
enable_prompt_caching = False
|
||||
betas = ["computer-use-2025-01-24"]
|
||||
if self.model_name == "claude-3-7-sonnet-20250219" or self.model_name == "claude-4-opus-20250514" or self.model_name == "claude-4-sonnet-20250514":
|
||||
betas = ["computer-use-2025-01-24"]
|
||||
elif self.model_name == "claude-3-5-sonnet-20241022":
|
||||
betas = [COMPUTER_USE_BETA_FLAG]
|
||||
|
||||
image_truncation_threshold = 10
|
||||
if self.provider == APIProvider.ANTHROPIC:
|
||||
client = Anthropic(api_key=self.api_key, max_retries=4)
|
||||
enable_prompt_caching = True
|
||||
elif self.provider == APIProvider.VERTEX:
|
||||
client = AnthropicVertex()
|
||||
elif self.provider == APIProvider.BEDROCK:
|
||||
client = AnthropicBedrock(
|
||||
# Authenticate by either providing the keys below or use the default AWS credential providers, such as
|
||||
# using ~/.aws/credentials or the "AWS_SECRET_ACCESS_KEY" and "AWS_ACCESS_KEY_ID" environment variables.
|
||||
aws_access_key=os.getenv('AWS_ACCESS_KEY_ID'),
|
||||
aws_secret_key=os.getenv('AWS_SECRET_ACCESS_KEY'),
|
||||
# aws_region changes the aws region to which the request is made. By default, we read AWS_REGION,
|
||||
# and if that's not present, we default to us-east-1. Note that we do not read ~/.aws/config for the region.
|
||||
aws_region=os.getenv('AWS_DEFAULT_REGION'),
|
||||
)
|
||||
|
||||
if enable_prompt_caching:
|
||||
betas.append(PROMPT_CACHING_BETA_FLAG)
|
||||
_inject_prompt_caching(self.messages)
|
||||
image_truncation_threshold = 50
|
||||
system["cache_control"] = {"type": "ephemeral"}
|
||||
|
||||
if self.only_n_most_recent_images:
|
||||
_maybe_filter_to_n_most_recent_images(
|
||||
self.messages,
|
||||
self.only_n_most_recent_images,
|
||||
min_removal_threshold=image_truncation_threshold,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
if self.model_name == "claude-3-5-sonnet-20241022":
|
||||
tools = [
|
||||
{'name': 'computer', 'type': 'computer_20241022', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1},
|
||||
# {'type': 'bash_20241022', 'name': 'bash'},
|
||||
# {'name': 'str_replace_editor', 'type': 'text_editor_20241022'}
|
||||
] if self.platform == 'Ubuntu' else [
|
||||
{'name': 'computer', 'type': 'computer_20241022', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1},
|
||||
]
|
||||
elif self.model_name == "claude-3-7-sonnet-20250219" or self.model_name == "claude-4-opus-20250514" or self.model_name == "claude-4-sonnet-20250514":
|
||||
tools = [
|
||||
{'name': 'computer', 'type': 'computer_20250124', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1},
|
||||
# {'type': 'bash_20250124', 'name': 'bash'},
|
||||
# {'name': 'str_replace_editor', 'type': 'text_editor_20250124'}
|
||||
] if self.platform == 'Ubuntu' else [
|
||||
{'name': 'computer', 'type': 'computer_20250124', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1},
|
||||
]
|
||||
extra_body = {
|
||||
"thinking": {"type": "enabled", "budget_tokens": 1024}
|
||||
}
|
||||
response = None
|
||||
if self.model_name == "claude-3-7-sonnet-20250219" or self.model_name == "claude-4-opus-20250514" or self.model_name == "claude-4-sonnet-20250514":
|
||||
response = client.beta.messages.create(
|
||||
max_tokens=self.max_tokens,
|
||||
messages=self.messages,
|
||||
model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name],
|
||||
system=[system],
|
||||
tools=tools,
|
||||
betas=betas,
|
||||
extra_body=extra_body
|
||||
)
|
||||
elif self.model_name == "claude-3-5-sonnet-20241022":
|
||||
response = client.beta.messages.create(
|
||||
max_tokens=self.max_tokens,
|
||||
messages=self.messages,
|
||||
model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name],
|
||||
system=[system],
|
||||
tools=tools,
|
||||
betas=betas,
|
||||
)
|
||||
|
||||
except (APIError, APIStatusError, APIResponseValidationError) as e:
|
||||
self.logger.exception(f"Anthropic API error: {str(e)}")
|
||||
try:
|
||||
self.logger.warning("Retrying with backup API key...")
|
||||
backup_client = Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY_BACKUP"), max_retries=4)
|
||||
|
||||
if self.model_name == "claude-3-7-sonnet-20250219" or self.model_name == "claude-4-opus-20250514" or self.model_name == "claude-4-sonnet-20250514":
|
||||
response = backup_client.beta.messages.create(
|
||||
max_tokens=self.max_tokens,
|
||||
messages=self.messages,
|
||||
model=PROVIDER_TO_DEFAULT_MODEL_NAME[APIProvider.ANTHROPIC, self.model_name],
|
||||
system=[system],
|
||||
tools=tools,
|
||||
betas=betas,
|
||||
extra_body=extra_body
|
||||
)
|
||||
elif self.model_name == "claude-3-5-sonnet-20241022":
|
||||
response = backup_client.beta.messages.create(
|
||||
max_tokens=self.max_tokens,
|
||||
messages=self.messages,
|
||||
model=PROVIDER_TO_DEFAULT_MODEL_NAME[APIProvider.ANTHROPIC, self.model_name],
|
||||
system=[system],
|
||||
tools=tools,
|
||||
betas=betas,
|
||||
)
|
||||
self.logger.info("Successfully used backup API key")
|
||||
except Exception as backup_e:
|
||||
self.logger.exception(f"Backup API call also failed: {str(backup_e)}")
|
||||
return None, None
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Error in Anthropic API: {str(e)}")
|
||||
return None, None
|
||||
|
||||
response_params = _response_to_params(response)
|
||||
logger.info(f"Received response params: {response_params}")
|
||||
|
||||
# Store response in message history
|
||||
self.messages.append({
|
||||
"role": "assistant",
|
||||
"content": response_params
|
||||
})
|
||||
|
||||
actions: list[Any] = []
|
||||
reasonings: list[str] = []
|
||||
for content_block in response_params:
|
||||
if content_block["type"] == "tool_use":
|
||||
actions.append({
|
||||
"name": content_block["name"],
|
||||
"input": cast(dict[str, Any], content_block["input"]),
|
||||
"id": content_block["id"],
|
||||
"action_type": content_block.get("type"),
|
||||
"command": self.parse_actions_from_tool_call(content_block)
|
||||
})
|
||||
elif content_block["type"] == "text":
|
||||
reasonings.append(content_block["text"])
|
||||
if isinstance(reasonings, list) and len(reasonings) > 0:
|
||||
reasonings = reasonings[0]
|
||||
else:
|
||||
reasonings = ""
|
||||
logger.info(f"Received actions: {actions}")
|
||||
logger.info(f"Received reasonings: {reasonings}")
|
||||
if len(actions) == 0:
|
||||
actions = ["DONE"]
|
||||
return reasonings, actions
|
||||
|
||||
def reset(self, *args, **kwargs):
|
||||
"""
|
||||
Reset the agent's state.
|
||||
"""
|
||||
self.messages = []
|
||||
self.logger.info(f"{self.class_name} reset.")
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
from .base import CLIResult, ToolResult
|
||||
from .bash import BashTool
|
||||
from .collection import ToolCollection
|
||||
from .computer import ComputerTool
|
||||
from .edit import EditTool
|
||||
|
||||
__ALL__ = [
|
||||
BashTool,
|
||||
CLIResult,
|
||||
ComputerTool,
|
||||
EditTool,
|
||||
ToolCollection,
|
||||
ToolResult,
|
||||
]
|
||||
|
|
@ -0,0 +1,69 @@
|
|||
from abc import ABCMeta, abstractmethod
|
||||
from dataclasses import dataclass, fields, replace
|
||||
from typing import Any, Optional
|
||||
|
||||
from anthropic.types.beta import BetaToolUnionParam
|
||||
|
||||
|
||||
class BaseAnthropicTool(metaclass=ABCMeta):
|
||||
"""Abstract base class for Anthropic-defined tools."""
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, **kwargs) -> Any:
|
||||
"""Executes the tool with the given arguments."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def to_params(
|
||||
self,
|
||||
) -> BetaToolUnionParam:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass(frozen=True) #kw_only=True,
|
||||
class ToolResult:
|
||||
"""Represents the result of a tool execution."""
|
||||
|
||||
output: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
base64_image: Optional[str] = None
|
||||
system: Optional[str] = None
|
||||
|
||||
def __bool__(self):
|
||||
return any(getattr(self, field.name) for field in fields(self))
|
||||
|
||||
def __add__(self, other: "ToolResult"):
|
||||
def combine_fields(
|
||||
field: Optional[str], other_field: Optional[str], concatenate: bool = True
|
||||
):
|
||||
if field and other_field:
|
||||
if concatenate:
|
||||
return field + other_field
|
||||
raise ValueError("Cannot combine tool results")
|
||||
return field or other_field
|
||||
|
||||
return ToolResult(
|
||||
output=combine_fields(self.output, other.output),
|
||||
error=combine_fields(self.error, other.error),
|
||||
base64_image=combine_fields(self.base64_image, other.base64_image, False),
|
||||
system=combine_fields(self.system, other.system),
|
||||
)
|
||||
|
||||
def replace(self, **kwargs):
|
||||
"""Returns a new ToolResult with the given fields replaced."""
|
||||
return replace(self, **kwargs)
|
||||
|
||||
|
||||
class CLIResult(ToolResult):
|
||||
"""A ToolResult that can be rendered as a CLI output."""
|
||||
|
||||
|
||||
class ToolFailure(ToolResult):
|
||||
"""A ToolResult that represents a failure."""
|
||||
|
||||
|
||||
class ToolError(Exception):
|
||||
"""Raised when a tool encounters an error."""
|
||||
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
|
|
@ -0,0 +1,144 @@
|
|||
import asyncio
|
||||
import os
|
||||
from typing import ClassVar, Literal, Optional
|
||||
|
||||
from anthropic.types.beta import BetaToolBash20241022Param
|
||||
|
||||
from .base import BaseAnthropicTool, CLIResult, ToolError, ToolResult
|
||||
|
||||
|
||||
class _BashSession:
|
||||
"""A session of a bash shell."""
|
||||
|
||||
_started: bool
|
||||
_process: asyncio.subprocess.Process
|
||||
|
||||
command: str = "/bin/bash"
|
||||
_output_delay: float = 0.2 # seconds
|
||||
_timeout: float = 120.0 # seconds
|
||||
_sentinel: str = "<<exit>>"
|
||||
|
||||
def __init__(self):
|
||||
self._started = False
|
||||
self._timed_out = False
|
||||
|
||||
async def start(self):
|
||||
if self._started:
|
||||
return
|
||||
|
||||
self._process = await asyncio.create_subprocess_shell(
|
||||
self.command,
|
||||
preexec_fn=os.setsid,
|
||||
shell=True,
|
||||
bufsize=0,
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
|
||||
self._started = True
|
||||
|
||||
def stop(self):
|
||||
"""Terminate the bash shell."""
|
||||
if not self._started:
|
||||
raise ToolError("Session has not started.")
|
||||
if self._process.returncode is not None:
|
||||
return
|
||||
self._process.terminate()
|
||||
|
||||
async def run(self, command: str):
|
||||
"""Execute a command in the bash shell."""
|
||||
if not self._started:
|
||||
raise ToolError("Session has not started.")
|
||||
if self._process.returncode is not None:
|
||||
return ToolResult(
|
||||
system="tool must be restarted",
|
||||
error=f"bash has exited with returncode {self._process.returncode}",
|
||||
)
|
||||
if self._timed_out:
|
||||
raise ToolError(
|
||||
f"timed out: bash has not returned in {self._timeout} seconds and must be restarted",
|
||||
)
|
||||
|
||||
# we know these are not None because we created the process with PIPEs
|
||||
assert self._process.stdin
|
||||
assert self._process.stdout
|
||||
assert self._process.stderr
|
||||
|
||||
# send command to the process
|
||||
self._process.stdin.write(
|
||||
command.encode() + f"; echo '{self._sentinel}'\n".encode()
|
||||
)
|
||||
await self._process.stdin.drain()
|
||||
|
||||
# read output from the process, until the sentinel is found
|
||||
try:
|
||||
async with asyncio.timeout(self._timeout):
|
||||
while True:
|
||||
await asyncio.sleep(self._output_delay)
|
||||
# if we read directly from stdout/stderr, it will wait forever for
|
||||
# EOF. use the StreamReader buffer directly instead.
|
||||
output = self._process.stdout._buffer.decode() # pyright: ignore[reportAttributeAccessIssue]
|
||||
if self._sentinel in output:
|
||||
# strip the sentinel and break
|
||||
output = output[: output.index(self._sentinel)]
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
self._timed_out = True
|
||||
raise ToolError(
|
||||
f"timed out: bash has not returned in {self._timeout} seconds and must be restarted",
|
||||
) from None
|
||||
|
||||
if output.endswith("\n"):
|
||||
output = output[:-1]
|
||||
|
||||
error = self._process.stderr._buffer.decode() # pyright: ignore[reportAttributeAccessIssue]
|
||||
if error.endswith("\n"):
|
||||
error = error[:-1]
|
||||
|
||||
# clear the buffers so that the next output can be read correctly
|
||||
self._process.stdout._buffer.clear() # pyright: ignore[reportAttributeAccessIssue]
|
||||
self._process.stderr._buffer.clear() # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
return CLIResult(output=output, error=error)
|
||||
|
||||
|
||||
class BashTool(BaseAnthropicTool):
|
||||
"""
|
||||
A tool that allows the agent to run bash commands.
|
||||
The tool parameters are defined by Anthropic and are not editable.
|
||||
"""
|
||||
|
||||
_session: Optional[_BashSession]
|
||||
name: ClassVar[Literal["bash"]] = "bash"
|
||||
api_type: ClassVar[Literal["bash_20241022"]] = "bash_20241022"
|
||||
|
||||
def __init__(self):
|
||||
self._session = None
|
||||
super().__init__()
|
||||
|
||||
async def __call__(
|
||||
self, command: Optional[str] = None, restart: bool = False, **kwargs
|
||||
):
|
||||
if restart:
|
||||
if self._session:
|
||||
self._session.stop()
|
||||
self._session = _BashSession()
|
||||
await self._session.start()
|
||||
|
||||
return ToolResult(system="tool has been restarted.")
|
||||
|
||||
if self._session is None:
|
||||
self._session = _BashSession()
|
||||
await self._session.start()
|
||||
|
||||
if command is not None:
|
||||
return await self._session.run(command)
|
||||
|
||||
raise ToolError("no command provided.")
|
||||
|
||||
def to_params(self) -> BetaToolBash20241022Param:
|
||||
return {
|
||||
"type": self.api_type,
|
||||
"name": self.name,
|
||||
}
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
"""Collection classes for managing multiple tools."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from anthropic.types.beta import BetaToolUnionParam
|
||||
|
||||
from .base import (
|
||||
BaseAnthropicTool,
|
||||
ToolError,
|
||||
ToolFailure,
|
||||
ToolResult,
|
||||
)
|
||||
|
||||
|
||||
class ToolCollection:
|
||||
"""A collection of anthropic-defined tools."""
|
||||
|
||||
def __init__(self, *tools: BaseAnthropicTool):
|
||||
self.tools = tools
|
||||
self.tool_map = {tool.to_params()["name"]: tool for tool in tools}
|
||||
|
||||
def to_params(
|
||||
self,
|
||||
) -> list[BetaToolUnionParam]:
|
||||
return [tool.to_params() for tool in self.tools]
|
||||
|
||||
async def run(self, *, name: str, tool_input: dict[str, Any]) -> ToolResult:
|
||||
tool = self.tool_map.get(name)
|
||||
if not tool:
|
||||
return ToolFailure(error=f"Tool {name} is invalid")
|
||||
try:
|
||||
return await tool(**tool_input)
|
||||
except ToolError as e:
|
||||
return ToolFailure(error=e.message)
|
||||
|
|
@ -0,0 +1,260 @@
|
|||
import asyncio
|
||||
import base64
|
||||
import os
|
||||
import shlex
|
||||
import shutil
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Literal, TypedDict, Optional, Tuple
|
||||
from uuid import uuid4
|
||||
|
||||
from anthropic.types.beta import BetaToolComputerUse20241022Param
|
||||
|
||||
from .base import BaseAnthropicTool, ToolError, ToolResult
|
||||
from .run import run
|
||||
|
||||
OUTPUT_DIR = "/tmp/outputs"
|
||||
|
||||
TYPING_DELAY_MS = 12
|
||||
TYPING_GROUP_SIZE = 50
|
||||
|
||||
Action = Literal[
|
||||
"key",
|
||||
"type",
|
||||
"mouse_move",
|
||||
"left_click",
|
||||
"left_click_drag",
|
||||
"right_click",
|
||||
"middle_click",
|
||||
"double_click",
|
||||
"screenshot",
|
||||
"cursor_position",
|
||||
]
|
||||
|
||||
|
||||
class Resolution(TypedDict):
|
||||
width: int
|
||||
height: int
|
||||
|
||||
|
||||
# sizes above XGA/WXGA are not recommended (see README.md)
|
||||
# scale down to one of these targets if ComputerTool._scaling_enabled is set
|
||||
MAX_SCALING_TARGETS: dict[str, Resolution] = {
|
||||
"XGA": Resolution(width=1024, height=768), # 4:3
|
||||
"WXGA": Resolution(width=1280, height=800), # 16:10
|
||||
"FWXGA": Resolution(width=1366, height=768), # ~16:9
|
||||
}
|
||||
|
||||
|
||||
class ScalingSource(Enum):
|
||||
COMPUTER = "computer"
|
||||
API = "api"
|
||||
|
||||
|
||||
class ComputerToolOptions(TypedDict):
|
||||
display_height_px: int
|
||||
display_width_px: int
|
||||
display_number: Optional[int]
|
||||
|
||||
|
||||
def chunks(s: str, chunk_size: int) -> list[str]:
|
||||
return [s[i : i + chunk_size] for i in range(0, len(s), chunk_size)]
|
||||
|
||||
|
||||
class ComputerTool(BaseAnthropicTool):
|
||||
"""
|
||||
A tool that allows the agent to interact with the screen, keyboard, and mouse of the current computer.
|
||||
The tool parameters are defined by Anthropic and are not editable.
|
||||
"""
|
||||
|
||||
name: Literal["computer"] = "computer"
|
||||
api_type: Literal["computer_20241022"] = "computer_20241022"
|
||||
width: int
|
||||
height: int
|
||||
display_num: Optional[int]
|
||||
|
||||
_screenshot_delay = 2.0
|
||||
_scaling_enabled = True
|
||||
|
||||
@property
|
||||
def options(self) -> ComputerToolOptions:
|
||||
width, height = self.scale_coordinates(
|
||||
ScalingSource.COMPUTER, self.width, self.height
|
||||
)
|
||||
return {
|
||||
"display_width_px": width,
|
||||
"display_height_px": height,
|
||||
"display_number": self.display_num,
|
||||
}
|
||||
|
||||
def to_params(self) -> BetaToolComputerUse20241022Param:
|
||||
return {"name": self.name, "type": self.api_type, **self.options}
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.width = int(os.getenv("WIDTH") or 0)
|
||||
self.height = int(os.getenv("HEIGHT") or 0)
|
||||
assert self.width and self.height, "WIDTH, HEIGHT must be set"
|
||||
if (display_num := os.getenv("DISPLAY_NUM")) is not None:
|
||||
self.display_num = int(display_num)
|
||||
self._display_prefix = f"DISPLAY=:{self.display_num} "
|
||||
else:
|
||||
self.display_num = None
|
||||
self._display_prefix = ""
|
||||
|
||||
self.xdotool = f"{self._display_prefix}xdotool"
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
*,
|
||||
action: Action,
|
||||
text: Optional[str] = None,
|
||||
coordinate: Optional[Tuple[int, int]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if action in ("mouse_move", "left_click_drag"):
|
||||
if coordinate is None:
|
||||
raise ToolError(f"coordinate is required for {action}")
|
||||
if text is not None:
|
||||
raise ToolError(f"text is not accepted for {action}")
|
||||
if not isinstance(coordinate, list) or len(coordinate) != 2:
|
||||
raise ToolError(f"{coordinate} must be a tuple of length 2")
|
||||
if not all(isinstance(i, int) and i >= 0 for i in coordinate):
|
||||
raise ToolError(f"{coordinate} must be a tuple of non-negative ints")
|
||||
|
||||
x, y = self.scale_coordinates(
|
||||
ScalingSource.API, coordinate[0], coordinate[1]
|
||||
)
|
||||
|
||||
if action == "mouse_move":
|
||||
return await self.shell(f"{self.xdotool} mousemove --sync {x} {y}")
|
||||
elif action == "left_click_drag":
|
||||
return await self.shell(
|
||||
f"{self.xdotool} mousedown 1 mousemove --sync {x} {y} mouseup 1"
|
||||
)
|
||||
|
||||
if action in ("key", "type"):
|
||||
if text is None:
|
||||
raise ToolError(f"text is required for {action}")
|
||||
if coordinate is not None:
|
||||
raise ToolError(f"coordinate is not accepted for {action}")
|
||||
if not isinstance(text, str):
|
||||
raise ToolError(output=f"{text} must be a string")
|
||||
|
||||
if action == "key":
|
||||
return await self.shell(f"{self.xdotool} key -- {text}")
|
||||
elif action == "type":
|
||||
results: list[ToolResult] = []
|
||||
for chunk in chunks(text, TYPING_GROUP_SIZE):
|
||||
cmd = f"{self.xdotool} type --delay {TYPING_DELAY_MS} -- {shlex.quote(chunk)}"
|
||||
results.append(await self.shell(cmd, take_screenshot=False))
|
||||
screenshot_base64 = (await self.screenshot()).base64_image
|
||||
return ToolResult(
|
||||
output="".join(result.output or "" for result in results),
|
||||
error="".join(result.error or "" for result in results),
|
||||
base64_image=screenshot_base64,
|
||||
)
|
||||
|
||||
if action in (
|
||||
"left_click",
|
||||
"right_click",
|
||||
"double_click",
|
||||
"middle_click",
|
||||
"screenshot",
|
||||
"cursor_position",
|
||||
):
|
||||
if text is not None:
|
||||
raise ToolError(f"text is not accepted for {action}")
|
||||
if coordinate is not None:
|
||||
raise ToolError(f"coordinate is not accepted for {action}")
|
||||
|
||||
if action == "screenshot":
|
||||
return await self.screenshot()
|
||||
elif action == "cursor_position":
|
||||
result = await self.shell(
|
||||
f"{self.xdotool} getmouselocation --shell",
|
||||
take_screenshot=False,
|
||||
)
|
||||
output = result.output or ""
|
||||
x, y = self.scale_coordinates(
|
||||
ScalingSource.COMPUTER,
|
||||
int(output.split("X=")[1].split("\n")[0]),
|
||||
int(output.split("Y=")[1].split("\n")[0]),
|
||||
)
|
||||
return result.replace(output=f"X={x},Y={y}")
|
||||
else:
|
||||
click_arg = {
|
||||
"left_click": "1",
|
||||
"right_click": "3",
|
||||
"middle_click": "2",
|
||||
"double_click": "--repeat 2 --delay 500 1",
|
||||
}[action]
|
||||
return await self.shell(f"{self.xdotool} click {click_arg}")
|
||||
|
||||
raise ToolError(f"Invalid action: {action}")
|
||||
|
||||
async def screenshot(self):
|
||||
"""Take a screenshot of the current screen and return the base64 encoded image."""
|
||||
output_dir = Path(OUTPUT_DIR)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
path = output_dir / f"screenshot_{uuid4().hex}.png"
|
||||
|
||||
# Try gnome-screenshot first
|
||||
if shutil.which("gnome-screenshot"):
|
||||
screenshot_cmd = f"{self._display_prefix}gnome-screenshot -f {path} -p"
|
||||
else:
|
||||
# Fall back to scrot if gnome-screenshot isn't available
|
||||
screenshot_cmd = f"{self._display_prefix}scrot -p {path}"
|
||||
|
||||
result = await self.shell(screenshot_cmd, take_screenshot=False)
|
||||
if self._scaling_enabled:
|
||||
x, y = self.scale_coordinates(
|
||||
ScalingSource.COMPUTER, self.width, self.height
|
||||
)
|
||||
await self.shell(
|
||||
f"convert {path} -resize {x}x{y}! {path}", take_screenshot=False
|
||||
)
|
||||
|
||||
if path.exists():
|
||||
return result.replace(
|
||||
base64_image=base64.b64encode(path.read_bytes()).decode()
|
||||
)
|
||||
raise ToolError(f"Failed to take screenshot: {result.error}")
|
||||
|
||||
async def shell(self, command: str, take_screenshot=True) -> ToolResult:
|
||||
"""Run a shell command and return the output, error, and optionally a screenshot."""
|
||||
_, stdout, stderr = await run(command)
|
||||
base64_image = None
|
||||
|
||||
if take_screenshot:
|
||||
# delay to let things settle before taking a screenshot
|
||||
await asyncio.sleep(self._screenshot_delay)
|
||||
base64_image = (await self.screenshot()).base64_image
|
||||
|
||||
return ToolResult(output=stdout, error=stderr, base64_image=base64_image)
|
||||
|
||||
def scale_coordinates(self, source: ScalingSource, x: int, y: int):
|
||||
"""Scale coordinates to a target maximum resolution."""
|
||||
if not self._scaling_enabled:
|
||||
return x, y
|
||||
ratio = self.width / self.height
|
||||
target_dimension = None
|
||||
for dimension in MAX_SCALING_TARGETS.values():
|
||||
# allow some error in the aspect ratio - not ratios are exactly 16:9
|
||||
if abs(dimension["width"] / dimension["height"] - ratio) < 0.02:
|
||||
if dimension["width"] < self.width:
|
||||
target_dimension = dimension
|
||||
break
|
||||
if target_dimension is None:
|
||||
return x, y
|
||||
# should be less than 1
|
||||
x_scaling_factor = target_dimension["width"] / self.width
|
||||
y_scaling_factor = target_dimension["height"] / self.height
|
||||
if source == ScalingSource.API:
|
||||
if x > self.width or y > self.height:
|
||||
raise ToolError(f"Coordinates {x}, {y} are out of bounds")
|
||||
# scale up
|
||||
return round(x / x_scaling_factor), round(y / y_scaling_factor)
|
||||
# scale down
|
||||
return round(x * x_scaling_factor), round(y * y_scaling_factor)
|
||||
|
|
@ -0,0 +1,290 @@
|
|||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Literal, get_args, Optional, List
|
||||
|
||||
from anthropic.types.beta import BetaToolTextEditor20241022Param
|
||||
|
||||
from .base import BaseAnthropicTool, CLIResult, ToolError, ToolResult
|
||||
from .run import maybe_truncate, run
|
||||
|
||||
Command = Literal[
|
||||
"view",
|
||||
"create",
|
||||
"str_replace",
|
||||
"insert",
|
||||
"undo_edit",
|
||||
]
|
||||
SNIPPET_LINES: int = 4
|
||||
|
||||
|
||||
class EditTool(BaseAnthropicTool):
|
||||
"""
|
||||
An filesystem editor tool that allows the agent to view, create, and edit files.
|
||||
The tool parameters are defined by Anthropic and are not editable.
|
||||
"""
|
||||
|
||||
api_type: Literal["text_editor_20241022"] = "text_editor_20241022"
|
||||
name: Literal["str_replace_editor"] = "str_replace_editor"
|
||||
|
||||
_file_history: dict[Path, list[str]]
|
||||
|
||||
def __init__(self):
|
||||
self._file_history = defaultdict(list)
|
||||
super().__init__()
|
||||
|
||||
def to_params(self) -> BetaToolTextEditor20241022Param:
|
||||
return {
|
||||
"name": self.name,
|
||||
"type": self.api_type,
|
||||
}
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
*,
|
||||
command: Command,
|
||||
path: str,
|
||||
file_text: Optional[str] = None,
|
||||
view_range: Optional[list[int]] = None,
|
||||
old_str: Optional[str] = None,
|
||||
new_str: Optional[str] = None,
|
||||
insert_line: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
_path = Path(path)
|
||||
self.validate_path(command, _path)
|
||||
if command == "view":
|
||||
return await self.view(_path, view_range)
|
||||
elif command == "create":
|
||||
if file_text is None:
|
||||
raise ToolError("Parameter `file_text` is required for command: create")
|
||||
self.write_file(_path, file_text)
|
||||
self._file_history[_path].append(file_text)
|
||||
return ToolResult(output=f"File created successfully at: {_path}")
|
||||
elif command == "str_replace":
|
||||
if old_str is None:
|
||||
raise ToolError(
|
||||
"Parameter `old_str` is required for command: str_replace"
|
||||
)
|
||||
return self.str_replace(_path, old_str, new_str)
|
||||
elif command == "insert":
|
||||
if insert_line is None:
|
||||
raise ToolError(
|
||||
"Parameter `insert_line` is required for command: insert"
|
||||
)
|
||||
if new_str is None:
|
||||
raise ToolError("Parameter `new_str` is required for command: insert")
|
||||
return self.insert(_path, insert_line, new_str)
|
||||
elif command == "undo_edit":
|
||||
return self.undo_edit(_path)
|
||||
raise ToolError(
|
||||
f'Unrecognized command {command}. The allowed commands for the {self.name} tool are: {", ".join(get_args(Command))}'
|
||||
)
|
||||
|
||||
def validate_path(self, command: str, path: Path):
|
||||
"""
|
||||
Check that the path/command combination is valid.
|
||||
"""
|
||||
# Check if its an absolute path
|
||||
if not path.is_absolute():
|
||||
suggested_path = Path("") / path
|
||||
raise ToolError(
|
||||
f"The path {path} is not an absolute path, it should start with `/`. Maybe you meant {suggested_path}?"
|
||||
)
|
||||
# Check if path exists
|
||||
if not path.exists() and command != "create":
|
||||
raise ToolError(
|
||||
f"The path {path} does not exist. Please provide a valid path."
|
||||
)
|
||||
if path.exists() and command == "create":
|
||||
raise ToolError(
|
||||
f"File already exists at: {path}. Cannot overwrite files using command `create`."
|
||||
)
|
||||
# Check if the path points to a directory
|
||||
if path.is_dir():
|
||||
if command != "view":
|
||||
raise ToolError(
|
||||
f"The path {path} is a directory and only the `view` command can be used on directories"
|
||||
)
|
||||
|
||||
async def view(self, path: Path, view_range: Optional[List[int]] = None):
|
||||
"""Implement the view command"""
|
||||
if path.is_dir():
|
||||
if view_range:
|
||||
raise ToolError(
|
||||
"The `view_range` parameter is not allowed when `path` points to a directory."
|
||||
)
|
||||
|
||||
_, stdout, stderr = await run(
|
||||
rf"find {path} -maxdepth 2 -not -path '*/\.*'"
|
||||
)
|
||||
if not stderr:
|
||||
stdout = f"Here's the files and directories up to 2 levels deep in {path}, excluding hidden items:\n{stdout}\n"
|
||||
return CLIResult(output=stdout, error=stderr)
|
||||
|
||||
file_content = self.read_file(path)
|
||||
init_line = 1
|
||||
if view_range:
|
||||
if len(view_range) != 2 or not all(isinstance(i, int) for i in view_range):
|
||||
raise ToolError(
|
||||
"Invalid `view_range`. It should be a list of two integers."
|
||||
)
|
||||
file_lines = file_content.split("\n")
|
||||
n_lines_file = len(file_lines)
|
||||
init_line, final_line = view_range
|
||||
if init_line < 1 or init_line > n_lines_file:
|
||||
raise ToolError(
|
||||
f"Invalid `view_range`: {view_range}. Its first element `{init_line}` should be within the range of lines of the file: {[1, n_lines_file]}"
|
||||
)
|
||||
if final_line > n_lines_file:
|
||||
raise ToolError(
|
||||
f"Invalid `view_range`: {view_range}. Its second element `{final_line}` should be smaller than the number of lines in the file: `{n_lines_file}`"
|
||||
)
|
||||
if final_line != -1 and final_line < init_line:
|
||||
raise ToolError(
|
||||
f"Invalid `view_range`: {view_range}. Its second element `{final_line}` should be larger or equal than its first `{init_line}`"
|
||||
)
|
||||
|
||||
if final_line == -1:
|
||||
file_content = "\n".join(file_lines[init_line - 1 :])
|
||||
else:
|
||||
file_content = "\n".join(file_lines[init_line - 1 : final_line])
|
||||
|
||||
return CLIResult(
|
||||
output=self._make_output(file_content, str(path), init_line=init_line)
|
||||
)
|
||||
|
||||
def str_replace(self, path: Path, old_str: str, new_str: Optional[str]):
|
||||
"""Implement the str_replace command, which replaces old_str with new_str in the file content"""
|
||||
# Read the file content
|
||||
file_content = self.read_file(path).expandtabs()
|
||||
old_str = old_str.expandtabs()
|
||||
new_str = new_str.expandtabs() if new_str is not None else ""
|
||||
|
||||
# Check if old_str is unique in the file
|
||||
occurrences = file_content.count(old_str)
|
||||
if occurrences == 0:
|
||||
raise ToolError(
|
||||
f"No replacement was performed, old_str `{old_str}` did not appear verbatim in {path}."
|
||||
)
|
||||
elif occurrences > 1:
|
||||
file_content_lines = file_content.split("\n")
|
||||
lines = [
|
||||
idx + 1
|
||||
for idx, line in enumerate(file_content_lines)
|
||||
if old_str in line
|
||||
]
|
||||
raise ToolError(
|
||||
f"No replacement was performed. Multiple occurrences of old_str `{old_str}` in lines {lines}. Please ensure it is unique"
|
||||
)
|
||||
|
||||
# Replace old_str with new_str
|
||||
new_file_content = file_content.replace(old_str, new_str)
|
||||
|
||||
# Write the new content to the file
|
||||
self.write_file(path, new_file_content)
|
||||
|
||||
# Save the content to history
|
||||
self._file_history[path].append(file_content)
|
||||
|
||||
# Create a snippet of the edited section
|
||||
replacement_line = file_content.split(old_str)[0].count("\n")
|
||||
start_line = max(0, replacement_line - SNIPPET_LINES)
|
||||
end_line = replacement_line + SNIPPET_LINES + new_str.count("\n")
|
||||
snippet = "\n".join(new_file_content.split("\n")[start_line : end_line + 1])
|
||||
|
||||
# Prepare the success message
|
||||
success_msg = f"The file {path} has been edited. "
|
||||
success_msg += self._make_output(
|
||||
snippet, f"a snippet of {path}", start_line + 1
|
||||
)
|
||||
success_msg += "Review the changes and make sure they are as expected. Edit the file again if necessary."
|
||||
|
||||
return CLIResult(output=success_msg)
|
||||
|
||||
def insert(self, path: Path, insert_line: int, new_str: str):
|
||||
"""Implement the insert command, which inserts new_str at the specified line in the file content."""
|
||||
file_text = self.read_file(path).expandtabs()
|
||||
new_str = new_str.expandtabs()
|
||||
file_text_lines = file_text.split("\n")
|
||||
n_lines_file = len(file_text_lines)
|
||||
|
||||
if insert_line < 0 or insert_line > n_lines_file:
|
||||
raise ToolError(
|
||||
f"Invalid `insert_line` parameter: {insert_line}. It should be within the range of lines of the file: {[0, n_lines_file]}"
|
||||
)
|
||||
|
||||
new_str_lines = new_str.split("\n")
|
||||
new_file_text_lines = (
|
||||
file_text_lines[:insert_line]
|
||||
+ new_str_lines
|
||||
+ file_text_lines[insert_line:]
|
||||
)
|
||||
snippet_lines = (
|
||||
file_text_lines[max(0, insert_line - SNIPPET_LINES) : insert_line]
|
||||
+ new_str_lines
|
||||
+ file_text_lines[insert_line : insert_line + SNIPPET_LINES]
|
||||
)
|
||||
|
||||
new_file_text = "\n".join(new_file_text_lines)
|
||||
snippet = "\n".join(snippet_lines)
|
||||
|
||||
self.write_file(path, new_file_text)
|
||||
self._file_history[path].append(file_text)
|
||||
|
||||
success_msg = f"The file {path} has been edited. "
|
||||
success_msg += self._make_output(
|
||||
snippet,
|
||||
"a snippet of the edited file",
|
||||
max(1, insert_line - SNIPPET_LINES + 1),
|
||||
)
|
||||
success_msg += "Review the changes and make sure they are as expected (correct indentation, no duplicate lines, etc). Edit the file again if necessary."
|
||||
return CLIResult(output=success_msg)
|
||||
|
||||
def undo_edit(self, path: Path):
|
||||
"""Implement the undo_edit command."""
|
||||
if not self._file_history[path]:
|
||||
raise ToolError(f"No edit history found for {path}.")
|
||||
|
||||
old_text = self._file_history[path].pop()
|
||||
self.write_file(path, old_text)
|
||||
|
||||
return CLIResult(
|
||||
output=f"Last edit to {path} undone successfully. {self._make_output(old_text, str(path))}"
|
||||
)
|
||||
|
||||
def read_file(self, path: Path):
|
||||
"""Read the content of a file from a given path; raise a ToolError if an error occurs."""
|
||||
try:
|
||||
return path.read_text()
|
||||
except Exception as e:
|
||||
raise ToolError(f"Ran into {e} while trying to read {path}") from None
|
||||
|
||||
def write_file(self, path: Path, file: str):
|
||||
"""Write the content of a file to a given path; raise a ToolError if an error occurs."""
|
||||
try:
|
||||
path.write_text(file)
|
||||
except Exception as e:
|
||||
raise ToolError(f"Ran into {e} while trying to write to {path}") from None
|
||||
|
||||
def _make_output(
|
||||
self,
|
||||
file_content: str,
|
||||
file_descriptor: str,
|
||||
init_line: int = 1,
|
||||
expand_tabs: bool = True,
|
||||
):
|
||||
"""Generate output for the CLI based on the content of a file."""
|
||||
file_content = maybe_truncate(file_content)
|
||||
if expand_tabs:
|
||||
file_content = file_content.expandtabs()
|
||||
file_content = "\n".join(
|
||||
[
|
||||
f"{i + init_line:6}\t{line}"
|
||||
for i, line in enumerate(file_content.split("\n"))
|
||||
]
|
||||
)
|
||||
return (
|
||||
f"Here's the result of running `cat -n` on {file_descriptor}:\n"
|
||||
+ file_content
|
||||
+ "\n"
|
||||
)
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
"""Utility to run shell commands asynchronously with a timeout."""
|
||||
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
TRUNCATED_MESSAGE: str = "<response clipped><NOTE>To save on context only part of this file has been shown to you. You should retry this tool after you have searched inside the file with `grep -n` in order to find the line numbers of what you are looking for.</NOTE>"
|
||||
MAX_RESPONSE_LEN: int = 16000
|
||||
|
||||
|
||||
def maybe_truncate(content: str, truncate_after: Optional[int] = MAX_RESPONSE_LEN):
|
||||
"""Truncate content and append a notice if content exceeds the specified length."""
|
||||
return (
|
||||
content
|
||||
if not truncate_after or len(content) <= truncate_after
|
||||
else content[:truncate_after] + TRUNCATED_MESSAGE
|
||||
)
|
||||
|
||||
|
||||
async def run(
|
||||
cmd: str,
|
||||
timeout: Optional[float] = 120.0, # seconds
|
||||
truncate_after: Optional[int] = MAX_RESPONSE_LEN,
|
||||
):
|
||||
"""Run a shell command asynchronously with a timeout."""
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout)
|
||||
return (
|
||||
process.returncode or 0,
|
||||
maybe_truncate(stdout.decode(), truncate_after=truncate_after),
|
||||
maybe_truncate(stderr.decode(), truncate_after=truncate_after),
|
||||
)
|
||||
except asyncio.TimeoutError as exc:
|
||||
try:
|
||||
process.kill()
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
raise TimeoutError(
|
||||
f"Command '{cmd}' timed out after {timeout} seconds"
|
||||
) from exc
|
||||
|
|
@ -0,0 +1,246 @@
|
|||
"""
|
||||
Utility functions for the Anthropic API.
|
||||
"""
|
||||
from typing import List, Union, cast
|
||||
from enum import Enum
|
||||
from anthropic import (
|
||||
Anthropic,
|
||||
AnthropicBedrock,
|
||||
AnthropicVertex,
|
||||
APIError,
|
||||
APIResponseValidationError,
|
||||
APIStatusError,
|
||||
)
|
||||
from anthropic.types.beta import (
|
||||
BetaCacheControlEphemeralParam,
|
||||
BetaContentBlockParam,
|
||||
BetaImageBlockParam,
|
||||
BetaMessage,
|
||||
BetaMessageParam,
|
||||
BetaTextBlock,
|
||||
BetaTextBlockParam,
|
||||
BetaToolResultBlockParam,
|
||||
BetaToolUseBlockParam,
|
||||
)
|
||||
from datetime import datetime
|
||||
|
||||
from .tools import ToolResult
|
||||
|
||||
|
||||
COMPUTER_USE_BETA_FLAG = "computer-use-2024-10-22"
|
||||
PROMPT_CACHING_BETA_FLAG = "prompt-caching-2024-07-31"
|
||||
|
||||
|
||||
class APIProvider(Enum):
|
||||
ANTHROPIC = "anthropic"
|
||||
BEDROCK = "bedrock"
|
||||
VERTEX = "vertex"
|
||||
|
||||
|
||||
PROVIDER_TO_DEFAULT_MODEL_NAME: dict[(APIProvider, str), str] = {
|
||||
(APIProvider.ANTHROPIC, "claude-3-5-sonnet-20241022"): "claude-3-5-sonnet-20241022",
|
||||
(APIProvider.BEDROCK, "claude-3-5-sonnet-20241022"): "us.anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
(APIProvider.VERTEX, "claude-3-5-sonnet-20241022"): "claude-3-5-sonnet-v1@20241022",
|
||||
(APIProvider.ANTHROPIC, "claude-3-7-sonnet-20250219"): "claude-3-7-sonnet-20250219",
|
||||
(APIProvider.BEDROCK, "claude-3-7-sonnet-20250219"): "us.anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||
(APIProvider.VERTEX, "claude-3-7-sonnet-20250219"): "claude-3-7-sonnet-v1@20250219",
|
||||
(APIProvider.ANTHROPIC, "claude-4-opus-20250514"): "claude-4-opus-20250514",
|
||||
(APIProvider.BEDROCK, "claude-4-opus-20250514"): "us.anthropic.claude-opus-4-20250514-v1:0",
|
||||
(APIProvider.VERTEX, "claude-4-opus-20250514"): "claude-4-opus-v1@20250514",
|
||||
(APIProvider.ANTHROPIC, "claude-4-sonnet-20250514"): "claude-4-sonnet-20250514",
|
||||
(APIProvider.BEDROCK, "claude-4-sonnet-20250514"): "us.anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
(APIProvider.VERTEX, "claude-4-sonnet-20250514"): "claude-sonnet-4-v1@20250514",
|
||||
}
|
||||
|
||||
|
||||
# This system prompt is optimized for the Docker environment in this repository and
|
||||
# specific tool combinations enabled.
|
||||
# We encourage modifying this system prompt to ensure the model has context for the
|
||||
# environment it is running in, and to provide any additional information that may be
|
||||
# helpful for the task at hand.
|
||||
SYSTEM_PROMPT = f"""<SYSTEM_CAPABILITY>
|
||||
* You are utilising an Ubuntu virtual machine using x86_64 architecture with internet access.
|
||||
* You can feel free to install Ubuntu applications with your bash tool. Use curl instead of wget.
|
||||
* To open browser, please just click on the Chrome icon. Note, Chrome is what is installed on your system.
|
||||
* Using bash tool you can start GUI applications, but you need to set export DISPLAY=:1 and use a subshell. For example "(DISPLAY=:1 xterm &)". GUI apps run with bash tool will appear within your desktop environment, but they may take some time to appear. Take a screenshot to confirm it did.
|
||||
* When using your bash tool with commands that are expected to output very large quantities of text, redirect into a tmp file and use str_replace_editor or `grep -n -B <lines before> -A <lines after> <query> <filename>` to confirm output.
|
||||
* When viewing a page it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available.
|
||||
* DO NOT ask users for clarification during task execution. DO NOT stop to request more information from users. Always take action using available tools.
|
||||
* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request.
|
||||
* The current date is {datetime.today().strftime('%A, %B %d, %Y')}.
|
||||
* Home directory of this Ubuntu system is '/home/user'.
|
||||
</SYSTEM_CAPABILITY>
|
||||
|
||||
<IMPORTANT>
|
||||
* If the item you are looking at is a pdf, if after taking a single screenshot of the pdf it seems that you want to read the entire document instead of trying to continue to read the pdf from your screenshots + navigation, determine the URL, use curl to download the pdf, install and use pdftotext to convert it to a text file, and then read that text file directly with your StrReplaceEditTool.
|
||||
</IMPORTANT>"""
|
||||
|
||||
SYSTEM_PROMPT_WINDOWS = f"""<SYSTEM_CAPABILITY>
|
||||
* You are utilising a Windows virtual machine using x86_64 architecture with internet access.
|
||||
* To open browser, please just click on the Chrome icon. Note, Chrome is what is installed on your system.
|
||||
* When viewing a page it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available.
|
||||
* The current date is {datetime.today().strftime('%A, %B %d, %Y')}.
|
||||
* Home directory of this Windows system is 'C:\\Users\\user'.
|
||||
* When you want to open some applications on Windows, please use Double Click on it instead of clicking once.
|
||||
</SYSTEM_CAPABILITY>"""
|
||||
|
||||
|
||||
|
||||
def _make_api_tool_result(
|
||||
result: ToolResult, tool_use_id: str
|
||||
) -> BetaToolResultBlockParam:
|
||||
"""Convert an agent ToolResult to an API ToolResultBlockParam."""
|
||||
tool_result_content: Union[List[Union[BetaTextBlockParam,
|
||||
BetaImageBlockParam]], str] = []
|
||||
is_error = False
|
||||
|
||||
if not result or (result.get('error') is not None and result.get('error') != ""):
|
||||
is_error = True
|
||||
error_message = str(result.get('error', 'Unknown error occurred')) if result else 'No result received'
|
||||
tool_result_content = [{
|
||||
"type": "text",
|
||||
"text": _maybe_prepend_system_tool_result(result, error_message)
|
||||
}]
|
||||
|
||||
else:
|
||||
if result.get('output'):
|
||||
tool_result_content.append({
|
||||
"type": "text",
|
||||
"text": _maybe_prepend_system_tool_result(
|
||||
result,
|
||||
str(result.get('output', '')
|
||||
if result else '')
|
||||
),
|
||||
})
|
||||
|
||||
if result.get('base64_image'):
|
||||
tool_result_content.append({
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": result.get('base64_image', ''),
|
||||
},
|
||||
})
|
||||
|
||||
if not tool_result_content:
|
||||
tool_result_content.append({
|
||||
"type": "text",
|
||||
"text": "Action completed successfully"
|
||||
})
|
||||
|
||||
return {
|
||||
"type": "tool_result",
|
||||
"content": tool_result_content,
|
||||
"tool_use_id": tool_use_id,
|
||||
"is_error": is_error,
|
||||
}
|
||||
|
||||
def _maybe_prepend_system_tool_result(result: ToolResult, result_text: str):
|
||||
if not result:
|
||||
return result_text
|
||||
|
||||
if result.get('system', False):
|
||||
result_text = f"<system>{result.get('system','')}</system>\n{result_text}"
|
||||
return result_text
|
||||
|
||||
|
||||
|
||||
def _inject_prompt_caching(
|
||||
messages: list[BetaMessageParam],
|
||||
):
|
||||
"""
|
||||
Set cache breakpoints for the 3 most recent turns
|
||||
one cache breakpoint is left for tools/system prompt, to be shared across sessions
|
||||
"""
|
||||
|
||||
breakpoints_remaining = 3
|
||||
for message in reversed(messages):
|
||||
if message["role"] == "user" and isinstance(
|
||||
content := message["content"], list
|
||||
):
|
||||
if breakpoints_remaining:
|
||||
breakpoints_remaining -= 1
|
||||
# Use type ignore to bypass TypedDict check until SDK types are updated
|
||||
content[-1]["cache_control"] = BetaCacheControlEphemeralParam( # type: ignore
|
||||
{"type": "ephemeral"}
|
||||
)
|
||||
else:
|
||||
content[-1].pop("cache_control", None)
|
||||
# we'll only every have one extra turn per loop
|
||||
break
|
||||
|
||||
|
||||
def _maybe_filter_to_n_most_recent_images(
|
||||
messages: list[BetaMessageParam],
|
||||
images_to_keep: int,
|
||||
min_removal_threshold: int,
|
||||
):
|
||||
"""
|
||||
With the assumption that images are screenshots that are of diminishing value as
|
||||
the conversation progresses, remove all but the final `images_to_keep` tool_result
|
||||
images in place, with a chunk of min_removal_threshold to reduce the amount we
|
||||
break the implicit prompt cache.
|
||||
"""
|
||||
if images_to_keep is None:
|
||||
return messages
|
||||
|
||||
tool_result_blocks = cast(
|
||||
list[BetaToolResultBlockParam],
|
||||
[
|
||||
item
|
||||
for message in messages
|
||||
for item in (
|
||||
message["content"] if isinstance(message["content"], list) else []
|
||||
)
|
||||
if isinstance(item, dict) and item.get("type") == "tool_result"
|
||||
],
|
||||
)
|
||||
|
||||
total_images = sum(
|
||||
1
|
||||
for tool_result in tool_result_blocks
|
||||
for content in tool_result.get("content", [])
|
||||
if isinstance(content, dict) and content.get("type") == "image"
|
||||
)
|
||||
|
||||
images_to_remove = total_images - images_to_keep
|
||||
# for better cache behavior, we want to remove in chunks
|
||||
images_to_remove -= images_to_remove % min_removal_threshold
|
||||
|
||||
for tool_result in tool_result_blocks:
|
||||
if isinstance(tool_result.get("content"), list):
|
||||
new_content = []
|
||||
for content in tool_result.get("content", []):
|
||||
if isinstance(content, dict) and content.get("type") == "image":
|
||||
if images_to_remove > 0:
|
||||
images_to_remove -= 1
|
||||
continue
|
||||
new_content.append(content)
|
||||
tool_result["content"] = new_content
|
||||
|
||||
|
||||
def _response_to_params(
|
||||
response: BetaMessage,
|
||||
) -> list[BetaContentBlockParam]:
|
||||
res: list[BetaContentBlockParam] = []
|
||||
if response.content:
|
||||
for block in response.content:
|
||||
if isinstance(block, BetaTextBlock):
|
||||
if block.text:
|
||||
res.append(BetaTextBlockParam(type="text", text=block.text))
|
||||
elif getattr(block, "type", None) == "thinking":
|
||||
# Handle thinking blocks - include signature field
|
||||
thinking_block = {
|
||||
"type": "thinking",
|
||||
"thinking": getattr(block, "thinking", None),
|
||||
}
|
||||
if hasattr(block, "signature"):
|
||||
thinking_block["signature"] = getattr(block, "signature", None)
|
||||
res.append(cast(BetaContentBlockParam, thinking_block))
|
||||
else:
|
||||
# Handle tool use blocks normally
|
||||
res.append(cast(BetaToolUseBlockParam, block.model_dump()))
|
||||
return res
|
||||
else:
|
||||
return []
|
||||
|
|
@ -0,0 +1,384 @@
|
|||
"""Script to run end-to-end evaluation on the benchmark.
|
||||
Utils and basic architecture credit to https://github.com/web-arena-x/webarena/blob/main/run.py.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Dict
|
||||
import math
|
||||
from tqdm import tqdm
|
||||
from multiprocessing import Process, Manager
|
||||
import lib_run_single
|
||||
from desktop_env.desktop_env import DesktopEnv
|
||||
from mm_agents.anthropic import AnthropicAgent as PromptAgent
|
||||
|
||||
# import fake_run_single as lib_run_single
|
||||
# from test_env import DesktopEnv
|
||||
|
||||
# .env
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
|
||||
# Logger Configs {{{ #
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
|
||||
file_handler = logging.FileHandler(
|
||||
os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||
)
|
||||
debug_handler = logging.FileHandler(
|
||||
os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||
)
|
||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||
sdebug_handler = logging.FileHandler(
|
||||
os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||
)
|
||||
|
||||
file_handler.setLevel(logging.INFO)
|
||||
debug_handler.setLevel(logging.DEBUG)
|
||||
stdout_handler.setLevel(logging.INFO)
|
||||
sdebug_handler.setLevel(logging.DEBUG)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
|
||||
)
|
||||
file_handler.setFormatter(formatter)
|
||||
debug_handler.setFormatter(formatter)
|
||||
stdout_handler.setFormatter(formatter)
|
||||
sdebug_handler.setFormatter(formatter)
|
||||
|
||||
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
||||
sdebug_handler.addFilter(logging.Filter("desktopenv"))
|
||||
|
||||
logger.addHandler(file_handler)
|
||||
logger.addHandler(debug_handler)
|
||||
logger.addHandler(stdout_handler)
|
||||
logger.addHandler(sdebug_handler)
|
||||
# }}} Logger Configs #
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
|
||||
def config() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run end-to-end evaluation on the benchmark"
|
||||
)
|
||||
|
||||
# environment config
|
||||
parser.add_argument("--path_to_vm", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--headless", action="store_true", help="Run in headless machine"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--action_space", type=str, default="claude_computer_use", help="Action type"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--observation_type",
|
||||
choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
default="a11y_tree",
|
||||
help="Observation type",
|
||||
)
|
||||
parser.add_argument("--screen_width", type=int, default=1920)
|
||||
parser.add_argument("--screen_height", type=int, default=1080)
|
||||
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
|
||||
parser.add_argument("--max_steps", type=int, default=15)
|
||||
|
||||
# agent config
|
||||
parser.add_argument("--max_trajectory_length", type=int, default=3)
|
||||
parser.add_argument(
|
||||
"--test_config_base_dir", type=str, default="evaluation_examples"
|
||||
)
|
||||
|
||||
# lm config
|
||||
parser.add_argument("--model", type=str, default="claude-4-sonnet-20250514")
|
||||
parser.add_argument("--temperature", type=float, default=1.0)
|
||||
parser.add_argument("--top_p", type=float, default=0.9)
|
||||
parser.add_argument("--max_tokens", type=int, default=1500)
|
||||
parser.add_argument("--stop_token", type=str, default=None)
|
||||
|
||||
# example config
|
||||
parser.add_argument("--domain", type=str, default="all")
|
||||
parser.add_argument(
|
||||
"--test_all_meta_path", type=str, default="evaluation_examples/test_all.json"
|
||||
)
|
||||
|
||||
# logging related
|
||||
parser.add_argument("--result_dir", type=str, default="./results")
|
||||
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel")
|
||||
|
||||
# aws config
|
||||
parser.add_argument(
|
||||
"--region", type=str, default="us-east-1", help="AWS region for the VM"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def distribute_tasks(test_all_meta: dict, num_envs: int) -> List[Dict]:
|
||||
"""Distribute tasks evenly across environments."""
|
||||
# Flatten the tasks into a single list
|
||||
all_tasks = []
|
||||
for domain, examples in test_all_meta.items():
|
||||
for example_id in examples:
|
||||
all_tasks.append((domain, example_id))
|
||||
|
||||
# Calculate tasks per environment
|
||||
tasks_per_env = math.ceil(len(all_tasks) / num_envs)
|
||||
|
||||
# Distribute tasks
|
||||
distributed_tasks = []
|
||||
for i in range(num_envs):
|
||||
env_tasks = {}
|
||||
start_idx = i * tasks_per_env
|
||||
end_idx = min((i + 1) * tasks_per_env, len(all_tasks))
|
||||
|
||||
for domain, example_id in all_tasks[start_idx:end_idx]:
|
||||
if domain not in env_tasks:
|
||||
env_tasks[domain] = []
|
||||
env_tasks[domain].append(example_id)
|
||||
|
||||
distributed_tasks.append(env_tasks)
|
||||
|
||||
return distributed_tasks
|
||||
|
||||
|
||||
|
||||
def run_env_tasks(env_idx: int, env: DesktopEnv, agent: PromptAgent, env_tasks: dict, args: argparse.Namespace, shared_scores: list):
|
||||
"""Run tasks for a single environment."""
|
||||
logger.info(f"Executing tasks in environment {env_idx + 1}/{args.num_envs}")
|
||||
|
||||
for domain in tqdm(env_tasks, desc=f"Env{env_idx+1}-Domain"):
|
||||
for example_id in tqdm(env_tasks[domain], desc="Example", leave=False):
|
||||
config_file = os.path.join(
|
||||
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
|
||||
)
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
example = json.load(f)
|
||||
|
||||
logger.info(f"[Env {env_idx+1}][Domain]: {domain}")
|
||||
logger.info(f"[Env {env_idx+1}][Example ID]: {example_id}")
|
||||
logger.info(f"[Env {env_idx+1}][Instruction]: {example['instruction']}")
|
||||
|
||||
example_result_dir = os.path.join(
|
||||
args.result_dir,
|
||||
args.action_space,
|
||||
args.observation_type,
|
||||
args.model,
|
||||
domain,
|
||||
example_id,
|
||||
)
|
||||
os.makedirs(example_result_dir, exist_ok=True)
|
||||
|
||||
try:
|
||||
lib_run_single.run_single_example(
|
||||
agent,
|
||||
env,
|
||||
example,
|
||||
args.max_steps,
|
||||
example["instruction"],
|
||||
args,
|
||||
example_result_dir,
|
||||
shared_scores,
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
# logger traceback
|
||||
logger.error(f"Exception in Env{env_idx+1} {domain}/{example_id}: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
env.controller.end_recording(
|
||||
os.path.join(example_result_dir, "recording.mp4")
|
||||
)
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||
f.write(
|
||||
json.dumps(
|
||||
{"Error": f"Time limit exceeded in {domain}/{example_id}"}
|
||||
)
|
||||
)
|
||||
f.write("\n")
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||
logger.info("Args: %s", args)
|
||||
|
||||
distributed_tasks = distribute_tasks(test_all_meta, args.num_envs)
|
||||
|
||||
# First, set up all environments
|
||||
logger.info("Setting up all environments...")
|
||||
envs = []
|
||||
agents = []
|
||||
|
||||
for env_idx in range(args.num_envs):
|
||||
logger.info(f"Setting up environment {env_idx + 1}/{args.num_envs}")
|
||||
|
||||
agent = PromptAgent(
|
||||
model=args.model,
|
||||
max_tokens=args.max_tokens,
|
||||
top_p=args.top_p,
|
||||
temperature=args.temperature,
|
||||
action_space=args.action_space,
|
||||
observation_type=args.observation_type,
|
||||
max_trajectory_length=args.max_trajectory_length,
|
||||
screen_size=(args.screen_width, args.screen_height),
|
||||
)
|
||||
agents.append(agent)
|
||||
|
||||
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
|
||||
REGION = "us-east-1"
|
||||
env = DesktopEnv(
|
||||
path_to_vm=args.path_to_vm,
|
||||
action_space=agent.action_space,
|
||||
|
||||
provider_name="aws",
|
||||
region="us-east-1",
|
||||
snapshot_name=IMAGE_ID_MAP[REGION],
|
||||
screen_size=(args.screen_width, args.screen_height),
|
||||
headless=args.headless,
|
||||
os_type="Ubuntu",
|
||||
require_a11y_tree=args.observation_type
|
||||
in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
)
|
||||
envs.append(env)
|
||||
|
||||
logger.info("All environments are ready. Starting parallel task execution...")
|
||||
|
||||
# Create a shared list for scores across processes
|
||||
with Manager() as manager:
|
||||
shared_scores = manager.list()
|
||||
|
||||
# Create and start processes for each environment
|
||||
processes = []
|
||||
for env_idx, (env, agent, env_tasks) in enumerate(zip(envs, agents, distributed_tasks)):
|
||||
p = Process(
|
||||
target=run_env_tasks,
|
||||
args=(env_idx, env, agent, env_tasks, args, shared_scores)
|
||||
)
|
||||
processes.append(p)
|
||||
p.start()
|
||||
|
||||
# Wait for all processes to complete
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
# Convert shared list to regular list
|
||||
scores = list(shared_scores)
|
||||
|
||||
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
|
||||
|
||||
|
||||
def get_unfinished(
|
||||
action_space, use_model, observation_type, result_dir, total_file_json
|
||||
):
|
||||
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||
|
||||
if not os.path.exists(target_dir):
|
||||
return total_file_json
|
||||
|
||||
finished = {}
|
||||
for domain in os.listdir(target_dir):
|
||||
finished[domain] = []
|
||||
domain_path = os.path.join(target_dir, domain)
|
||||
if os.path.isdir(domain_path):
|
||||
for example_id in os.listdir(domain_path):
|
||||
if example_id == "onboard":
|
||||
continue
|
||||
example_path = os.path.join(domain_path, example_id)
|
||||
if os.path.isdir(example_path):
|
||||
if "result.txt" not in os.listdir(example_path):
|
||||
# empty all files under example_id
|
||||
for file in os.listdir(example_path):
|
||||
os.remove(os.path.join(example_path, file))
|
||||
else:
|
||||
finished[domain].append(example_id)
|
||||
|
||||
if not finished:
|
||||
return total_file_json
|
||||
|
||||
for domain, examples in finished.items():
|
||||
if domain in total_file_json:
|
||||
total_file_json[domain] = [
|
||||
x for x in total_file_json[domain] if x not in examples
|
||||
]
|
||||
|
||||
return total_file_json
|
||||
|
||||
|
||||
def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
|
||||
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||
if not os.path.exists(target_dir):
|
||||
print("New experiment, no result yet.")
|
||||
return None
|
||||
|
||||
all_result = []
|
||||
|
||||
for domain in os.listdir(target_dir):
|
||||
domain_path = os.path.join(target_dir, domain)
|
||||
if os.path.isdir(domain_path):
|
||||
for example_id in os.listdir(domain_path):
|
||||
example_path = os.path.join(domain_path, example_id)
|
||||
if os.path.isdir(example_path):
|
||||
if "result.txt" in os.listdir(example_path):
|
||||
# empty all files under example_id
|
||||
try:
|
||||
all_result.append(
|
||||
float(
|
||||
open(
|
||||
os.path.join(example_path, "result.txt"), "r"
|
||||
).read()
|
||||
)
|
||||
)
|
||||
except:
|
||||
all_result.append(0.0)
|
||||
|
||||
if not all_result:
|
||||
print("New experiment, no result yet.")
|
||||
return None
|
||||
else:
|
||||
print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%")
|
||||
return all_result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
####### The complete version of the list of examples #######
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
args = config()
|
||||
|
||||
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
|
||||
test_all_meta = json.load(f)
|
||||
|
||||
if args.domain != "all":
|
||||
test_all_meta = {args.domain: test_all_meta[args.domain]}
|
||||
|
||||
test_file_list = get_unfinished(
|
||||
args.action_space,
|
||||
args.model,
|
||||
args.observation_type,
|
||||
args.result_dir,
|
||||
test_all_meta,
|
||||
)
|
||||
left_info = ""
|
||||
for domain in test_file_list:
|
||||
left_info += f"{domain}: {len(test_file_list[domain])}\n"
|
||||
logger.info(f"Left tasks:\n{left_info}")
|
||||
|
||||
get_result(
|
||||
args.action_space,
|
||||
args.model,
|
||||
args.observation_type,
|
||||
args.result_dir,
|
||||
test_all_meta,
|
||||
)
|
||||
test(args, test_file_list)
|
||||
|
||||
|
||||
# path_to_vm can be a list["xxx","xxx"]
|
||||
Loading…
Reference in New Issue