Compare commits
27 Commits
djlu/qwen3
...
main
| Author | SHA1 | Date |
|---|---|---|
|
|
5ef8bdfa35 | |
|
|
439e178a2e | |
|
|
951e1928c8 | |
|
|
02a35be067 | |
|
|
662826f57e | |
|
|
410ec63a89 | |
|
|
031696e83c | |
|
|
f593f35b1c | |
|
|
ac31778ee3 | |
|
|
60caa52fc4 | |
|
|
41477a9c40 | |
|
|
78433ecfcf | |
|
|
9540454b0a | |
|
|
cbc3b590ff | |
|
|
903ed36715 | |
|
|
3167339e45 | |
|
|
00b6468eb7 | |
|
|
6d43dbc532 | |
|
|
8365edc975 | |
|
|
21c2b7629b | |
|
|
3bf54c92a9 | |
|
|
a484f2e484 | |
|
|
9f97535ef9 | |
|
|
afd29115da | |
|
|
55372c4432 | |
|
|
d25464c203 | |
|
|
f9e9273b3b |
|
|
@ -204,4 +204,11 @@ reference/
|
|||
draft/
|
||||
manual_examine.py
|
||||
run_human_examine.sh
|
||||
quick_start.py
|
||||
quick_start.py
|
||||
result_multi_apps_pengxiang_transformers12evaluation_examples/settings/proxy/dataimpulse.json
|
||||
evaluation_examples/settings/proxy/dataimpulse.json
|
||||
|
||||
# Local test configurations (not for public repo)
|
||||
evaluation_examples/spiderman.json
|
||||
evaluation_examples/test_50_random_proportional.json
|
||||
evaluation_examples/test_chrome.json
|
||||
|
|
|
|||
|
|
@ -228,3 +228,7 @@ Special thanks to the following institutions that provided feedback and particip
|
|||
Special thanks to the following students who participated in the specific fixes: [Mengqi Yuan](https://yuanmengqi.github.io/), [Danyang Zhang](https://zdy023.github.io/), [Xinzhuang Xiong](https://thisisxxz.com/), [Zhennan Shen](https://scholar.google.com/citations?user=JPwg5MwAAAAJ&hl=en), [Zilong Zhou](https://github.com/adlsdztony), Yanxu Chen, [Jiaqi Deng](https://millank0817.github.io/), [Tianbao Xie](https://tianbaoxie.com/), Junda Chen, [Jixuan Chen](https://chenjix.github.io/), [Haoyuan Wu](https://www.linkedin.com/in/haoyuan-wu-240878291/).
|
||||
|
||||
Special thanks to the following students who participated in running the re-evaluation: [Mengqi Yuan](https://yuanmengqi.github.io/), [Zilong Zhou](https://github.com/adlsdztony), [Xinyuan Wang](https://xinyuanwangcs.github.io/), [Bowen Wang](https://bowenbryanwang.github.io/).
|
||||
|
||||
## You might also be interested
|
||||
|
||||
- **OSWorld-MCP**: Benchmarking MCP Tool Invocation in Computer-Use Agents. [Website](https://osworld-mcp.github.io/)
|
||||
|
|
|
|||
|
|
@ -238,12 +238,17 @@ class PythonController:
|
|||
"returncode": -1
|
||||
}
|
||||
|
||||
def execute_action(self, action: Dict[str, Any]):
|
||||
def execute_action(self, action):
|
||||
"""
|
||||
Executes an action on the server computer.
|
||||
"""
|
||||
# Handle string actions
|
||||
if action in ['WAIT', 'FAIL', 'DONE']:
|
||||
return
|
||||
|
||||
# Handle dictionary actions
|
||||
if type(action) == dict and action.get('action_type') in ['WAIT', 'FAIL', 'DONE']:
|
||||
return
|
||||
|
||||
action_type = action["action_type"]
|
||||
parameters = action["parameters"] if "parameters" in action else {param: action[param] for param in action if param != 'action_type'}
|
||||
|
|
|
|||
|
|
@ -391,12 +391,12 @@ class DesktopEnv(gym.Env):
|
|||
logger.info(f"Step {self._step_no} in trajectory {self._traj_no} with action: {action}")
|
||||
# handle the special actions
|
||||
if action in ['WAIT', 'FAIL', 'DONE'] or (type(action) == dict and action['action_type'] in ['WAIT', 'FAIL', 'DONE']):
|
||||
if action == 'WAIT':
|
||||
if action == 'WAIT' or (type(action) == dict and action.get('action_type') == 'WAIT'):
|
||||
time.sleep(pause)
|
||||
elif action == 'FAIL':
|
||||
elif action == 'FAIL' or (type(action) == dict and action.get('action_type') == 'FAIL'):
|
||||
done = True
|
||||
info = {"fail": True}
|
||||
elif action == 'DONE':
|
||||
elif action == 'DONE' or (type(action) == dict and action.get('action_type') == 'DONE'):
|
||||
done = True
|
||||
info = {"done": True}
|
||||
|
||||
|
|
@ -404,7 +404,7 @@ class DesktopEnv(gym.Env):
|
|||
# the set of all possible actions defined in the action representation
|
||||
self.controller.execute_action(action)
|
||||
elif self.action_space == "pyautogui" or self.action_space == "claude_computer_use":
|
||||
if action in ['WAIT', 'FAIL', 'DONE']:
|
||||
if action in ['WAIT', 'FAIL', 'DONE'] or (type(action) == dict and action.get('action_type') in ['WAIT', 'FAIL', 'DONE']):
|
||||
self.controller.execute_action(action)
|
||||
else:
|
||||
# the set of all possible python commands insides `pyautogui`
|
||||
|
|
@ -434,13 +434,16 @@ class DesktopEnv(gym.Env):
|
|||
self.is_environment_used = True
|
||||
|
||||
if self.evaluator['func'] == "infeasible":
|
||||
if len(self.action_history) > 0 and self.action_history[-1] == "FAIL":
|
||||
return 1
|
||||
else:
|
||||
return 0
|
||||
if len(self.action_history) > 0:
|
||||
last_action = self.action_history[-1]
|
||||
if last_action == "FAIL" or (type(last_action) == dict and last_action.get('action_type') == 'FAIL'):
|
||||
return 1
|
||||
return 0
|
||||
else:
|
||||
if len(self.action_history) > 0 and self.action_history[-1] == "FAIL":
|
||||
return 0
|
||||
if len(self.action_history) > 0:
|
||||
last_action = self.action_history[-1]
|
||||
if last_action == "FAIL" or (type(last_action) == dict and last_action.get('action_type') == 'FAIL'):
|
||||
return 0
|
||||
|
||||
if type(self.metric) == list:
|
||||
# Multiple metrics to evaluate whether the task is successfully completed
|
||||
|
|
|
|||
|
|
@ -0,0 +1,499 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import re
|
||||
from typing import Callable, Any, Optional, Tuple
|
||||
from typing import List, Dict, Union
|
||||
|
||||
import gymnasium as gym
|
||||
|
||||
from desktop_env.controllers.python import PythonController
|
||||
from desktop_env.controllers.setup import SetupController
|
||||
from desktop_env.evaluators import metrics, getters
|
||||
from desktop_env.providers import create_vm_manager_and_provider
|
||||
|
||||
logger = logging.getLogger("desktopenv.env")
|
||||
|
||||
Metric = Callable[[Any, Any], float]
|
||||
Getter = Callable[[gym.Env, Dict[str, Any]], Any]
|
||||
|
||||
MAX_RETRIES = 5 # Maximum retries for environment setup
|
||||
|
||||
|
||||
|
||||
def _fix_pyautogui_less_than_bug(command: str) -> str:
|
||||
"""
|
||||
Fix PyAutoGUI '<' character bug by converting it to hotkey("shift", ',') calls.
|
||||
|
||||
This fixes the known PyAutoGUI issue where typing '<' produces '>' instead.
|
||||
References:
|
||||
- https://github.com/asweigart/pyautogui/issues/198
|
||||
- https://github.com/xlang-ai/OSWorld/issues/257
|
||||
|
||||
Args:
|
||||
command (str): The original pyautogui command
|
||||
|
||||
Returns:
|
||||
str: The fixed command with '<' characters handled properly
|
||||
"""
|
||||
# Pattern to match press('<') or press('\u003c') calls
|
||||
press_pattern = r'pyautogui\.press\(["\'](?:<|\\u003c)["\']\)'
|
||||
|
||||
# Handle press('<') calls
|
||||
def replace_press_less_than(match):
|
||||
return 'pyautogui.hotkey("shift", ",")'
|
||||
|
||||
# First handle press('<') calls
|
||||
command = re.sub(press_pattern, replace_press_less_than, command)
|
||||
|
||||
# Pattern to match typewrite calls with quoted strings
|
||||
typewrite_pattern = r'pyautogui\.typewrite\((["\'])(.*?)\1\)'
|
||||
|
||||
# Then handle typewrite calls
|
||||
def process_typewrite_match(match):
|
||||
quote_char = match.group(1)
|
||||
content = match.group(2)
|
||||
|
||||
# Preprocess: Try to decode Unicode escapes like \u003c to actual '<'
|
||||
# This handles cases where '<' is represented as escaped Unicode
|
||||
try:
|
||||
# Attempt to decode unicode escapes
|
||||
decoded_content = content.encode('utf-8').decode('unicode_escape')
|
||||
content = decoded_content
|
||||
except UnicodeDecodeError:
|
||||
# If decoding fails, proceed with original content to avoid breaking existing logic
|
||||
pass # English comment: Graceful degradation - fall back to original content if decoding fails
|
||||
|
||||
# Check if content contains '<'
|
||||
if '<' not in content:
|
||||
return match.group(0)
|
||||
|
||||
# Split by '<' and rebuild
|
||||
parts = content.split('<')
|
||||
result_parts = []
|
||||
|
||||
for i, part in enumerate(parts):
|
||||
if i == 0:
|
||||
# First part
|
||||
if part:
|
||||
result_parts.append(f"pyautogui.typewrite({quote_char}{part}{quote_char})")
|
||||
else:
|
||||
# Add hotkey for '<' and then typewrite for the rest
|
||||
result_parts.append('pyautogui.hotkey("shift", ",")')
|
||||
if part:
|
||||
result_parts.append(f"pyautogui.typewrite({quote_char}{part}{quote_char})")
|
||||
|
||||
return '; '.join(result_parts)
|
||||
|
||||
command = re.sub(typewrite_pattern, process_typewrite_match, command)
|
||||
|
||||
return command
|
||||
|
||||
|
||||
class DesktopEnv(gym.Env):
|
||||
"""
|
||||
DesktopEnv with OpenAI Gym interface. It provides a desktop environment for setting and evaluating desktop automation tasks.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
provider_name: str = "vmware",
|
||||
region: str = None,
|
||||
path_to_vm: str = None,
|
||||
snapshot_name: str = "init_state",
|
||||
action_space: str = "pyautogui",
|
||||
cache_dir: str = "cache",
|
||||
screen_size: Tuple[int] = (int(os.environ.get("SCREEN_WIDTH", 1920)), int(os.environ.get("SCREEN_HEIGHT", 1080))),
|
||||
headless: bool = False,
|
||||
require_a11y_tree: bool = True,
|
||||
require_terminal: bool = False,
|
||||
os_type: str = "Ubuntu",
|
||||
enable_proxy: bool = False,
|
||||
client_password: str = "",
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
provider_name (str): virtualization provider name, default to "vmware"
|
||||
region (str): the region for allocate machines, work for cloud services, default to "us-east-1"
|
||||
path_to_vm (str): path to .vmx file
|
||||
snapshot_name (str): snapshot name to revert to, default to "init_state"
|
||||
action_space (str): "computer_13" | "pyautogui"
|
||||
cache_dir (str): cache directory to cache task-related stuffs like
|
||||
reference file for evaluation
|
||||
screen_size (Tuple[int]): screen size of the VM
|
||||
headless (bool): whether to run the VM in headless mode
|
||||
require_a11y_tree (bool): whether to require accessibility tree
|
||||
require_terminal (bool): whether to require terminal output
|
||||
os_type (str): operating system type, default to "Ubuntu"
|
||||
enable_proxy (bool): whether to enable proxy support, default to False
|
||||
"""
|
||||
# Initialize VM manager and vitualization provider
|
||||
self.region = region
|
||||
self.provider_name = provider_name
|
||||
self.enable_proxy = enable_proxy # Store proxy enablement setting
|
||||
if client_password == "":
|
||||
if self.provider_name == "aws":
|
||||
self.client_password = "osworld-public-evaluation"
|
||||
else:
|
||||
self.client_password = "password"
|
||||
else:
|
||||
self.client_password = client_password
|
||||
|
||||
self.screen_width = screen_size[0]
|
||||
self.screen_height = screen_size[1]
|
||||
|
||||
# Default
|
||||
self.server_port = 5000
|
||||
self.chromium_port = 9222
|
||||
self.vnc_port = 8006
|
||||
self.vlc_port = 8080
|
||||
|
||||
# Initialize with default (no proxy) provider
|
||||
self.current_use_proxy = False
|
||||
self.manager, self.provider = None, None
|
||||
self.os_type = os_type
|
||||
self.path_to_vm = path_to_vm
|
||||
# Track whether environment has been used (step/setup) to optimize snapshot revert
|
||||
# docker, aws, gcp, azure are always unused as the emulator starts from a clean state
|
||||
# vmware, virtualbox are always used as the emulator starts from a dirty state
|
||||
if self.provider_name in {"docker", "aws", "gcp", "azure", "aliyun", "volcengine"}:
|
||||
self.is_environment_used = False
|
||||
elif self.provider_name in {"vmware", "virtualbox"}:
|
||||
self.is_environment_used = True
|
||||
else:
|
||||
raise ValueError(f"Invalid provider name: {self.provider_name}")
|
||||
|
||||
self.snapshot_name = snapshot_name
|
||||
self.cache_dir_base: str = cache_dir
|
||||
self.headless = headless
|
||||
self.require_a11y_tree = require_a11y_tree
|
||||
self.require_terminal = require_terminal
|
||||
|
||||
# mode: human or machine
|
||||
self.instruction = None
|
||||
assert action_space in ["computer_13", "pyautogui", "claude_computer_use", "autoglm_computer_use"]
|
||||
self.action_space = action_space # todo: refactor it to the ActType
|
||||
|
||||
# episodic stuffs, like counters, will be updated or reset
|
||||
# when calling self.reset()
|
||||
self._traj_no: int = -1
|
||||
self._step_no: int = 0
|
||||
self.action_history: List[Dict[str, any]] = []
|
||||
|
||||
def start(self):
|
||||
# Initialize emulator and controller
|
||||
if not self.manager and not self.provider:
|
||||
logger.info("Initializing...")
|
||||
self.manager, self.provider = create_vm_manager_and_provider(self.provider_name, self.region, use_proxy=False)
|
||||
|
||||
if self.path_to_vm:
|
||||
self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(self.path_to_vm))) \
|
||||
if self.provider_name in {"vmware", "virtualbox"} else self.path_to_vm
|
||||
else:
|
||||
self.path_to_vm = self.manager.get_vm_path(os_type=self.os_type, region=self.region, screen_size=(self.screen_width, self.screen_height))
|
||||
|
||||
self._start_emulator()
|
||||
|
||||
def _start_emulator(self):
|
||||
try:
|
||||
# Power on the virtual machine
|
||||
self.provider.start_emulator(self.path_to_vm, self.headless, self.os_type)
|
||||
|
||||
# Get the ip from the virtual machine, and setup the controller
|
||||
vm_ip_ports = self.provider.get_ip_address(self.path_to_vm).split(':')
|
||||
self.vm_ip = vm_ip_ports[0]
|
||||
# Get the ports from the virtual machine (for Docker provider only)
|
||||
if len(vm_ip_ports) > 1:
|
||||
self.server_port = int(vm_ip_ports[1])
|
||||
self.chromium_port = int(vm_ip_ports[2])
|
||||
self.vnc_port = int(vm_ip_ports[3])
|
||||
self.vlc_port = int(vm_ip_ports[4])
|
||||
self.controller = PythonController(vm_ip=self.vm_ip, server_port=self.server_port)
|
||||
self.setup_controller = SetupController(vm_ip=self.vm_ip, server_port=self.server_port, chromium_port=self.chromium_port, vlc_port=self.vlc_port, cache_dir=self.cache_dir_base, client_password=self.client_password, screen_width=self.screen_width, screen_height=self.screen_height)
|
||||
|
||||
except Exception as e:
|
||||
try:
|
||||
self.provider.stop_emulator(self.path_to_vm)
|
||||
except Exception as stop_err:
|
||||
logger.warning(f"Cleanup after interrupt failed: {stop_err}")
|
||||
raise
|
||||
|
||||
def _revert_to_snapshot(self):
|
||||
# Revert to certain snapshot of the virtual machine, and refresh the path to vm and ip of vm
|
||||
# due to the fact it could be changed when implemented by cloud services
|
||||
path_to_vm = self.provider.revert_to_snapshot(self.path_to_vm, self.snapshot_name)
|
||||
if path_to_vm and not path_to_vm == self.path_to_vm:
|
||||
# path_to_vm has to be a new path
|
||||
|
||||
self.manager.delete_vm(self.path_to_vm, self.region)
|
||||
self.manager.add_vm(path_to_vm, self.region)
|
||||
self.manager.occupy_vm(path_to_vm, os.getpid(), self.region)
|
||||
self.path_to_vm = path_to_vm
|
||||
|
||||
def _save_state(self, snapshot_name=None):
|
||||
# Save the current virtual machine state to a certain snapshot name
|
||||
self.provider.save_state(self.path_to_vm, snapshot_name)
|
||||
|
||||
def close(self):
|
||||
# Close (release) the virtual machine
|
||||
self.provider.stop_emulator(self.path_to_vm)
|
||||
|
||||
def reset(self, task_config: Optional[Dict[str, Any]] = None, seed=None, options=None) -> Dict[str, Any]:
|
||||
|
||||
# Reset to certain task in OSWorld
|
||||
logger.info("Resetting environment...")
|
||||
logger.info("Switching task...")
|
||||
logger.info("Setting counters...")
|
||||
self._traj_no += 1
|
||||
self._step_no = 0
|
||||
self.action_history.clear()
|
||||
|
||||
for attempt in range(MAX_RETRIES):
|
||||
# Only revert to snapshot if environment has been used (step/setup)
|
||||
# This optimization is especially important for cloud providers like AWS
|
||||
# where unnecessary snapshot operations are costly and time-consuming
|
||||
|
||||
if task_config is not None:
|
||||
# Only consider task proxy requirement if proxy is enabled at system level
|
||||
task_use_proxy = task_config.get("proxy", False) and self.enable_proxy
|
||||
if not self.enable_proxy and task_config.get("proxy", False):
|
||||
logger.info("Task requires proxy but proxy is disabled at system level, ignoring proxy requirement.")
|
||||
|
||||
if task_use_proxy != self.current_use_proxy:
|
||||
# keep because get_info_from_website depend on this
|
||||
self.current_use_proxy = task_use_proxy
|
||||
|
||||
if self.is_environment_used:
|
||||
logger.info("Environment has been used, reverting to snapshot {}...".format(self.snapshot_name))
|
||||
self._revert_to_snapshot()
|
||||
logger.info("Starting emulator...")
|
||||
self._start_emulator()
|
||||
logger.info("Emulator started.")
|
||||
# Reset the usage flag after reverting
|
||||
self.is_environment_used = False
|
||||
else:
|
||||
logger.info("Environment is clean, skipping snapshot revert (provider: {}).".format(self.provider_name))
|
||||
|
||||
if task_config is not None:
|
||||
if task_config.get("proxy", False) and self.enable_proxy:
|
||||
# If using proxy and proxy is enabled, set up the proxy configuration
|
||||
self.setup_controller._proxy_setup(self.client_password)
|
||||
self._set_task_info(task_config)
|
||||
self.setup_controller.reset_cache_dir(self.cache_dir)
|
||||
logger.info("Setting up environment...")
|
||||
success = self.setup_controller.setup(self.config, task_config.get("proxy", False) and self.enable_proxy)
|
||||
if success:
|
||||
# Mark environment as used when setup is successfully executed
|
||||
if self.config: # Only mark as used if there were actual setup operations
|
||||
self.is_environment_used = True
|
||||
break
|
||||
else:
|
||||
logger.error(
|
||||
"Environment setup failed, retrying (%d/%d)...",
|
||||
attempt + 1,
|
||||
MAX_RETRIES,
|
||||
)
|
||||
time.sleep(5)
|
||||
else:
|
||||
break
|
||||
|
||||
logger.info("Environment setup complete.")
|
||||
|
||||
observation = self._get_obs()
|
||||
return observation
|
||||
|
||||
def _get_obs(self):
|
||||
# We provide screenshot, accessibility_tree (optional), terminal (optional), and instruction.
|
||||
# can be customized and scaled
|
||||
return {
|
||||
"screenshot": self.controller.get_screenshot(),
|
||||
"accessibility_tree": self.controller.get_accessibility_tree() if self.require_a11y_tree else None,
|
||||
"terminal": self.controller.get_terminal_output() if self.require_terminal else None,
|
||||
"instruction": self.instruction
|
||||
}
|
||||
|
||||
@property
|
||||
def vm_platform(self):
|
||||
return self.controller.get_vm_platform()
|
||||
|
||||
@property
|
||||
def vm_screen_size(self):
|
||||
return self.controller.get_vm_screen_size()
|
||||
|
||||
def _set_task_info(self, task_config: Dict[str, Any]):
|
||||
"""Set task info (proxy logic is handled in reset method)"""
|
||||
self.task_id: str = task_config["id"]
|
||||
self.cache_dir: str = os.path.join(self.cache_dir_base, self.task_id)
|
||||
os.makedirs(self.cache_dir, exist_ok=True)
|
||||
self.instruction = task_config["instruction"]
|
||||
self.config = task_config["config"] if "config" in task_config else []
|
||||
|
||||
self._set_evaluator_info(task_config)
|
||||
|
||||
def _set_evaluator_info(self, task_config: Dict[str, Any]):
|
||||
"""Set evaluator information from task config"""
|
||||
if "evaluator" not in task_config:
|
||||
return
|
||||
# evaluator dict
|
||||
# func -> metric function string, or list of metric function strings
|
||||
# conj -> conjunction of multiple metrics if func is a list with length > 1, "and"/"or"
|
||||
# result -> result getter config, or list of result getter configs
|
||||
# expected (optional) -> expected getter config, or list of expected getter configs
|
||||
# options (optional) -> metric options, or list of metric options
|
||||
# if func is a str list, then result, expected (if exists), options (if exists) should also be lists of the same length
|
||||
# even if one of the metrics does not need expected or options field, it should be included in the list with None
|
||||
self.evaluator = task_config["evaluator"]
|
||||
self.metric: Metric = [getattr(metrics, func) for func in self.evaluator["func"]] \
|
||||
if isinstance(self.evaluator["func"], list) \
|
||||
else getattr(metrics, self.evaluator["func"])
|
||||
self.metric_conj: str = self.evaluator.get("conj", "and") # take conjunction of multiple metrics
|
||||
if "result" in self.evaluator and len(self.evaluator["result"]) > 0:
|
||||
self.result_getter: Getter = [getattr(getters, "get_{:}".format(res["type"])) for res in
|
||||
self.evaluator["result"]] \
|
||||
if isinstance(self.evaluator["result"], list) \
|
||||
else getattr(getters, "get_{:}".format(self.evaluator["result"]["type"]))
|
||||
else:
|
||||
self.result_getter = [None] * len(self.metric) \
|
||||
if isinstance(self.metric, list) \
|
||||
else None
|
||||
|
||||
if "expected" in self.evaluator and len(self.evaluator["expected"]) > 0:
|
||||
self.expected_getter: Getter = [getattr(getters, "get_{:}".format(exp["type"])) if exp else None for exp in
|
||||
self.evaluator["expected"]] \
|
||||
if isinstance(self.evaluator["expected"], list) \
|
||||
else getattr(getters, "get_{:}".format(self.evaluator["expected"]["type"]))
|
||||
else:
|
||||
self.expected_getter = [None] * len(self.metric) \
|
||||
if isinstance(self.metric, list) \
|
||||
else None
|
||||
self.metric_options: Union[List[Dict[str, Any]], Dict[str, Any]] = [opt if opt else {} for opt in
|
||||
self.evaluator["options"]] \
|
||||
if isinstance(self.evaluator.get("options", {}), list) \
|
||||
else self.evaluator["options"] \
|
||||
if "options" in self.evaluator \
|
||||
else [{}] * len(self.metric) \
|
||||
if isinstance(self.metric, list) \
|
||||
else {}
|
||||
|
||||
assert (not isinstance(self.evaluator["func"], list)
|
||||
or (len(self.metric) == len(self.result_getter) == len(self.expected_getter) == len(
|
||||
self.metric_options)))
|
||||
|
||||
def step(self, action, pause=2):
|
||||
self._step_no += 1
|
||||
self.action_history.append(action)
|
||||
|
||||
# Mark environment as used when step is called
|
||||
self.is_environment_used = True
|
||||
|
||||
reward = 0 # todo: Define reward calculation for each example
|
||||
done = False # todo: Define episode termination condition for each example
|
||||
info = {}
|
||||
logger.info(f"Step {self._step_no} in trajectory {self._traj_no} with action: {action}")
|
||||
# handle the special actions
|
||||
if action in ['WAIT', 'FAIL', 'DONE'] or (type(action) == dict and action['action_type'] in ['WAIT', 'FAIL', 'DONE']):
|
||||
if action == 'WAIT' or (type(action) == dict and action.get('action_type') == 'WAIT'):
|
||||
time.sleep(pause)
|
||||
elif action == 'FAIL' or (type(action) == dict and action.get('action_type') == 'FAIL'):
|
||||
done = True
|
||||
info = {"fail": True}
|
||||
elif action == 'DONE' or (type(action) == dict and action.get('action_type') == 'DONE'):
|
||||
done = True
|
||||
info = {"done": True}
|
||||
|
||||
if self.action_space == "computer_13":
|
||||
# the set of all possible actions defined in the action representation
|
||||
self.controller.execute_action(action)
|
||||
elif self.action_space == "pyautogui" or self.action_space == "claude_computer_use":
|
||||
if action in ['WAIT', 'FAIL', 'DONE'] or (type(action) == dict and action.get('action_type') in ['WAIT', 'FAIL', 'DONE']):
|
||||
self.controller.execute_action(action)
|
||||
else:
|
||||
# the set of all possible python commands insides `pyautogui`
|
||||
if type(action) == str:
|
||||
# Fix PyAutoGUI '<' character bug before execution
|
||||
fixed_command = _fix_pyautogui_less_than_bug(action)
|
||||
self.controller.execute_python_command(fixed_command)
|
||||
elif type(action) == dict:
|
||||
# Fix PyAutoGUI '<' character bug before execution
|
||||
fixed_command = _fix_pyautogui_less_than_bug(action['command'])
|
||||
self.controller.execute_python_command(fixed_command)
|
||||
|
||||
time.sleep(pause)
|
||||
observation = self._get_obs()
|
||||
|
||||
return observation, reward, done, info
|
||||
|
||||
def evaluate(self):
|
||||
"""
|
||||
Evaluate whether the task is successfully completed.
|
||||
"""
|
||||
|
||||
postconfig = self.evaluator.get("postconfig", [])
|
||||
self.setup_controller.setup(postconfig, self.enable_proxy)
|
||||
# Mark environment as used if there were postconfig setup operations
|
||||
if postconfig:
|
||||
self.is_environment_used = True
|
||||
|
||||
if self.evaluator['func'] == "infeasible":
|
||||
if len(self.action_history) > 0:
|
||||
last_action = self.action_history[-1]
|
||||
if last_action == "FAIL" or (type(last_action) == dict and last_action.get('action_type') == 'FAIL'):
|
||||
return 1
|
||||
return 0
|
||||
else:
|
||||
if len(self.action_history) > 0:
|
||||
last_action = self.action_history[-1]
|
||||
if last_action == "FAIL" or (type(last_action) == dict and last_action.get('action_type') == 'FAIL'):
|
||||
return 0
|
||||
|
||||
if type(self.metric) == list:
|
||||
# Multiple metrics to evaluate whether the task is successfully completed
|
||||
results = []
|
||||
assert len(self.metric) == len(self.result_getter), "The number of metrics and result getters must be the same"
|
||||
if "expected" in self.evaluator:
|
||||
assert len(self.metric) == len(self.expected_getter), "The number of metrics and expected getters must be the same"
|
||||
for idx, metric in enumerate(self.metric):
|
||||
try:
|
||||
config = self.evaluator["result"][idx]
|
||||
result_state = self.result_getter[idx](self, config)
|
||||
except FileNotFoundError:
|
||||
logger.error("File not found!")
|
||||
if self.metric_conj == 'and':
|
||||
return 0
|
||||
|
||||
if "expected" in self.evaluator and self.expected_getter and self.evaluator["expected"]:
|
||||
expected_state = self.expected_getter[idx](self, self.evaluator["expected"][idx])
|
||||
metric: int = metric(result_state, expected_state, **self.metric_options[idx])
|
||||
else:
|
||||
metric: int = metric(result_state, **self.metric_options[idx])
|
||||
|
||||
if self.metric_conj == 'and' and float(metric) == 0.0:
|
||||
return 0
|
||||
elif self.metric_conj == 'or' and float(metric) == 1.0:
|
||||
return 1
|
||||
else:
|
||||
results.append(metric)
|
||||
|
||||
return sum(results) / len(results) if self.metric_conj == 'and' else max(results)
|
||||
else:
|
||||
# Single metric to evaluate whether the task is successfully completed
|
||||
try:
|
||||
result_state = self.result_getter(self, self.evaluator["result"])
|
||||
except FileNotFoundError:
|
||||
logger.error("File not found!")
|
||||
return 0
|
||||
|
||||
if "expected" in self.evaluator and self.expected_getter and self.evaluator["expected"]:
|
||||
expected_state = self.expected_getter(self, self.evaluator["expected"])
|
||||
metric: float = self.metric(result_state, expected_state, **self.metric_options)
|
||||
else:
|
||||
metric: float = self.metric(result_state, **self.metric_options)
|
||||
|
||||
return metric
|
||||
|
||||
def render(self, mode='rgb_array'):
|
||||
if mode == 'rgb_array':
|
||||
return self.controller.get_screenshot()
|
||||
else:
|
||||
raise ValueError('Unsupported render mode: {}'.format(mode))
|
||||
|
|
@ -827,8 +827,8 @@ def get_active_tab_info(env, config: Dict[str, str]):
|
|||
|
||||
try:
|
||||
logger.info(f"[ACTIVE_TAB_INFO] Navigating to URL: {active_tab_url}")
|
||||
page.goto(active_tab_url, wait_until='networkidle', timeout=timeout_ms)
|
||||
page.wait_for_load_state('networkidle', timeout=timeout_ms) # Wait for the 'load' event to complete
|
||||
page.goto(active_tab_url, wait_until='load', timeout=timeout_ms)
|
||||
page.wait_for_load_state('load', timeout=timeout_ms) # Wait for the 'load' event to complete
|
||||
|
||||
active_tab_info = {
|
||||
'title': page.title(),
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ import functools
|
|||
import itertools
|
||||
import logging
|
||||
import os.path
|
||||
import re
|
||||
import unicodedata
|
||||
|
||||
# import operator
|
||||
from numbers import Number
|
||||
|
|
@ -744,6 +746,18 @@ def compare_table(result: str, expected: str = None, **options) -> float:
|
|||
# }}} function compare_table #
|
||||
|
||||
|
||||
def _normalize_city_string(value: Any) -> str:
|
||||
"""Lowercase, strip punctuation, and remove accents for tolerant matching."""
|
||||
if value is None:
|
||||
return ""
|
||||
if not isinstance(value, str):
|
||||
value = str(value)
|
||||
normalized = unicodedata.normalize("NFKD", value)
|
||||
normalized = "".join(ch for ch in normalized if not unicodedata.combining(ch))
|
||||
normalized = re.sub(r"[^a-z0-9]+", " ", normalized.lower())
|
||||
return normalized.strip()
|
||||
|
||||
|
||||
def compare_conference_city_in_order(actual_city_list_path, expected_city):
|
||||
expected_city_list = expected_city["expected"]
|
||||
wb = openpyxl.load_workbook(actual_city_list_path)
|
||||
|
|
@ -752,38 +766,35 @@ def compare_conference_city_in_order(actual_city_list_path, expected_city):
|
|||
for row in sheet["C2:C22"]:
|
||||
for cell in row:
|
||||
actual_city_list.append(cell.value)
|
||||
# expected_city is the city that we want to compare with the actual city list
|
||||
# must in order index
|
||||
# debug
|
||||
|
||||
try:
|
||||
for i in range(len(actual_city_list)):
|
||||
if isinstance(expected_city_list[i], str):
|
||||
if expected_city_list[i] not in actual_city_list[i]:
|
||||
logger.debug(
|
||||
f"Expected city {expected_city_list[i]}; Actual city {actual_city_list[i]}"
|
||||
)
|
||||
print(
|
||||
f"Expected city {expected_city_list[i]}; Actual city {actual_city_list[i]}"
|
||||
)
|
||||
return 0.0
|
||||
|
||||
elif isinstance(expected_city_list[i], List):
|
||||
if not any(
|
||||
possible_str in actual_city_list[i]
|
||||
for possible_str in expected_city_list[i]
|
||||
):
|
||||
logger.debug(
|
||||
f"Expected city {expected_city_list[i]}; Actual city {actual_city_list[i]}"
|
||||
)
|
||||
print(
|
||||
f"Expected city {expected_city_list[i]}; Actual city {actual_city_list[i]}"
|
||||
)
|
||||
return 0.0
|
||||
for i, actual_city in enumerate(actual_city_list):
|
||||
actual_normalized = _normalize_city_string(actual_city)
|
||||
expected_entry = expected_city_list[i]
|
||||
|
||||
if isinstance(expected_entry, str):
|
||||
expected_candidates = [expected_entry]
|
||||
elif isinstance(expected_entry, List):
|
||||
expected_candidates = expected_entry
|
||||
else:
|
||||
raise TypeError("Expected city should be a string or a list of strings")
|
||||
|
||||
except:
|
||||
matched = False
|
||||
for candidate in expected_candidates:
|
||||
normalized_candidate = _normalize_city_string(candidate)
|
||||
if normalized_candidate and normalized_candidate in actual_normalized:
|
||||
matched = True
|
||||
break
|
||||
|
||||
if not matched:
|
||||
logger.debug(
|
||||
f"Expected city {expected_entry}; Actual city {actual_city}"
|
||||
)
|
||||
print(f"Expected city {expected_entry}; Actual city {actual_city}")
|
||||
return 0.0
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"Error comparing conference cities: {exc}")
|
||||
return 0.0
|
||||
|
||||
return 1.0
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from desktop_env.providers.aws.config import ENABLE_TTL, DEFAULT_TTL_MINUTES, AW
|
|||
from desktop_env.providers.aws.scheduler_utils import schedule_instance_termination
|
||||
|
||||
|
||||
INSTANCE_TYPE = "t3.medium"
|
||||
INSTANCE_TYPE = "t3.xlarge"
|
||||
|
||||
# Load environment variables from .env file
|
||||
dotenv.load_dotenv()
|
||||
|
|
@ -40,9 +40,9 @@ DEFAULT_REGION = "us-east-1"
|
|||
# todo: public the AMI images
|
||||
IMAGE_ID_MAP = {
|
||||
"us-east-1": {
|
||||
# (1920, 1080): "ami-0d23263edb96951d8"
|
||||
(1920, 1080): "ami-0d23263edb96951d8",
|
||||
# For CoACT-1, uncomment to use the following AMI
|
||||
(1920, 1080): "ami-0b505e9d0d99ba88c"
|
||||
# (1920, 1080): "ami-0b505e9d0d99ba88c"
|
||||
},
|
||||
"ap-east-1": {
|
||||
(1920, 1080): "ami-06850864d18fad836"
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@
|
|||
"type": "rule",
|
||||
"rules": {
|
||||
"expected": [
|
||||
"united.com/en/us/checked-bag-fee-calculator"
|
||||
"united\\.com/en/us/checked-bag-fee-calculator(/.*)?"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -88,7 +88,7 @@
|
|||
],
|
||||
"func": "check_image_mirror",
|
||||
"expected": {
|
||||
"type": "vm_file",
|
||||
"type": "cloud_file",
|
||||
"path": "https://huggingface.co/datasets/xlangai/ubuntu_osworld_file_cache/resolve/main/gimp/72f83cdc-bf76-4531-9a1b-eb893a13f8aa/berry.jpeg",
|
||||
"dest": "berry.png"
|
||||
},
|
||||
|
|
|
|||
|
|
@ -32,8 +32,8 @@
|
|||
"evaluator": {
|
||||
"func": "check_file_exists_and_structure_sim",
|
||||
"expected": {
|
||||
"type": "vm_file",
|
||||
"path": "/home/user/Desktop/The_Lost_River_Of_Dreams.jpg",
|
||||
"type": "cloud_file",
|
||||
"path": "https://huggingface.co/datasets/xlangai/ubuntu_osworld_file_cache/resolve/main/gimp/77b8ab4d-994f-43ac-8930-8ca087d7c4b4/The_Lost_River_Of_Dreams.jpg",
|
||||
"dest": "The_Lost_River_Of_Dreams.jpg"
|
||||
},
|
||||
"result": {
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@
|
|||
"rules": {
|
||||
"expected": [
|
||||
"Zoom Chrome Extension",
|
||||
"Speechify Text to Speech Voice Reader",
|
||||
"Speechify — Voice AI Assistant",
|
||||
"React Developer Tools",
|
||||
"Momentum",
|
||||
"Google Translate"
|
||||
|
|
|
|||
|
|
@ -40,8 +40,8 @@
|
|||
},
|
||||
"result": {
|
||||
"type": "vm_file",
|
||||
"path": "/home/user/Recruitment_and_retention_of_health_professionals_across_Europe.zip",
|
||||
"dest": "Recruitment_and_retention_of_health_professionals_across_Europe.zip"
|
||||
"path": "/home/user/essay_submission.zip",
|
||||
"dest": "essay_submission.zip"
|
||||
}
|
||||
},
|
||||
"proxy": false,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,135 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Thread-safe results logging for OSWorld evaluations.
|
||||
Appends task completion results to results.json in real-time.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import fcntl
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
|
||||
def extract_domain_from_path(result_path: str) -> str:
|
||||
"""
|
||||
Extract domain/application from result directory path.
|
||||
Expected structure: results/{action_space}/{observation_type}/{model}/{domain}/{task_id}/
|
||||
"""
|
||||
path_parts = Path(result_path).parts
|
||||
if len(path_parts) >= 2:
|
||||
return path_parts[-2] # Second to last part should be domain
|
||||
return "unknown"
|
||||
|
||||
|
||||
def append_task_result(
|
||||
task_id: str,
|
||||
domain: str,
|
||||
score: float,
|
||||
result_dir: str,
|
||||
args: Any,
|
||||
error_message: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
Thread-safely append a task result to results.json.
|
||||
|
||||
Args:
|
||||
task_id: UUID of the task
|
||||
domain: Application domain (chrome, vlc, etc.)
|
||||
score: Task score (0.0 or 1.0)
|
||||
result_dir: Full path to the task result directory
|
||||
args: Command line arguments object
|
||||
error_message: Error message if task failed
|
||||
"""
|
||||
# Create result entry
|
||||
result_entry = {
|
||||
"application": domain,
|
||||
"task_id": task_id,
|
||||
"status": "error" if error_message else "success",
|
||||
"score": score,
|
||||
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
}
|
||||
|
||||
if error_message:
|
||||
result_entry["err_message"] = error_message
|
||||
|
||||
# Determine summary directory and results file path
|
||||
# Extract base result directory from args
|
||||
base_result_dir = Path(args.result_dir)
|
||||
summary_dir = base_result_dir / "summary"
|
||||
results_file = summary_dir / "results.json"
|
||||
|
||||
# Ensure summary directory exists
|
||||
summary_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Thread-safe JSON append with file locking
|
||||
try:
|
||||
with open(results_file, 'a+') as f:
|
||||
# Lock the file for exclusive access
|
||||
fcntl.flock(f.fileno(), fcntl.LOCK_EX)
|
||||
|
||||
try:
|
||||
# Move to beginning to read existing content
|
||||
f.seek(0)
|
||||
content = f.read().strip()
|
||||
|
||||
# Parse existing JSON array or create new one
|
||||
if content:
|
||||
try:
|
||||
existing_results = json.loads(content)
|
||||
if not isinstance(existing_results, list):
|
||||
existing_results = []
|
||||
except json.JSONDecodeError:
|
||||
existing_results = []
|
||||
else:
|
||||
existing_results = []
|
||||
|
||||
# Add new result
|
||||
existing_results.append(result_entry)
|
||||
|
||||
# Write back the complete JSON array
|
||||
f.seek(0)
|
||||
f.truncate()
|
||||
json.dump(existing_results, f, indent=2)
|
||||
f.write('\n') # Add newline for readability
|
||||
|
||||
finally:
|
||||
# Always unlock the file
|
||||
fcntl.flock(f.fileno(), fcntl.LOCK_UN)
|
||||
|
||||
print(f"📝 Logged result: {domain}/{task_id} -> {result_entry['status']} (score: {score})")
|
||||
|
||||
except Exception as e:
|
||||
# Don't let logging errors break the main evaluation
|
||||
print(f"⚠️ Failed to log result for {task_id}: {e}")
|
||||
|
||||
|
||||
def log_task_completion(example: Dict, result: float, result_dir: str, args: Any) -> None:
|
||||
"""
|
||||
Convenience wrapper for logging successful task completion.
|
||||
|
||||
Args:
|
||||
example: Task configuration dictionary
|
||||
result: Task score
|
||||
result_dir: Path to task result directory
|
||||
args: Command line arguments
|
||||
"""
|
||||
task_id = example.get('id', 'unknown')
|
||||
domain = extract_domain_from_path(result_dir)
|
||||
append_task_result(task_id, domain, result, result_dir, args)
|
||||
|
||||
|
||||
def log_task_error(example: Dict, error_msg: str, result_dir: str, args: Any) -> None:
|
||||
"""
|
||||
Convenience wrapper for logging task errors.
|
||||
|
||||
Args:
|
||||
example: Task configuration dictionary
|
||||
error_msg: Error message
|
||||
result_dir: Path to task result directory
|
||||
args: Command line arguments
|
||||
"""
|
||||
task_id = example.get('id', 'unknown')
|
||||
domain = extract_domain_from_path(result_dir)
|
||||
append_task_result(task_id, domain, 0.0, result_dir, args, error_msg)
|
||||
|
|
@ -4,18 +4,22 @@ import logging
|
|||
import os
|
||||
import time
|
||||
from wrapt_timeout_decorator import *
|
||||
from lib_results_logger import log_task_completion
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
|
||||
def run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir, scores):
|
||||
runtime_logger = setup_logger(example, example_result_dir)
|
||||
try:
|
||||
agent.reset(runtime_logger)
|
||||
except Exception as e:
|
||||
agent.reset()
|
||||
|
||||
# Reset environment first to get fresh VM IP
|
||||
env.reset(task_config=example)
|
||||
|
||||
# Reset agent with fresh VM IP (for snapshot reverts)
|
||||
try:
|
||||
agent.reset(runtime_logger, vm_ip=env.vm_ip)
|
||||
except Exception as e:
|
||||
agent.reset(vm_ip=env.vm_ip)
|
||||
|
||||
time.sleep(60) # Wait for the environment to be ready
|
||||
obs = env._get_obs() # Get the initial observation
|
||||
|
|
@ -29,7 +33,7 @@ def run_single_example(agent, env, example, max_steps, instruction, args, exampl
|
|||
)
|
||||
for action in actions:
|
||||
# Capture the timestamp before executing the action
|
||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S%f")
|
||||
logger.info("Step %d: %s", step_idx + 1, action)
|
||||
obs, reward, done, info = env.step(action, args.sleep_after_execution)
|
||||
|
||||
|
|
@ -55,11 +59,16 @@ def run_single_example(agent, env, example, max_steps, instruction, args, exampl
|
|||
logger.info("The episode is done.")
|
||||
break
|
||||
step_idx += 1
|
||||
time.sleep(20) # Wait for the environment to settle
|
||||
result = env.evaluate()
|
||||
logger.info("Result: %.2f", result)
|
||||
scores.append(result)
|
||||
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{result}\n")
|
||||
|
||||
# Log task completion to results.json
|
||||
log_task_completion(example, result, example_result_dir, args)
|
||||
|
||||
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
|
||||
|
||||
|
||||
|
|
@ -96,6 +105,67 @@ def run_single_example_human(env, example, max_steps, instruction, args, example
|
|||
|
||||
|
||||
|
||||
def run_single_example_agi(agent, env, example, max_steps, instruction, args, example_result_dir, scores):
|
||||
runtime_logger = setup_logger(example, example_result_dir)
|
||||
agent.reset(runtime_logger)
|
||||
env.reset(task_config=example)
|
||||
time.sleep(60) # Wait for the environment to be ready
|
||||
obs = env._get_obs() # Get the initial observation
|
||||
done = False
|
||||
step_idx = 0
|
||||
env.controller.start_recording()
|
||||
while not done and step_idx < max_steps:
|
||||
response, actions = agent.predict(
|
||||
instruction,
|
||||
obs
|
||||
)
|
||||
|
||||
done = not response.get('state_correct', False)
|
||||
|
||||
for action in actions:
|
||||
# Capture the timestamp before executing the action
|
||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
logger.info("Step %d: %s", step_idx + 1, action)
|
||||
obs, reward, done, info, step_info = agent.step(action)
|
||||
|
||||
if not done:
|
||||
if not response.get('state_correct', False):
|
||||
done = True
|
||||
|
||||
logger.info("Reward: %.2f", reward)
|
||||
logger.info("Done: %s", done)
|
||||
# Save screenshot and trajectory information
|
||||
with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"),
|
||||
"wb") as _f:
|
||||
_f.write(obs['screenshot'])
|
||||
|
||||
# Remove pending checks if they exist which will cause issues with json serialization
|
||||
if action.get('pending_checks', None):
|
||||
del action['pending_checks']
|
||||
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||
f.write(json.dumps({
|
||||
"step_num": step_idx + 1,
|
||||
"action_timestamp": action_timestamp,
|
||||
"action": action,
|
||||
"reward": reward,
|
||||
"done": done,
|
||||
"info": info,
|
||||
"screenshot_file": f"step_{step_idx + 1}_{action_timestamp}.png"
|
||||
}))
|
||||
f.write("\n")
|
||||
if done:
|
||||
logger.info("The episode is done.")
|
||||
break
|
||||
step_idx += 1
|
||||
result = env.evaluate()
|
||||
logger.info("Result: %.2f", result)
|
||||
scores.append(result)
|
||||
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{result}\n")
|
||||
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
|
||||
|
||||
|
||||
def run_single_example_openaicua(agent, env, example, max_steps, instruction, args, example_result_dir, scores):
|
||||
runtime_logger = setup_logger(example, example_result_dir)
|
||||
agent.reset(runtime_logger)
|
||||
|
|
@ -186,23 +256,25 @@ def run_single_example_opencua(agent, env, example, max_steps, instruction, args
|
|||
"wb") as _f:
|
||||
_f.write(obs['screenshot'])
|
||||
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps({
|
||||
"step_num": step_idx + 1,
|
||||
"action_timestamp": action_timestamp,
|
||||
"action": action,
|
||||
"natural_language_action": info_dict.get("action"),
|
||||
"action_timestamp": action_timestamp,
|
||||
"response": response,
|
||||
"reward": reward,
|
||||
"done": done,
|
||||
"info": info,
|
||||
"screenshot_file": f"step_{step_idx + 1}_{action_timestamp}.png"
|
||||
}))
|
||||
}, ensure_ascii=False))
|
||||
f.write("\n")
|
||||
if done:
|
||||
logger.info("The episode is done.")
|
||||
break
|
||||
step_idx += 1
|
||||
|
||||
time.sleep(20) # Wait for the environment to settle
|
||||
result = env.evaluate()
|
||||
logger.info("Result: %.2f", result)
|
||||
scores.append(result)
|
||||
|
|
@ -389,3 +461,185 @@ def run_single_example_uipath(agent, env, example, max_steps, instruction, args,
|
|||
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{result}\n")
|
||||
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
|
||||
|
||||
|
||||
from mm_agents.os_symphony.utils.common_utils import draw_coordinates
|
||||
from mm_agents.os_symphony.utils.process_context import set_current_result_dir
|
||||
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
def run_single_example_os_symphony(agent, env, example, max_steps, instruction, args, example_result_dir, scores):
|
||||
set_current_result_dir(example_result_dir)
|
||||
|
||||
agent.reset(result_dir=example_result_dir)
|
||||
env.reset(task_config=example)
|
||||
time.sleep(30) # Wait for the environment to be ready
|
||||
obs = env._get_obs() # Get the initial observation
|
||||
done = False
|
||||
step_idx = 0
|
||||
# env.controller.start_recording()
|
||||
start_time = time.time()
|
||||
|
||||
while not done and step_idx < max_steps:
|
||||
response, actions = agent.predict(
|
||||
instruction,
|
||||
obs,
|
||||
step_idx == max_steps - 1
|
||||
)
|
||||
for action in actions:
|
||||
# Save screenshot and trajectory information
|
||||
if "reflection" in response and response["reflection"].get("is_milestone"):
|
||||
img_name = f"step_{step_idx + 1}_milestone.png"
|
||||
else:
|
||||
img_name = f"step_{step_idx + 1}.png"
|
||||
|
||||
with open(os.path.join(example_result_dir, img_name),
|
||||
"wb") as _f:
|
||||
_f.write(obs['screenshot'])
|
||||
if "coordinates" in response and response["coordinates"]:
|
||||
draw_coordinates(
|
||||
image_bytes=obs['screenshot'],
|
||||
coordinates=response["coordinates"],
|
||||
save_path=os.path.join(example_result_dir, img_name[:-4] + "_draw.png")
|
||||
)
|
||||
|
||||
logger.info("Step %d: %s", step_idx + 1, action)
|
||||
obs, reward, done, info = env.step(action, args.sleep_after_execution)
|
||||
logger.info("Done: %s", done)
|
||||
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps({
|
||||
"instruction": instruction,
|
||||
"step_num": step_idx + 1,
|
||||
"action": action,
|
||||
"response": response,
|
||||
"done": done,
|
||||
"info": info,
|
||||
"screenshot_file": img_name
|
||||
}))
|
||||
f.write("\n")
|
||||
with open(os.path.join(example_result_dir, f"traj_{step_idx+1}.json"), "w", encoding="utf-8") as f:
|
||||
json.dump({
|
||||
"step_num": step_idx + 1,
|
||||
"action": action,
|
||||
"response": response,
|
||||
"done": done,
|
||||
"info": info,
|
||||
"screenshot_file": img_name
|
||||
}, f, indent=4, ensure_ascii=False)
|
||||
if done:
|
||||
logger.info("The episode is done.")
|
||||
time.sleep(60)
|
||||
break
|
||||
step_idx += 1
|
||||
end_time = time.time()
|
||||
result = float(env.evaluate())
|
||||
logger.info("Result: %.2f", result)
|
||||
scores.append(result)
|
||||
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{result}\n")
|
||||
|
||||
with open(os.path.join(example_result_dir, "time.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{end_time-start_time:.2f}\n")
|
||||
|
||||
|
||||
def run_single_example_evocua(agent, env, example, max_steps, instruction, args, example_result_dir, scores):
|
||||
"""
|
||||
Unified run function for EvoCUAAgent (supporting both S1 and S2 modes).
|
||||
"""
|
||||
runtime_logger = setup_logger(example, example_result_dir)
|
||||
|
||||
# Reset Environment
|
||||
env.reset(task_config=example)
|
||||
|
||||
# Reset Agent
|
||||
# Handle agent reset signature differences if any
|
||||
try:
|
||||
agent.reset(runtime_logger, vm_ip=env.vm_ip)
|
||||
except Exception:
|
||||
try:
|
||||
agent.reset(runtime_logger)
|
||||
except Exception:
|
||||
agent.reset()
|
||||
|
||||
time.sleep(60) # Wait for the environment to be ready
|
||||
obs = env._get_obs() # Get the initial observation
|
||||
done = False
|
||||
step_idx = 0
|
||||
|
||||
env.controller.start_recording()
|
||||
while not done and step_idx < max_steps:
|
||||
# EvoCUAAgent.predict unified signature: returns (response, actions)
|
||||
# It handles both modes internally.
|
||||
predict_res = agent.predict(instruction, obs)
|
||||
|
||||
# Check return signature logic
|
||||
if len(predict_res) == 3:
|
||||
# Compatibility with S1 original signature if agent was updated to match
|
||||
response, actions, info_dict = predict_res
|
||||
else:
|
||||
response, actions = predict_res
|
||||
info_dict = {}
|
||||
|
||||
logger.info(f"Step {step_idx + 1} Actions: {actions}")
|
||||
|
||||
# Break if no actions (fail-safe)
|
||||
if not actions or (len(actions) == 1 and (actions[0] == "" or "error" in actions[0].lower())):
|
||||
# Allow "FAIL" or "DONE" to process through execution loop if agent outputs them as actions
|
||||
if not (actions and actions[0] in ["FAIL", "DONE"]):
|
||||
logger.warning("No valid actions returned. Breaking loop.")
|
||||
break
|
||||
|
||||
for action in actions:
|
||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S%f")
|
||||
logger.info("Executing action: %s", action)
|
||||
|
||||
# Execute
|
||||
obs, reward, done, info = env.step(action, args.sleep_after_execution)
|
||||
|
||||
logger.info("Reward: %.2f", reward)
|
||||
logger.info("Done: %s", done)
|
||||
|
||||
# Save screenshot
|
||||
screenshot_file = f"step_{step_idx + 1}_{action_timestamp}.png"
|
||||
with open(os.path.join(example_result_dir, screenshot_file), "wb") as _f:
|
||||
_f.write(obs['screenshot'])
|
||||
|
||||
# Log Trajectory
|
||||
log_entry = {
|
||||
"step_num": step_idx + 1,
|
||||
"action_timestamp": action_timestamp,
|
||||
"action": action,
|
||||
"response": response,
|
||||
"reward": reward,
|
||||
"done": done,
|
||||
"info": info,
|
||||
"screenshot_file": screenshot_file
|
||||
}
|
||||
# Add natural language info if available (S1 style)
|
||||
if info_dict:
|
||||
log_entry["natural_language_action"] = info_dict.get("action")
|
||||
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(log_entry, ensure_ascii=False))
|
||||
f.write("\n")
|
||||
|
||||
if done:
|
||||
logger.info("The episode is done.")
|
||||
break
|
||||
|
||||
step_idx += 1
|
||||
|
||||
time.sleep(20) # Wait for environment to settle
|
||||
result = env.evaluate()
|
||||
logger.info("Result: %.2f", result)
|
||||
scores.append(result)
|
||||
|
||||
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{result}\n")
|
||||
|
||||
log_task_completion(example, result, example_result_dir, args)
|
||||
|
||||
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,84 @@
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from wrapt_timeout_decorator import *
|
||||
from mm_agents.os_symphony.utils.common_utils import draw_coordinates
|
||||
from mm_agents.os_symphony.utils.process_context import set_current_result_dir
|
||||
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
def run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir, scores):
|
||||
set_current_result_dir(example_result_dir)
|
||||
|
||||
agent.reset(result_dir=example_result_dir)
|
||||
env.reset(task_config=example)
|
||||
time.sleep(30) # Wait for the environment to be ready
|
||||
obs = env._get_obs() # Get the initial observation
|
||||
done = False
|
||||
step_idx = 0
|
||||
# env.controller.start_recording()
|
||||
start_time = time.time()
|
||||
|
||||
while not done and step_idx < max_steps:
|
||||
response, actions = agent.predict(
|
||||
instruction,
|
||||
obs,
|
||||
step_idx == max_steps - 1
|
||||
)
|
||||
for action in actions:
|
||||
# Save screenshot and trajectory information
|
||||
if "reflection" in response and response["reflection"].get("is_milestone"):
|
||||
img_name = f"step_{step_idx + 1}_milestone.png"
|
||||
else:
|
||||
img_name = f"step_{step_idx + 1}.png"
|
||||
|
||||
with open(os.path.join(example_result_dir, img_name),
|
||||
"wb") as _f:
|
||||
_f.write(obs['screenshot'])
|
||||
if "coordinates" in response and response["coordinates"]:
|
||||
draw_coordinates(
|
||||
image_bytes=obs['screenshot'],
|
||||
coordinates=response["coordinates"],
|
||||
save_path=os.path.join(example_result_dir, img_name[:-4] + "_draw.png")
|
||||
)
|
||||
|
||||
logger.info("Step %d: %s", step_idx + 1, action)
|
||||
obs, reward, done, info = env.step(action, args.sleep_after_execution)
|
||||
logger.info("Done: %s", done)
|
||||
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps({
|
||||
"instruction": instruction,
|
||||
"step_num": step_idx + 1,
|
||||
"action": action,
|
||||
"response": response,
|
||||
"done": done,
|
||||
"info": info,
|
||||
"screenshot_file": img_name
|
||||
}))
|
||||
f.write("\n")
|
||||
with open(os.path.join(example_result_dir, f"traj_{step_idx+1}.json"), "w", encoding="utf-8") as f:
|
||||
json.dump({
|
||||
"step_num": step_idx + 1,
|
||||
"action": action,
|
||||
"response": response,
|
||||
"done": done,
|
||||
"info": info,
|
||||
"screenshot_file": img_name
|
||||
}, f, indent=4, ensure_ascii=False)
|
||||
if done:
|
||||
logger.info("The episode is done.")
|
||||
time.sleep(60)
|
||||
break
|
||||
step_idx += 1
|
||||
end_time = time.time()
|
||||
result = float(env.evaluate())
|
||||
logger.info("Result: %.2f", result)
|
||||
scores.append(result)
|
||||
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{result}\n")
|
||||
|
||||
with open(os.path.join(example_result_dir, "time.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{end_time-start_time:.2f}\n")
|
||||
|
|
@ -1134,10 +1134,12 @@ class PromptAgent:
|
|||
|
||||
return actions
|
||||
|
||||
def reset(self, _logger=None):
|
||||
def reset(self, _logger=None, vm_ip=None, **kwargs):
|
||||
global logger
|
||||
logger = _logger if _logger is not None else logging.getLogger("desktopenv.agent")
|
||||
|
||||
self.vm_ip = vm_ip
|
||||
|
||||
self.thoughts = []
|
||||
self.actions = []
|
||||
self.observations = []
|
||||
self.observations = []
|
||||
|
|
@ -0,0 +1,219 @@
|
|||
import base64
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, List, Tuple, Any, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger("desktopenv.agent")
|
||||
|
||||
|
||||
class Timer:
|
||||
"""Context manager for timing code blocks."""
|
||||
|
||||
def __enter__(self):
|
||||
self.start = time.time()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.duration = time.time() - self.start
|
||||
|
||||
|
||||
class AGIAgent:
|
||||
"""Agent that communicates with your private AGI server for decision-making."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env,
|
||||
server_url: str = "https://your-private-agi-endpoint", # Contact the authors for access to a private deployment endpoint.
|
||||
platform: str = "ubuntu",
|
||||
action_space: str = "pyautogui",
|
||||
observation_type: str = "screenshot",
|
||||
max_trajectory_length: int = 100,
|
||||
client_password: str = "",
|
||||
provider_name: str = "aws",
|
||||
screen_width: int = 1920,
|
||||
screen_height: int = 1080,
|
||||
timeout: int = 1800,
|
||||
):
|
||||
"""Initialize the AGI client.
|
||||
|
||||
Args:
|
||||
env: The desktop environment
|
||||
server_url: URL of your private AGI server
|
||||
"""
|
||||
self.env = env
|
||||
self.server_url = server_url.rstrip("/")
|
||||
self.platform = platform
|
||||
self.action_space = action_space
|
||||
self.observation_type = observation_type
|
||||
self.max_trajectory_length = max_trajectory_length
|
||||
self.client_password = client_password
|
||||
self.provider_name = provider_name
|
||||
self.screen_width = screen_width
|
||||
self.screen_height = screen_height
|
||||
|
||||
# Session management
|
||||
self.session_id: Optional[str] = None
|
||||
self.instruction: Optional[str] = None
|
||||
|
||||
# HTTP client
|
||||
self.client = httpx.Client(timeout=timeout)
|
||||
|
||||
# Tracking
|
||||
self.thoughts = []
|
||||
self.actions = []
|
||||
self.observations = []
|
||||
|
||||
logger.info(f"Initialized AGIAgent with server URL: {self.server_url}")
|
||||
|
||||
def reset(self, runtime_logger=None):
|
||||
"""Reset the agent and create a new session on the server.
|
||||
|
||||
Args:
|
||||
runtime_logger: Optional logger for runtime information
|
||||
"""
|
||||
global logger
|
||||
logger = runtime_logger if runtime_logger is not None else logging.getLogger("desktopenv.agent")
|
||||
|
||||
# Clear local state
|
||||
self.thoughts = []
|
||||
self.actions = []
|
||||
self.observations = []
|
||||
self.session_id = None
|
||||
|
||||
logger.info("AGIAgent reset complete")
|
||||
|
||||
def _create_session(self, instruction: str) -> str:
|
||||
"""Create a new session on the server.
|
||||
|
||||
Args:
|
||||
instruction: The task instruction
|
||||
|
||||
Returns:
|
||||
The session ID
|
||||
|
||||
Equivalent curl request:
|
||||
curl -X POST {server_url}/sessions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"task_description": "{instruction}"}'
|
||||
"""
|
||||
try:
|
||||
# print(f"Creating session with instruction: {instruction}")
|
||||
# print(f"Server URL: {self.server_url}")
|
||||
response = self.client.post(
|
||||
f"{self.server_url}/sessions",
|
||||
json={"task_description": instruction}
|
||||
)
|
||||
response.raise_for_status()
|
||||
session_id = response.json()["session_id"]
|
||||
logger.info(f"Created session: {session_id}")
|
||||
return session_id
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create session: {e}")
|
||||
raise
|
||||
|
||||
def predict(self, instruction: str, obs: Dict) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
|
||||
"""Predict the next action based on the current observation.
|
||||
|
||||
Args:
|
||||
instruction: The task instruction
|
||||
obs: Observation dictionary containing 'screenshot' key with image bytes
|
||||
|
||||
Returns:
|
||||
Tuple of (predict_info dict, list of action dicts)
|
||||
"""
|
||||
# Create session on first prediction
|
||||
if self.session_id is None:
|
||||
self.instruction = instruction
|
||||
self.session_id = self._create_session(instruction)
|
||||
|
||||
# input("Session created, press Enter to continue")
|
||||
|
||||
# Encode screenshot to base64
|
||||
screenshot_bytes = obs["screenshot"]
|
||||
screenshot_b64 = base64.b64encode(screenshot_bytes).decode("utf-8")
|
||||
|
||||
# Call the server
|
||||
with Timer() as model_timer:
|
||||
try:
|
||||
response = self.client.post(
|
||||
f"{self.server_url}/sessions/{self.session_id}/step",
|
||||
json={
|
||||
"screenshot_base64_png": screenshot_b64,
|
||||
"error": None # Could be populated from previous step errors
|
||||
}
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
parsed_action = result["parsed_response"]
|
||||
|
||||
logger.info(f"Server returned action: {parsed_action[:100]}...")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling server: {e}")
|
||||
raise
|
||||
|
||||
# Format response as expected by lib_run_single
|
||||
actions = [{
|
||||
"action_space": "pyautogui",
|
||||
"action": parsed_action,
|
||||
"pending_checks": [],
|
||||
"call_id": ""
|
||||
}]
|
||||
|
||||
# Check if task is complete or failed
|
||||
state_correct = parsed_action not in ["FAIL", "DONE"]
|
||||
|
||||
predict_info = {
|
||||
"model_usage": {
|
||||
"model_time": model_timer.duration,
|
||||
"prompt_tokens": 0, # Server doesn't expose these
|
||||
"completion_tokens": 0,
|
||||
},
|
||||
"messages": [], # Server manages conversation history
|
||||
"response": parsed_action,
|
||||
"state_correct": state_correct,
|
||||
}
|
||||
|
||||
return predict_info, actions
|
||||
|
||||
def step(self, action: Dict[str, Any]) -> Tuple[Dict, float, bool, Dict, Dict]:
|
||||
"""Execute an action in the environment.
|
||||
|
||||
Args:
|
||||
action: Action dictionary with 'action' key containing PyAutoGUI command
|
||||
|
||||
Returns:
|
||||
Tuple of (observation, reward, done, info, step_info)
|
||||
"""
|
||||
try:
|
||||
if not action:
|
||||
logger.warning("Empty action received, terminating episode")
|
||||
# Get observation without executing action
|
||||
obs = self.env._get_obs()
|
||||
return obs, 0.0, True, {}, {"step_time": 0.0, "action": action}
|
||||
|
||||
action_str = action.get("action", "")
|
||||
logger.info(f"Executing action: {action_str[:100]}...")
|
||||
|
||||
with Timer() as step_timer:
|
||||
# Execute the action directly (it's already a PyAutoGUI command string)
|
||||
obs, reward, terminated, info = self.env.step(action_str)
|
||||
|
||||
logger.debug(f"Action completed in {step_timer.duration:.2f}s")
|
||||
if terminated:
|
||||
logger.info("Environment signaled termination")
|
||||
|
||||
return obs, reward, terminated, info, {
|
||||
"step_time": step_timer.duration,
|
||||
"action": action
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Environment step failed: {str(e)}")
|
||||
raise
|
||||
|
||||
def close(self):
|
||||
"""Close the HTTP client."""
|
||||
self.client.close()
|
||||
|
|
@ -17,7 +17,7 @@ from anthropic.types.beta import (
|
|||
BetaMessageParam,
|
||||
BetaTextBlockParam,
|
||||
)
|
||||
from .utils import COMPUTER_USE_BETA_FLAG, PROMPT_CACHING_BETA_FLAG,SYSTEM_PROMPT, SYSTEM_PROMPT_WINDOWS, APIProvider, PROVIDER_TO_DEFAULT_MODEL_NAME
|
||||
from .utils import COMPUTER_USE_BETA_FLAG, PROMPT_CACHING_BETA_FLAG,SYSTEM_PROMPT, SYSTEM_PROMPT_WINDOWS, APIProvider, PROVIDER_TO_DEFAULT_MODEL_NAME, get_model_name
|
||||
from .utils import _response_to_params, _inject_prompt_caching, _maybe_filter_to_n_most_recent_images
|
||||
|
||||
import logging
|
||||
|
|
@ -30,14 +30,18 @@ API_RETRY_INTERVAL = 5
|
|||
class AnthropicAgent:
|
||||
def __init__(self,
|
||||
platform: str = "Ubuntu",
|
||||
model: str = "claude-3-5-sonnet-20241022",
|
||||
provider: APIProvider = APIProvider.BEDROCK,
|
||||
model: str = "claude-sonnet-4-5-20250929",
|
||||
provider: APIProvider = APIProvider.ANTHROPIC,
|
||||
max_tokens: int = 4096,
|
||||
api_key: str = os.environ.get("ANTHROPIC_API_KEY", None),
|
||||
system_prompt_suffix: str = "",
|
||||
only_n_most_recent_images: Optional[int] = 10,
|
||||
action_space: str = "claude_computer_use",
|
||||
screen_size: tuple[int, int] = (1920, 1080),
|
||||
no_thinking: bool = False,
|
||||
use_isp: bool = False,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
*args, **kwargs
|
||||
):
|
||||
self.platform = platform
|
||||
|
|
@ -52,10 +56,24 @@ class AnthropicAgent:
|
|||
self.only_n_most_recent_images = only_n_most_recent_images
|
||||
self.messages: list[BetaMessageParam] = []
|
||||
self.screen_size = screen_size
|
||||
self.no_thinking = no_thinking
|
||||
self.use_isp = use_isp
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
|
||||
self.resize_factor = (
|
||||
screen_size[0] / 1280, # Assuming 1280 is the base width
|
||||
screen_size[1] / 720 # Assuming 720 is the base height
|
||||
)
|
||||
|
||||
def _get_sampling_params(self):
|
||||
"""Get sampling parameters (temperature and/or top_p) - let API validate exclusivity"""
|
||||
params = {}
|
||||
if self.temperature is not None:
|
||||
params['temperature'] = self.temperature
|
||||
if self.top_p is not None:
|
||||
params['top_p'] = self.top_p
|
||||
return params
|
||||
|
||||
def add_tool_result(self, tool_call_id: str, result: str, screenshot: bytes = None):
|
||||
"""Add tool result to message history"""
|
||||
|
|
@ -84,6 +102,21 @@ class AnthropicAgent:
|
|||
"content": tool_result_content
|
||||
})
|
||||
|
||||
def _extract_raw_response_string(self, response) -> str:
|
||||
"""Extract and concatenate raw response content into a single string."""
|
||||
raw_response_str = ""
|
||||
if response.content:
|
||||
for block in response.content:
|
||||
if hasattr(block, 'text') and block.text:
|
||||
raw_response_str += f"[TEXT] {block.text}\n"
|
||||
elif hasattr(block, 'thinking') and block.thinking:
|
||||
raw_response_str += f"[THINKING] {block.thinking}\n"
|
||||
elif hasattr(block, 'name') and hasattr(block, 'input'):
|
||||
raw_response_str += f"[TOOL_USE] {block.name}: {block.input}\n"
|
||||
else:
|
||||
raw_response_str += f"[OTHER] {str(block)}\n"
|
||||
return raw_response_str.strip()
|
||||
|
||||
def parse_actions_from_tool_call(self, tool_call: Dict) -> str:
|
||||
result = ""
|
||||
function_args = (
|
||||
|
|
@ -194,13 +227,23 @@ class AnthropicAgent:
|
|||
result += (f"pyautogui.keyUp('{key}')\n")
|
||||
expected_outcome = f"Key {key} pressed."
|
||||
elif action == "type":
|
||||
result += (
|
||||
f"pyautogui.typewrite(\"\"\"{text}\"\"\", interval=0.01)\n"
|
||||
)
|
||||
for char in text:
|
||||
if char == '\n':
|
||||
result += "pyautogui.press('enter')\n"
|
||||
elif char == "'":
|
||||
result += 'pyautogui.press("\'")\n'
|
||||
elif char == '\\':
|
||||
result += "pyautogui.press('\\\\')\n"
|
||||
elif char == '"':
|
||||
result += "pyautogui.press('\"')\n"
|
||||
else:
|
||||
result += f"pyautogui.press('{char}')\n"
|
||||
expected_outcome = f"Text {text} written."
|
||||
|
||||
# Handle scroll actions
|
||||
elif action == "scroll":
|
||||
if text is not None:
|
||||
result += (f"pyautogui.keyDown('{text.lower()}')\n")
|
||||
if coordinate is None:
|
||||
if scroll_direction in ("up", "down"):
|
||||
result += (
|
||||
|
|
@ -221,6 +264,8 @@ class AnthropicAgent:
|
|||
result += (
|
||||
f"pyautogui.hscroll({scroll_amount if scroll_direction == 'right' else -scroll_amount}, {x}, {y})\n"
|
||||
)
|
||||
if text is not None:
|
||||
result += (f"pyautogui.keyUp('{text.lower()}')\n")
|
||||
expected_outcome = "Scroll action finished"
|
||||
|
||||
# Handle click actions
|
||||
|
|
@ -285,7 +330,7 @@ class AnthropicAgent:
|
|||
expected_outcome = "Call user"
|
||||
elif action == "screenshot":
|
||||
result += "pyautogui.sleep(0.1)\n"
|
||||
expected_outcome = "Screenshot taken"
|
||||
expected_outcome = "Screenshot taken"
|
||||
else:
|
||||
raise ValueError(f"Invalid action: {action}")
|
||||
|
||||
|
|
@ -303,6 +348,9 @@ class AnthropicAgent:
|
|||
screenshot_bytes = obs["screenshot"]
|
||||
screenshot_image = Image.open(io.BytesIO(screenshot_bytes))
|
||||
|
||||
# Store original unresized screenshot for zoom processing
|
||||
obs["screenshot_original"] = screenshot_bytes
|
||||
|
||||
# Calculate new size based on resize factor
|
||||
new_width, new_height = 1280, 720
|
||||
|
||||
|
|
@ -334,23 +382,45 @@ class AnthropicAgent:
|
|||
]
|
||||
})
|
||||
|
||||
if self.messages and "tool_use" in [content_block["type"] for content_block in self.messages[-1]["content"]]:
|
||||
self.add_tool_result(
|
||||
self.messages[-1]["content"][-1]["id"],
|
||||
f"Success",
|
||||
screenshot=obs.get("screenshot") if obs else None
|
||||
)
|
||||
# Add tool_result for ALL tool_use blocks in the last message
|
||||
if self.messages:
|
||||
last_message_content = self.messages[-1]["content"]
|
||||
tool_use_blocks = [block for block in last_message_content if block.get("type") == "tool_use"]
|
||||
|
||||
for i, tool_block in enumerate(tool_use_blocks):
|
||||
tool_input = tool_block.get("input", {})
|
||||
action = tool_input.get("action")
|
||||
is_last_tool = i == len(tool_use_blocks) - 1
|
||||
|
||||
include_screenshot = None
|
||||
|
||||
if obs:
|
||||
if action == "screenshot":
|
||||
# Screenshot action always gets regular screenshot
|
||||
include_screenshot = obs.get("screenshot")
|
||||
elif is_last_tool:
|
||||
# Auto-screenshot: last tool gets regular screenshot (unless it's zoom, handled above)
|
||||
include_screenshot = obs.get("screenshot")
|
||||
|
||||
self.add_tool_result(
|
||||
tool_block["id"],
|
||||
f"Success",
|
||||
screenshot=include_screenshot
|
||||
)
|
||||
|
||||
enable_prompt_caching = False
|
||||
betas = ["computer-use-2025-01-24"]
|
||||
if self.model_name == "claude-3-7-sonnet-20250219" or self.model_name == "claude-4-opus-20250514" or self.model_name == "claude-4-sonnet-20250514":
|
||||
betas = ["computer-use-2025-01-24"]
|
||||
elif self.model_name == "claude-3-5-sonnet-20241022":
|
||||
betas = [COMPUTER_USE_BETA_FLAG]
|
||||
betas = [COMPUTER_USE_BETA_FLAG]
|
||||
|
||||
# Add interleaved thinking beta if ISP is requested
|
||||
if self.use_isp:
|
||||
betas.append("interleaved-thinking-2025-05-14")
|
||||
logger.info(f"Added interleaved thinking beta. Betas: {betas}")
|
||||
|
||||
image_truncation_threshold = 10
|
||||
if self.provider == APIProvider.ANTHROPIC:
|
||||
client = Anthropic(api_key=self.api_key, max_retries=4)
|
||||
client = Anthropic(api_key=self.api_key, max_retries=4).with_options(
|
||||
default_headers={"anthropic-beta": COMPUTER_USE_BETA_FLAG}
|
||||
)
|
||||
enable_prompt_caching = True
|
||||
elif self.provider == APIProvider.VERTEX:
|
||||
client = AnthropicVertex()
|
||||
|
|
@ -368,7 +438,7 @@ class AnthropicAgent:
|
|||
if enable_prompt_caching:
|
||||
betas.append(PROMPT_CACHING_BETA_FLAG)
|
||||
_inject_prompt_caching(self.messages)
|
||||
image_truncation_threshold = 50
|
||||
image_truncation_threshold = 20
|
||||
system["cache_control"] = {"type": "ephemeral"}
|
||||
|
||||
if self.only_n_most_recent_images:
|
||||
|
|
@ -378,49 +448,65 @@ class AnthropicAgent:
|
|||
min_removal_threshold=image_truncation_threshold,
|
||||
)
|
||||
|
||||
try:
|
||||
if self.model_name == "claude-3-5-sonnet-20241022":
|
||||
tools = [
|
||||
{'name': 'computer', 'type': 'computer_20241022', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1},
|
||||
# {'type': 'bash_20241022', 'name': 'bash'},
|
||||
# {'name': 'str_replace_editor', 'type': 'text_editor_20241022'}
|
||||
] if self.platform == 'Ubuntu' else [
|
||||
{'name': 'computer', 'type': 'computer_20241022', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1},
|
||||
]
|
||||
elif self.model_name in ["claude-3-7-sonnet-20250219", "claude-4-opus-20250514", "claude-4-sonnet-20250514"]:
|
||||
tools = [
|
||||
{'name': 'computer', 'type': 'computer_20250124', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1},
|
||||
# {'type': 'bash_20250124', 'name': 'bash'},
|
||||
# {'name': 'str_replace_editor', 'type': 'text_editor_20250124'}
|
||||
] if self.platform == 'Ubuntu' else [
|
||||
{'name': 'computer', 'type': 'computer_20250124', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1},
|
||||
]
|
||||
# Configure tool settings - use modern computer tool for all models
|
||||
tool_config = {
|
||||
'name': 'computer',
|
||||
'type': 'computer_20250124',
|
||||
'display_width_px': 1280,
|
||||
'display_height_px': 720,
|
||||
'display_number': 1
|
||||
}
|
||||
|
||||
tools = [
|
||||
tool_config,
|
||||
] if self.platform == 'Ubuntu' else [
|
||||
tool_config,
|
||||
]
|
||||
|
||||
# Configure thinking mode based on user preferences
|
||||
if self.no_thinking:
|
||||
# Disable thinking mode - omit the thinking parameter
|
||||
extra_body = {}
|
||||
actual_max_tokens = self.max_tokens # Use default when no thinking
|
||||
logger.info("Thinking mode: DISABLED")
|
||||
else:
|
||||
# Enable thinking mode (regular or interleaved)
|
||||
# Use consistent 2048 budget for both regular and ISP thinking
|
||||
budget_tokens = 2048
|
||||
|
||||
# For regular thinking: max_tokens > budget_tokens (API requirement)
|
||||
# For ISP: budget_tokens can exceed max_tokens (represents total across all thinking blocks)
|
||||
if self.max_tokens <= budget_tokens:
|
||||
required_max_tokens = budget_tokens + 500 # Give some headroom
|
||||
logger.warning(f"Regular thinking requires max_tokens > budget_tokens. Increasing max_tokens from {self.max_tokens} to {required_max_tokens}")
|
||||
actual_max_tokens = required_max_tokens
|
||||
else:
|
||||
actual_max_tokens = self.max_tokens
|
||||
|
||||
extra_body = {
|
||||
"thinking": {"type": "enabled", "budget_tokens": 1024}
|
||||
"thinking": {"type": "enabled", "budget_tokens": budget_tokens}
|
||||
}
|
||||
if self.use_isp:
|
||||
logger.info("Thinking mode: INTERLEAVED SCRATCHPAD (ISP)")
|
||||
else:
|
||||
logger.info("Thinking mode: REGULAR SCRATCHPAD")
|
||||
|
||||
try:
|
||||
response = None
|
||||
|
||||
for attempt in range(API_RETRY_TIMES):
|
||||
try:
|
||||
if self.model_name in ["claude-3-7-sonnet-20250219", "claude-4-opus-20250514", "claude-4-sonnet-20250514"]:
|
||||
response = client.beta.messages.create(
|
||||
max_tokens=self.max_tokens,
|
||||
messages=self.messages,
|
||||
model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name],
|
||||
system=[system],
|
||||
tools=tools,
|
||||
betas=betas,
|
||||
extra_body=extra_body
|
||||
)
|
||||
elif self.model_name == "claude-3-5-sonnet-20241022":
|
||||
response = client.beta.messages.create(
|
||||
max_tokens=self.max_tokens,
|
||||
messages=self.messages,
|
||||
model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name],
|
||||
system=[system],
|
||||
tools=tools,
|
||||
betas=betas,
|
||||
)
|
||||
response = client.beta.messages.create(
|
||||
max_tokens=actual_max_tokens,
|
||||
messages=self.messages,
|
||||
model=get_model_name(self.provider, self.model_name),
|
||||
system=[system],
|
||||
tools=tools,
|
||||
betas=betas,
|
||||
extra_body=extra_body,
|
||||
**self._get_sampling_params()
|
||||
)
|
||||
|
||||
logger.info(f"Response: {response}")
|
||||
break
|
||||
except (APIError, APIStatusError, APIResponseValidationError) as e:
|
||||
|
|
@ -450,26 +536,20 @@ class AnthropicAgent:
|
|||
try:
|
||||
logger.warning("Retrying with backup API key...")
|
||||
|
||||
backup_client = Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY_BACKUP"), max_retries=4)
|
||||
if self.model_name in ["claude-3-7-sonnet-20250219", "claude-4-opus-20250514", "claude-4-sonnet-20250514"]:
|
||||
response = backup_client.beta.messages.create(
|
||||
max_tokens=self.max_tokens,
|
||||
messages=self.messages,
|
||||
model=PROVIDER_TO_DEFAULT_MODEL_NAME[APIProvider.ANTHROPIC, self.model_name],
|
||||
system=[system],
|
||||
tools=tools,
|
||||
betas=betas,
|
||||
extra_body=extra_body
|
||||
)
|
||||
elif self.model_name == "claude-3-5-sonnet-20241022":
|
||||
response = backup_client.beta.messages.create(
|
||||
max_tokens=self.max_tokens,
|
||||
messages=self.messages,
|
||||
model=PROVIDER_TO_DEFAULT_MODEL_NAME[APIProvider.ANTHROPIC, self.model_name],
|
||||
system=[system],
|
||||
tools=tools,
|
||||
betas=betas,
|
||||
)
|
||||
backup_client = Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY_BACKUP"), max_retries=4).with_options(
|
||||
default_headers={"anthropic-beta": COMPUTER_USE_BETA_FLAG}
|
||||
)
|
||||
response = backup_client.beta.messages.create(
|
||||
max_tokens=actual_max_tokens,
|
||||
messages=self.messages,
|
||||
model=get_model_name(self.provider, self.model_name),
|
||||
system=[system],
|
||||
tools=tools,
|
||||
betas=betas,
|
||||
extra_body=extra_body,
|
||||
**self._get_sampling_params()
|
||||
)
|
||||
|
||||
logger.info("Successfully used backup API key")
|
||||
except Exception as backup_e:
|
||||
backup_error_msg = str(backup_e)
|
||||
|
|
@ -497,9 +577,16 @@ class AnthropicAgent:
|
|||
logger.exception(f"Error in Anthropic API: {str(e)}")
|
||||
return None, None
|
||||
|
||||
if response is None:
|
||||
logger.error("Response is None after API call - this should not happen")
|
||||
return None, None
|
||||
|
||||
response_params = _response_to_params(response)
|
||||
logger.info(f"Received response params: {response_params}")
|
||||
|
||||
# Convert raw response to concatenated string for trajectory logging
|
||||
raw_response_str = self._extract_raw_response_string(response)
|
||||
|
||||
# Store response in message history
|
||||
self.messages.append({
|
||||
"role": "assistant",
|
||||
|
|
@ -518,7 +605,8 @@ class AnthropicAgent:
|
|||
"input": cast(dict[str, Any], content_block["input"]),
|
||||
"id": content_block["id"],
|
||||
"action_type": content_block.get("type"),
|
||||
"command": self.parse_actions_from_tool_call(content_block)
|
||||
"command": self.parse_actions_from_tool_call(content_block),
|
||||
"raw_response": raw_response_str # Add raw response to each action
|
||||
})
|
||||
elif content_block["type"] == "text":
|
||||
reasonings.append(content_block["text"])
|
||||
|
|
@ -526,10 +614,23 @@ class AnthropicAgent:
|
|||
reasonings = reasonings[0]
|
||||
else:
|
||||
reasonings = ""
|
||||
|
||||
# Check if the model indicated the task is infeasible
|
||||
if raw_response_str and "[INFEASIBLE]" in raw_response_str:
|
||||
logger.info("Detected [INFEASIBLE] pattern in response, triggering FAIL action")
|
||||
# Override actions with FAIL
|
||||
actions = [{
|
||||
"action_type": "FAIL",
|
||||
"raw_response": raw_response_str
|
||||
}]
|
||||
|
||||
logger.info(f"Received actions: {actions}")
|
||||
logger.info(f"Received reasonings: {reasonings}")
|
||||
if len(actions) == 0:
|
||||
actions = ["DONE"]
|
||||
actions = [{
|
||||
"action_type": "DONE",
|
||||
"raw_response": raw_response_str
|
||||
}]
|
||||
return reasonings, actions
|
||||
except Exception as e:
|
||||
logger.warning(f"parse_actions_from_tool_call parsing failed (attempt {parse_retry+1}/3), will retry API request: {e}")
|
||||
|
|
@ -539,25 +640,17 @@ class AnthropicAgent:
|
|||
response = None
|
||||
for attempt in range(API_RETRY_TIMES):
|
||||
try:
|
||||
if self.model_name in ["claude-3-7-sonnet-20250219", "claude-4-opus-20250514", "claude-4-sonnet-20250514"]:
|
||||
response = client.beta.messages.create(
|
||||
max_tokens=self.max_tokens,
|
||||
messages=self.messages,
|
||||
model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name],
|
||||
system=[system],
|
||||
tools=tools,
|
||||
betas=betas,
|
||||
extra_body=extra_body
|
||||
)
|
||||
elif self.model_name == "claude-3-5-sonnet-20241022":
|
||||
response = client.beta.messages.create(
|
||||
max_tokens=self.max_tokens,
|
||||
messages=self.messages,
|
||||
model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name],
|
||||
system=[system],
|
||||
tools=tools,
|
||||
betas=betas,
|
||||
)
|
||||
response = client.beta.messages.create(
|
||||
max_tokens=actual_max_tokens,
|
||||
messages=self.messages,
|
||||
model=get_model_name(self.provider, self.model_name),
|
||||
system=[system],
|
||||
tools=tools,
|
||||
betas=betas,
|
||||
extra_body=extra_body,
|
||||
**self._get_sampling_params()
|
||||
)
|
||||
|
||||
logger.info(f"Response: {response}")
|
||||
break # Success, exit retry loop
|
||||
except (APIError, APIStatusError, APIResponseValidationError) as e2:
|
||||
|
|
@ -569,13 +662,20 @@ class AnthropicAgent:
|
|||
raise
|
||||
response_params = _response_to_params(response)
|
||||
logger.info(f"Received response params: {response_params}")
|
||||
|
||||
# Update raw response string for retry case (will be used in next loop iteration)
|
||||
raw_response_str = self._extract_raw_response_string(response)
|
||||
|
||||
self.messages.append({
|
||||
"role": "assistant",
|
||||
"content": response_params
|
||||
})
|
||||
if parse_retry == max_parse_retry - 1:
|
||||
logger.error(f"parse_actions_from_tool_call parsing failed 3 times consecutively, terminating: {e}")
|
||||
actions = ["FAIL"]
|
||||
actions = [{
|
||||
"action_type": "FAIL",
|
||||
"raw_response": f"Failed to parse actions from tool call after {max_parse_retry} attempts: {e}"
|
||||
}]
|
||||
return reasonings, actions
|
||||
def reset(self, _logger = None, *args, **kwargs):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ from datetime import datetime
|
|||
from .tools import ToolResult
|
||||
|
||||
|
||||
COMPUTER_USE_BETA_FLAG = "computer-use-2024-10-22"
|
||||
COMPUTER_USE_BETA_FLAG = "computer-use-2025-01-24"
|
||||
PROMPT_CACHING_BETA_FLAG = "prompt-caching-2024-07-31"
|
||||
|
||||
|
||||
|
|
@ -47,12 +47,25 @@ PROVIDER_TO_DEFAULT_MODEL_NAME: dict[(APIProvider, str), str] = {
|
|||
(APIProvider.ANTHROPIC, "claude-4-opus-20250514"): "claude-4-opus-20250514",
|
||||
(APIProvider.BEDROCK, "claude-4-opus-20250514"): "us.anthropic.claude-opus-4-20250514-v1:0",
|
||||
(APIProvider.VERTEX, "claude-4-opus-20250514"): "claude-4-opus-v1@20250514",
|
||||
# Add mapping for the alternative model name format
|
||||
(APIProvider.ANTHROPIC, "claude-opus-4-20250514"): "claude-opus-4-20250514",
|
||||
(APIProvider.ANTHROPIC, "claude-opus-4-1-20250805"): "claude-opus-4-1-20250805",
|
||||
(APIProvider.ANTHROPIC, "claude-4-sonnet-20250514"): "claude-4-sonnet-20250514",
|
||||
(APIProvider.ANTHROPIC, "claude-sonnet-4-20250514"): "claude-sonnet-4-20250514",
|
||||
(APIProvider.BEDROCK, "claude-4-sonnet-20250514"): "us.anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
(APIProvider.VERTEX, "claude-4-sonnet-20250514"): "claude-sonnet-4-v1@20250514",
|
||||
}
|
||||
|
||||
|
||||
def get_model_name(provider: APIProvider, model_name: str) -> str:
|
||||
"""
|
||||
Get the actual model name to use for API calls.
|
||||
|
||||
Simply returns the model name as-is for direct API usage.
|
||||
"""
|
||||
return model_name
|
||||
|
||||
|
||||
# This system prompt is optimized for the Docker environment in this repository and
|
||||
# specific tool combinations enabled.
|
||||
# We encourage modifying this system prompt to ensure the model has context for the
|
||||
|
|
@ -67,8 +80,15 @@ SYSTEM_PROMPT = f"""<SYSTEM_CAPABILITY>
|
|||
* When viewing a page it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available.
|
||||
* DO NOT ask users for clarification during task execution. DO NOT stop to request more information from users. Always take action using available tools.
|
||||
* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request.
|
||||
* TASK FEASIBILITY: You can declare a task infeasible at any point during execution - whether at the beginning after taking a screenshot, or later after attempting some actions and discovering barriers. Carefully evaluate whether the task is feasible given the current system state, available applications, and task requirements. If you determine that a task cannot be completed due to:
|
||||
- Missing required applications or dependencies that cannot be installed
|
||||
- Insufficient permissions or system limitations
|
||||
- Contradictory or impossible requirements
|
||||
- Any other fundamental barriers that make completion impossible
|
||||
Then you MUST output exactly "[INFEASIBLE]" (including the square brackets) anywhere in your response to trigger the fail action. The system will automatically detect this pattern and terminate the task appropriately.
|
||||
* The current date is {datetime.today().strftime('%A, %B %d, %Y')}.
|
||||
* Home directory of this Ubuntu system is '/home/user'.
|
||||
* If you need a password for sudo, the password of the computer is 'osworld-public-evaluation'.
|
||||
</SYSTEM_CAPABILITY>
|
||||
|
||||
<IMPORTANT>
|
||||
|
|
@ -82,6 +102,7 @@ SYSTEM_PROMPT_WINDOWS = f"""<SYSTEM_CAPABILITY>
|
|||
* The current date is {datetime.today().strftime('%A, %B %d, %Y')}.
|
||||
* Home directory of this Windows system is 'C:\\Users\\user'.
|
||||
* When you want to open some applications on Windows, please use Double Click on it instead of clicking once.
|
||||
* If you need a password for sudo, The password of the computer is 'osworld-public-evaluation'.
|
||||
</SYSTEM_CAPABILITY>"""
|
||||
|
||||
|
||||
|
|
@ -154,21 +175,30 @@ def _inject_prompt_caching(
|
|||
one cache breakpoint is left for tools/system prompt, to be shared across sessions
|
||||
"""
|
||||
|
||||
breakpoints_remaining = 3
|
||||
breakpoints_remaining = 2 # Use full budget for recent messages
|
||||
messages_processed = 0
|
||||
|
||||
for message in reversed(messages):
|
||||
if message["role"] == "user" and isinstance(
|
||||
content := message["content"], list
|
||||
):
|
||||
if breakpoints_remaining:
|
||||
breakpoints_remaining -= 1
|
||||
messages_processed += 1
|
||||
# Check if this message would fit within the remaining budget
|
||||
if breakpoints_remaining >= len(content):
|
||||
# We have enough budget, spend it and add cache_control
|
||||
breakpoints_remaining -= len(content)
|
||||
# Use type ignore to bypass TypedDict check until SDK types are updated
|
||||
content[-1]["cache_control"] = BetaCacheControlEphemeralParam( # type: ignore
|
||||
{"type": "ephemeral"}
|
||||
)
|
||||
else:
|
||||
content[-1].pop("cache_control", None)
|
||||
# we'll only every have one extra turn per loop
|
||||
break
|
||||
# Check if this is the first message (contains image + text with task description)
|
||||
is_first_message = messages_processed == len([msg for msg in messages if msg["role"] == "user"])
|
||||
|
||||
if not is_first_message:
|
||||
# Not enough budget, remove any existing cache_control from this message
|
||||
content[-1].pop("cache_control", None)
|
||||
# Continue to clean up older messages that might have cache_control from previous turns
|
||||
|
||||
|
||||
def _maybe_filter_to_n_most_recent_images(
|
||||
|
|
@ -220,6 +250,105 @@ def _maybe_filter_to_n_most_recent_images(
|
|||
tool_result["content"] = new_content
|
||||
|
||||
|
||||
def validate_model_support(model_name: str, api_key: str = None, temperature: float = None, top_p: float = None, no_thinking: bool = False, use_isp: bool = False) -> bool:
|
||||
"""
|
||||
Validate model support with the same API call pattern as the main agent.
|
||||
|
||||
Args:
|
||||
model_name: The model name to validate
|
||||
api_key: Optional API key, defaults to ANTHROPIC_API_KEY env var
|
||||
temperature: Optional temperature parameter for testing
|
||||
top_p: Optional top_p parameter for testing
|
||||
no_thinking: Disable thinking mode (matches AnthropicAgent)
|
||||
use_isp: Use interleaved scratchpad mode (matches AnthropicAgent)
|
||||
|
||||
Returns:
|
||||
True if model is supported and API call succeeds, False otherwise
|
||||
"""
|
||||
print(f"🔍 Validating model support: {model_name}")
|
||||
|
||||
try:
|
||||
from anthropic import Anthropic
|
||||
import os
|
||||
import time
|
||||
|
||||
# Same client setup as main agent but with manual retry (max_retries=1 for faster feedback)
|
||||
client = Anthropic(
|
||||
api_key=api_key or os.environ.get("ANTHROPIC_API_KEY"),
|
||||
max_retries=4
|
||||
).with_options(default_headers={"anthropic-beta": COMPUTER_USE_BETA_FLAG})
|
||||
|
||||
# Same message format as main agent - always use structured format with cache_control
|
||||
messages = [{"role": "user", "content": [{"type": "text", "text": "Respond with 'OK'", "cache_control": {"type": "ephemeral"}}]}]
|
||||
|
||||
# Same betas configuration as main agent
|
||||
betas = [COMPUTER_USE_BETA_FLAG]
|
||||
if use_isp:
|
||||
betas.append("interleaved-thinking-2025-05-14")
|
||||
|
||||
system = [{"type": "text", "text": "You are Claude. Respond with 'OK'."}]
|
||||
|
||||
# Same tools configuration as main agent - use modern computer tool for all models
|
||||
tools = [{"name": "computer", "type": "computer_20250124",
|
||||
"display_width_px": 1280, "display_height_px": 720, "display_number": 1}]
|
||||
|
||||
# Same thinking configuration as main agent
|
||||
max_tokens = 50 # Base validation max_tokens
|
||||
if no_thinking:
|
||||
extra_body = {}
|
||||
actual_max_tokens = max_tokens
|
||||
else:
|
||||
budget_tokens = 2048
|
||||
# Same logic as main agent: if max_tokens <= budget_tokens, increase it
|
||||
if max_tokens <= budget_tokens:
|
||||
actual_max_tokens = budget_tokens + 500
|
||||
else:
|
||||
actual_max_tokens = max_tokens
|
||||
extra_body = {
|
||||
"thinking": {"type": "enabled", "budget_tokens": budget_tokens}
|
||||
}
|
||||
|
||||
# Sampling parameters (same logic as main agent)
|
||||
sampling_params = {}
|
||||
if temperature is not None:
|
||||
sampling_params['temperature'] = temperature
|
||||
if top_p is not None:
|
||||
sampling_params['top_p'] = top_p
|
||||
|
||||
# Retry logic with 5 attempts, 5 second delays
|
||||
for attempt in range(5):
|
||||
try:
|
||||
# Same API call pattern as main agent
|
||||
client.beta.messages.create(
|
||||
max_tokens=actual_max_tokens,
|
||||
messages=messages,
|
||||
model=get_model_name(APIProvider.ANTHROPIC, model_name),
|
||||
system=system,
|
||||
tools=tools,
|
||||
betas=betas,
|
||||
extra_body=extra_body,
|
||||
**sampling_params
|
||||
)
|
||||
|
||||
print(f"✅ Model {model_name} validated successfully")
|
||||
return True
|
||||
except Exception as e:
|
||||
if attempt < 4: # Don't print error on final attempt
|
||||
print(f"🔄 Validation attempt {attempt + 1}/5 failed: {e}")
|
||||
print(f"⏳ Retrying in 5 seconds...")
|
||||
time.sleep(5)
|
||||
else:
|
||||
print(f"❌ All validation attempts failed. Final error: {e}")
|
||||
|
||||
return False
|
||||
|
||||
except ValueError:
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"❌ API validation setup failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _response_to_params(
|
||||
response: BetaMessage,
|
||||
) -> list[BetaContentBlockParam]:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,161 @@
|
|||
COMPUTER_USE_PROMPT = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
|
||||
|
||||
## Output Format
|
||||
```
|
||||
Thought: ...
|
||||
Action: ...
|
||||
```
|
||||
|
||||
## Action Space
|
||||
|
||||
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
|
||||
hotkey(key='')
|
||||
type(content='') #If you want to submit your input, use "\\n" at the end of `content`.
|
||||
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
|
||||
wait() #Sleep for 5s and take a screenshot to check for any changes.
|
||||
finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format.
|
||||
|
||||
## Note
|
||||
- Use {language} in `Thought` part.
|
||||
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
|
||||
- My computer's password is 'password', feel free to use it when you need sudo rights.
|
||||
|
||||
## User Instruction
|
||||
{instruction}
|
||||
"""
|
||||
|
||||
COMPUTER_USE_PROMPT_WITH_CALL_USER = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
|
||||
|
||||
## Output Format
|
||||
```
|
||||
Thought: ...
|
||||
Action: ...
|
||||
```
|
||||
|
||||
## Action Space
|
||||
|
||||
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
|
||||
hotkey(key='')
|
||||
type(content='') #If you want to submit your input, use "\\n" at the end of `content`.
|
||||
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
|
||||
wait() #Sleep for 5s and take a screenshot to check for any changes.
|
||||
finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format.
|
||||
call_user() # Submit the task and call the user when the task is unsolvable, or when you need the user's help.
|
||||
|
||||
## Note
|
||||
- Use {language} in `Thought` part.
|
||||
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
|
||||
- My computer's password is 'password', feel free to use it when you need sudo rights.
|
||||
|
||||
## User Instruction
|
||||
{instruction}
|
||||
"""
|
||||
|
||||
UITARS_ACTION_SPACE = """
|
||||
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
|
||||
hotkey(key='')
|
||||
type(content='') #If you want to submit your input, use "\\n" at the end of `content`.
|
||||
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
|
||||
wait() #Sleep for 5s and take a screenshot to check for any changes.
|
||||
finished()
|
||||
"""
|
||||
|
||||
UITARS_CALL_USR_ACTION_SPACE = """
|
||||
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
|
||||
hotkey(key='')
|
||||
type(content='') #If you want to submit your input, use "\\n" at the end of `content`.
|
||||
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
|
||||
wait() #Sleep for 5s and take a screenshot to check for any changes.
|
||||
finished()
|
||||
call_user() # Submit the task and call the user when the task is unsolvable, or when you need the user's help.
|
||||
"""
|
||||
|
||||
UITARS_NORMAL_ACTION_SPACE = """
|
||||
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
|
||||
hotkey(key='')
|
||||
type(content='') #If you want to submit your input, use "\\n" at the end of `content`.
|
||||
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
|
||||
wait() #Sleep for 5s and take a screenshot to check for any changes.
|
||||
finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format.
|
||||
"""
|
||||
|
||||
UITARS_USR_PROMPT_NOTHOUGHT = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
|
||||
## Output Format
|
||||
```
|
||||
Action: ...
|
||||
```
|
||||
## Action Space
|
||||
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
|
||||
hotkey(key='')
|
||||
type(content='') #If you want to submit your input, use "\\n" at the end of `content`.
|
||||
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
|
||||
wait() #Sleep for 5s and take a screenshot to check for any changes.
|
||||
finished()
|
||||
call_user() # Submit the task and call the user when the task is unsolvable, or when you need the user's help.
|
||||
## User Instruction
|
||||
{instruction}
|
||||
"""
|
||||
|
||||
UITARS_USR_PROMPT_THOUGHT = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
|
||||
|
||||
## Output Format
|
||||
```
|
||||
Thought: ...
|
||||
Action: ...
|
||||
```
|
||||
|
||||
## Action Space
|
||||
{action_space}
|
||||
|
||||
## Note
|
||||
- Use {language} in `Thought` part.
|
||||
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
|
||||
|
||||
## User Instruction
|
||||
{instruction}
|
||||
"""
|
||||
|
||||
|
||||
FAILURE_INDICATORS = [
|
||||
# Direct inability expressions
|
||||
"无法", "不能", "不可以", "做不到", "实现不了", "完成不了","没法",
|
||||
|
||||
# Regret/apology expressions
|
||||
"遗憾", "抱歉", "很抱歉", "非常抱歉", "对不起",
|
||||
|
||||
# Not supported/available
|
||||
"不直接支持", "不支持", "不提供", "不具备", "没有权限", "权限不足", "不在这里面","不符合",#"不存在",
|
||||
|
||||
# Cannot access/handle
|
||||
"无权访问", "访问不了", "处理不了", "操作不了", "执行不了", "没找到", "空空如也",
|
||||
|
||||
# Not possible/feasible
|
||||
"不可能", "无法实现", "实现不了", "办不到", "做不了","找不到","存在技术限制","没有找到","没有内置",
|
||||
|
||||
# System limitations
|
||||
"超出范围", "不在我的能力范围", "能力有限", "功能限制","没有成功","没成功","硬件的问题",
|
||||
|
||||
# Refusal indicators
|
||||
"拒绝", "不允许", "禁止", "不合适", "不恰当",
|
||||
|
||||
# Trying Restart
|
||||
"从头开始", "藏在", "浪费时间","一个更合理的思路","正确的方向","没有意义",#, "重新","重启",
|
||||
]
|
||||
|
|
@ -0,0 +1,202 @@
|
|||
import asyncio
|
||||
from typing import List, Optional, Union, Dict, Any
|
||||
import json
|
||||
import os
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from omegaconf import DictConfig
|
||||
from dataclasses import dataclass, asdict
|
||||
import copy
|
||||
import logging
|
||||
import random
|
||||
|
||||
from prompts import COMPUTER_USE_PROMPT, COMPUTER_USE_PROMPT_WITH_CALL_USER
|
||||
from log_config import setup_logging
|
||||
|
||||
# 设置统一的日志系统
|
||||
setup_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TaskLoader:
|
||||
def __init__(self, task_cfg: DictConfig, storage_root):
|
||||
self.task_file = Path(task_cfg.task_file)
|
||||
#self.task_root = Path(task_cfg.task_root)
|
||||
self.osworld_root = Path(task_cfg.osworld_root)
|
||||
|
||||
self._latest_sha: Optional[str] = None
|
||||
self.storage_root = storage_root
|
||||
self.resume = task_cfg.resume
|
||||
|
||||
def poll_for_tasks(self) -> List[Dict]:
|
||||
"""find new tasks json file
|
||||
return list of TaskInfo dict if there is new json
|
||||
else return []
|
||||
"""
|
||||
self._maybe_refresh_dataset()
|
||||
|
||||
tasks_list = [task.to_dict() for task in self._tasks]
|
||||
random.shuffle(tasks_list)
|
||||
|
||||
return tasks_list
|
||||
|
||||
def _maybe_refresh_dataset_bak(self):
|
||||
|
||||
# check new json
|
||||
latest_json = self._find_latest_json()
|
||||
|
||||
if latest_json is None:
|
||||
return False # no json file
|
||||
|
||||
sha = self._calc_sha1(latest_json)
|
||||
if sha == self._latest_sha:
|
||||
return False # no change
|
||||
|
||||
with open(latest_json) as f:
|
||||
data = json.load(f)
|
||||
|
||||
raw_tasks = [
|
||||
{"task_type": task_type, "task_id": task_id}
|
||||
for task_type, task_ids in data.items()
|
||||
for task_id in task_ids
|
||||
]
|
||||
|
||||
self._tasks = [build_task(raw, self.osworld_root) for raw in raw_tasks]
|
||||
self._latest_sha = sha
|
||||
|
||||
logger.info(f"当前任务文件: {str(latest_json)}")
|
||||
logger.info(f"任务总数: {len(raw_tasks)}")
|
||||
|
||||
return True
|
||||
|
||||
def _maybe_refresh_dataset(self):
|
||||
|
||||
latest_json = self.task_file
|
||||
print("Current tasks file: ", str(latest_json))
|
||||
|
||||
with open(latest_json) as f:
|
||||
data = json.load(f)
|
||||
|
||||
raw_tasks = [
|
||||
{"task_type": task_type, "task_id": task_id}
|
||||
for task_type, task_ids in data.items()
|
||||
for task_id in task_ids
|
||||
]
|
||||
|
||||
if self.resume:
|
||||
# 过滤已完成或类型不匹配的任务
|
||||
filtered_tasks = []
|
||||
storage_root = Path(self.storage_root)
|
||||
|
||||
for raw in raw_tasks:
|
||||
task_id = str(raw["task_id"])
|
||||
task_type_expected = raw["task_type"]
|
||||
|
||||
# 找到所有以 task_id 开头的子目录(允许有多个版本)
|
||||
candidate_dirs = [
|
||||
d for d in storage_root.iterdir()
|
||||
if d.is_dir() and d.name.startswith(task_id)
|
||||
]
|
||||
|
||||
# 默认认为任务未完成
|
||||
task_finished = False
|
||||
|
||||
for d in candidate_dirs:
|
||||
cfg_path = d / "task_config.json"
|
||||
if not cfg_path.exists():
|
||||
print("找不到config文件")
|
||||
continue
|
||||
|
||||
try:
|
||||
with cfg_path.open("r", encoding="utf-8") as cf:
|
||||
cfg = json.load(cf)
|
||||
except Exception:
|
||||
print("配置损坏,忽略此目录")
|
||||
continue
|
||||
|
||||
# 3.1 task_type 不同 => 不是同一个任务,直接跳过这目录
|
||||
if cfg.get("raw", {}).get("task_type") != task_type_expected:
|
||||
continue
|
||||
|
||||
# 3.2 task_type 相同,检查 reward.txt
|
||||
if (d / "reward.txt").exists():
|
||||
task_finished = True
|
||||
break # 已找到完成记录,无需再看其他目录
|
||||
if not task_finished:
|
||||
filtered_tasks.append(raw)
|
||||
self._tasks = [build_task(raw, self.osworld_root) for raw in filtered_tasks]
|
||||
print(f"Total number of tasks: {len(raw_tasks)}, Remained:{len(filtered_tasks)}")
|
||||
|
||||
else:
|
||||
self._tasks = [build_task(raw, self.osworld_root) for raw in raw_tasks]
|
||||
print(f"Total number of tasks: {len(raw_tasks)}")
|
||||
|
||||
return True
|
||||
|
||||
def _find_latest_json(self) -> Optional[Path]:
|
||||
files = list(self.task_root.glob("*.json"))
|
||||
return max(files, key=lambda p: p.stat().st_mtime) if files else None
|
||||
|
||||
@staticmethod
|
||||
def _calc_sha1(fp: Path, chunk_size=2<<20) -> str:
|
||||
h = hashlib.sha1()
|
||||
with fp.open("rb") as f:
|
||||
for chunk in iter(lambda: f.read(chunk_size), b""):
|
||||
h.update(chunk)
|
||||
return h.hexdigest()
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskInfo:
|
||||
messages: List
|
||||
instruction: str
|
||||
task_config: Dict
|
||||
|
||||
def to_dict(self):
|
||||
return asdict(self)
|
||||
|
||||
|
||||
def build_task(raw: Dict, osworld_root: Path, use_call_user: bool = False) -> TaskInfo:
|
||||
|
||||
task_type = raw["task_type"]
|
||||
task_id = raw["task_id"]
|
||||
task_path = os.path.join(osworld_root, task_type, task_id + ".json")
|
||||
with open(task_path) as f:
|
||||
task_data = json.load(f)
|
||||
|
||||
task_data["raw"] = {
|
||||
"task_type": task_type,
|
||||
"task_id": task_id
|
||||
}
|
||||
|
||||
instruction = task_data["instruction"]
|
||||
|
||||
if "human-ground-truth" in task_data and "single-action" in task_data["human-ground-truth"]:
|
||||
plan = task_data["human-ground-truth"]["single-action"]
|
||||
plan_text = "\n".join(plan)
|
||||
instruction = instruction.strip() + "\nHere is an instruction to help you complete the task: \n" + plan_text
|
||||
|
||||
system_prompt = COMPUTER_USE_PROMPT if not use_call_user else COMPUTER_USE_PROMPT_WITH_CALL_USER
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": system_prompt.format(
|
||||
instruction=instruction,
|
||||
language="English"
|
||||
)}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
return TaskInfo(
|
||||
messages = messages,
|
||||
instruction = instruction,
|
||||
task_config = task_data
|
||||
)
|
||||
|
|
@ -0,0 +1,511 @@
|
|||
import ast
|
||||
import base64
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
import xml.etree.ElementTree as ET
|
||||
from io import BytesIO
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
import openai
|
||||
|
||||
from openai import OpenAI
|
||||
from PIL import Image
|
||||
from requests.exceptions import SSLError
|
||||
from mm_agents.dart_gui.prompts import FAILURE_INDICATORS
|
||||
|
||||
# 设置日志系统
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
FINISH_WORD = "finished"
|
||||
WAIT_WORD = "wait"
|
||||
ENV_FAIL_WORD = "error_env"
|
||||
CALL_USER = "call_user"
|
||||
|
||||
IMAGE_FACTOR = 28
|
||||
MIN_PIXELS = 100 * 28 * 28
|
||||
MAX_PIXELS = 16384 * 28 * 28
|
||||
MAX_RATIO = 200
|
||||
|
||||
pure_text_settings = ["a11y_tree"]
|
||||
|
||||
attributes_ns_ubuntu = "https://accessibility.windows.example.org/ns/attributes"
|
||||
attributes_ns_windows = "https://accessibility.windows.example.org/ns/attributes"
|
||||
state_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/state"
|
||||
state_ns_windows = "https://accessibility.windows.example.org/ns/state"
|
||||
component_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/component"
|
||||
component_ns_windows = "https://accessibility.windows.example.org/ns/component"
|
||||
value_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/value"
|
||||
value_ns_windows = "https://accessibility.windows.example.org/ns/value"
|
||||
class_ns_windows = "https://accessibility.windows.example.org/ns/class"
|
||||
# More namespaces defined in OSWorld, please check desktop_env/server/main.py
|
||||
|
||||
# 定义一个函数来解析每个 action
|
||||
def parse_action(action_str):
|
||||
try:
|
||||
# 解析字符串为 AST 节点
|
||||
node = ast.parse(action_str, mode='eval')
|
||||
|
||||
# 确保节点是一个表达式
|
||||
if not isinstance(node, ast.Expression):
|
||||
raise ValueError("Not an expression")
|
||||
|
||||
# 获取表达式的主体
|
||||
call = node.body
|
||||
|
||||
# 确保主体是一个函数调用
|
||||
if not isinstance(call, ast.Call):
|
||||
raise ValueError("Not a function call")
|
||||
|
||||
# 获取函数名
|
||||
if isinstance(call.func, ast.Name):
|
||||
func_name = call.func.id
|
||||
elif isinstance(call.func, ast.Attribute):
|
||||
func_name = call.func.attr
|
||||
else:
|
||||
func_name = None
|
||||
|
||||
# 获取关键字参数
|
||||
kwargs = {}
|
||||
for kw in call.keywords:
|
||||
key = kw.arg
|
||||
# 处理不同类型的值,这里假设都是常量
|
||||
if isinstance(kw.value, ast.Constant):
|
||||
value = kw.value.value
|
||||
elif isinstance(kw.value, ast.Str): # 兼容旧版本 Python
|
||||
value = kw.value.s
|
||||
else:
|
||||
value = None
|
||||
kwargs[key] = value
|
||||
|
||||
return {
|
||||
'function': func_name,
|
||||
'args': kwargs
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse action '{action_str}': {e}")
|
||||
return None
|
||||
|
||||
def escape_single_quotes(text):
|
||||
# 匹配未转义的单引号(不匹配 \\')
|
||||
pattern = r"(?<!\\)'"
|
||||
return re.sub(pattern, r"\\'", text)
|
||||
|
||||
def round_by_factor(number: int, factor: int) -> int:
|
||||
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
||||
return round(number / factor) * factor
|
||||
|
||||
|
||||
def ceil_by_factor(number: int, factor: int) -> int:
|
||||
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.ceil(number / factor) * factor
|
||||
|
||||
|
||||
def floor_by_factor(number: int, factor: int) -> int:
|
||||
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.floor(number / factor) * factor
|
||||
|
||||
def linear_resize(
|
||||
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
|
||||
) -> tuple[int, int]:
|
||||
if width * height > max_pixels:
|
||||
"""
|
||||
如果图片超过/低于像素限制,则计算一个缩放因子resize_factor,使图片的像素数缩小到等于或小于max_pixels。这个缩放因子是通过开平方根计算的,确保纵横比保持不变,这样原始的相对坐标可以不经转换直接复用
|
||||
"""
|
||||
resize_factor = math.sqrt(max_pixels / (width * height))
|
||||
width, height = int(width * resize_factor), int(height * resize_factor)
|
||||
if width * height < min_pixels:
|
||||
resize_factor = math.sqrt(min_pixels / (width * height))
|
||||
width, height = math.ceil(width * resize_factor), math.ceil(height * resize_factor)
|
||||
|
||||
return height, width
|
||||
|
||||
def smart_resize(
|
||||
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Rescales the image so that the following conditions are met:
|
||||
|
||||
1. Both dimensions (height and width) are divisible by 'factor'.
|
||||
|
||||
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
||||
|
||||
3. The aspect ratio of the image is maintained as closely as possible.
|
||||
"""
|
||||
if max(height, width) / min(height, width) > MAX_RATIO:
|
||||
raise ValueError(
|
||||
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
|
||||
)
|
||||
h_bar = max(factor, round_by_factor(height, factor))
|
||||
w_bar = max(factor, round_by_factor(width, factor))
|
||||
if h_bar * w_bar > max_pixels:
|
||||
beta = math.sqrt((height * width) / max_pixels)
|
||||
h_bar = floor_by_factor(height / beta, factor)
|
||||
w_bar = floor_by_factor(width / beta, factor)
|
||||
elif h_bar * w_bar < min_pixels:
|
||||
beta = math.sqrt(min_pixels / (height * width))
|
||||
h_bar = ceil_by_factor(height * beta, factor)
|
||||
w_bar = ceil_by_factor(width * beta, factor)
|
||||
return h_bar, w_bar
|
||||
|
||||
def parse_action_to_structure_output(text, factor, origin_resized_height, origin_resized_width, model_type, max_pixels=16384*28*28, min_pixels=100*28*28):
|
||||
text = text.strip()
|
||||
if model_type == "qwen25vl":
|
||||
smart_resize_height, smart_resize_width = smart_resize(origin_resized_height, origin_resized_width, factor=IMAGE_FACTOR, min_pixels=min_pixels, max_pixels=max_pixels)
|
||||
|
||||
# 正则表达式匹配 Action 字符串
|
||||
if text.startswith("Thought:"):
|
||||
thought_pattern = r"Thought: (.+?)(?=\s*Action:|$)"
|
||||
thought_hint = "Thought: "
|
||||
elif text.startswith("Reflection:"):
|
||||
thought_pattern = r"Reflection: (.+?)Action_Summary: (.+?)(?=\s*Action:|$)"
|
||||
thought_hint = "Reflection: "
|
||||
elif text.startswith("Action_Summary:"):
|
||||
thought_pattern = r"Action_Summary: (.+?)(?=\s*Action:|$)"
|
||||
thought_hint = "Action_Summary: "
|
||||
else:
|
||||
thought_pattern = r"Thought: (.+?)(?=\s*Action:|$)"
|
||||
thought_hint = "Thought: "
|
||||
reflection, thought = None, None
|
||||
thought_match = re.search(thought_pattern, text, re.DOTALL)
|
||||
if thought_match:
|
||||
if len(thought_match.groups()) == 1:
|
||||
thought = thought_match.group(1).strip()
|
||||
elif len(thought_match.groups()) == 2:
|
||||
thought = thought_match.group(2).strip()
|
||||
reflection = thought_match.group(1).strip()
|
||||
assert "Action:" in text
|
||||
action_str = text.split("Action:")[-1]
|
||||
|
||||
tmp_all_action = action_str.split("\n\n")
|
||||
all_action = []
|
||||
for action_str in tmp_all_action:
|
||||
if "type(content" in action_str:
|
||||
# 正则表达式匹配 content 中的字符串并转义单引号
|
||||
def escape_quotes(match):
|
||||
content = match.group(1) # 获取 content 的值
|
||||
return content
|
||||
|
||||
# 使用正则表达式进行替换
|
||||
pattern = r"type\(content='(.*?)'\)" # 匹配 type(content='...')
|
||||
content = re.sub(pattern, escape_quotes, action_str)
|
||||
|
||||
# 处理字符串
|
||||
action_str = escape_single_quotes(content)
|
||||
action_str = "type(content='" + action_str + "')"
|
||||
|
||||
if "finished(content" in action_str:
|
||||
# 正则表达式匹配 content 中的字符串并转义单引号
|
||||
def escape_quotes(match):
|
||||
content = match.group(1) # 获取 content 的值
|
||||
return content
|
||||
|
||||
# 使用正则表达式进行替换
|
||||
pattern = r"finished\(content='(.*?)'\)" # 匹配 type(content='...')
|
||||
content = re.sub(pattern, escape_quotes, action_str)
|
||||
|
||||
# 处理字符串
|
||||
action_str = escape_single_quotes(content)
|
||||
action_str = "finished(content='" + action_str + "')"
|
||||
all_action.append(action_str)
|
||||
|
||||
parsed_actions = [parse_action(action.replace("\n","\\n").lstrip()) for action in all_action]
|
||||
actions = []
|
||||
for action_instance, raw_str in zip(parsed_actions, all_action):
|
||||
if action_instance == None:
|
||||
logger.error(f"Action can't parse: {raw_str}")
|
||||
# raise ValueError(f"Action can't parse: {raw_str}")
|
||||
continue
|
||||
action_type = action_instance["function"]
|
||||
params = action_instance["args"]
|
||||
|
||||
# import pdb; pdb.set_trace()
|
||||
action_inputs = {}
|
||||
for param_name, param in params.items():
|
||||
if param == "": continue
|
||||
param = param.lstrip() # 去掉引号和多余的空格
|
||||
# 处理start_box或者end_box参数格式 '<bbox>x1 y1 x2 y2</bbox>'
|
||||
action_inputs[param_name.strip()] = param
|
||||
|
||||
if "start_box" in param_name or "end_box" in param_name:
|
||||
ori_box = param
|
||||
# Remove parentheses and split the string by commas
|
||||
numbers = ori_box.replace("(", "").replace(")", "").split(",")
|
||||
|
||||
# Convert to float and scale by 1000
|
||||
# Qwen2.5vl output absolute coordinates, qwen2vl output relative coordinates
|
||||
if model_type == "qwen25vl":
|
||||
float_numbers = []
|
||||
for num_idx, num in enumerate(numbers):
|
||||
num = float(num)
|
||||
if (num_idx + 1) % 2 == 0:
|
||||
float_numbers.append(float(num/smart_resize_height))
|
||||
else:
|
||||
float_numbers.append(float(num/smart_resize_width))
|
||||
else:
|
||||
float_numbers = [float(num) / factor for num in numbers]
|
||||
|
||||
if len(float_numbers) == 2:
|
||||
float_numbers = [float_numbers[0], float_numbers[1], float_numbers[0], float_numbers[1]]
|
||||
action_inputs[param_name.strip()] = str(float_numbers)
|
||||
|
||||
# import pdb; pdb.set_trace()
|
||||
actions.append(
|
||||
{
|
||||
"reflection": reflection,
|
||||
"thought": thought,
|
||||
"action_type": action_type,
|
||||
"action_inputs": action_inputs,
|
||||
"text": text
|
||||
})
|
||||
return actions
|
||||
|
||||
def parsing_response_to_pyautogui_code(responses, image_height: int, image_width:int, input_swap:bool=True) -> str:
|
||||
'''
|
||||
将M模型的输出解析为OSWorld中的action,生成pyautogui代码字符串
|
||||
参数:
|
||||
response: 包含模型输出的字典,结构类似于:
|
||||
{
|
||||
"action_type": "hotkey",
|
||||
"action_inputs": {
|
||||
"hotkey": "v ctrl",
|
||||
"start_box": None,
|
||||
"end_box": None
|
||||
}
|
||||
}
|
||||
返回:
|
||||
生成的pyautogui代码字符串
|
||||
'''
|
||||
|
||||
pyautogui_code = "import pyautogui\nimport time\n"
|
||||
if isinstance(responses, dict):
|
||||
responses = [responses]
|
||||
for response_id, response in enumerate(responses):
|
||||
if "observation" in response:
|
||||
observation = response["observation"]
|
||||
else:
|
||||
observation = ""
|
||||
|
||||
if "thought" in response:
|
||||
thought = response["thought"]
|
||||
else:
|
||||
thought = ""
|
||||
|
||||
if response_id == 0:
|
||||
pyautogui_code += f"'''\nObservation:\n{observation}\n\nThought:\n{thought}\n'''\n"
|
||||
else:
|
||||
pyautogui_code += "\ntime.sleep(1)\n"
|
||||
|
||||
action_dict = response
|
||||
response_text = action_dict.get("text", "")
|
||||
action_type = action_dict.get("action_type")
|
||||
action_inputs = action_dict.get("action_inputs", {})
|
||||
|
||||
if action_type == "hotkey":
|
||||
# Parsing hotkey action
|
||||
if "key" in action_inputs:
|
||||
hotkey = action_inputs.get("key", "")
|
||||
else:
|
||||
hotkey = action_inputs.get("hotkey", "")
|
||||
|
||||
if hotkey == "arrowleft":
|
||||
hotkey = "left"
|
||||
|
||||
elif hotkey == "arrowright":
|
||||
hotkey = "right"
|
||||
|
||||
elif hotkey == "arrowup":
|
||||
hotkey = "up"
|
||||
|
||||
elif hotkey == "arrowdown":
|
||||
hotkey = "down"
|
||||
|
||||
if hotkey:
|
||||
# Handle other hotkeys
|
||||
keys = hotkey.split() # Split the keys by space
|
||||
convert_keys = []
|
||||
for key in keys:
|
||||
if key == "space":
|
||||
key = ' '
|
||||
convert_keys.append(key)
|
||||
pyautogui_code += f"\npyautogui.hotkey({', '.join([repr(k) for k in convert_keys])})"
|
||||
|
||||
elif action_type == "press":
|
||||
# Parsing press action
|
||||
if "key" in action_inputs:
|
||||
key_to_press = action_inputs.get("key", "")
|
||||
else:
|
||||
key_to_press = action_inputs.get("press", "")
|
||||
|
||||
if hotkey == "arrowleft":
|
||||
hotkey = "left"
|
||||
|
||||
elif hotkey == "arrowright":
|
||||
hotkey = "right"
|
||||
|
||||
elif hotkey == "arrowup":
|
||||
hotkey = "up"
|
||||
|
||||
elif hotkey == "arrowdown":
|
||||
hotkey = "down"
|
||||
|
||||
elif hotkey == "space":
|
||||
hotkey = " "
|
||||
|
||||
if key_to_press:
|
||||
# Simulate pressing a single key
|
||||
pyautogui_code += f"\npyautogui.press({repr(key_to_press)})"
|
||||
|
||||
elif action_type == "keyup":
|
||||
key_to_up = action_inputs.get("key", "")
|
||||
pyautogui_code += f"\npyautogui.keyUp({repr(key_to_up)})"
|
||||
|
||||
elif action_type == "keydown":
|
||||
key_to_down = action_inputs.get("key", "")
|
||||
pyautogui_code += f"\npyautogui.keyDown({repr(key_to_down)})"
|
||||
|
||||
elif action_type == "type":
|
||||
# Parsing typing action using clipboard
|
||||
content = action_inputs.get("content", "")
|
||||
content = escape_single_quotes(content)
|
||||
stripped_content = content
|
||||
if content.endswith("\n") or content.endswith("\\n"):
|
||||
stripped_content = stripped_content.rstrip("\\n").rstrip("\n")
|
||||
if content:
|
||||
if input_swap:
|
||||
pyautogui_code += "\nimport pyperclip"
|
||||
pyautogui_code += f"\npyperclip.copy('{stripped_content}')"
|
||||
pyautogui_code += "\npyautogui.hotkey('ctrl', 'v')"
|
||||
pyautogui_code += "\ntime.sleep(0.5)\n"
|
||||
if content.endswith("\n") or content.endswith("\\n"):
|
||||
pyautogui_code += "\npyautogui.press('enter')"
|
||||
else:
|
||||
pyautogui_code += f"\npyautogui.write('{stripped_content}', interval=0.1)"
|
||||
pyautogui_code += "\ntime.sleep(0.5)\n"
|
||||
if content.endswith("\n") or content.endswith("\\n"):
|
||||
pyautogui_code += "\npyautogui.press('enter')"
|
||||
|
||||
|
||||
elif action_type in ["drag", "select"]:
|
||||
# Parsing drag or select action based on start and end_boxes
|
||||
start_box = action_inputs.get("start_box")
|
||||
end_box = action_inputs.get("end_box")
|
||||
if start_box and end_box:
|
||||
x1, y1, x2, y2 = eval(start_box) # Assuming box is in [x1, y1, x2, y2]
|
||||
sx = round(float((x1 + x2) / 2) * image_width, 3)
|
||||
sy = round(float((y1 + y2) / 2) * image_height, 3)
|
||||
x1, y1, x2, y2 = eval(end_box) # Assuming box is in [x1, y1, x2, y2]
|
||||
ex = round(float((x1 + x2) / 2) * image_width, 3)
|
||||
ey = round(float((y1 + y2) / 2) * image_height, 3)
|
||||
pyautogui_code += (
|
||||
f"\npyautogui.moveTo({sx}, {sy})\n"
|
||||
f"\npyautogui.dragTo({ex}, {ey}, duration=1.0)\n"
|
||||
)
|
||||
|
||||
elif action_type == "scroll":
|
||||
# Parsing scroll action
|
||||
start_box = action_inputs.get("start_box")
|
||||
if start_box:
|
||||
x1, y1, x2, y2 = eval(start_box) # Assuming box is in [x1, y1, x2, y2]
|
||||
x = round(float((x1 + x2) / 2) * image_width, 3)
|
||||
y = round(float((y1 + y2) / 2) * image_height, 3)
|
||||
|
||||
# # 先点对应区域,再滚动
|
||||
# pyautogui_code += f"\npyautogui.click({x}, {y}, button='left')"
|
||||
else:
|
||||
x = None
|
||||
y = None
|
||||
direction = action_inputs.get("direction", "")
|
||||
|
||||
if x == None:
|
||||
if "up" in direction.lower():
|
||||
pyautogui_code += "\npyautogui.scroll(5)"
|
||||
elif "down" in direction.lower():
|
||||
pyautogui_code += "\npyautogui.scroll(-5)"
|
||||
else:
|
||||
if "up" in direction.lower():
|
||||
pyautogui_code += f"\npyautogui.scroll(5, x={x}, y={y})"
|
||||
elif "down" in direction.lower():
|
||||
pyautogui_code += f"\npyautogui.scroll(-5, x={x}, y={y})"
|
||||
|
||||
elif action_type in ["click", "left_single", "left_double", "right_single", "hover"]:
|
||||
# Parsing mouse click actions
|
||||
start_box = action_inputs.get("start_box")
|
||||
start_box = str(start_box)
|
||||
if start_box:
|
||||
start_box = eval(start_box)
|
||||
if start_box is None:
|
||||
logger.warning(f"[Warning] start_box is None and wired condition:\n{action_inputs}")
|
||||
|
||||
if len(start_box) == 4:
|
||||
x1, y1, x2, y2 = start_box # Assuming box is in [x1, y1, x2, y2]
|
||||
elif len(start_box) == 2:
|
||||
x1, y1 = start_box
|
||||
x2 = x1
|
||||
y2 = y1
|
||||
x = round(float((x1 + x2) / 2) * image_width, 3)
|
||||
y = round(float((y1 + y2) / 2) * image_height, 3)
|
||||
if action_type == "left_single" or action_type == "click":
|
||||
pyautogui_code += f"\npyautogui.click({x}, {y}, button='left')"
|
||||
elif action_type == "left_double":
|
||||
pyautogui_code += f"\npyautogui.doubleClick({x}, {y}, button='left')"
|
||||
elif action_type == "right_single":
|
||||
pyautogui_code += f"\npyautogui.click({x}, {y}, button='right')"
|
||||
elif action_type == "hover":
|
||||
pyautogui_code += f"\npyautogui.moveTo({x}, {y})"
|
||||
|
||||
elif action_type in ["finished"]:
|
||||
pyautogui_code = "DONE"
|
||||
print(f"FINISHED:response_text: {response_text}")
|
||||
print(f"FINISHED:response: {str(response)}")
|
||||
for failure_indicator in FAILURE_INDICATORS:
|
||||
if failure_indicator in response_text:
|
||||
pyautogui_code = "FAIL"
|
||||
break
|
||||
elif action_type in ["wait"]:
|
||||
pyautogui_code = "WAIT"
|
||||
|
||||
elif action_type in ["call_user"]:
|
||||
pyautogui_code = "FAIL"
|
||||
else:
|
||||
pyautogui_code += f"\n# Unrecognized action type: {action_type}"
|
||||
|
||||
return pyautogui_code
|
||||
|
||||
def add_box_token(input_string):
|
||||
# Step 1: Split the string into individual actions
|
||||
if "Action: " in input_string and "start_box=" in input_string:
|
||||
suffix = input_string.split("Action: ")[0] + "Action: "
|
||||
actions = input_string.split("Action: ")[1:]
|
||||
processed_actions = []
|
||||
for action in actions:
|
||||
action = action.strip()
|
||||
# Step 2: Extract coordinates (start_box or end_box) using regex
|
||||
coordinates = re.findall(r"(start_box|end_box)='\((\d+),\s*(\d+)\)'", action)
|
||||
|
||||
updated_action = action # Start with the original action
|
||||
for coord_type, x, y in coordinates:
|
||||
# Convert x and y to integers
|
||||
updated_action = updated_action.replace(f"{coord_type}='({x},{y})'", f"{coord_type}='<|box_start|>({x},{y})<|box_end|>'")
|
||||
processed_actions.append(updated_action)
|
||||
|
||||
# Step 5: Reconstruct the final string
|
||||
final_string = suffix + "\n\n".join(processed_actions)
|
||||
else:
|
||||
final_string = input_string
|
||||
# print(f"Input string: {input_string}")
|
||||
# print(f"Final string: {final_string}")
|
||||
return [{"type": "text", "text": final_string}]
|
||||
|
||||
def pil_to_base64(image):
|
||||
"""Convert PIL Image or bytes to base64 string"""
|
||||
if isinstance(image, bytes):
|
||||
# If it's already bytes, just encode to base64
|
||||
return base64.b64encode(image).decode("utf-8")
|
||||
else:
|
||||
# If it's a PIL Image, convert it
|
||||
buffer = BytesIO()
|
||||
image.save(buffer, format="PNG")
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
|
@ -0,0 +1,686 @@
|
|||
"""
|
||||
Dart Agent - Custom agent for GUI automation using Dart models
|
||||
Based on UITARSAgent structure but using Dart-specific utilities and prompts
|
||||
"""
|
||||
import ast
|
||||
import base64
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from io import BytesIO
|
||||
from typing import Dict, List, Any
|
||||
from PIL import Image
|
||||
from openai import OpenAI
|
||||
import backoff
|
||||
import openai
|
||||
import requests
|
||||
from requests.exceptions import SSLError
|
||||
from google.api_core.exceptions import (
|
||||
BadRequest,
|
||||
InternalServerError,
|
||||
InvalidArgument,
|
||||
ResourceExhausted,
|
||||
)
|
||||
|
||||
# Import Dart-specific utilities and prompts
|
||||
from mm_agents.dart_gui.utils import (
|
||||
pil_to_base64,
|
||||
parse_action_to_structure_output,
|
||||
parsing_response_to_pyautogui_code,
|
||||
parse_action,
|
||||
escape_single_quotes,
|
||||
round_by_factor,
|
||||
ceil_by_factor,
|
||||
floor_by_factor,
|
||||
linear_resize,
|
||||
smart_resize,
|
||||
add_box_token,
|
||||
IMAGE_FACTOR,
|
||||
MIN_PIXELS,
|
||||
MAX_PIXELS,
|
||||
MAX_RATIO,
|
||||
FINISH_WORD,
|
||||
WAIT_WORD,
|
||||
ENV_FAIL_WORD,
|
||||
CALL_USER
|
||||
)
|
||||
|
||||
from mm_agents.dart_gui.prompts import (
|
||||
COMPUTER_USE_PROMPT,
|
||||
COMPUTER_USE_PROMPT_WITH_CALL_USER,
|
||||
UITARS_ACTION_SPACE,
|
||||
UITARS_CALL_USR_ACTION_SPACE,
|
||||
UITARS_USR_PROMPT_THOUGHT,
|
||||
UITARS_USR_PROMPT_NOTHOUGHT
|
||||
)
|
||||
|
||||
logger = logging.getLogger("desktopenv.agent")
|
||||
|
||||
class DartAgent:
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
runtime_conf: Dict,
|
||||
platform="ubuntu",
|
||||
max_tokens=1000,
|
||||
top_p=0.9,
|
||||
top_k=1.0,
|
||||
temperature=0.0,
|
||||
action_space="pyautogui",
|
||||
observation_type="screenshot",
|
||||
max_trajectory_length=50,
|
||||
model_type="qwen25vl",
|
||||
**kwargs
|
||||
):
|
||||
self.model = model
|
||||
self.platform = platform
|
||||
self.action_space = action_space
|
||||
self.observation_type = observation_type
|
||||
self.max_trajectory_length = max_trajectory_length
|
||||
self.model_type = model_type
|
||||
self.runtime_conf = runtime_conf
|
||||
|
||||
# Extract runtime configuration parameters
|
||||
self.max_tokens = self.runtime_conf.get("max_tokens", max_tokens)
|
||||
self.top_p = self.runtime_conf.get("top_p", top_p)
|
||||
self.top_k = self.runtime_conf.get("top_k", top_k)
|
||||
self.temperature = self.runtime_conf.get("temperature", temperature)
|
||||
self.infer_mode = self.runtime_conf.get("infer_mode", "dart_mode")
|
||||
self.prompt_style = self.runtime_conf.get("prompt_style", "dart_style")
|
||||
self.input_swap = self.runtime_conf.get("input_swap", False)
|
||||
self.language = self.runtime_conf.get("language", "English")
|
||||
self.max_pixels = self.runtime_conf.get("max_pixels", MAX_PIXELS)
|
||||
self.min_pixels = self.runtime_conf.get("min_pixels", MIN_PIXELS)
|
||||
self.history_n = self.runtime_conf.get("history_n", 5)
|
||||
|
||||
# Dart specific configurations
|
||||
self.max_images = self.runtime_conf.get("max_images", 5)
|
||||
self.max_texts = self.runtime_conf.get("max_texts", 35)
|
||||
|
||||
# Initialize OpenAI client - use Dart API if provided
|
||||
dart_api_key = self.runtime_conf.get("dart_api_key", "")
|
||||
dart_base_url = self.runtime_conf.get("dart_base_url", "")
|
||||
|
||||
if dart_base_url:
|
||||
# 检查是否为直接的生成端点(包含 /generate)
|
||||
if '/generate' in dart_base_url:
|
||||
# 直接使用提供的 URL,不添加 /v1
|
||||
logger.info(f"使用直接生成端点: {dart_base_url}")
|
||||
self.dart_direct_url = dart_base_url
|
||||
self.vlm = None # 不使用 OpenAI 客户端
|
||||
else:
|
||||
# 传统的 OpenAI 兼容端点,确保以 /v1 结尾
|
||||
if not dart_base_url.endswith('/v1'):
|
||||
dart_base_url = dart_base_url.rstrip('/') + '/v1'
|
||||
|
||||
self.vlm = OpenAI(
|
||||
base_url=dart_base_url,
|
||||
api_key=dart_api_key,
|
||||
)
|
||||
self.dart_direct_url = None
|
||||
else:
|
||||
# Fallback to environment variables
|
||||
base_url = os.environ.get('DART_API_URL', os.environ.get('DOUBAO_API_URL'))
|
||||
if base_url:
|
||||
if '/generate' in base_url:
|
||||
# 直接生成端点
|
||||
self.dart_direct_url = base_url
|
||||
self.vlm = None
|
||||
else:
|
||||
if not base_url.endswith('/v1'):
|
||||
base_url = base_url.rstrip('/') + '/v1'
|
||||
self.vlm = OpenAI(
|
||||
base_url=base_url,
|
||||
api_key=os.environ.get('DART_API_KEY', os.environ.get('DOUBAO_API_KEY')),
|
||||
)
|
||||
self.dart_direct_url = None
|
||||
else:
|
||||
self.vlm = None
|
||||
self.dart_direct_url = None
|
||||
|
||||
# Initialize trajectory storage - similar to trajectory_runner.py
|
||||
self.thoughts = []
|
||||
self.actions = []
|
||||
self.observations = []
|
||||
self.history_images = []
|
||||
self.history_responses = []
|
||||
|
||||
# Message handling similar to trajectory_runner.py
|
||||
self.base_messages = [] # for model client (with base64 images)
|
||||
self.base_messages_for_save = [] # for storage (with file paths)
|
||||
self.prompt_dialogue = [] # for model client
|
||||
self.save_dialogue = [] # for storage
|
||||
self.save_dialogue_full = [] # for full storage (保存所有图片路径)
|
||||
self.image_refs = [] # record image position
|
||||
|
||||
# All image paths storage - to keep track of all images even when trimmed
|
||||
self.all_image_paths = []
|
||||
|
||||
# Current screenshot file path for proper saving
|
||||
self.current_screenshot_path = None
|
||||
|
||||
# Configure prompt and action space based on mode
|
||||
if self.infer_mode == "dart_mode":
|
||||
self.prompt_action_space = UITARS_ACTION_SPACE
|
||||
self.prompt_template = COMPUTER_USE_PROMPT
|
||||
else:
|
||||
# For qwen2vl_user mode
|
||||
self.prompt_action_space = UITARS_CALL_USR_ACTION_SPACE
|
||||
if self.prompt_style == "qwen2vl_user":
|
||||
self.prompt_template = UITARS_USR_PROMPT_THOUGHT
|
||||
elif self.prompt_style == "qwen2vl_no_thought":
|
||||
self.prompt_template = UITARS_USR_PROMPT_NOTHOUGHT
|
||||
else:
|
||||
self.prompt_template = UITARS_USR_PROMPT_THOUGHT
|
||||
|
||||
self.action_parse_res_factor = 1000
|
||||
|
||||
logger.info(f"Initialized DartAgent with model: {self.model}, mode: {self.infer_mode}")
|
||||
|
||||
def reset(self, runtime_logger=None):
|
||||
"""Reset the agent state"""
|
||||
self.thoughts = []
|
||||
self.actions = []
|
||||
self.observations = []
|
||||
self.history_images = []
|
||||
self.history_responses = []
|
||||
|
||||
# Reset message handling
|
||||
self.base_messages = []
|
||||
self.base_messages_for_save = []
|
||||
self.prompt_dialogue = []
|
||||
self.save_dialogue = []
|
||||
self.save_dialogue_full = []
|
||||
self.image_refs = []
|
||||
self.all_image_paths = []
|
||||
self.current_screenshot_path = None
|
||||
|
||||
logger.info("DartAgent reset")
|
||||
|
||||
def set_base_messages(self, instruction: str):
|
||||
"""Initialize base messages similar to task_loader.py"""
|
||||
system_prompt = COMPUTER_USE_PROMPT
|
||||
|
||||
self.base_messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": system_prompt.format(
|
||||
instruction=instruction,
|
||||
language=self.language
|
||||
)
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Copy for save version
|
||||
from copy import deepcopy
|
||||
self.base_messages_for_save = deepcopy(self.base_messages)
|
||||
|
||||
def set_current_screenshot_path(self, screenshot_path: str):
|
||||
"""Set the current screenshot file path for proper saving"""
|
||||
self.current_screenshot_path = screenshot_path
|
||||
|
||||
def predict(
|
||||
self, instruction: str, obs: Dict, last_action_after_obs: Dict = None
|
||||
) -> tuple:
|
||||
"""
|
||||
Predict the next action(s) based on the current observation.
|
||||
Returns: (response_text, actions_list)
|
||||
"""
|
||||
# Initialize base messages if not set
|
||||
if not self.base_messages:
|
||||
self.set_base_messages(instruction)
|
||||
|
||||
# Store current observation
|
||||
self._add_observation(obs)
|
||||
|
||||
# For first step, set the first frame
|
||||
if len(self.observations) == 1:
|
||||
self._set_first_frame(obs["screenshot"], self.current_screenshot_path)
|
||||
else:
|
||||
# For subsequent steps, add the new image to dialogue
|
||||
# This represents the result of the previous action
|
||||
self._add_image(obs["screenshot"], self.current_screenshot_path)
|
||||
|
||||
# Build prompt messages (base_messages + prompt_dialogue)
|
||||
messages = self._build_messages()
|
||||
|
||||
# Call model to get response
|
||||
prediction = self._call_model(messages)
|
||||
if prediction is None:
|
||||
return "client error", ["DONE"]
|
||||
|
||||
# Store response and parse actions
|
||||
self._add_text(prediction)
|
||||
|
||||
# Parse response to actions
|
||||
try:
|
||||
image_size = self._get_current_image_size()
|
||||
actions = self._parse_and_convert_actions(prediction, image_size)
|
||||
|
||||
# Check for terminal actions
|
||||
terminal_action = self._check_terminal_actions(actions)
|
||||
if terminal_action:
|
||||
self.actions.append(actions)
|
||||
return prediction, [terminal_action]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Parsing action error: {prediction}, error: {e}")
|
||||
return f"Parsing action error: {prediction}, error: {e}", ["DONE"]
|
||||
|
||||
self.actions.append(actions)
|
||||
# Check max steps
|
||||
if len(self.history_responses) >= self.max_trajectory_length:
|
||||
actions = ["FAIL"]
|
||||
|
||||
return prediction, actions
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.constant,
|
||||
(
|
||||
# General exceptions
|
||||
SSLError,
|
||||
# OpenAI exceptions
|
||||
openai.RateLimitError,
|
||||
openai.BadRequestError,
|
||||
openai.InternalServerError,
|
||||
# Google exceptions
|
||||
InvalidArgument,
|
||||
ResourceExhausted,
|
||||
InternalServerError,
|
||||
BadRequest,
|
||||
),
|
||||
interval=30,
|
||||
max_tries=10,
|
||||
)
|
||||
def predict_with_backoff(self, instruction: str, obs: Dict, last_action_after_obs: Dict = None):
|
||||
"""Predict with backoff for rate limiting and temporary errors"""
|
||||
return self.predict(instruction, obs, last_action_after_obs)
|
||||
|
||||
def get_trajectory(self) -> List[Dict]:
|
||||
"""Get the current trajectory for saving"""
|
||||
trajectory = []
|
||||
for i in range(len(self.observations)):
|
||||
trajectory.append({
|
||||
"observation": self.observations[i],
|
||||
"thought": self.thoughts[i] if i < len(self.thoughts) else "",
|
||||
"action": self.actions[i] if i < len(self.actions) else []
|
||||
})
|
||||
return trajectory
|
||||
|
||||
def get_full_messages(self) -> List[Dict]:
|
||||
"""Get the complete conversation messages for saving (including base messages and dialogue)"""
|
||||
# Combine base_messages_for_save with save_dialogue_full to get complete conversation
|
||||
full_messages = []
|
||||
|
||||
# Add base messages (system prompt and initial user message)
|
||||
full_messages.extend(self.base_messages_for_save)
|
||||
|
||||
# Add dialogue messages (user images + assistant responses) with all images
|
||||
full_messages.extend(self.save_dialogue_full)
|
||||
|
||||
return full_messages
|
||||
|
||||
def get_all_image_paths(self) -> List[str]:
|
||||
"""Get all image paths that have been used throughout the conversation"""
|
||||
return self.all_image_paths.copy()
|
||||
|
||||
# ========== Private Methods ==========
|
||||
|
||||
def _validate_trajectory(self):
|
||||
"""Validate trajectory consistency"""
|
||||
assert len(self.observations) == len(self.actions) and len(self.actions) == len(
|
||||
self.thoughts
|
||||
), "The number of observations and actions should be the same."
|
||||
|
||||
def _add_observation(self, obs: Dict):
|
||||
"""Process observation and add to history"""
|
||||
# Store observation
|
||||
if self.observation_type in ["screenshot", "screenshot_a11y_tree"]:
|
||||
base64_image = obs["screenshot"]
|
||||
try:
|
||||
# Handle accessibility tree if needed
|
||||
linearized_accessibility_tree = None
|
||||
if self.observation_type == "screenshot_a11y_tree" and "accessibility_tree" in obs:
|
||||
# For now, we'll skip accessibility tree processing in Dart mode
|
||||
linearized_accessibility_tree = None
|
||||
except:
|
||||
linearized_accessibility_tree = None
|
||||
|
||||
if self.observation_type == "screenshot_a11y_tree":
|
||||
self.observations.append({
|
||||
"screenshot": base64_image,
|
||||
"accessibility_tree": linearized_accessibility_tree,
|
||||
})
|
||||
else:
|
||||
self.observations.append({
|
||||
"screenshot": base64_image,
|
||||
"accessibility_tree": None
|
||||
})
|
||||
else:
|
||||
raise ValueError("Invalid observation_type type: " + self.observation_type)
|
||||
|
||||
|
||||
def _build_messages(self) -> List[Dict]:
|
||||
"""Build messages for model API call - similar to trajectory_runner._build_messages"""
|
||||
return self.base_messages + self.prompt_dialogue
|
||||
|
||||
def _call_model(self, messages: List[Dict]) -> str:
|
||||
"""Call model with retry logic"""
|
||||
try_times = 3
|
||||
while try_times > 0:
|
||||
try:
|
||||
# 如果使用直接生成端点
|
||||
if hasattr(self, 'dart_direct_url') and self.dart_direct_url:
|
||||
prediction = self._call_direct_generate_endpoint(messages)
|
||||
else:
|
||||
# 使用标准 OpenAI 客户端
|
||||
response = self.vlm.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
frequency_penalty=1,
|
||||
max_tokens=self.max_tokens,
|
||||
temperature=self.temperature,
|
||||
top_p=self.top_p
|
||||
)
|
||||
prediction = response.choices[0].message.content
|
||||
|
||||
logger.info(f"Model response: {prediction}")
|
||||
return prediction
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error when fetching response from client: {e}")
|
||||
try_times -= 1
|
||||
if try_times <= 0:
|
||||
logger.error("Reach max retry times to fetch response from client")
|
||||
return None
|
||||
return None
|
||||
|
||||
def _call_direct_generate_endpoint(self, messages: List[Dict]) -> str:
|
||||
"""直接调用生成端点"""
|
||||
try:
|
||||
|
||||
|
||||
# 构建请求数据
|
||||
payload = {
|
||||
"messages": messages,
|
||||
"model": self.model,
|
||||
"max_tokens": self.max_tokens,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"frequency_penalty": 1
|
||||
}
|
||||
|
||||
# 添加 API key 到 headers
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.runtime_conf.get('dart_api_key', '')}"
|
||||
}
|
||||
|
||||
|
||||
# 重试机制:最多重试3次,每次推理60秒
|
||||
max_retries = 3
|
||||
response = None
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
logger.info(f"尝试第 {attempt + 1} 次请求...")
|
||||
response = requests.post(
|
||||
self.dart_direct_url,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=60
|
||||
)
|
||||
response.raise_for_status()
|
||||
break # 成功则跳出重试循环
|
||||
except Exception as e:
|
||||
logger.warning(f"第 {attempt + 1} 次请求失败: {e}")
|
||||
if attempt == max_retries - 1: # 最后一次重试失败
|
||||
logger.error(f"所有 {max_retries} 次重试都失败了")
|
||||
raise e
|
||||
else:
|
||||
logger.info(f"等待后重试...")
|
||||
import time
|
||||
time.sleep(2) # 等待2秒后重试
|
||||
|
||||
# 解析响应
|
||||
result = response.json()
|
||||
|
||||
# 尝试多种可能的响应格式
|
||||
if 'choices' in result and len(result['choices']) > 0:
|
||||
# OpenAI 兼容格式
|
||||
return result['choices'][0]['message']['content']
|
||||
elif 'response' in result:
|
||||
# 简单的 response 字段
|
||||
return result['response']
|
||||
elif 'text' in result:
|
||||
# text 字段
|
||||
return result['text']
|
||||
elif 'content' in result:
|
||||
# content 字段
|
||||
return result['content']
|
||||
else:
|
||||
# 如果找不到标准字段,返回整个响应的字符串
|
||||
logger.warning(f"未知的响应格式: {result}")
|
||||
return str(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"直接端点调用失败: {e}")
|
||||
raise e
|
||||
|
||||
def _add_text(self, assistant_txt: str):
|
||||
"""Add text response to history - similar to trajectory_runner.py"""
|
||||
self.history_responses.append(assistant_txt)
|
||||
self.thoughts.append(assistant_txt)
|
||||
|
||||
# Add to dialogue similar to trajectory_runner._add_text
|
||||
msg = {
|
||||
"role": "assistant",
|
||||
"content": add_box_token(assistant_txt)
|
||||
}
|
||||
self.prompt_dialogue.append(msg)
|
||||
self.save_dialogue.append(msg)
|
||||
self.save_dialogue_full.append(msg)
|
||||
self._trim()
|
||||
|
||||
def _set_first_frame(self, obs_img: bytes, frame_path: str = None):
|
||||
"""Set first frame in base_messages - similar to trajectory_runner._set_first_frame"""
|
||||
self.base_messages[1]["content"].append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:image/png;base64," + pil_to_base64(obs_img)}
|
||||
}
|
||||
)
|
||||
|
||||
# Use actual frame path if provided, otherwise use current_screenshot_path or placeholder
|
||||
if frame_path:
|
||||
first_frame_path = frame_path
|
||||
elif self.current_screenshot_path:
|
||||
first_frame_path = self.current_screenshot_path
|
||||
else:
|
||||
first_frame_path = "first_frame.png"
|
||||
|
||||
# Store in all_image_paths
|
||||
self.all_image_paths.append(first_frame_path)
|
||||
|
||||
self.base_messages_for_save[1]["content"].append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": first_frame_path
|
||||
}
|
||||
)
|
||||
|
||||
self.image_refs.append(
|
||||
{"source": "base", "msg_idx": 1,
|
||||
"content_idx": len(self.base_messages[1]["content"]) - 1}
|
||||
)
|
||||
|
||||
def _add_image(self, img_bytes: bytes, frame_path: str = None):
|
||||
"""Add image to dialogue - similar to trajectory_runner._add_image"""
|
||||
self.prompt_dialogue.append({
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:image/png;base64," + pil_to_base64(img_bytes)}
|
||||
}]
|
||||
})
|
||||
|
||||
# Use actual frame path if provided, otherwise use current_screenshot_path
|
||||
if frame_path:
|
||||
image_url = frame_path
|
||||
elif self.current_screenshot_path:
|
||||
image_url = self.current_screenshot_path
|
||||
else:
|
||||
# Fallback to a placeholder - this should rarely happen in practice
|
||||
image_url = f"frame_{len(self.save_dialogue)}.png"
|
||||
|
||||
# Store in all_image_paths for complete record
|
||||
self.all_image_paths.append(image_url)
|
||||
|
||||
# Add to save_dialogue (trimmed version)
|
||||
self.save_dialogue.append({
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "image_url",
|
||||
"image_url": image_url
|
||||
}]
|
||||
})
|
||||
|
||||
# Add to save_dialogue_full (complete version - never trimmed)
|
||||
self.save_dialogue_full.append({
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "image_url",
|
||||
"image_url": image_url
|
||||
}]
|
||||
})
|
||||
|
||||
self.image_refs.append(
|
||||
{"source": "dialogue", "msg_idx": len(self.prompt_dialogue) - 1,
|
||||
"content_idx": None}
|
||||
)
|
||||
|
||||
self._trim()
|
||||
|
||||
def _trim(self):
|
||||
"""Ensure image num ≤ max_images and assistant text num ≤ max_texts - similar to trajectory_runner._trim"""
|
||||
img_cnt = len(self.image_refs)
|
||||
txt_cnt = sum(m["role"] == "assistant" for m in self.prompt_dialogue)
|
||||
|
||||
while img_cnt > self.max_images or txt_cnt > self.max_texts:
|
||||
# 图片超限:最早一张
|
||||
if img_cnt > self.max_images:
|
||||
ref = self.image_refs.pop(0)
|
||||
if ref["source"] == "base":
|
||||
self.base_messages[ref["msg_idx"]]["content"].pop(ref["content_idx"])
|
||||
else: # dialogue 图
|
||||
self._remove_dialogue_msg(ref["msg_idx"])
|
||||
img_cnt -= 1
|
||||
continue
|
||||
|
||||
# 文本超限:最早 assistant 文本
|
||||
if txt_cnt > self.max_texts:
|
||||
for i, m in enumerate(self.prompt_dialogue):
|
||||
if m["role"] == "assistant":
|
||||
self._remove_dialogue_msg(i)
|
||||
txt_cnt -= 1
|
||||
break
|
||||
|
||||
def _remove_dialogue_msg(self, idx: int):
|
||||
"""Remove dialogue message and update refs - similar to trajectory_runner._remove_dialogue_msg"""
|
||||
self.prompt_dialogue.pop(idx)
|
||||
self.save_dialogue.pop(idx)
|
||||
# Note: save_dialogue_full is never trimmed, so we don't remove from it
|
||||
|
||||
# 更新 image_refs
|
||||
self.image_refs = [
|
||||
r if not (r["source"] == "dialogue" and r["msg_idx"] == idx)
|
||||
else None # 同一条被删掉的图引用直接丢弃
|
||||
for r in self.image_refs
|
||||
]
|
||||
self.image_refs = [
|
||||
(
|
||||
{**r, "msg_idx": r["msg_idx"] - 1}
|
||||
if r and r["source"] == "dialogue" and r["msg_idx"] > idx # idx后的图片索引均-1
|
||||
else r
|
||||
)
|
||||
for r in self.image_refs
|
||||
if r # 剔除 None
|
||||
]
|
||||
|
||||
def _get_current_image_size(self) -> tuple:
|
||||
"""Get current image size for coordinate conversion"""
|
||||
if len(self.observations) > 0:
|
||||
try:
|
||||
current_image_bytes = self.observations[-1]["screenshot"]
|
||||
if isinstance(current_image_bytes, bytes):
|
||||
current_image = Image.open(BytesIO(current_image_bytes))
|
||||
return (current_image.height, current_image.width)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting image size: {e}")
|
||||
|
||||
# Fallback to default screen size
|
||||
return (1080, 1920)
|
||||
|
||||
def _parse_and_convert_actions(self, prediction: str, image_size: tuple) -> List[str]:
|
||||
"""Parse response and convert to pyautogui actions - similar to trajectory_runner._parse"""
|
||||
image_height, image_width = image_size
|
||||
|
||||
# Parse the response to structured actions
|
||||
parsed_responses = parse_action_to_structure_output(
|
||||
prediction,
|
||||
factor=self.action_parse_res_factor,
|
||||
origin_resized_height=image_height,
|
||||
origin_resized_width=image_width,
|
||||
model_type=self.model_type,
|
||||
max_pixels=self.max_pixels,
|
||||
min_pixels=self.min_pixels
|
||||
)
|
||||
|
||||
# Convert parsed responses to pyautogui actions
|
||||
actions = []
|
||||
for parsed_response in parsed_responses:
|
||||
try:
|
||||
pyautogui_code = parsing_response_to_pyautogui_code(
|
||||
parsed_response,
|
||||
image_height=image_height,
|
||||
image_width=image_width,
|
||||
input_swap=self.input_swap
|
||||
)
|
||||
|
||||
|
||||
|
||||
actions.append(pyautogui_code)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating pyautogui code: {e}")
|
||||
actions.append("FAIL")
|
||||
|
||||
return actions
|
||||
|
||||
|
||||
|
||||
def _check_terminal_actions(self, actions: List[str]) -> str:
|
||||
"""Check if any action is terminal and return appropriate code"""
|
||||
for action in actions:
|
||||
if isinstance(action, dict) and "action_type" in action:
|
||||
action_type = action["action_type"]
|
||||
if action_type == FINISH_WORD:
|
||||
return "DONE"
|
||||
elif action_type == WAIT_WORD:
|
||||
return "WAIT"
|
||||
elif action_type == ENV_FAIL_WORD:
|
||||
return "FAIL"
|
||||
elif action_type == CALL_USER:
|
||||
return "FAIL"
|
||||
return None
|
||||
|
|
@ -0,0 +1,653 @@
|
|||
import os
|
||||
import re
|
||||
import json
|
||||
import logging
|
||||
import backoff
|
||||
import openai
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
|
||||
from mm_agents.evocua.utils import (
|
||||
process_image,
|
||||
encode_image,
|
||||
rewrite_pyautogui_text_inputs,
|
||||
project_coordinate_to_absolute_scale,
|
||||
log_messages
|
||||
)
|
||||
|
||||
from mm_agents.evocua.prompts import (
|
||||
S1_SYSTEM_PROMPT,
|
||||
S1_INSTRUTION_TEMPLATE,
|
||||
S1_STEP_TEMPLATE,
|
||||
S1_ACTION_HISTORY_TEMPLATE,
|
||||
S2_ACTION_DESCRIPTION,
|
||||
S2_DESCRIPTION_PROMPT_TEMPLATE,
|
||||
S2_SYSTEM_PROMPT,
|
||||
build_s2_tools_def
|
||||
)
|
||||
|
||||
logger = logging.getLogger("desktopenv.evocua")
|
||||
|
||||
class EvoCUAAgent:
|
||||
"""
|
||||
EvoCUA - A Native GUI agent model for desktop automation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "EvoCUA-S2",
|
||||
max_tokens: int = 32768,
|
||||
top_p: float = 0.9,
|
||||
temperature: float = 0.0,
|
||||
action_space: str = "pyautogui",
|
||||
observation_type: str = "screenshot",
|
||||
max_steps: int = 50,
|
||||
prompt_style: str = "S2", # "S1" or "S2"
|
||||
max_history_turns: int = 4,
|
||||
screen_size: Tuple[int, int] = (1920, 1080),
|
||||
coordinate_type: str = "relative",
|
||||
password: str = "osworld-public-evaluation",
|
||||
resize_factor: int = 32,
|
||||
**kwargs
|
||||
):
|
||||
self.model = model
|
||||
self.max_tokens = max_tokens
|
||||
self.top_p = top_p
|
||||
self.temperature = temperature
|
||||
self.action_space = action_space
|
||||
self.observation_type = observation_type
|
||||
self.max_steps = max_steps
|
||||
|
||||
self.prompt_style = prompt_style
|
||||
assert self.prompt_style in ["S1", "S2"], f"Invalid prompt_style: {self.prompt_style}"
|
||||
|
||||
self.max_history_turns = max_history_turns
|
||||
|
||||
self.screen_size = screen_size
|
||||
self.coordinate_type = coordinate_type
|
||||
self.password = password
|
||||
self.resize_factor = resize_factor
|
||||
|
||||
# Action space assertion
|
||||
assert self.action_space == "pyautogui", f"Invalid action space: {self.action_space}"
|
||||
assert self.observation_type == "screenshot", f"Invalid observation type: {self.observation_type}"
|
||||
|
||||
# State
|
||||
self.thoughts = []
|
||||
self.actions = []
|
||||
self.observations = []
|
||||
self.responses = []
|
||||
self.screenshots = [] # Stores encoded string
|
||||
self.cots = [] # For S1 style history
|
||||
|
||||
def reset(self, _logger=None, vm_ip=None):
|
||||
global logger
|
||||
if _logger:
|
||||
logger = _logger
|
||||
|
||||
self.thoughts = []
|
||||
self.actions = []
|
||||
self.observations = []
|
||||
self.responses = []
|
||||
self.screenshots = []
|
||||
self.cots = []
|
||||
|
||||
def predict(self, instruction: str, obs: Dict) -> List:
|
||||
"""
|
||||
Main prediction loop.
|
||||
"""
|
||||
|
||||
logger.info(f"========================== {self.model} ===================================")
|
||||
logger.info(f"Instruction: \n{instruction}")
|
||||
|
||||
screenshot_bytes = obs["screenshot"]
|
||||
|
||||
try:
|
||||
original_img = Image.open(BytesIO(screenshot_bytes))
|
||||
original_width, original_height = original_img.size
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read screenshot size, falling back to screen_size: {e}")
|
||||
original_width, original_height = self.screen_size
|
||||
|
||||
if self.prompt_style == "S1":
|
||||
raw_b64 = encode_image(screenshot_bytes)
|
||||
self.screenshots.append(raw_b64)
|
||||
return self._predict_s1(instruction, obs, raw_b64)
|
||||
else:
|
||||
processed_b64, p_width, p_height = process_image(screenshot_bytes, factor=self.resize_factor)
|
||||
self.screenshots.append(processed_b64)
|
||||
return self._predict_s2(
|
||||
instruction,
|
||||
obs,
|
||||
processed_b64,
|
||||
p_width,
|
||||
p_height,
|
||||
original_width,
|
||||
original_height,
|
||||
)
|
||||
|
||||
|
||||
def _predict_s2(self, instruction, obs, processed_b64, p_width, p_height, original_width, original_height):
|
||||
current_step = len(self.actions)
|
||||
current_history_n = self.max_history_turns
|
||||
|
||||
response = None
|
||||
|
||||
if self.coordinate_type == "absolute":
|
||||
resolution_info = f"* The screen's resolution is {p_width}x{p_height}."
|
||||
else:
|
||||
resolution_info = "* The screen's resolution is 1000x1000."
|
||||
|
||||
description_prompt = S2_DESCRIPTION_PROMPT_TEMPLATE.format(resolution_info=resolution_info)
|
||||
|
||||
tools_def = build_s2_tools_def(description_prompt)
|
||||
|
||||
system_prompt = S2_SYSTEM_PROMPT.format(tools_xml=json.dumps(tools_def))
|
||||
|
||||
# Retry loop for context length
|
||||
while True:
|
||||
messages = self._build_s2_messages(
|
||||
instruction,
|
||||
processed_b64,
|
||||
current_step,
|
||||
current_history_n,
|
||||
system_prompt
|
||||
)
|
||||
|
||||
try:
|
||||
response = self.call_llm({
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": self.max_tokens,
|
||||
"top_p": self.top_p,
|
||||
"temperature": self.temperature,
|
||||
})
|
||||
break
|
||||
except Exception as e:
|
||||
# Handle Context Too Large
|
||||
if self._should_giveup_on_context_error(e) and current_history_n > 0:
|
||||
current_history_n -= 1
|
||||
logger.warning(f"Context too large, retrying with history_n={current_history_n}")
|
||||
else:
|
||||
logger.error(f"Error in predict: {e}")
|
||||
break
|
||||
|
||||
self.responses.append(response)
|
||||
|
||||
low_level_instruction, pyautogui_code = self._parse_response_s2(
|
||||
response, p_width, p_height, original_width, original_height
|
||||
)
|
||||
|
||||
# new added
|
||||
current_step = len(self.actions) + 1
|
||||
first_action = pyautogui_code[0] if pyautogui_code else ""
|
||||
if current_step >= self.max_steps and str(first_action).upper() not in ("DONE", "FAIL"):
|
||||
logger.warning(f"Reached maximum steps {self.max_steps}. Forcing termination with FAIL.")
|
||||
low_level_instruction = "Fail the task because reaching the maximum step limit."
|
||||
pyautogui_code = ["FAIL"]
|
||||
|
||||
logger.info(f"Low level instruction: {low_level_instruction}")
|
||||
logger.info(f"Pyautogui code: {pyautogui_code}")
|
||||
|
||||
self.actions.append(low_level_instruction)
|
||||
return response, pyautogui_code
|
||||
|
||||
def _build_s2_messages(self, instruction, current_img, step, history_n, system_prompt):
|
||||
messages = [{"role": "system", "content": [{"type": "text", "text": system_prompt}]}]
|
||||
|
||||
previous_actions = []
|
||||
history_start_idx = max(0, step - history_n)
|
||||
for i in range(history_start_idx):
|
||||
if i < len(self.actions):
|
||||
previous_actions.append(f"Step {i+1}: {self.actions[i]}")
|
||||
previous_actions_str = "\n".join(previous_actions) if previous_actions else "None"
|
||||
|
||||
# Add History
|
||||
history_len = min(history_n, len(self.responses))
|
||||
if history_len > 0:
|
||||
hist_responses = self.responses[-history_len:]
|
||||
hist_imgs = self.screenshots[-history_len-1:-1]
|
||||
|
||||
for i in range(history_len):
|
||||
if i < len(hist_imgs):
|
||||
screenshot_b64 = hist_imgs[i]
|
||||
if i == 0:
|
||||
# First history item: Inject Instruction + Previous Actions Context
|
||||
img_url = f"data:image/png;base64,{screenshot_b64}"
|
||||
instruction_prompt = f"""
|
||||
Please generate the next move according to the UI screenshot, instruction and previous actions.
|
||||
|
||||
Instruction: {instruction}
|
||||
|
||||
Previous actions:
|
||||
{previous_actions_str}"""
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": img_url}},
|
||||
{"type": "text", "text": instruction_prompt}
|
||||
]
|
||||
})
|
||||
else:
|
||||
img_url = f"data:image/png;base64,{screenshot_b64}"
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": img_url}},
|
||||
]
|
||||
})
|
||||
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": hist_responses[i]}]
|
||||
})
|
||||
|
||||
# Current Turn
|
||||
# We re-use previous_actions_str logic for the case where history_len == 0
|
||||
|
||||
if history_len == 0:
|
||||
# First turn logic: Include Instruction + Previous Actions
|
||||
instruction_prompt = f"""
|
||||
Please generate the next move according to the UI screenshot, instruction and previous actions.
|
||||
|
||||
Instruction: {instruction}
|
||||
|
||||
Previous actions:
|
||||
{previous_actions_str}"""
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{current_img}"}},
|
||||
{"type": "text", "text": instruction_prompt}
|
||||
]
|
||||
})
|
||||
else:
|
||||
# Subsequent turns logic (context already in first history message): Image Only
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{current_img}"}}
|
||||
]
|
||||
})
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def _parse_response_s2(
|
||||
self,
|
||||
response: str,
|
||||
processed_width: int = None,
|
||||
processed_height: int = None,
|
||||
original_width: Optional[int] = None,
|
||||
original_height: Optional[int] = None,
|
||||
) -> Tuple[str, List[str]]:
|
||||
"""
|
||||
Parse LLM response and convert it to low level action and pyautogui code.
|
||||
"""
|
||||
# Prefer the real screenshot resolution (passed from predict), fallback to configured screen_size.
|
||||
if not (original_width and original_height):
|
||||
original_width, original_height = self.screen_size
|
||||
low_level_instruction = ""
|
||||
pyautogui_code: List[str] = []
|
||||
|
||||
if response is None or not response.strip():
|
||||
return low_level_instruction, pyautogui_code
|
||||
|
||||
def adjust_coordinates(x: float, y: float) -> Tuple[int, int]:
|
||||
if not (original_width and original_height):
|
||||
return int(x), int(y)
|
||||
if self.coordinate_type == "absolute":
|
||||
# scale from processed pixels to original
|
||||
if processed_width and processed_height:
|
||||
x_scale = original_width / processed_width
|
||||
y_scale = original_height / processed_height
|
||||
return int(x * x_scale), int(y * y_scale)
|
||||
return int(x), int(y)
|
||||
# relative: scale from 0..999 grid
|
||||
x_scale = original_width / 999
|
||||
y_scale = original_height / 999
|
||||
return int(x * x_scale), int(y * y_scale)
|
||||
|
||||
def process_tool_call(json_str: str) -> None:
|
||||
try:
|
||||
tool_call = json.loads(json_str)
|
||||
if tool_call.get("name") == "computer_use":
|
||||
args = tool_call["arguments"]
|
||||
action = args["action"]
|
||||
|
||||
def _clean_keys(raw_keys):
|
||||
keys = raw_keys if isinstance(raw_keys, list) else [raw_keys]
|
||||
cleaned_keys = []
|
||||
for key in keys:
|
||||
if isinstance(key, str):
|
||||
if key.startswith("keys=["):
|
||||
key = key[6:]
|
||||
if key.endswith("]"):
|
||||
key = key[:-1]
|
||||
if key.startswith("['") or key.startswith('["'):
|
||||
key = key[2:] if len(key) > 2 else key
|
||||
if key.endswith("']") or key.endswith('"]'):
|
||||
key = key[:-2] if len(key) > 2 else key
|
||||
key = key.strip()
|
||||
cleaned_keys.append(key)
|
||||
else:
|
||||
cleaned_keys.append(key)
|
||||
return cleaned_keys
|
||||
|
||||
if action == "left_click" or action == "click":
|
||||
if "coordinate" in args:
|
||||
x, y = args["coordinate"]
|
||||
adj_x, adj_y = adjust_coordinates(x, y)
|
||||
pyautogui_code.append(f"pyautogui.click({adj_x}, {adj_y})")
|
||||
else:
|
||||
pyautogui_code.append("pyautogui.click()")
|
||||
|
||||
elif action == "right_click":
|
||||
if "coordinate" in args:
|
||||
x, y = args["coordinate"]
|
||||
adj_x, adj_y = adjust_coordinates(x, y)
|
||||
pyautogui_code.append(
|
||||
f"pyautogui.rightClick({adj_x}, {adj_y})"
|
||||
)
|
||||
else:
|
||||
pyautogui_code.append("pyautogui.rightClick()")
|
||||
|
||||
elif action == "middle_click":
|
||||
if "coordinate" in args:
|
||||
x, y = args["coordinate"]
|
||||
adj_x, adj_y = adjust_coordinates(x, y)
|
||||
pyautogui_code.append(
|
||||
f"pyautogui.middleClick({adj_x}, {adj_y})"
|
||||
)
|
||||
else:
|
||||
pyautogui_code.append("pyautogui.middleClick()")
|
||||
|
||||
elif action == "double_click":
|
||||
if "coordinate" in args:
|
||||
x, y = args["coordinate"]
|
||||
adj_x, adj_y = adjust_coordinates(x, y)
|
||||
pyautogui_code.append(
|
||||
f"pyautogui.doubleClick({adj_x}, {adj_y})"
|
||||
)
|
||||
else:
|
||||
pyautogui_code.append("pyautogui.doubleClick()")
|
||||
|
||||
elif action == "triple_click":
|
||||
if "coordinate" in args:
|
||||
x, y = args["coordinate"]
|
||||
adj_x, adj_y = adjust_coordinates(x, y)
|
||||
pyautogui_code.append(
|
||||
f"pyautogui.tripleClick({adj_x}, {adj_y})"
|
||||
)
|
||||
else:
|
||||
pyautogui_code.append("pyautogui.tripleClick()")
|
||||
|
||||
elif action == "type":
|
||||
text = args.get("text", "")
|
||||
|
||||
try:
|
||||
text = text.encode('latin-1', 'backslashreplace').decode('unicode_escape')
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to unescape text: {e}")
|
||||
|
||||
logger.info(f"Pyautogui code[before rewrite]: {text}")
|
||||
|
||||
result = ""
|
||||
for char in text:
|
||||
if char == '\n':
|
||||
result += "pyautogui.press('enter')\n"
|
||||
elif char == "'":
|
||||
result += 'pyautogui.press("\'")\n'
|
||||
elif char == '\\':
|
||||
result += "pyautogui.press('\\\\')\n"
|
||||
elif char == '"':
|
||||
result += "pyautogui.press('\"')\n"
|
||||
else:
|
||||
result += f"pyautogui.press('{char}')\n"
|
||||
|
||||
pyautogui_code.append(result)
|
||||
logger.info(f"Pyautogui code[after rewrite]: {pyautogui_code}")
|
||||
|
||||
|
||||
elif action == "key":
|
||||
keys = _clean_keys(args.get("keys", []))
|
||||
|
||||
keys_str = ", ".join([f"'{key}'" for key in keys])
|
||||
if len(keys) > 1:
|
||||
pyautogui_code.append(f"pyautogui.hotkey({keys_str})")
|
||||
else:
|
||||
pyautogui_code.append(f"pyautogui.press({keys_str})")
|
||||
|
||||
elif action == "key_down":
|
||||
keys = _clean_keys(args.get("keys", []))
|
||||
for k in keys:
|
||||
pyautogui_code.append(f"pyautogui.keyDown('{k}')")
|
||||
|
||||
elif action == "key_up":
|
||||
keys = _clean_keys(args.get("keys", []))
|
||||
for k in reversed(keys):
|
||||
pyautogui_code.append(f"pyautogui.keyUp('{k}')")
|
||||
|
||||
elif action == "scroll":
|
||||
pixels = args.get("pixels", 0)
|
||||
pyautogui_code.append(f"pyautogui.scroll({pixels})")
|
||||
|
||||
elif action == "wait":
|
||||
pyautogui_code.append("WAIT")
|
||||
|
||||
elif action == "terminate":
|
||||
# Termination should respect status:
|
||||
# - success -> DONE
|
||||
# - failure -> FAIL
|
||||
# Backward compatible: missing status defaults to success.
|
||||
status = args.get("status", "success")
|
||||
if str(status).lower() == "failure":
|
||||
pyautogui_code.append("FAIL")
|
||||
else:
|
||||
pyautogui_code.append("DONE")
|
||||
|
||||
elif action == "mouse_move":
|
||||
if "coordinate" in args:
|
||||
x, y = args["coordinate"]
|
||||
adj_x, adj_y = adjust_coordinates(x, y)
|
||||
pyautogui_code.append(
|
||||
f"pyautogui.moveTo({adj_x}, {adj_y})"
|
||||
)
|
||||
else:
|
||||
pyautogui_code.append("pyautogui.moveTo(0, 0)")
|
||||
|
||||
elif action == "left_click_drag":
|
||||
if "coordinate" in args:
|
||||
x, y = args["coordinate"]
|
||||
adj_x, adj_y = adjust_coordinates(x, y)
|
||||
duration = args.get("duration", 0.5)
|
||||
pyautogui_code.append(
|
||||
f"pyautogui.dragTo({adj_x}, {adj_y}, duration={duration})"
|
||||
)
|
||||
else:
|
||||
pyautogui_code.append("pyautogui.dragTo(0, 0)")
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
logger.error(f"Failed to parse tool call: {e}")
|
||||
|
||||
lines = response.split("\n")
|
||||
inside_tool_call = False
|
||||
current_tool_call: List[str] = []
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
if line.lower().startswith(("action:")):
|
||||
if not low_level_instruction:
|
||||
low_level_instruction = line.split("Action:")[-1].strip()
|
||||
continue
|
||||
|
||||
if line.startswith("<tool_call>"):
|
||||
inside_tool_call = True
|
||||
continue
|
||||
elif line.startswith("</tool_call>"):
|
||||
if current_tool_call:
|
||||
process_tool_call("\n".join(current_tool_call))
|
||||
current_tool_call = []
|
||||
inside_tool_call = False
|
||||
continue
|
||||
|
||||
if inside_tool_call:
|
||||
current_tool_call.append(line)
|
||||
continue
|
||||
|
||||
if line.startswith("{") and line.endswith("}"):
|
||||
try:
|
||||
json_obj = json.loads(line)
|
||||
if "name" in json_obj and "arguments" in json_obj:
|
||||
process_tool_call(line)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
if current_tool_call:
|
||||
process_tool_call("\n".join(current_tool_call))
|
||||
|
||||
if not low_level_instruction and len(pyautogui_code) > 0:
|
||||
first_action = pyautogui_code[0]
|
||||
if "." in first_action:
|
||||
action_type = first_action.split(".", 1)[1].split("(", 1)[0]
|
||||
else:
|
||||
action_type = first_action.lower()
|
||||
low_level_instruction = f"Performing {action_type} action"
|
||||
|
||||
return low_level_instruction, pyautogui_code
|
||||
|
||||
|
||||
|
||||
def _predict_s1(self, instruction, obs, processed_b64):
|
||||
messages = [{"role": "system", "content": S1_SYSTEM_PROMPT.format(password=self.password)}]
|
||||
|
||||
# Reconstruct History Logic for S1 mode
|
||||
history_step_texts = []
|
||||
|
||||
for i in range(len(self.actions)):
|
||||
cot = self.cots[i] if i < len(self.cots) else {}
|
||||
|
||||
# Step Content string
|
||||
step_content = S1_STEP_TEMPLATE.format(step_num=i+1) + S1_ACTION_HISTORY_TEMPLATE.format(action=cot.get('action', ''))
|
||||
|
||||
if i > len(self.actions) - self.max_history_turns:
|
||||
# Recent history: Add User(Image) and Assistant(Text)
|
||||
if i < len(self.screenshots) - 1: # Screenshot exists for this step
|
||||
img = self.screenshots[i]
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}}
|
||||
]
|
||||
})
|
||||
messages.append({"role": "assistant", "content": step_content})
|
||||
else:
|
||||
# Old history: Collect text
|
||||
history_step_texts.append(step_content)
|
||||
# If this is the last step before the recent window, flush collected texts
|
||||
if i == len(self.actions) - self.max_history_turns:
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": "\n".join(history_step_texts)
|
||||
})
|
||||
|
||||
# Current
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{processed_b64}"}},
|
||||
{"type": "text", "text": S1_INSTRUTION_TEMPLATE.format(instruction=instruction)}
|
||||
]
|
||||
})
|
||||
|
||||
response = self.call_llm({
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": self.max_tokens
|
||||
})
|
||||
|
||||
low_level, codes, cot_data = self._parse_response_s1(response)
|
||||
|
||||
self.observations.append(obs)
|
||||
self.cots.append(cot_data)
|
||||
self.actions.append(low_level)
|
||||
self.responses.append(response)
|
||||
|
||||
return response, codes
|
||||
|
||||
|
||||
def _parse_response_s1(self, response):
|
||||
sections = {}
|
||||
# Simple Regex Parsing
|
||||
for key, pattern in [
|
||||
('observation', r'#{1,2}\s*Observation\s*:?[\n\r]+(.*?)(?=^#{1,2}\s|$)'),
|
||||
('thought', r'#{1,2}\s*Thought\s*:?[\n\r]+(.*?)(?=^#{1,2}\s|$)'),
|
||||
('action', r'#{1,2}\s*Action\s*:?[\n\r]+(.*?)(?=^#{1,2}\s|$)')
|
||||
]:
|
||||
m = re.search(pattern, response, re.DOTALL | re.MULTILINE)
|
||||
if m: sections[key] = m.group(1).strip()
|
||||
|
||||
code_blocks = re.findall(r'```(?:code|python)?\s*(.*?)\s*```', response, re.DOTALL | re.IGNORECASE)
|
||||
code = code_blocks[-1].strip() if code_blocks else "FAIL"
|
||||
|
||||
sections['code'] = code
|
||||
|
||||
# Post-process code
|
||||
if "computer.terminate" in code:
|
||||
final_code = ["DONE"] if "success" in code.lower() else ["FAIL"]
|
||||
elif "computer.wait" in code:
|
||||
final_code = ["WAIT"]
|
||||
else:
|
||||
# Project coordinates
|
||||
code = project_coordinate_to_absolute_scale(
|
||||
code,
|
||||
self.screen_size[0],
|
||||
self.screen_size[1],
|
||||
self.coordinate_type,
|
||||
self.resize_factor
|
||||
)
|
||||
logger.info(f"[rewrite before]: {code}")
|
||||
final_code = [rewrite_pyautogui_text_inputs(code)]
|
||||
logger.info(f"[rewrite after]: {final_code}")
|
||||
|
||||
return sections.get('action', 'Acting'), final_code, sections
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _should_giveup_on_context_error(e):
|
||||
"""对于 context length 相关的错误,立即放弃重试,交给外层处理"""
|
||||
error_str = str(e)
|
||||
return "Too Large" in error_str or "context_length_exceeded" in error_str or "413" in error_str
|
||||
|
||||
@backoff.on_exception(backoff.constant, Exception, interval=30, max_tries=10, giveup=_should_giveup_on_context_error.__func__)
|
||||
def call_llm(self, payload):
|
||||
"""Unified OpenAI-compatible API call"""
|
||||
# Get env vars
|
||||
base_url = os.environ.get("OPENAI_BASE_URL", "url-xxx")
|
||||
api_key = os.environ.get("OPENAI_API_KEY", "sk-xxx")
|
||||
|
||||
client = openai.OpenAI(base_url=base_url, api_key=api_key)
|
||||
|
||||
messages = payload["messages"]
|
||||
log_messages(messages, "LLM Request")
|
||||
|
||||
params = {
|
||||
"model": payload["model"],
|
||||
"messages": messages,
|
||||
"max_tokens": payload["max_tokens"],
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p
|
||||
}
|
||||
|
||||
try:
|
||||
resp = client.chat.completions.create(**params)
|
||||
content = resp.choices[0].message.content
|
||||
logger.info(f"LLM Response:\n{content}")
|
||||
return content
|
||||
except Exception as e:
|
||||
logger.error(f"LLM Call failed: {e}")
|
||||
raise e
|
||||
|
|
@ -0,0 +1,148 @@
|
|||
S1_SYSTEM_PROMPT = """You are a GUI agent. You are given a task, a screenshot of the screen and your previous interactions with the computer. You need to perform a series of actions to complete the task. The password of the computer is "{password}", use it when you need sudo rights. You need to **wait** explicitly for installation, waiting website loading or running commands to finish. Don't terminate the task unless you are sure the task is finished. If you find that you can't finish the task, or the task is not finished exactly as the instruction indicates (you have made progress but not finished the task completely), or the task is impossible to complete, you must report **failure**.
|
||||
|
||||
For each step, provide your response in this format:
|
||||
# Step: {{step number}}
|
||||
## Thought:
|
||||
{{thought}}
|
||||
## Action:
|
||||
{{action}}
|
||||
## Code:
|
||||
{{code}}
|
||||
|
||||
For the Thought section, you should include the following parts:
|
||||
- Reflection on the task when there is previous action:
|
||||
- Consider the correnctness of previous action and its outcomes
|
||||
- If the previous action was correct, describe the change in the state of the computer and reason
|
||||
- If the previous action was incorrect, reflect on what went wrong and why
|
||||
- Step by Step Progress Assessment:
|
||||
- Add necessary information according to the history screenshots, former actions and current screenshot.
|
||||
- Analyze what parts of the task have already been completed and how they contribute to the overall goal.
|
||||
- Make a plan on how to complete the task based on the history and currect screenshot.
|
||||
- Next Action Prediction:
|
||||
- Propose the most possible next action and state the reason
|
||||
- For Text Input Actions:
|
||||
- Note current cursor position
|
||||
- Consolidate repetitive actions (specify count for multiple keypresses)
|
||||
- Describe expected final text outcome
|
||||
- Use first-person perspective in reasoning
|
||||
|
||||
For the action section, you should provide clear, concise, and actionable instructions in one sentence.
|
||||
- If the action involves interacting with a specific target:
|
||||
- Describe target explicitly (if multiple elements share that name, you should distinguish the target) without using coordinates
|
||||
- Specify element names when possible (use original language if non-English)
|
||||
- Describe features (shape, color, position) if name unavailable
|
||||
- If the action involves keyboard actions like 'press', 'write', 'hotkey':
|
||||
- Consolidate repetitive keypresses with count
|
||||
- Specify expected text outcome for typing actions
|
||||
|
||||
For the code section, you should output the corresponding code for the action. The code should be either PyAutoGUI code or one of the following functions warped in the code block:
|
||||
- {{"name": "computer.wait", "description": "Make the computer wait for 20 seconds for installation, running code, etc.", "parameters": {{"type": "object", "properties": {{}}, "required": []}}}}
|
||||
- {{"name": "computer.terminate", "description": "Terminate the current task and report its completion status", "parameters": {{"type": "object", "properties": {{"status": {{"type": "string", "enum": ["success", "failure"], "description": "The status of the task"}}, {{"answer": {{"type": "string", "description": "The answer of the task"}}}}, "required": ["status"]}}}}
|
||||
Examples for the code section:
|
||||
```python
|
||||
pyautogui.click(x=123, y=456)
|
||||
```
|
||||
```code
|
||||
computer.terminate(status="success")
|
||||
```
|
||||
```code
|
||||
computer.terminate(status="success", answer='''text''')
|
||||
```"""
|
||||
|
||||
|
||||
# S1 prompt templates for generating trajectories
|
||||
S1_STEP_TEMPLATE = "# Step {step_num}:\n"
|
||||
S1_INSTRUTION_TEMPLATE = "# Task Instruction:\n{instruction}\n\nPlease generate the next move according to the screenshot, task instruction and previous steps (if provided).\n"
|
||||
|
||||
S1_ACTION_HISTORY_TEMPLATE = "## Action:\n{action}\n"
|
||||
|
||||
|
||||
# S2 Prompts
|
||||
S2_ACTION_DESCRIPTION = """
|
||||
* `key`: Performs key down presses on the arguments passed in order, then performs key releases in reverse order.
|
||||
* `key_down`: Press and HOLD the specified key(s) down in order (no release). Use this for stateful holds like holding Shift while clicking.
|
||||
* `key_up`: Release the specified key(s) in reverse order.
|
||||
* `type`: Type a string of text on the keyboard.
|
||||
* `mouse_move`: Move the cursor to a specified (x, y) pixel coordinate on the screen.
|
||||
* `left_click`: Click the left mouse button at a specified (x, y) pixel coordinate on the screen.
|
||||
* `left_click_drag`: Click and drag the cursor to a specified (x, y) pixel coordinate on the screen.
|
||||
* `right_click`: Click the right mouse button at a specified (x, y) pixel coordinate on the screen.
|
||||
* `middle_click`: Click the middle mouse button at a specified (x, y) pixel coordinate on the screen.
|
||||
* `double_click`: Double-click the left mouse button at a specified (x, y) pixel coordinate on the screen.
|
||||
* `triple_click`: Triple-click the left mouse button at a specified (x, y) pixel coordinate on the screen.
|
||||
* `scroll`: Performs a scroll of the mouse scroll wheel.
|
||||
* `hscroll`: Performs a horizontal scroll (mapped to regular scroll).
|
||||
* `wait`: Wait specified seconds for the change to happen.
|
||||
* `terminate`: Terminate the current task and report its completion status.
|
||||
* `answer`: Answer a question.
|
||||
"""
|
||||
|
||||
S2_DESCRIPTION_PROMPT_TEMPLATE = """Use a mouse and keyboard to interact with a computer, and take screenshots.
|
||||
* This is an interface to a desktop GUI. You must click on desktop icons to start applications.
|
||||
* Some applications may take time to start or process actions, so you may need to wait and take successive screenshots to see the results of your actions. E.g. if you click on Firefox and a window doesn't open, try wait and taking another screenshot.
|
||||
{resolution_info}
|
||||
* Whenever you intend to move the cursor to click on an element like an icon, you should consult a screenshot to determine the coordinates of the element before moving the cursor.
|
||||
* If you tried clicking on a program or link but it failed to load even after waiting, try adjusting your cursor position so that the tip of the cursor visually falls on the element that you want to click.
|
||||
* Make sure to click any buttons, links, icons, etc with the cursor tip in the center of the element. Don't click boxes on their edges unless asked."""
|
||||
|
||||
S2_SYSTEM_PROMPT = """# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{tools_xml}
|
||||
</tools>
|
||||
|
||||
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
||||
<tool_call>
|
||||
{{"name": <function-name>, "arguments": <args-json-object>}}
|
||||
</tool_call>
|
||||
|
||||
# Response format
|
||||
|
||||
Response format for every step:
|
||||
1) Action: a short imperative describing what to do in the UI.
|
||||
2) A single <tool_call>...</tool_call> block containing only the JSON: {{"name": <function-name>, "arguments": <args-json-object>}}.
|
||||
|
||||
Rules:
|
||||
- Output exactly in the order: Action, <tool_call>.
|
||||
- Be brief: one sentence for Action.
|
||||
- Do not output anything else outside those parts.
|
||||
- If finishing, use action=terminate in the tool call."""
|
||||
|
||||
|
||||
def build_s2_tools_def(description_prompt):
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name_for_human": "computer_use",
|
||||
"name": "computer_use",
|
||||
"description": description_prompt,
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"action": {
|
||||
"description": S2_ACTION_DESCRIPTION,
|
||||
"enum": ["key", "type", "mouse_move", "left_click", "left_click_drag",
|
||||
"right_click", "middle_click", "double_click", "triple_click", "scroll",
|
||||
"wait", "terminate", "key_down", "key_up"],
|
||||
"type": "string"
|
||||
},
|
||||
"keys": {"description": "Required only by `action=key`.", "type": "array"},
|
||||
"text": {"description": "Required only by `action=type`.", "type": "string"},
|
||||
"coordinate": {"description": "The x,y coordinates for mouse actions.", "type": "array"},
|
||||
"pixels": {"description": "The amount of scrolling.", "type": "number"},
|
||||
"time": {"description": "The seconds to wait.", "type": "number"},
|
||||
"status": {
|
||||
"description": "The status of the task.",
|
||||
"type": "string",
|
||||
"enum": ["success", "failure"]
|
||||
}
|
||||
},
|
||||
"required": ["action"],
|
||||
"type": "object"
|
||||
},
|
||||
"args_format": "Format the arguments as a JSON object."
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,302 @@
|
|||
import base64
|
||||
import re
|
||||
import ast
|
||||
import logging
|
||||
from io import BytesIO
|
||||
import json
|
||||
from PIL import Image
|
||||
|
||||
from mm_agents.utils.qwen_vl_utils import smart_resize
|
||||
|
||||
logger = logging.getLogger("desktopenv.evocua.utils")
|
||||
|
||||
def encode_image(image_content):
|
||||
"""Encode image bytes to base64 string."""
|
||||
return base64.b64encode(image_content).decode("utf-8")
|
||||
|
||||
|
||||
def process_image(image_bytes, factor=32):
|
||||
"""
|
||||
Process an image for VL models.
|
||||
factor: 32 for S2 mode, 28 for S1 mode default
|
||||
"""
|
||||
image = Image.open(BytesIO(image_bytes))
|
||||
width, height = image.size
|
||||
|
||||
resized_height, resized_width = smart_resize(
|
||||
height=height,
|
||||
width=width,
|
||||
factor=factor,
|
||||
max_pixels=16 * 16 * 4 * 12800, # Large buffer
|
||||
)
|
||||
|
||||
image = image.resize((resized_width, resized_height))
|
||||
|
||||
buffer = BytesIO()
|
||||
image.save(buffer, format="PNG")
|
||||
processed_bytes = buffer.getvalue()
|
||||
|
||||
return base64.b64encode(processed_bytes).decode("utf-8"), resized_width, resized_height
|
||||
|
||||
def _fallback_rewrite_pyautogui_text_inputs(code: str) -> str:
|
||||
"""
|
||||
Regex-based fallback to handle malformed pyautogui.write/typewrite calls.
|
||||
"""
|
||||
logger.info(f"SyntaxError detected in code, using regex fallback. Original code: {code}")
|
||||
|
||||
def _replacer(match):
|
||||
call_content = match.group(0)
|
||||
m = re.search(r'pyautogui\.(?:write|typewrite)\s*\(', call_content)
|
||||
if not m:
|
||||
return call_content
|
||||
|
||||
args_part = call_content[m.end():].strip()
|
||||
args_part = re.sub(r'^(?:message|text)\s*=\s*', '', args_part)
|
||||
|
||||
text_content = ""
|
||||
if args_part.startswith(("'''", '"""')):
|
||||
quote_type = args_part[:3]
|
||||
content = args_part[3:]
|
||||
end_idx = content.rfind(quote_type)
|
||||
if end_idx != -1:
|
||||
text_content = content[:end_idx]
|
||||
else:
|
||||
text_content = content[:-1] if content.endswith(')') else content
|
||||
elif args_part.startswith(("'", '"')):
|
||||
quote_type = args_part[0]
|
||||
content = args_part[1:]
|
||||
if content.endswith(quote_type + ")"):
|
||||
text_content = content[:-2]
|
||||
elif content.endswith(")"):
|
||||
if len(content) > 1 and content[-2] == quote_type:
|
||||
text_content = content[:-2]
|
||||
else:
|
||||
text_content = content[:-1]
|
||||
elif content.endswith(quote_type):
|
||||
text_content = content[:-1]
|
||||
else:
|
||||
text_content = content
|
||||
else:
|
||||
text_content = args_part[:-1] if args_part.endswith(')') else args_part
|
||||
|
||||
new_cmds = []
|
||||
for char in text_content:
|
||||
p = "enter" if char == "\n" else char
|
||||
p_esc = p.replace("'", "\\'")
|
||||
new_cmds.append(f"pyautogui.press('{p_esc}')")
|
||||
|
||||
return "; ".join(new_cmds)
|
||||
|
||||
pattern = r"pyautogui\.(?:write|typewrite)\s*\(.*?(?=\s*;|\s*$|\n)"
|
||||
new_code = re.sub(pattern, _replacer, code)
|
||||
|
||||
if new_code == code and ("pyautogui.write" in code or "pyautogui.typewrite" in code):
|
||||
new_code = re.sub(r"pyautogui\.(?:write|typewrite)\s*\(.*", _replacer, code)
|
||||
|
||||
return new_code
|
||||
|
||||
def rewrite_pyautogui_text_inputs(code: str) -> str:
|
||||
"""
|
||||
Expand pyautogui.write/typewrite string literals into per-character presses.
|
||||
"""
|
||||
try:
|
||||
tree = ast.parse(code)
|
||||
|
||||
class _TextCallRewriter(ast.NodeTransformer):
|
||||
def _extract_text(self, call: ast.Call):
|
||||
if not (
|
||||
isinstance(call.func, ast.Attribute)
|
||||
and isinstance(call.func.value, ast.Name)
|
||||
and call.func.value.id == "pyautogui"
|
||||
and call.func.attr in ("write", "typewrite")
|
||||
):
|
||||
return None
|
||||
|
||||
message_node = call.args[0] if call.args else None
|
||||
if message_node is None:
|
||||
for kw in call.keywords:
|
||||
if kw.arg in ("message", "text"):
|
||||
message_node = kw.value
|
||||
break
|
||||
|
||||
if isinstance(message_node, ast.Constant) and isinstance(message_node.value, str):
|
||||
return message_node.value
|
||||
return None
|
||||
|
||||
def visit_Expr(self, node):
|
||||
self.generic_visit(node)
|
||||
if isinstance(node.value, ast.Call):
|
||||
text = self._extract_text(node.value)
|
||||
if text is not None:
|
||||
new_nodes = []
|
||||
for char in text:
|
||||
press_value = "enter" if char == "\n" else char
|
||||
press_call = ast.Expr(
|
||||
value=ast.Call(
|
||||
func=ast.Attribute(
|
||||
value=ast.Name(id="pyautogui", ctx=ast.Load()),
|
||||
attr="press",
|
||||
ctx=ast.Load(),
|
||||
),
|
||||
args=[ast.Constant(value=press_value)],
|
||||
keywords=[],
|
||||
)
|
||||
)
|
||||
new_nodes.append(press_call)
|
||||
return new_nodes if new_nodes else node
|
||||
return node
|
||||
|
||||
tree = _TextCallRewriter().visit(tree)
|
||||
tree = ast.fix_missing_locations(tree)
|
||||
new_code = ast.unparse(tree)
|
||||
return new_code
|
||||
|
||||
except (SyntaxError, Exception):
|
||||
return _fallback_rewrite_pyautogui_text_inputs(code)
|
||||
|
||||
|
||||
|
||||
def project_coordinate_to_absolute_scale(pyautogui_code_relative_coordinates, screen_width, screen_height, coordinate_type="relative", resize_factor=28):
|
||||
"""
|
||||
Convert the relative coordinates in the pyautogui code to absolute coordinates based on the logical screen size.
|
||||
"""
|
||||
def _coordinate_projection(x, y, screen_width, screen_height, coordinate_type):
|
||||
if coordinate_type == "qwen25":
|
||||
height, width = smart_resize(
|
||||
height=screen_height,
|
||||
width=screen_width,
|
||||
factor=resize_factor,
|
||||
min_pixels=3136,
|
||||
max_pixels=12845056
|
||||
)
|
||||
if 0 <= x <= 1 and 0 <= y <= 1:
|
||||
# If already normalized, treat like "relative"
|
||||
return int(round(x * width)), int(round(y * height))
|
||||
return int(x / width * screen_width), int(y / height * screen_height)
|
||||
else:
|
||||
raise ValueError(f"Invalid coordinate type: {coordinate_type}. Expected 'qwen25'")
|
||||
|
||||
pattern = r'(pyautogui\.\w+\([^\)]*\))'
|
||||
matches = re.findall(pattern, pyautogui_code_relative_coordinates)
|
||||
|
||||
new_code = pyautogui_code_relative_coordinates
|
||||
|
||||
for full_call in matches:
|
||||
func_name_pattern = r'(pyautogui\.\w+)\((.*)\)'
|
||||
func_match = re.match(func_name_pattern, full_call, re.DOTALL)
|
||||
if not func_match:
|
||||
continue
|
||||
|
||||
func_name = func_match.group(1)
|
||||
args_str = func_match.group(2)
|
||||
|
||||
try:
|
||||
parsed = ast.parse(f"func({args_str})").body[0].value
|
||||
parsed_args = parsed.args
|
||||
parsed_keywords = parsed.keywords
|
||||
|
||||
except SyntaxError:
|
||||
continue
|
||||
|
||||
function_parameters = {
|
||||
'click': ['x', 'y', 'clicks', 'interval', 'button', 'duration', 'pause'],
|
||||
'rightClick': ['x', 'y', 'duration', 'tween', 'pause'],
|
||||
'middleClick': ['x', 'y', 'duration', 'tween', 'pause'],
|
||||
'doubleClick': ['x', 'y', 'interval', 'button', 'duration', 'pause'],
|
||||
'tripleClick': ['x', 'y', 'interval', 'button', 'duration', 'pause'],
|
||||
'moveTo': ['x', 'y', 'duration', 'tween', 'pause'],
|
||||
'dragTo': ['x', 'y', 'duration', 'button', 'mouseDownUp', 'pause'],
|
||||
}
|
||||
|
||||
func_base_name = func_name.split('.')[-1]
|
||||
|
||||
param_names = function_parameters.get(func_base_name, [])
|
||||
|
||||
args = {}
|
||||
for idx, arg in enumerate(parsed_args):
|
||||
if idx < len(param_names):
|
||||
param_name = param_names[idx]
|
||||
try:
|
||||
arg_value = ast.literal_eval(arg)
|
||||
args[param_name] = arg_value
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
for kw in parsed_keywords:
|
||||
param_name = kw.arg
|
||||
arg_value = ast.literal_eval(kw.value)
|
||||
args[param_name] = arg_value
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing keyword arguments: {e}")
|
||||
continue
|
||||
|
||||
updated = False
|
||||
if 'x' in args and 'y' in args:
|
||||
try:
|
||||
x_rel = float(args['x'])
|
||||
y_rel = float(args['y'])
|
||||
# Only project if they look like relative coords (e.g. <= 1.0 or depending on type)
|
||||
# Projection applies unconditionally if type is relative
|
||||
x_abs, y_abs = _coordinate_projection(x_rel, y_rel, screen_width, screen_height, coordinate_type)
|
||||
|
||||
# Apply coordinate transformation
|
||||
args['x'] = x_abs
|
||||
args['y'] = y_abs
|
||||
updated = True
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
if updated:
|
||||
reconstructed_args = []
|
||||
for idx, param_name in enumerate(param_names):
|
||||
if param_name in args:
|
||||
arg_value = args[param_name]
|
||||
if isinstance(arg_value, str):
|
||||
arg_repr = f"'{arg_value}'"
|
||||
else:
|
||||
arg_repr = str(arg_value)
|
||||
reconstructed_args.append(arg_repr)
|
||||
else:
|
||||
break
|
||||
|
||||
used_params = set(param_names[:len(reconstructed_args)])
|
||||
for kw in parsed_keywords:
|
||||
if kw.arg not in used_params:
|
||||
arg_value = args[kw.arg]
|
||||
if isinstance(arg_value, str):
|
||||
arg_repr = f"{kw.arg}='{arg_value}'"
|
||||
else:
|
||||
arg_repr = f"{kw.arg}={arg_value}"
|
||||
reconstructed_args.append(arg_repr)
|
||||
|
||||
new_args_str = ', '.join(reconstructed_args)
|
||||
new_full_call = f"{func_name}({new_args_str})"
|
||||
new_code = new_code.replace(full_call, new_full_call)
|
||||
|
||||
return new_code
|
||||
|
||||
|
||||
def log_messages(messages, prefix="LLM Messages"):
|
||||
"""Log messages with truncated base64 images"""
|
||||
try:
|
||||
log_msgs = []
|
||||
for msg in messages:
|
||||
msg_copy = msg.copy()
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list):
|
||||
new_content = []
|
||||
for item in content:
|
||||
if isinstance(item, dict) and item.get("type") == "image_url":
|
||||
item_copy = item.copy()
|
||||
url = item_copy.get("image_url", {}).get("url", "")
|
||||
if len(url) > 100:
|
||||
item_copy["image_url"] = {"url": url[:30] + "...[base64_truncated]..." + url[-10:]}
|
||||
new_content.append(item_copy)
|
||||
else:
|
||||
new_content.append(item)
|
||||
msg_copy["content"] = new_content
|
||||
log_msgs.append(msg_copy)
|
||||
logger.info(f"{prefix}:\n{json.dumps(log_msgs, indent=2, ensure_ascii=False)}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to log messages: {e}")
|
||||
|
|
@ -0,0 +1,190 @@
|
|||
"""
|
||||
Hosted GBOX Agent Client
|
||||
Thin HTTP wrapper that calls the hosted GBOX service
|
||||
"""
|
||||
import os
|
||||
import logging
|
||||
import requests
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
logger = logging.getLogger("hosted-gbox-agent")
|
||||
|
||||
|
||||
class HostedGboxAgent:
|
||||
"""
|
||||
Client wrapper for hosted GBOX service.
|
||||
Follows the same interface as other OSWorld agents but delegates execution to remote service.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
api_key: str,
|
||||
vm_ip: str,
|
||||
platform: str = "ubuntu",
|
||||
model: str = "claude-sonnet-4-5",
|
||||
max_steps: int = 15,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Initialize hosted agent client
|
||||
|
||||
Args:
|
||||
server_url: URL of hosted GBOX service (e.g., "http://44.201.221.203:8000")
|
||||
api_key: API key for authentication
|
||||
vm_ip: IP address of the VM to control
|
||||
platform: OS platform (ubuntu/windows)
|
||||
model: Claude model to use
|
||||
max_steps: Maximum steps per task
|
||||
"""
|
||||
self.server_url = server_url.rstrip('/')
|
||||
self.api_key = api_key
|
||||
self.vm_ip = vm_ip
|
||||
self.platform = platform
|
||||
self.model = model
|
||||
self.max_steps = max_steps
|
||||
self.runtime_logger = None
|
||||
|
||||
# HTTP client with timeout
|
||||
self.client = requests.Session()
|
||||
self.client.headers.update({"X-API-Key": api_key})
|
||||
|
||||
logger.info(f"Initialized hosted agent client for VM {vm_ip}")
|
||||
logger.info(f"Server: {server_url}, Model: {model}")
|
||||
|
||||
def reset(self, runtime_logger=None, vm_ip: str = None):
|
||||
"""
|
||||
Reset agent state (called by OSWorld before each task)
|
||||
|
||||
Args:
|
||||
runtime_logger: Logger instance for OSWorld runtime logs
|
||||
vm_ip: Updated VM IP (in case of snapshot revert)
|
||||
"""
|
||||
self.runtime_logger = runtime_logger
|
||||
|
||||
if vm_ip:
|
||||
self.vm_ip = vm_ip
|
||||
if self.runtime_logger:
|
||||
self.runtime_logger.info(f"[HOSTED] Updated VM IP to {vm_ip}")
|
||||
|
||||
if self.runtime_logger:
|
||||
self.runtime_logger.info(f"[HOSTED] Agent reset for VM {self.vm_ip}")
|
||||
|
||||
def predict(self, instruction: str, obs: Dict) -> Tuple[str, List[str]]:
|
||||
"""
|
||||
Execute task prediction (one call = full task execution)
|
||||
|
||||
Args:
|
||||
instruction: Task instruction
|
||||
obs: Observation dict (not used - agent fetches its own screenshots)
|
||||
|
||||
Returns:
|
||||
(reasoning_text, actions_list)
|
||||
- reasoning_text: Claude's reasoning/explanation
|
||||
- actions_list: ["DONE"] or ["FAIL"] or PyAutoGUI code
|
||||
"""
|
||||
try:
|
||||
# Prepare request (no screenshot needed - agent fetches its own)
|
||||
payload = {
|
||||
"vm_ip": self.vm_ip,
|
||||
"instruction": instruction,
|
||||
"platform": self.platform,
|
||||
"model": self.model,
|
||||
"max_steps": self.max_steps
|
||||
}
|
||||
|
||||
# Log request
|
||||
if self.runtime_logger:
|
||||
self.runtime_logger.info(f"[HOSTED] Sending task to service...")
|
||||
self.runtime_logger.info(f"[HOSTED] Instruction: {instruction[:100]}...")
|
||||
|
||||
# Call hosted service (this may take several minutes)
|
||||
response = self.client.post(
|
||||
f"{self.server_url}/execute",
|
||||
json=payload,
|
||||
timeout=3600 # 60 minutes timeout for full task execution
|
||||
)
|
||||
|
||||
# Check for errors
|
||||
if response.status_code == 401:
|
||||
raise RuntimeError("Authentication failed - invalid API key")
|
||||
elif response.status_code != 200:
|
||||
raise RuntimeError(f"Service returned {response.status_code}: {response.text}")
|
||||
|
||||
# Parse response
|
||||
result = response.json()
|
||||
reasoning = result.get("reasoning", "")
|
||||
actions = result.get("actions", ["FAIL"])
|
||||
logs = result.get("logs", "")
|
||||
session_id = result.get("session_id", "unknown")
|
||||
|
||||
# Forward server logs to OSWorld's runtime logger
|
||||
if logs and self.runtime_logger:
|
||||
for line in logs.split('\n'):
|
||||
if line.strip():
|
||||
self.runtime_logger.info(f"[SERVER] {line}")
|
||||
|
||||
# Log results
|
||||
if self.runtime_logger:
|
||||
self.runtime_logger.info(f"[HOSTED] Session ID: {session_id}")
|
||||
self.runtime_logger.info(f"[HOSTED] Actions: {actions}")
|
||||
self.runtime_logger.info(f"[HOSTED] Reasoning: {reasoning[:200]}...")
|
||||
|
||||
return reasoning, actions
|
||||
|
||||
except requests.Timeout:
|
||||
error_msg = "Service timeout (task took longer than 60 minutes)"
|
||||
logger.error(error_msg)
|
||||
if self.runtime_logger:
|
||||
self.runtime_logger.error(f"[HOSTED] {error_msg}")
|
||||
return f"ERROR: {error_msg}", ["FAIL"]
|
||||
|
||||
except requests.ConnectionError as e:
|
||||
error_msg = f"Cannot connect to service at {self.server_url}: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
if self.runtime_logger:
|
||||
self.runtime_logger.error(f"[HOSTED] {error_msg}")
|
||||
return f"ERROR: {error_msg}", ["FAIL"]
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Hosted agent error: {str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
if self.runtime_logger:
|
||||
self.runtime_logger.error(f"[HOSTED] {error_msg}")
|
||||
return f"ERROR: {error_msg}", ["FAIL"]
|
||||
|
||||
def close(self):
|
||||
"""Close HTTP session"""
|
||||
self.client.close()
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup on deletion"""
|
||||
try:
|
||||
self.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
# Factory function for compatibility with OSWorld runner
|
||||
def create_agent(vm_ip: str, **kwargs) -> HostedGboxAgent:
|
||||
"""
|
||||
Factory function to create hosted agent
|
||||
|
||||
Expects environment variables:
|
||||
- GBOX_SERVICE_URL: URL of hosted service
|
||||
- GBOX_SERVICE_API_KEY: API key for authentication
|
||||
"""
|
||||
server_url = os.getenv("GBOX_SERVICE_URL")
|
||||
api_key = os.getenv("GBOX_SERVICE_API_KEY")
|
||||
|
||||
if not server_url:
|
||||
raise ValueError("GBOX_SERVICE_URL environment variable not set")
|
||||
if not api_key:
|
||||
raise ValueError("GBOX_SERVICE_API_KEY environment variable not set")
|
||||
|
||||
return HostedGboxAgent(
|
||||
server_url=server_url,
|
||||
api_key=api_key,
|
||||
vm_ip=vm_ip,
|
||||
**kwargs
|
||||
)
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
from mm_agents.opencua.opencua_agent import OpenCUAAgent
|
||||
|
||||
__all__ = ["OpenCUAAgent"]
|
||||
|
|
@ -0,0 +1,470 @@
|
|||
"""
|
||||
OpenCUA Agent Implementation
|
||||
|
||||
This module implements an OpenCUA agent for desktop automation tasks, building upon
|
||||
existing frameworks and integrating multiple coordinate mapping systems.
|
||||
|
||||
Framework and Implementation Sources:
|
||||
- Main framework structure follows: https://github.com/xlang-ai/OSWorld/blob/main/mm_agents/agent.py
|
||||
- Agent implementation adapted from: https://github.com/xlang-ai/OSWorld/blob/main/mm_agents/aguvis_agent.py
|
||||
- Qwen2.5-VL coordinate mapping from: https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py
|
||||
"""
|
||||
|
||||
import re
|
||||
import os
|
||||
import ast
|
||||
import time
|
||||
import math
|
||||
import httpx
|
||||
import base64
|
||||
import backoff
|
||||
import traceback
|
||||
from loguru import logger
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
from mm_agents.opencua.utils import (
|
||||
encode_image,
|
||||
smart_resize,
|
||||
)
|
||||
from mm_agents.opencua.prompts import (
|
||||
INSTRUTION_TEMPLATE,
|
||||
STEP_TEMPLATE,
|
||||
ACTION_HISTORY_TEMPLATE,
|
||||
THOUGHT_HISTORY_TEMPLATE,
|
||||
OBSERVATION_HISTORY_TEMPLATE,
|
||||
# OpenCUA-7B, 32B system prompts
|
||||
SYSTEM_PROMPT_V1_L1,
|
||||
SYSTEM_PROMPT_V1_L2,
|
||||
SYSTEM_PROMPT_V1_L3,
|
||||
# OpenCUA-72B system prompts
|
||||
build_sys_prompt,
|
||||
)
|
||||
|
||||
def parse_response_to_cot_and_action(input_string, screen_size, coordinate_type) -> Tuple[str, List[str], dict]:
|
||||
"""Parse response including Observation, Thought, Action and code block"""
|
||||
sections = {}
|
||||
try:
|
||||
|
||||
obs_match = re.search(r'^##\s*Observation\s*:?[\n\r]+(.*?)(?=^##\s*Thought:|^##\s*Action:|^##|\Z)', input_string, re.DOTALL | re.MULTILINE)
|
||||
if obs_match:
|
||||
sections['observation'] = obs_match.group(1).strip()
|
||||
|
||||
thought_match = re.search(r'^##\s*Thought\s*:?[\n\r]+(.*?)(?=^##\s*Action:|^##|\Z)', input_string, re.DOTALL | re.MULTILINE)
|
||||
if thought_match:
|
||||
sections['thought'] = thought_match.group(1).strip()
|
||||
|
||||
action_match = re.search(r'^##\s*Action\s*:?[\n\r]+(.*?)(?=^##|\Z)', input_string, re.DOTALL | re.MULTILINE)
|
||||
if action_match:
|
||||
action = action_match.group(1).strip()
|
||||
sections['action'] = action.strip()
|
||||
|
||||
code_blocks = re.findall(r'```(?:code|python)?\s*(.*?)\s*```', input_string, re.DOTALL | re.IGNORECASE)
|
||||
if not code_blocks:
|
||||
logger.error("No code blocks found in the input string")
|
||||
return f"<Error>: no code blocks found in the input string: {input_string}", ["FAIL"], sections
|
||||
code_block = code_blocks[-1].strip()
|
||||
sections['original_code'] = code_block
|
||||
|
||||
if "computer.wait" in code_block.lower():
|
||||
sections["code"] = "WAIT"
|
||||
return sections['action'], ["WAIT"], sections
|
||||
|
||||
elif "computer.terminate" in code_block.lower():
|
||||
lower_block = code_block.lower()
|
||||
if ("failure" in lower_block) or ("fail" in lower_block):
|
||||
sections['code'] = "FAIL"
|
||||
return code_block, ["FAIL"], sections
|
||||
elif "success" in lower_block:
|
||||
sections['code'] = "DONE"
|
||||
return code_block, ["DONE"], sections
|
||||
else:
|
||||
logger.error("Terminate action found but no specific status provided in code block")
|
||||
return f"<Error>: terminate action found but no specific status provided in code block: {input_string}", ["FAIL"], sections
|
||||
|
||||
# corrected_code = correct_pyautogui_arguments(code_block)
|
||||
corrected_code = code_block
|
||||
sections['code'] = corrected_code
|
||||
sections['code'] = project_coordinate_to_absolute_scale(corrected_code, screen_width=screen_size[0], screen_height=screen_size[1], coordinate_type=coordinate_type)
|
||||
|
||||
if ('code' not in sections or sections['code'] is None or sections['code'] == "") or ('action' not in sections or sections['action'] is None or sections['action'] == ""):
|
||||
logger.error("Missing required action or code section")
|
||||
return f"<Error>: no code parsed: {input_string}", ["FAIL"], sections
|
||||
|
||||
return sections['action'], [sections['code']], sections
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"<Error>: parsing response: {str(e)}\nTraceback:\n{traceback.format_exc()}\nInput string: {input_string}"
|
||||
logger.error(error_message)
|
||||
return error_message, ['FAIL'], sections
|
||||
|
||||
def project_coordinate_to_absolute_scale(pyautogui_code_relative_coordinates, screen_width, screen_height, coordinate_type="relative"):
|
||||
"""
|
||||
Convert the relative coordinates in the pyautogui code to absolute coordinates based on the logical screen size.
|
||||
"""
|
||||
def _coordinate_projection(x, y, screen_width, screen_height, coordinate_type):
|
||||
if coordinate_type == "relative":
|
||||
return int(round(x * screen_width)), int(round(y * screen_height))
|
||||
elif coordinate_type == "qwen25":
|
||||
height, width = smart_resize(
|
||||
height=screen_height,
|
||||
width=screen_width,
|
||||
factor=28,
|
||||
min_pixels=3136,
|
||||
max_pixels=12845056
|
||||
)
|
||||
if 0 <= x <= 1 and 0 <= y <= 1:
|
||||
# If already normalized, treat like "relative"
|
||||
return int(round(x * width)), int(round(y * height))
|
||||
return int(x / width * screen_width), int(y / height * screen_height)
|
||||
else:
|
||||
raise ValueError(f"Invalid coordinate type: {coordinate_type}. Expected one of ['relative', 'relative1000', 'absolute', 'qwen25'].")
|
||||
|
||||
pattern = r'(pyautogui\.\w+\([^\)]*\))'
|
||||
matches = re.findall(pattern, pyautogui_code_relative_coordinates)
|
||||
|
||||
new_code = pyautogui_code_relative_coordinates
|
||||
|
||||
for full_call in matches:
|
||||
func_name_pattern = r'(pyautogui\.\w+)\((.*)\)'
|
||||
func_match = re.match(func_name_pattern, full_call, re.DOTALL)
|
||||
if not func_match:
|
||||
continue
|
||||
|
||||
func_name = func_match.group(1)
|
||||
args_str = func_match.group(2)
|
||||
|
||||
try:
|
||||
parsed = ast.parse(f"func({args_str})").body[0].value
|
||||
parsed_args = parsed.args
|
||||
parsed_keywords = parsed.keywords
|
||||
|
||||
except SyntaxError:
|
||||
return pyautogui_code_relative_coordinates
|
||||
|
||||
function_parameters = {
|
||||
'click': ['x', 'y', 'clicks', 'interval', 'button', 'duration', 'pause'],
|
||||
'rightClick': ['x', 'y', 'duration', 'tween', 'pause'],
|
||||
'middleClick': ['x', 'y', 'duration', 'tween', 'pause'],
|
||||
'doubleClick': ['x', 'y', 'interval', 'button', 'duration', 'pause'],
|
||||
'tripleClick': ['x', 'y', 'interval', 'button', 'duration', 'pause'],
|
||||
'moveTo': ['x', 'y', 'duration', 'tween', 'pause'],
|
||||
'dragTo': ['x', 'y', 'duration', 'button', 'mouseDownUp', 'pause'],
|
||||
}
|
||||
|
||||
func_base_name = func_name.split('.')[-1]
|
||||
|
||||
param_names = function_parameters.get(func_base_name, [])
|
||||
|
||||
args = {}
|
||||
for idx, arg in enumerate(parsed_args):
|
||||
if idx < len(param_names):
|
||||
param_name = param_names[idx]
|
||||
arg_value = ast.literal_eval(arg)
|
||||
args[param_name] = arg_value
|
||||
|
||||
try:
|
||||
for kw in parsed_keywords:
|
||||
param_name = kw.arg
|
||||
arg_value = ast.literal_eval(kw.value)
|
||||
args[param_name] = arg_value
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing keyword arguments: {e}")
|
||||
return pyautogui_code_relative_coordinates
|
||||
|
||||
updated = False
|
||||
if 'x' in args and 'y' in args:
|
||||
try:
|
||||
x_rel = float(args['x'])
|
||||
y_rel = float(args['y'])
|
||||
x_abs, y_abs = _coordinate_projection(x_rel, y_rel, screen_width, screen_height, coordinate_type)
|
||||
logger.warning(f"Projecting coordinates: ({x_rel}, {y_rel}) to ({x_abs}, {y_abs}) using {coordinate_type} projection.")
|
||||
args['x'] = x_abs
|
||||
args['y'] = y_abs
|
||||
updated = True
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if updated:
|
||||
reconstructed_args = []
|
||||
for idx, param_name in enumerate(param_names):
|
||||
if param_name in args:
|
||||
arg_value = args[param_name]
|
||||
if isinstance(arg_value, str):
|
||||
arg_repr = f"'{arg_value}'"
|
||||
else:
|
||||
arg_repr = str(arg_value)
|
||||
reconstructed_args.append(arg_repr)
|
||||
else:
|
||||
break
|
||||
|
||||
used_params = set(param_names[:len(reconstructed_args)])
|
||||
for kw in parsed_keywords:
|
||||
if kw.arg not in used_params:
|
||||
arg_value = args[kw.arg]
|
||||
if isinstance(arg_value, str):
|
||||
arg_repr = f"{kw.arg}='{arg_value}'"
|
||||
else:
|
||||
arg_repr = f"{kw.arg}={arg_value}"
|
||||
reconstructed_args.append(arg_repr)
|
||||
|
||||
new_args_str = ', '.join(reconstructed_args)
|
||||
new_full_call = f"{func_name}({new_args_str})"
|
||||
new_code = new_code.replace(full_call, new_full_call)
|
||||
|
||||
return new_code
|
||||
|
||||
def transform_agnet_action_to_code_block(action):
|
||||
if any(keyword in action for keyword in ["computer.terminate", "computer.wait", "browser.select_option", "browser.clear"]):
|
||||
return f"```code\n{action}\n```"
|
||||
else:
|
||||
return f"```python\n{action}\n```"
|
||||
|
||||
class OpenCUAAgent:
|
||||
"""
|
||||
OpenCUA Agent for desktop automation tasks.
|
||||
|
||||
This class implements a OpenCUA Model based agent that can observe
|
||||
desktop environments through screenshots and execute mouse/keyboard actions
|
||||
via PyAutoGUI to complete automation tasks.
|
||||
|
||||
Attributes:
|
||||
model (str): Name of the language model being used
|
||||
history_type (str): Type of history recording mechanism
|
||||
actions (list): History of executed actions
|
||||
observations (list): History of environment observations
|
||||
cots (list): Chain of thought reasoning records
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model: str, # OpenCUA model name
|
||||
history_type: str, # History step type: action_history, thought_history, observation_history
|
||||
max_steps: int, # The max number of steps to finish the task
|
||||
max_image_history_length: int = 3, # The max number of images in the history
|
||||
platform: str = "ubuntu", # The platform of the computer
|
||||
max_tokens: int = 1500, # The max number of tokens in the response
|
||||
top_p: float = 0.9, # The top p value in the response
|
||||
temperature: float = 0, # The temperature value in the response
|
||||
action_space: str = "pyautogui", # The action space: pyautogui
|
||||
observation_type: str = "screenshot", # The observation type: screenshot
|
||||
cot_level: str = "l2", # The CoT level: l1, l2, l3
|
||||
screen_size: Tuple[int, int] = (1920, 1080), # The screen size
|
||||
coordinate_type: str = "relative", # The coordinate type: relative, absolute, qwen25
|
||||
use_old_sys_prompt: bool = False, # Whether to use the old system prompt
|
||||
password="osworld-public-evaluation", # The password for the ubuntu platform
|
||||
**kwargs
|
||||
):
|
||||
assert coordinate_type in ["relative", "absolute", "qwen25"]
|
||||
assert action_space in ["pyautogui"], "Invalid action space"
|
||||
assert observation_type in ["screenshot"], "Invalid observation type"
|
||||
assert history_type in ["action_history", "thought_history", "observation_history"]
|
||||
assert model is not None, "Model cannot be None"
|
||||
|
||||
self.model = model
|
||||
self.platform = platform
|
||||
self.max_tokens = max_tokens
|
||||
self.top_p = top_p
|
||||
self.temperature = temperature
|
||||
self.action_space = action_space
|
||||
self.observation_type = observation_type
|
||||
self.history_type = history_type
|
||||
self.coordinate_type = coordinate_type
|
||||
self.cot_level = cot_level
|
||||
self.screen_size = screen_size
|
||||
self.max_image_history_length = max_image_history_length
|
||||
self.max_steps = max_steps
|
||||
self.password = password
|
||||
|
||||
if history_type == "action_history":
|
||||
self.HISTORY_TEMPLATE = ACTION_HISTORY_TEMPLATE
|
||||
elif history_type == "thought_history":
|
||||
self.HISTORY_TEMPLATE = THOUGHT_HISTORY_TEMPLATE
|
||||
elif history_type == "observation_history":
|
||||
self.HISTORY_TEMPLATE = OBSERVATION_HISTORY_TEMPLATE
|
||||
else:
|
||||
raise ValueError(f"Invalid history type: {history_type}")
|
||||
|
||||
if use_old_sys_prompt:
|
||||
if cot_level == "l1":
|
||||
self.system_prompt = SYSTEM_PROMPT_V1_L1
|
||||
elif cot_level == "l2":
|
||||
self.system_prompt = SYSTEM_PROMPT_V1_L2
|
||||
elif cot_level == "l3":
|
||||
self.system_prompt = SYSTEM_PROMPT_V1_L3
|
||||
else:
|
||||
raise ValueError("Invalid cot_level. Choose from 'l1', 'l2', or 'l3'.")
|
||||
else:
|
||||
self.system_prompt = build_sys_prompt(
|
||||
level=self.cot_level,
|
||||
password=self.password,
|
||||
use_random=False
|
||||
)
|
||||
|
||||
self.actions = []
|
||||
self.observations = []
|
||||
self.cots = []
|
||||
|
||||
def reset(self, _logger=None):
|
||||
global logger
|
||||
logger = _logger if _logger is not None else logging.getLogger("desktopenv.agent")
|
||||
|
||||
self.observations = []
|
||||
self.cots = []
|
||||
self.actions = []
|
||||
|
||||
def _scale_scroll_for_windows(self, code: str, factor: int = 50) -> str:
|
||||
""" pyautogui.scroll has a different scale on Ubuntu and Windows, multiple 'factor' when scrolling on Windows system"""
|
||||
if self.platform.lower() != "windows":
|
||||
return code
|
||||
|
||||
pattern_pos = re.compile(r'(pyautogui\.scroll\()\s*([-+]?\d+)\s*\)')
|
||||
code = pattern_pos.sub(lambda m: f"{m.group(1)}{int(m.group(2))*factor})", code)
|
||||
return code
|
||||
|
||||
def predict(self, instruction: str, obs: Dict, **kwargs) -> Tuple[str, List[str], Dict]:
|
||||
"""
|
||||
Predict the next action(s) based on the current observation.
|
||||
"""
|
||||
if "step_idx" in kwargs:
|
||||
logger.info(f"========= {self.model} Step {kwargs['step_idx']} =======")
|
||||
else:
|
||||
logger.info(f"========================== {self.model} ===================================")
|
||||
logger.info(f"Instruction: \n{instruction}")
|
||||
|
||||
messages = []
|
||||
messages.append({
|
||||
"role": "system",
|
||||
"content": self.system_prompt
|
||||
})
|
||||
instruction_prompt = INSTRUTION_TEMPLATE.format(instruction=instruction)
|
||||
|
||||
history_step_texts = []
|
||||
for i in range(len(self.actions)):
|
||||
if i > len(self.actions) - self.max_image_history_length:
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{encode_image(self.observations[i]['screenshot'])}"}
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
history_content = STEP_TEMPLATE.format(step_num=i+1) + self.HISTORY_TEMPLATE.format(
|
||||
observation=self.cots[i].get('observation'),
|
||||
thought=self.cots[i].get('thought'),
|
||||
action=self.cots[i].get('action')
|
||||
)
|
||||
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": history_content
|
||||
})
|
||||
else:
|
||||
history_content = STEP_TEMPLATE.format(step_num=i+1) + self.HISTORY_TEMPLATE.format(
|
||||
observation=self.cots[i].get('observation'),
|
||||
thought=self.cots[i].get('thought'),
|
||||
action=self.cots[i].get('action')
|
||||
)
|
||||
history_step_texts.append(history_content)
|
||||
if i == len(self.actions) - self.max_image_history_length:
|
||||
messages.append({
|
||||
"role":"assistant",
|
||||
"content": "\n".join(history_step_texts)
|
||||
})
|
||||
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{encode_image(obs['screenshot'])}"}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": instruction_prompt
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
max_retry = 5
|
||||
retry_count = 0
|
||||
low_level_instruction = None
|
||||
pyautogui_actions = None
|
||||
other_cot = {}
|
||||
|
||||
while retry_count < max_retry:
|
||||
try:
|
||||
response = self.call_llm({
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": self.max_tokens,
|
||||
"top_p": self.top_p,
|
||||
"temperature": self.temperature if retry_count==0 else max(0.2, self.temperature)
|
||||
}, self.model)
|
||||
|
||||
logger.info(f"Model Output: \n{response}")
|
||||
if not response:
|
||||
logger.error("No response found in the response.")
|
||||
raise ValueError(f"No response found in the response:\n{response}.")
|
||||
|
||||
low_level_instruction, pyautogui_actions, other_cot = parse_response_to_cot_and_action(response, self.screen_size, self.coordinate_type)
|
||||
if "<Error>" in low_level_instruction or not pyautogui_actions:
|
||||
logger.error(f"Error parsing response: {low_level_instruction}")
|
||||
raise ValueError(f"Error parsing response: {low_level_instruction}")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during message preparation: {e}")
|
||||
retry_count += 1
|
||||
if retry_count == max_retry:
|
||||
logger.error("Maximum retries reached. Exiting.")
|
||||
return str(e), ['FAIL'], other_cot
|
||||
|
||||
pyautogui_actions = [
|
||||
self._scale_scroll_for_windows(code) for code in pyautogui_actions
|
||||
]
|
||||
logger.info(f"Action: \n{low_level_instruction}")
|
||||
logger.info(f"Code: \n{pyautogui_actions}")
|
||||
|
||||
self.observations.append(obs)
|
||||
self.actions.append(low_level_instruction)
|
||||
self.cots.append(other_cot)
|
||||
|
||||
current_step = len(self.actions)
|
||||
if current_step >= self.max_steps and 'computer.terminate' not in pyautogui_actions[0].lower():
|
||||
logger.warning(f"Reached maximum steps {self.max_steps}. Forcing termination.")
|
||||
low_level_instruction = 'Fail the task because reaching the maximum step limit.'
|
||||
pyautogui_actions = ['FAIL']
|
||||
other_cot['code'] = 'FAIL'
|
||||
|
||||
return response, pyautogui_actions, other_cot
|
||||
|
||||
|
||||
def call_llm(self, payload, model):
|
||||
"""Call the LLM API"""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {os.environ['OPENCUA_API_KEY']}"
|
||||
}
|
||||
|
||||
for _ in range(20):
|
||||
response = httpx.post(
|
||||
f"https://{self.model}.app.msh.team/v1/chat/completions",
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=500,
|
||||
verify=False
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error("Failed to call LLM: " + response.text)
|
||||
logger.error("Retrying...")
|
||||
time.sleep(5)
|
||||
else:
|
||||
response = response.json()
|
||||
finish_reason = response["choices"][0].get("finish_reason")
|
||||
if finish_reason is not None and finish_reason == "stop": # for most of the time, length will not exceed max_tokens
|
||||
return response['choices'][0]['message']['content']
|
||||
else:
|
||||
logger.error("LLM did not finish properly, retrying...")
|
||||
time.sleep(5)
|
||||
|
|
@ -0,0 +1,349 @@
|
|||
import random
|
||||
|
||||
# System prompt for OpenCUA-7B, OpenCUA-32B
|
||||
# System prompts used in the training data
|
||||
SYSTEM_PROMPT_V1_L1 = "You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task.\n\nFor each step, provide your response in this format:\n\nAction:\n Provide clear, concise, and actionable instructions:\n - If the action involves interacting with a specific target:\n - Describe target explicitly without using coordinates\n - Specify element names when possible (use original language if non-English)\n - Describe features (shape, color, position) if name unavailable\n - For window control buttons, identify correctly (minimize \"—\", maximize \"□\", close \"X\")\n - if the action involves keyboard actions like 'press', 'write', 'hotkey':\n - Consolidate repetitive keypresses with count\n - Specify expected text outcome for typing actions\n\nFinally, output the action as PyAutoGUI code or the following functions:\n- {\"name\": \"computer.triple_click\", \"description\": \"Triple click on the screen\", \"parameters\": {\"type\": \"object\", \"properties\": {\"x\": {\"type\": \"number\", \"description\": \"The x coordinate of the triple click\"}, \"y\": {\"type\": \"number\", \"description\": \"The y coordinate of the triple click\"}}, \"required\": [\"x\", \"y\"]}}\n- {\"name\": \"computer.terminate\", \"description\": \"Terminate the current task and report its completion status\", \"parameters\": {\"type\": \"object\", \"properties\": {\"status\": {\"type\": \"string\", \"enum\": [\"success\", \"fail\"], \"description\": \"The status of the task\"}}, \"required\": [\"status\"]}}".strip()
|
||||
SYSTEM_PROMPT_V1_L2 = "You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task.\n\nFor each step, provide your response in this format:\n\nThought:\n - Step by Step Progress Assessment:\n - Analyze completed task parts and their contribution to the overall goal\n - Reflect on potential errors, unexpected results, or obstacles\n - If previous action was incorrect, predict a logical recovery step\n - Next Action Analysis:\n - List possible next actions based on current state\n - Evaluate options considering current state and previous actions\n - Propose most logical next action\n - Anticipate consequences of the proposed action\n - For Text Input Actions:\n - Note current cursor position\n - Consolidate repetitive actions (specify count for multiple keypresses)\n - Describe expected final text outcome\n - Use first-person perspective in reasoning\n\nAction:\n Provide clear, concise, and actionable instructions:\n - If the action involves interacting with a specific target:\n - Describe target explicitly without using coordinates\n - Specify element names when possible (use original language if non-English)\n - Describe features (shape, color, position) if name unavailable\n - For window control buttons, identify correctly (minimize \"—\", maximize \"□\", close \"X\")\n - if the action involves keyboard actions like 'press', 'write', 'hotkey':\n - Consolidate repetitive keypresses with count\n - Specify expected text outcome for typing actions\n\nFinally, output the action as PyAutoGUI code or the following functions:\n- {\"name\": \"computer.triple_click\", \"description\": \"Triple click on the screen\", \"parameters\": {\"type\": \"object\", \"properties\": {\"x\": {\"type\": \"number\", \"description\": \"The x coordinate of the triple click\"}, \"y\": {\"type\": \"number\", \"description\": \"The y coordinate of the triple click\"}}, \"required\": [\"x\", \"y\"]}}\n- {\"name\": \"computer.terminate\", \"description\": \"Terminate the current task and report its completion status\", \"parameters\": {\"type\": \"object\", \"properties\": {\"status\": {\"type\": \"string\", \"enum\": [\"success\", \"fail\"], \"description\": \"The status of the task\"}}, \"required\": [\"status\"]}}".strip()
|
||||
SYSTEM_PROMPT_V1_L3 = "You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task.\n\nFor each step, provide your response in this format:\n\nObservation:\n - Describe the current computer state based on the full screenshot in detail. \n - Application Context:\n - The active application\n - The active window or page\n - Overall layout and visible interface\n - Key Elements:\n - Menu items and toolbars \n - Buttons and controls\n - Text fields and content\n - Dialog boxes or popups\n - Error messages or notifications\n - Loading states\n - Other key elements\n - Describe any content, elements, options, information or clues that are possibly relevant to achieving the task goal, including their name, content, or shape (if possible).\n\nThought:\n - Step by Step Progress Assessment:\n - Analyze completed task parts and their contribution to the overall goal\n - Reflect on potential errors, unexpected results, or obstacles\n - If previous action was incorrect, predict a logical recovery step\n - Next Action Analysis:\n - List possible next actions based on current state\n - Evaluate options considering current state and previous actions\n - Propose most logical next action\n - Anticipate consequences of the proposed action\n - For Text Input Actions:\n - Note current cursor position\n - Consolidate repetitive actions (specify count for multiple keypresses)\n - Describe expected final text outcome\n - Use first-person perspective in reasoning\n\nAction:\n Provide clear, concise, and actionable instructions:\n - If the action involves interacting with a specific target:\n - Describe target explicitly without using coordinates\n - Specify element names when possible (use original language if non-English)\n - Describe features (shape, color, position) if name unavailable\n - For window control buttons, identify correctly (minimize \"—\", maximize \"□\", close \"X\")\n - if the action involves keyboard actions like 'press', 'write', 'hotkey':\n - Consolidate repetitive keypresses with count\n - Specify expected text outcome for typing actions\n\nFinally, output the action as PyAutoGUI code or the following functions:\n- {\"name\": \"computer.triple_click\", \"description\": \"Triple click on the screen\", \"parameters\": {\"type\": \"object\", \"properties\": {\"x\": {\"type\": \"number\", \"description\": \"The x coordinate of the triple click\"}, \"y\": {\"type\": \"number\", \"description\": \"The y coordinate of the triple click\"}}, \"required\": [\"x\", \"y\"]}}\n- {\"name\": \"computer.terminate\", \"description\": \"Terminate the current task and report its completion status\", \"parameters\": {\"type\": \"object\", \"properties\": {\"status\": {\"type\": \"string\", \"enum\": [\"success\", \"fail\"], \"description\": \"The status of the task\"}}, \"required\": [\"status\"]}}\n".strip()
|
||||
|
||||
# Testing prompt on OSWorld-Verified
|
||||
SYSTEM_PROMPT_V1_L2 = """You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task. The password of the computer is "osworld-public-evaluation". If the task is not possible to do, output the action computer.terminate(status='failure').
|
||||
|
||||
For each step, provide your response in this format:
|
||||
|
||||
Thought:\n - Step by Step Progress Assessment:\n - Analyze completed task parts and their contribution to the overall goal\n - Reflect on potential errors, unexpected results, or obstacles\n - If previous action was incorrect, predict a logical recovery step\n - Next Action Analysis:\n - List possible next actions based on current state\n - Evaluate options considering current state and previous actions\n - Propose most logical next action\n - Anticipate consequences of the proposed action\n - For Text Input Actions:\n - Note current cursor position\n - Consolidate repetitive actions (specify count for multiple keypresses)\n - Describe expected final text outcome\n - Use first-person perspective in reasoning
|
||||
|
||||
Action:\n Provide clear, concise, and actionable instructions:\n - If the action involves interacting with a specific target:\n - Describe target explicitly without using coordinates\n - Specify element names when possible (use original language if non-English)\n - Describe features (shape, color, position) if name unavailable\n - For window control buttons, identify correctly (minimize "—", maximize "□", close "X")\n - if the action involves keyboard actions like \'press\', \'write\', \'hotkey\':\n - Consolidate repetitive keypresses with count\n - Specify expected text outcome for typing actions
|
||||
|
||||
Finally, output the action as PyAutoGUI code or the following functions:
|
||||
- {"name": "computer.triple_click", "description": "Triple click on the screen", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The x coordinate of the triple click"}, "y": {"type": "number", "description": "The y coordinate of the triple click"}}, "required": ["x", "y"]}}
|
||||
- {"name": "computer.terminate", "description": "Terminate the current task and report its completion status", "parameters": {"type": "object", "properties": {"status": {"type": "string", "enum": ["success", "failure"], "description": "The status of the task"}}, "required": ["status"]}}
|
||||
""".strip()
|
||||
|
||||
|
||||
# SYSTEM_PROMPT for OpenCUA-72B
|
||||
general_computer_instructions = [
|
||||
"""
|
||||
You are a GUI agent. You are given a task, a screenshot of the screen and your previous interactions with the computer. You need to perform a series of actions to complete the task. The password of the computer is "{password}", use it when you need sudo rights. You need to **wait** explicitly for installation, waiting website loading or running commands to finish. Don\'t terminate the task unless you are sure the task is finished. If you find that you can\'t finish the task, or the task is not finished exactly as the instruction indicates (you have made progress but not finished the task completely), or the task is impossible to complete, you must report **failure**.
|
||||
""".strip(),
|
||||
"""
|
||||
You are acting as a GUI agent. A task description, a screenshot, and your past interactions will be supplied. Execute the necessary steps to fulfil the task. Whenever sudo operations are required, use the computer’s password "{password}". Insert an explicit **wait** after launching any installation, waiting website loading or long-running command to let it finish. Do not output terminate action unless you are certain the task is complete. If you realise the task can be finished or impossible to do, you should report **failure**.
|
||||
""".strip(),
|
||||
"""
|
||||
Your mission as a GUI agent is to complete the provided task using the current screen image and the history of interactions. For commands requiring elevated privileges, supply "{password}" as the sudo password. Explicitly invoke **wait** after launching any installation or command that may take time to finish. Do not terminate the session unless success is certain. If the task cannot be fully executed, or turns out impossible, you must declare **failure**.
|
||||
""".strip(),
|
||||
]
|
||||
|
||||
l3_format_instruction = """For each step, provide your response in this format:
|
||||
# Step: {step number}
|
||||
## Observation:
|
||||
{observation}
|
||||
## Thought:
|
||||
{thought}
|
||||
## Action:
|
||||
{action}
|
||||
## Code:
|
||||
{code}"""
|
||||
|
||||
l2_format_instruction = """For each step, provide your response in this format:
|
||||
# Step: {step number}
|
||||
## Thought:
|
||||
{thought}
|
||||
## Action:
|
||||
{action}
|
||||
## Code:
|
||||
{code}"""
|
||||
|
||||
l1_format_instruction = """For each step, provide your response in this format:
|
||||
# Step: {step number}
|
||||
## Action:
|
||||
{action}
|
||||
## Code:
|
||||
{code}"""
|
||||
|
||||
observation_instructions = [
|
||||
"""For the Observation section, you should include the following parts if helpful:
|
||||
- Describe the current computer state based on the full screenshot in detail.
|
||||
- Application Context:
|
||||
- The active application
|
||||
- The active window or page
|
||||
- Overall layout and visible interface
|
||||
- Key Elements:
|
||||
- Menu items and toolbars
|
||||
- Buttons and controls
|
||||
- Text fields and content
|
||||
- Dialog boxes or popups
|
||||
- Error messages or notifications
|
||||
- Loading states
|
||||
- Other key elements
|
||||
- Describe any content, elements, options, information or clues that are possibly relevant to achieving the task goal, including their name, content, or shape (if possible).
|
||||
""".strip(),
|
||||
|
||||
"""In the Observation section, outline everything visible on screen that could influence your next move:
|
||||
• Current system state as seen in the screenshot.
|
||||
• Application context:
|
||||
- Which application is running in the foreground
|
||||
- Specific window, tab, or page being displayed
|
||||
- High-level layout of panels, sidebars, and work areas
|
||||
• Salient interface elements:
|
||||
- Menus, ribbons, and toolbars
|
||||
- Actionable buttons, icons, toggles, and controls
|
||||
- Input areas such as text boxes or code editors
|
||||
- Pop-up dialogs, modals, alerts, or system notifications
|
||||
- Progress bars, spinners, or other loading indicators
|
||||
• Any text, labels, shapes, or on-screen cues that might help accomplish the task (cite names or visual traits when available).
|
||||
""".strip(),
|
||||
|
||||
# ── Variant 3 ──────────────────────────────────────────────────────────
|
||||
"""Write the Observation section as a thorough snapshot of the UI:
|
||||
- Start with a full-screen description: what the user sees at a glance.
|
||||
- Give application details: title, active workspace, and structural layout.
|
||||
- Enumerate critical elements:
|
||||
* Navigation menus and context bars
|
||||
* Primary and secondary buttons or icons
|
||||
* Editable fields, lists, tables, or rich-text areas
|
||||
* Dialogs, pop-ups, warnings, or confirmations
|
||||
* Indicators of loading or processing activity
|
||||
- Note any evidence, hints, or data (textual or visual) that could guide the task toward completion, referencing names, colors, shapes, or positions when explicit identifiers are missing.
|
||||
""".strip(),
|
||||
]
|
||||
|
||||
thought_instructions = [
|
||||
"""For the Thought section, you should include the following parts:
|
||||
- Reflection on the task when there is previous action:
|
||||
- Consider the correnctness of previous action and its outcomes
|
||||
- If the previous action was correct, describe the change in the state of the computer and reason
|
||||
- If the previous action was incorrect, reflect on what went wrong and why
|
||||
- Step by Step Progress Assessment:
|
||||
- Add necessary information according to the history screenshots, former actions and current screenshot.
|
||||
- Analyze what parts of the task have already been completed and how they contribute to the overall goal.
|
||||
- Make a plan on how to complete the task based on the history and currect screenshot.
|
||||
- Next Action Prediction:
|
||||
- Propose the most possible next action and state the reason
|
||||
- For Text Input Actions:
|
||||
- Note current cursor position
|
||||
- Consolidate repetitive actions (specify count for multiple keypresses)
|
||||
- Describe expected final text outcome
|
||||
- Use first-person perspective in reasoning
|
||||
""".strip(),
|
||||
|
||||
"""
|
||||
In the **Thought** block, cover these topics:
|
||||
|
||||
1. **Last-Step Reflection** (when a prior action exists)
|
||||
• Was my previous action correct? What evidence shows this?
|
||||
• If it succeeded, what state change occurred and why?
|
||||
• If it failed, where did I go wrong?
|
||||
|
||||
2. **Incremental Progress Audit**
|
||||
• Which sub-tasks are completed and how do they advance the mission?
|
||||
• Make a plan to finish the task based on past actions and the current UI state.
|
||||
|
||||
3. **Foresight for the Coming Action**
|
||||
• Predict the most logical next step.
|
||||
• State the reason why it is the best choice given the current context.
|
||||
|
||||
4. **Guidance for Text Entry**
|
||||
• Note the cursor location
|
||||
• Compress multiple identical keystrokes (e.g., “press Backspace ×3”)
|
||||
• Clarify the exact text expected after input
|
||||
|
||||
Use first-person inner dialogue throughout.
|
||||
""".strip(),
|
||||
|
||||
"""
|
||||
Compose your **Thought** section as an internal monologue that includes:
|
||||
|
||||
- **Retrospective** (if a prior step exists):
|
||||
* Evaluate the accuracy and effect of the last action.
|
||||
* If it was successful, reason about the resulting interface change.
|
||||
* If it was faulty, diagnose the misstep and its cause.
|
||||
|
||||
- **Ongoing Progress Evaluation**:
|
||||
* Outline which parts of the task are done and their impact on the overall objective.
|
||||
* Suggest a plan to complete the task based on past history and the current screen.
|
||||
|
||||
- **Decision Framework for the Next Move**:
|
||||
* Brainstorm possible next action given the present state.
|
||||
* Explain why this action is the most logical choice.
|
||||
|
||||
- **Special Rules for Keyboard Input**:
|
||||
* Specify current cursor focus or field.
|
||||
* Merge repeated keypresses into counts for brevity.
|
||||
* Describe the intended final text after typing.
|
||||
|
||||
Maintain a first-person voice for clarity of reasoning.
|
||||
""".strip(),
|
||||
]
|
||||
|
||||
action_instructions = [
|
||||
"""For the action section, you should provide clear, concise, and actionable instructions in one sentence.
|
||||
- If the action involves interacting with a specific target:
|
||||
- Describe target explicitly (if multiple elements share that name, you should distinguish the target) without using coordinates
|
||||
- Specify element names when possible (use original language if non-English)
|
||||
- Describe features (shape, color, position) if name unavailable
|
||||
- If the action involves keyboard actions like 'press', 'write', 'hotkey':
|
||||
- Consolidate repetitive keypresses with count
|
||||
- Specify expected text outcome for typing actions
|
||||
""".strip(),
|
||||
|
||||
"""
|
||||
Write the **Action** in one short, direct sentence.
|
||||
|
||||
• When clicking or otherwise interacting with a UI element:
|
||||
- Name the element explicitly — and, if multiple elements share that name, add a distinguishing detail.
|
||||
- Do **not** give coordinates.
|
||||
- Use the element's label (keep original language when it isn't English).
|
||||
- If unnamed, describe recognisable traits (shape, colour, on-screen position).
|
||||
|
||||
• When using the keyboard (press, type, hotkey):
|
||||
- Collapse repeated key presses into counts.
|
||||
- For typing, specify the text that should appear.
|
||||
""".strip(),
|
||||
|
||||
"""
|
||||
Provide the **Action** as a single, crisp imperative sentence.
|
||||
|
||||
- Mouse/GUI interactions:
|
||||
* Identify the target by name, and if duplicate names exist, clarify which one you mean.
|
||||
* Do not supply XY coordinates.
|
||||
* Preserve non-English labels verbatim.
|
||||
* If unnamed, describe the element's look or location (colour, shape, relative position).
|
||||
|
||||
- Keyboard operations (press, write, hotkey):
|
||||
* Combine repeated keystrokes with a multiplier.
|
||||
* State the exact text that will be entered.
|
||||
""".strip(),
|
||||
]
|
||||
|
||||
code_instrucion = """For the code section, you should output the corresponding code for the action. The code should be either PyAutoGUI code or one of the following functions warped in the code block:
|
||||
- {"name": "computer.wait", "description": "Make the computer wait for 20 seconds for installation, running code, etc.", "parameters": {"type": "object", "properties": {}, "required": []}}
|
||||
- {"name": "computer.terminate", "description": "Terminate the current task and report its completion status", "parameters": {"type": "object", "properties": {"status": {"type": "string", "enum": ["success", "failure"], "description": "The status of the task"}, {"answer": {"type": "string", "description": "The answer of the task"}}, "required": ["status"]}}
|
||||
Examples for the code section:
|
||||
```python
|
||||
pyautogui.click(x=123, y=456)
|
||||
```
|
||||
```code
|
||||
computer.terminate(status="success")
|
||||
```
|
||||
```code
|
||||
computer.terminate(status="success", answer='''text''')
|
||||
```"""
|
||||
|
||||
SYSTEM_PROMPT_V2_L1 = """
|
||||
{general_computer_instruction}
|
||||
|
||||
{format_instruction}
|
||||
|
||||
{action_instruction}
|
||||
|
||||
{code_instruction}
|
||||
""".strip()
|
||||
|
||||
SYSTEM_PROMPT_V2_L2 = """
|
||||
{general_computer_instruction}
|
||||
|
||||
{format_instruction}
|
||||
|
||||
{thought_instruction}
|
||||
|
||||
{action_instruction}
|
||||
|
||||
{code_instruction}
|
||||
""".strip()
|
||||
|
||||
SYSTEM_PROMPT_V2_L3 = """
|
||||
{general_computer_instruction}
|
||||
|
||||
{format_instruction}
|
||||
|
||||
{observation_instruction}
|
||||
|
||||
{thought_instruction}
|
||||
|
||||
{action_instruction}
|
||||
|
||||
{code_instruction}
|
||||
""".strip()
|
||||
|
||||
|
||||
def build_sys_prompt(level, password="password", use_random=False):
|
||||
if not use_random:
|
||||
if level == "l1":
|
||||
return SYSTEM_PROMPT_V2_L1.format(
|
||||
general_computer_instruction=general_computer_instructions[0].format(
|
||||
password=password
|
||||
),
|
||||
format_instruction=l1_format_instruction,
|
||||
action_instruction=action_instructions[0],
|
||||
code_instruction=code_instrucion,
|
||||
)
|
||||
elif level == "l2":
|
||||
return SYSTEM_PROMPT_V2_L2.format(
|
||||
general_computer_instruction=general_computer_instructions[0].format(
|
||||
password=password
|
||||
),
|
||||
format_instruction=l2_format_instruction,
|
||||
thought_instruction=thought_instructions[0],
|
||||
action_instruction=action_instructions[0],
|
||||
code_instruction=code_instrucion,
|
||||
)
|
||||
elif level == "l3":
|
||||
return SYSTEM_PROMPT_V2_L3.format(
|
||||
general_computer_instruction=general_computer_instructions[0].format(
|
||||
password=password
|
||||
),
|
||||
format_instruction=l3_format_instruction,
|
||||
observation_instruction=observation_instructions[0],
|
||||
thought_instruction=thought_instructions[0],
|
||||
action_instruction=action_instructions[0],
|
||||
code_instruction=code_instrucion,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid level. Choose from 'l1', 'l2', or 'l3'.")
|
||||
else:
|
||||
if level == "l1":
|
||||
return SYSTEM_PROMPT_V2_L1.format(
|
||||
general_computer_instruction=random.choice(
|
||||
general_computer_instructions
|
||||
),
|
||||
format_instruction=l1_format_instruction,
|
||||
action_instruction=random.choice(action_instructions),
|
||||
code_instruction=code_instrucion,
|
||||
)
|
||||
elif level == "l2":
|
||||
return SYSTEM_PROMPT_V2_L2.format(
|
||||
general_computer_instruction=random.choice(
|
||||
general_computer_instructions
|
||||
),
|
||||
format_instruction=l2_format_instruction,
|
||||
thought_instruction=random.choice(thought_instructions),
|
||||
action_instruction=random.choice(action_instructions),
|
||||
code_instruction=code_instrucion,
|
||||
)
|
||||
elif level == "l3":
|
||||
return SYSTEM_PROMPT_V2_L3.format(
|
||||
general_computer_instruction=random.choice(
|
||||
general_computer_instructions
|
||||
),
|
||||
format_instruction=l3_format_instruction,
|
||||
observation_instruction=random.choice(observation_instructions),
|
||||
thought_instruction=random.choice(thought_instructions),
|
||||
action_instruction=random.choice(action_instructions),
|
||||
code_instruction=code_instrucion,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid level. Choose from 'l1', 'l2', or 'l3'.")
|
||||
|
||||
|
||||
# Modeling prompt templates for generating trajectories
|
||||
STEP_TEMPLATE = "# Step {step_num}:\n"
|
||||
INSTRUTION_TEMPLATE = "# Task Instruction:\n{instruction}\n\nPlease generate the next move according to the screenshot, task instruction and previous steps (if provided).\n"
|
||||
|
||||
ACTION_HISTORY_TEMPLATE = "## Action:\n{action}\n"
|
||||
THOUGHT_HISTORY_TEMPLATE = "## Thought:\n{thought}\n\n## Action:\n{action}\n"
|
||||
OBSERVATION_HISTORY_TEMPLATE = "## Observation:\n{observation}\n\n## Thought:\n{thought}\n\n## Action:\n{action}\n"
|
||||
|
||||
ACTION_HISTORY_TEMPLATE_WITH_CODE = "## Action:\n{action}\n\n## Code:\n{code}\n"
|
||||
THOUGHT_HISTORY_TEMPLATE_WITH_CODE = "## Thought:\n{thought}\n\n## Action:\n{action}\n\n## Code:\n{code}\n"
|
||||
OBSERVATION_HISTORY_TEMPLATE_WITH_CODE = "## Observation:\n{observation}\n\n## Thought:\n{thought}\n\n## Action:\n{action}\n\n## Code:\n{code}\n"
|
||||
|
|
@ -0,0 +1,483 @@
|
|||
import re
|
||||
import base64
|
||||
from loguru import logger
|
||||
from typing import List, Optional
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
import tempfile
|
||||
import os
|
||||
import math
|
||||
|
||||
def encode_image(image_content):
|
||||
return base64.b64encode(image_content).decode("utf-8")
|
||||
|
||||
def smart_resize(
|
||||
height: int,
|
||||
width: int,
|
||||
factor: int = 28,
|
||||
min_pixels: int = 56 * 56,
|
||||
max_pixels: int = 14 * 14 * 4 * 1280,
|
||||
max_aspect_ratio_allowed: Optional[float] = None,
|
||||
size_can_be_smaller_than_factor: bool = False,
|
||||
):
|
||||
"""Rescales the image so that the following conditions are met:
|
||||
|
||||
1. Both dimensions (height and width) are divisible by 'factor'.
|
||||
|
||||
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
||||
|
||||
3. The aspect ratio of the image is maintained as closely as possible.
|
||||
|
||||
"""
|
||||
if not size_can_be_smaller_than_factor and (height < factor or width < factor):
|
||||
raise ValueError(
|
||||
f"height:{height} or width:{width} must be larger than factor:{factor} "
|
||||
f"(when size_can_be_smaller_than_factor is False)"
|
||||
)
|
||||
elif (
|
||||
max_aspect_ratio_allowed is not None
|
||||
and max(height, width) / min(height, width) > max_aspect_ratio_allowed
|
||||
):
|
||||
raise ValueError(
|
||||
f"absolute aspect ratio must be smaller than {max_aspect_ratio_allowed}, "
|
||||
f"got {max(height, width) / min(height, width)}"
|
||||
f"(when max_aspect_ratio_allowed is not None)"
|
||||
)
|
||||
h_bar = max(1, round(height / factor)) * factor
|
||||
w_bar = max(1, round(width / factor)) * factor
|
||||
if h_bar * w_bar > max_pixels:
|
||||
beta = math.sqrt((height * width) / max_pixels)
|
||||
h_bar = max(1, math.floor(height / beta / factor)) * factor
|
||||
w_bar = max(1, math.floor(width / beta / factor)) * factor
|
||||
elif h_bar * w_bar < min_pixels:
|
||||
beta = math.sqrt(min_pixels / (height * width))
|
||||
h_bar = math.ceil(height * beta / factor) * factor
|
||||
w_bar = math.ceil(width * beta / factor) * factor
|
||||
return h_bar, w_bar
|
||||
|
||||
def call_openai_naive(model, payload, address_hint=None):
|
||||
"""
|
||||
Naive OpenAI API call using requests.
|
||||
"""
|
||||
# Extract fields from payload
|
||||
model = payload.get("model")
|
||||
payload["model"] = model.model_id if hasattr(model, "model_id") else "None"
|
||||
# address_hint not used here
|
||||
base_url = model.base_url
|
||||
# logger.warning(f"Base URL: {base_url}, Payload model: {payload['model']}")
|
||||
url = f"{base_url}/chat/completions"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
data = {
|
||||
**payload,
|
||||
"n": 1,
|
||||
}
|
||||
max_retry = 5
|
||||
chat_completions = None
|
||||
success = False
|
||||
while success is False and max_retry > 0:
|
||||
try:
|
||||
json_data = json.dumps(data)
|
||||
response = requests.post(
|
||||
url, headers=headers, data=json_data, timeout=120, verify=False
|
||||
)
|
||||
if response.status_code == 200:
|
||||
chat_completions = response.json()
|
||||
try:
|
||||
finish_reason = chat_completions["choices"][0].get("finish_reason")
|
||||
if (
|
||||
finish_reason is not None and finish_reason == "stop"
|
||||
): # for most of the time, length will not exceed max_tokens
|
||||
success = True
|
||||
else:
|
||||
time.sleep(5)
|
||||
max_retry -= 1
|
||||
except Exception as e:
|
||||
logger.error(f"Error in processing chat completion: {e}")
|
||||
time.sleep(5)
|
||||
max_retry -= 1
|
||||
else:
|
||||
logger.error(f"Failed to call OpenAI API: {response.text}")
|
||||
time.sleep(5)
|
||||
max_retry -= 1
|
||||
except requests.exceptions.ReadTimeout:
|
||||
# timeout is normal, don't print trace
|
||||
max_retry -= 1
|
||||
logger.warning(f"Timeout in OpenAI API call, left retries: {max_retry}")
|
||||
time.sleep(5)
|
||||
|
||||
except Exception as e:
|
||||
max_retry -= 1
|
||||
logger.exception(f"Failed to call OpenAI API: {e}")
|
||||
time.sleep(5)
|
||||
|
||||
if chat_completions is None:
|
||||
raise RuntimeError("Failed to call OpenAI API, max_retry used up")
|
||||
try:
|
||||
infos = {}
|
||||
if "choices" in chat_completions:
|
||||
infos["finish_reason"] = chat_completions["choices"][0].get("finish_reason")
|
||||
infos["n"] = len(chat_completions["choices"])
|
||||
if "tool_calls" in chat_completions["choices"][0]["message"]:
|
||||
infos["tool_calls"] = chat_completions["choices"][0]["message"][
|
||||
"tool_calls"
|
||||
]
|
||||
infos["choices"] = chat_completions["choices"] # for the case of n > 1
|
||||
if "usage" in chat_completions:
|
||||
infos["usage"] = chat_completions["usage"]
|
||||
return chat_completions["choices"][0]["message"]["content"], infos
|
||||
except Exception as e:
|
||||
logger.error(f"Error in processing chat completion {e}")
|
||||
return "", {"n": 1, "usage": 0, "finish_reason": f"error {e}"}
|
||||
|
||||
|
||||
def preprocess_for_naive_openai(self, payload):
|
||||
if isinstance(payload["model"], str):
|
||||
payload["model"] = getattr(self, "openai_client", None)
|
||||
return payload
|
||||
|
||||
def encoded_img_to_pil_img(data_str):
|
||||
base64_str = data_str.replace("data:image/png;base64,", "")
|
||||
image_data = base64.b64decode(base64_str)
|
||||
return Image.open(BytesIO(image_data))
|
||||
|
||||
|
||||
def save_to_tmp_img_file(data_str):
|
||||
base64_str = data_str.replace("data:image/png;base64,", "")
|
||||
image_data = base64.b64decode(base64_str)
|
||||
image = Image.open(BytesIO(image_data))
|
||||
|
||||
tmp_img_path = os.path.join(tempfile.mkdtemp(), "tmp_img.png")
|
||||
image.save(tmp_img_path)
|
||||
|
||||
return tmp_img_path
|
||||
|
||||
|
||||
def bbox_to_center_1000(bbox: str) -> tuple[int, int]:
|
||||
regex_list = [
|
||||
r"<\|box_start\|>\((\d+),(\d+)\),\((\d+),(\d+)\)<\|box_end\|>", # '<|box_start|>(576,12),(592,42)<|box_end|>'
|
||||
r"<\|box_start\|>\[\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)\]\]<\|box_end\|>", # '<|box_start|>[[576, 12, 592, 42]]<|box_end|>'
|
||||
r"<\|box_start\|>\[\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)\]<\|box_end\|>", # '<|box_start|>[[576, 12, 592, 42]<|box_end|>', this is actually wrong format, but we parse it anyway
|
||||
r"<\|box_start\|>\((\d+),\s*(\d+),\s*(\d+),\s*(\d+)\)<\|box_end\|>", # '<|box_start|>(576, 12, 592, 42)<|box_end|>', this is actually wrong format, but we parse it anyway
|
||||
r"\((\d+),(\d+)\),\((\d+),(\d+)\)", # Versions without the 'bbox' special tokens
|
||||
r"\[\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)\]\]",
|
||||
r"\[\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)\]",
|
||||
r"\((\d+),\s*(\d+),\s*(\d+),\s*(\d+)\)",
|
||||
]
|
||||
for regex in regex_list:
|
||||
match = re.search(regex, bbox)
|
||||
if match:
|
||||
break
|
||||
if not match:
|
||||
raise ValueError(
|
||||
f"Bounding box coordinates not found in the input string: {bbox}"
|
||||
)
|
||||
x_top_left, y_top_left, x_bottom_right, y_bottom_right = map(int, match.groups())
|
||||
x_center = (x_top_left + x_bottom_right) // 2
|
||||
y_center = (y_top_left + y_bottom_right) // 2
|
||||
return x_center, y_center
|
||||
|
||||
|
||||
def bbox_to_center_1(bbox: str) -> tuple[int, int]:
|
||||
regex_list = [
|
||||
r"\[\s*(-?\d+\.\d+)\s*,\s*(-?\d+\.\d+)\s*,\s*(-?\d+\.\d+)\s*,\s*(-?\d+\.\d+)\s*\]",
|
||||
]
|
||||
for regex in regex_list:
|
||||
match = re.search(regex, bbox)
|
||||
if match:
|
||||
break
|
||||
if not match:
|
||||
raise ValueError(
|
||||
f"Bounding box coordinates not found in the input string: {bbox}"
|
||||
)
|
||||
coordinates = tuple(map(float, match.groups()))
|
||||
coordinates = [int(coord * 1000) for coord in coordinates]
|
||||
x_center = (coordinates[0] + coordinates[2]) // 2
|
||||
y_center = (coordinates[1] + coordinates[3]) // 2
|
||||
return x_center, y_center
|
||||
|
||||
def _coordinate_projection(x, y, screen_width, screen_height, coordinate_type):
|
||||
if coordinate_type == "relative":
|
||||
return int(round(x * screen_width)), int(round(y * screen_height))
|
||||
elif coordinate_type == "absolute":
|
||||
return x, y
|
||||
elif coordinate_type == "qwen25":
|
||||
height, width = smart_resize(
|
||||
height=screen_height,
|
||||
width=screen_width,
|
||||
factor=28,
|
||||
min_pixels=3136,
|
||||
max_pixels=12845056,
|
||||
)
|
||||
return int(x / width * screen_width), int(y / height * screen_height)
|
||||
elif coordinate_type == "relative1000":
|
||||
if screen_width == 0 or screen_height == 0:
|
||||
raise ValueError(
|
||||
"Screen width and height must be greater than zero for relative1000 coordinates."
|
||||
)
|
||||
x_abs = int(round(x * screen_width / 1000))
|
||||
y_abs = int(round(y * screen_height / 1000))
|
||||
return x_abs, y_abs
|
||||
else:
|
||||
raise ValueError(f"Unsupported coordinate type: {coordinate_type}")
|
||||
|
||||
|
||||
def rescale_coord(
|
||||
coord: tuple[int, int],
|
||||
original_width: int,
|
||||
original_height: int,
|
||||
scaled_width=1000,
|
||||
scaled_height=1000,
|
||||
) -> tuple[int, int]:
|
||||
# According to https://huggingface.co/spaces/maxiw/OS-ATLAS/blob/398c3256a4fec409a074e0e4b5ac1d1d5bf7c240/app.py#L36
|
||||
# It seems that OS-ATLAS model are rescaled to output 1000x1000 images
|
||||
# So we need to rescale the coordinates back to the original image size
|
||||
x_scale = original_width / scaled_width
|
||||
y_scale = original_height / scaled_height
|
||||
return int(coord[0] * x_scale), int(coord[1] * y_scale)
|
||||
|
||||
|
||||
def _pyautogui_code_to_absolute_coordinates(
|
||||
pyautogui_code_relative_coordinates,
|
||||
logical_screen_size,
|
||||
coordinate_type="relative",
|
||||
model_input_size=None,
|
||||
):
|
||||
"""
|
||||
Convert the relative coordinates in the pyautogui code to absolute coordinates based on the logical screen size.
|
||||
"""
|
||||
import re
|
||||
import ast
|
||||
|
||||
if coordinate_type not in ["relative", "relative1000", "absolute", "qwen25"]:
|
||||
raise ValueError(
|
||||
f"Invalid coordinate type: {coordinate_type}. Expected one of ['relative', 'relative1000', 'absolute', 'qwen25']."
|
||||
)
|
||||
|
||||
screen_width, screen_height = logical_screen_size
|
||||
if model_input_size is not None:
|
||||
model_width, model_height = model_input_size
|
||||
width_scale, height_scale = (
|
||||
screen_width / model_width,
|
||||
screen_height / model_height,
|
||||
)
|
||||
else:
|
||||
width_scale, height_scale = 1, 1
|
||||
|
||||
pattern = r"(pyautogui\.\w+\([^\)]*\))"
|
||||
|
||||
matches = re.findall(pattern, pyautogui_code_relative_coordinates)
|
||||
|
||||
new_code = pyautogui_code_relative_coordinates
|
||||
|
||||
for full_call in matches:
|
||||
func_name_pattern = r"(pyautogui\.\w+)\((.*)\)"
|
||||
func_match = re.match(func_name_pattern, full_call, re.DOTALL)
|
||||
if not func_match:
|
||||
continue
|
||||
|
||||
func_name = func_match.group(1)
|
||||
args_str = func_match.group(2)
|
||||
|
||||
try:
|
||||
parsed = ast.parse(f"func({args_str})").body[0].value
|
||||
parsed_args = parsed.args
|
||||
parsed_keywords = parsed.keywords
|
||||
except SyntaxError:
|
||||
return pyautogui_code_relative_coordinates
|
||||
|
||||
function_parameters = {
|
||||
"click": ["x", "y", "clicks", "interval", "button", "duration", "pause"],
|
||||
"moveTo": ["x", "y", "duration", "tween", "pause"],
|
||||
"moveRel": ["xOffset", "yOffset", "duration", "tween", "pause"],
|
||||
"dragTo": ["x", "y", "duration", "button", "mouseDownUp", "pause"],
|
||||
"dragRel": [
|
||||
"xOffset",
|
||||
"yOffset",
|
||||
"duration",
|
||||
"button",
|
||||
"mouseDownUp",
|
||||
"pause",
|
||||
],
|
||||
"doubleClick": ["x", "y", "interval", "button", "duration", "pause"],
|
||||
}
|
||||
|
||||
func_base_name = func_name.split(".")[-1]
|
||||
|
||||
param_names = function_parameters.get(func_base_name, [])
|
||||
|
||||
args = {}
|
||||
for idx, arg in enumerate(parsed_args):
|
||||
if idx < len(param_names):
|
||||
param_name = param_names[idx]
|
||||
arg_value = ast.literal_eval(arg)
|
||||
args[param_name] = arg_value
|
||||
|
||||
try:
|
||||
for kw in parsed_keywords:
|
||||
param_name = kw.arg
|
||||
arg_value = ast.literal_eval(kw.value)
|
||||
args[param_name] = arg_value
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing keyword arguments: {e}")
|
||||
return pyautogui_code_relative_coordinates
|
||||
|
||||
updated = False
|
||||
if "x" in args and "y" in args:
|
||||
try:
|
||||
x_rel = float(args["x"])
|
||||
y_rel = float(args["y"])
|
||||
x_abs, y_abs = _coordinate_projection(
|
||||
x_rel, y_rel, screen_width, screen_height, coordinate_type
|
||||
)
|
||||
# logger.warning(f"Projecting coordinates: ({x_rel}, {y_rel}) to ({x_abs}, {y_abs}) using {coordinate_type} projection.")
|
||||
args["x"] = x_abs * width_scale
|
||||
args["y"] = y_abs * height_scale
|
||||
updated = True
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if "xOffset" in args and "yOffset" in args:
|
||||
try:
|
||||
x_rel = float(args["xOffset"])
|
||||
y_rel = float(args["yOffset"])
|
||||
x_abs, y_abs = _coordinate_projection(
|
||||
x_rel, y_rel, screen_width, screen_height, coordinate_type
|
||||
)
|
||||
args["xOffset"] = x_abs * width_scale
|
||||
args["yOffset"] = y_abs * height_scale
|
||||
updated = True
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if updated:
|
||||
reconstructed_args = []
|
||||
for idx, param_name in enumerate(param_names):
|
||||
if param_name in args:
|
||||
arg_value = args[param_name]
|
||||
if isinstance(arg_value, str):
|
||||
arg_repr = f"'{arg_value}'"
|
||||
else:
|
||||
arg_repr = str(arg_value)
|
||||
reconstructed_args.append(arg_repr)
|
||||
else:
|
||||
break
|
||||
|
||||
used_params = set(param_names[: len(reconstructed_args)])
|
||||
for kw in parsed_keywords:
|
||||
if kw.arg not in used_params:
|
||||
arg_value = args[kw.arg]
|
||||
if isinstance(arg_value, str):
|
||||
arg_repr = f"{kw.arg}='{arg_value}'"
|
||||
else:
|
||||
arg_repr = f"{kw.arg}={arg_value}"
|
||||
reconstructed_args.append(arg_repr)
|
||||
|
||||
new_args_str = ", ".join(reconstructed_args)
|
||||
new_full_call = f"{func_name}({new_args_str})"
|
||||
new_code = new_code.replace(full_call, new_full_call)
|
||||
|
||||
return new_code
|
||||
|
||||
|
||||
def split_args(args_str: str) -> List[str]:
|
||||
args = []
|
||||
current_arg = ""
|
||||
within_string = False
|
||||
string_char = ""
|
||||
prev_char = ""
|
||||
for char in args_str:
|
||||
if char in ['"', "'"]:
|
||||
if not within_string:
|
||||
within_string = True
|
||||
string_char = char
|
||||
elif within_string and prev_char != "\\" and char == string_char:
|
||||
within_string = False
|
||||
if char == "," and not within_string:
|
||||
args.append(current_arg)
|
||||
current_arg = ""
|
||||
else:
|
||||
current_arg += char
|
||||
prev_char = char
|
||||
if current_arg:
|
||||
args.append(current_arg)
|
||||
return args
|
||||
|
||||
|
||||
def correct_pyautogui_arguments(code: str) -> str:
|
||||
function_corrections = {
|
||||
"write": {
|
||||
"incorrect_args": ["text", "content"],
|
||||
"correct_args": [],
|
||||
"keyword_arg": "message",
|
||||
},
|
||||
"press": {
|
||||
"incorrect_args": ["key", "button"],
|
||||
"correct_args": [],
|
||||
"keyword_arg": None,
|
||||
},
|
||||
"hotkey": {
|
||||
"incorrect_args": ["key1", "key2", "keys"],
|
||||
"correct_args": [],
|
||||
"keyword_arg": None,
|
||||
},
|
||||
}
|
||||
|
||||
lines = code.strip().split("\n")
|
||||
corrected_lines = []
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
match = re.match(r"(pyautogui\.(\w+))\((.*)\)", line)
|
||||
if match:
|
||||
full_func_call = match.group(1)
|
||||
func_name = match.group(2)
|
||||
args_str = match.group(3)
|
||||
|
||||
if func_name in function_corrections:
|
||||
func_info = function_corrections[func_name]
|
||||
args = split_args(args_str)
|
||||
corrected_args = []
|
||||
|
||||
for arg in args:
|
||||
arg = arg.strip()
|
||||
kwarg_match = re.match(r"(\w+)\s*=\s*(.*)", arg)
|
||||
if kwarg_match:
|
||||
arg_name = kwarg_match.group(1)
|
||||
arg_value = kwarg_match.group(2)
|
||||
|
||||
if arg_name in func_info["incorrect_args"]:
|
||||
if func_info["keyword_arg"]:
|
||||
corrected_args.append(
|
||||
f"{func_info['keyword_arg']}={arg_value}"
|
||||
)
|
||||
else:
|
||||
corrected_args.append(arg_value)
|
||||
else:
|
||||
corrected_args.append(f"{arg_name}={arg_value}")
|
||||
else:
|
||||
corrected_args.append(arg)
|
||||
|
||||
corrected_args_str = ", ".join(corrected_args)
|
||||
corrected_line = f"{full_func_call}({corrected_args_str})"
|
||||
corrected_lines.append(corrected_line)
|
||||
else:
|
||||
corrected_lines.append(line)
|
||||
else:
|
||||
corrected_lines.append(line)
|
||||
|
||||
corrected_code = "\n".join(corrected_lines)
|
||||
return corrected_code
|
||||
|
||||
def image_message_from_obs(obs, for_training=False):
|
||||
if not for_training:
|
||||
return {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{encode_image(obs['screenshot'])}",
|
||||
"detail": "high",
|
||||
},
|
||||
}
|
||||
else:
|
||||
return {"type": "image_url", "image_url": {"url": obs["screenshot_path"]}}
|
||||
|
|
@ -1,736 +0,0 @@
|
|||
"""
|
||||
OpenCUA Agent Implementation
|
||||
|
||||
This module implements an OpenCUA agent for desktop automation tasks, building upon
|
||||
existing frameworks and integrating multiple coordinate mapping systems.
|
||||
|
||||
Framework and Implementation Sources:
|
||||
- Main framework structure follows: https://github.com/xlang-ai/OSWorld/blob/main/mm_agents/agent.py
|
||||
- Agent implementation adapted from: https://github.com/xlang-ai/OSWorld/blob/main/mm_agents/aguvis_agent.py
|
||||
- Qwen2.5-VL coordinate mapping from: https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py
|
||||
"""
|
||||
|
||||
import re
|
||||
import os
|
||||
import ast
|
||||
import time
|
||||
import math
|
||||
import httpx
|
||||
import base64
|
||||
import backoff
|
||||
from loguru import logger
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
|
||||
# System prompts used in the training data
|
||||
AGNET_SYS_PROMPT_L1 = "You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task.\n\nFor each step, provide your response in this format:\n\nAction:\n Provide clear, concise, and actionable instructions:\n - If the action involves interacting with a specific target:\n - Describe target explicitly without using coordinates\n - Specify element names when possible (use original language if non-English)\n - Describe features (shape, color, position) if name unavailable\n - For window control buttons, identify correctly (minimize \"—\", maximize \"□\", close \"X\")\n - if the action involves keyboard actions like 'press', 'write', 'hotkey':\n - Consolidate repetitive keypresses with count\n - Specify expected text outcome for typing actions\n\nFinally, output the action as PyAutoGUI code or the following functions:\n- {\"name\": \"computer.triple_click\", \"description\": \"Triple click on the screen\", \"parameters\": {\"type\": \"object\", \"properties\": {\"x\": {\"type\": \"number\", \"description\": \"The x coordinate of the triple click\"}, \"y\": {\"type\": \"number\", \"description\": \"The y coordinate of the triple click\"}}, \"required\": [\"x\", \"y\"]}}\n- {\"name\": \"computer.terminate\", \"description\": \"Terminate the current task and report its completion status\", \"parameters\": {\"type\": \"object\", \"properties\": {\"status\": {\"type\": \"string\", \"enum\": [\"success\", \"fail\"], \"description\": \"The status of the task\"}}, \"required\": [\"status\"]}}".strip()
|
||||
# AGNET_SYS_PROMPT_L2 = "You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task.\n\nFor each step, provide your response in this format:\n\nThought:\n - Step by Step Progress Assessment:\n - Analyze completed task parts and their contribution to the overall goal\n - Reflect on potential errors, unexpected results, or obstacles\n - If previous action was incorrect, predict a logical recovery step\n - Next Action Analysis:\n - List possible next actions based on current state\n - Evaluate options considering current state and previous actions\n - Propose most logical next action\n - Anticipate consequences of the proposed action\n - For Text Input Actions:\n - Note current cursor position\n - Consolidate repetitive actions (specify count for multiple keypresses)\n - Describe expected final text outcome\n - Use first-person perspective in reasoning\n\nAction:\n Provide clear, concise, and actionable instructions:\n - If the action involves interacting with a specific target:\n - Describe target explicitly without using coordinates\n - Specify element names when possible (use original language if non-English)\n - Describe features (shape, color, position) if name unavailable\n - For window control buttons, identify correctly (minimize \"—\", maximize \"□\", close \"X\")\n - if the action involves keyboard actions like 'press', 'write', 'hotkey':\n - Consolidate repetitive keypresses with count\n - Specify expected text outcome for typing actions\n\nFinally, output the action as PyAutoGUI code or the following functions:\n- {\"name\": \"computer.triple_click\", \"description\": \"Triple click on the screen\", \"parameters\": {\"type\": \"object\", \"properties\": {\"x\": {\"type\": \"number\", \"description\": \"The x coordinate of the triple click\"}, \"y\": {\"type\": \"number\", \"description\": \"The y coordinate of the triple click\"}}, \"required\": [\"x\", \"y\"]}}\n- {\"name\": \"computer.terminate\", \"description\": \"Terminate the current task and report its completion status\", \"parameters\": {\"type\": \"object\", \"properties\": {\"status\": {\"type\": \"string\", \"enum\": [\"success\", \"fail\"], \"description\": \"The status of the task\"}}, \"required\": [\"status\"]}}".strip()
|
||||
AGNET_SYS_PROMPT_L3 = "You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task.\n\nFor each step, provide your response in this format:\n\nObservation:\n - Describe the current computer state based on the full screenshot in detail. \n - Application Context:\n - The active application\n - The active window or page\n - Overall layout and visible interface\n - Key Elements:\n - Menu items and toolbars \n - Buttons and controls\n - Text fields and content\n - Dialog boxes or popups\n - Error messages or notifications\n - Loading states\n - Other key elements\n - Describe any content, elements, options, information or clues that are possibly relevant to achieving the task goal, including their name, content, or shape (if possible).\n\nThought:\n - Step by Step Progress Assessment:\n - Analyze completed task parts and their contribution to the overall goal\n - Reflect on potential errors, unexpected results, or obstacles\n - If previous action was incorrect, predict a logical recovery step\n - Next Action Analysis:\n - List possible next actions based on current state\n - Evaluate options considering current state and previous actions\n - Propose most logical next action\n - Anticipate consequences of the proposed action\n - For Text Input Actions:\n - Note current cursor position\n - Consolidate repetitive actions (specify count for multiple keypresses)\n - Describe expected final text outcome\n - Use first-person perspective in reasoning\n\nAction:\n Provide clear, concise, and actionable instructions:\n - If the action involves interacting with a specific target:\n - Describe target explicitly without using coordinates\n - Specify element names when possible (use original language if non-English)\n - Describe features (shape, color, position) if name unavailable\n - For window control buttons, identify correctly (minimize \"—\", maximize \"□\", close \"X\")\n - if the action involves keyboard actions like 'press', 'write', 'hotkey':\n - Consolidate repetitive keypresses with count\n - Specify expected text outcome for typing actions\n\nFinally, output the action as PyAutoGUI code or the following functions:\n- {\"name\": \"computer.triple_click\", \"description\": \"Triple click on the screen\", \"parameters\": {\"type\": \"object\", \"properties\": {\"x\": {\"type\": \"number\", \"description\": \"The x coordinate of the triple click\"}, \"y\": {\"type\": \"number\", \"description\": \"The y coordinate of the triple click\"}}, \"required\": [\"x\", \"y\"]}}\n- {\"name\": \"computer.terminate\", \"description\": \"Terminate the current task and report its completion status\", \"parameters\": {\"type\": \"object\", \"properties\": {\"status\": {\"type\": \"string\", \"enum\": [\"success\", \"fail\"], \"description\": \"The status of the task\"}}, \"required\": [\"status\"]}}\n".strip()
|
||||
|
||||
# Testing prompt on OSWorld-Verified
|
||||
AGNET_SYS_PROMPT_L2 = """You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task. The password of the computer is "osworld-public-evaluation". If the task is not possible to do, output the action computer.terminate(status='failure').
|
||||
|
||||
For each step, provide your response in this format:
|
||||
|
||||
Thought:\n - Step by Step Progress Assessment:\n - Analyze completed task parts and their contribution to the overall goal\n - Reflect on potential errors, unexpected results, or obstacles\n - If previous action was incorrect, predict a logical recovery step\n - Next Action Analysis:\n - List possible next actions based on current state\n - Evaluate options considering current state and previous actions\n - Propose most logical next action\n - Anticipate consequences of the proposed action\n - For Text Input Actions:\n - Note current cursor position\n - Consolidate repetitive actions (specify count for multiple keypresses)\n - Describe expected final text outcome\n - Use first-person perspective in reasoning
|
||||
|
||||
Action:\n Provide clear, concise, and actionable instructions:\n - If the action involves interacting with a specific target:\n - Describe target explicitly without using coordinates\n - Specify element names when possible (use original language if non-English)\n - Describe features (shape, color, position) if name unavailable\n - For window control buttons, identify correctly (minimize "—", maximize "□", close "X")\n - if the action involves keyboard actions like \'press\', \'write\', \'hotkey\':\n - Consolidate repetitive keypresses with count\n - Specify expected text outcome for typing actions
|
||||
|
||||
Finally, output the action as PyAutoGUI code or the following functions:
|
||||
- {"name": "computer.triple_click", "description": "Triple click on the screen", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The x coordinate of the triple click"}, "y": {"type": "number", "description": "The y coordinate of the triple click"}}, "required": ["x", "y"]}}
|
||||
- {"name": "computer.terminate", "description": "Terminate the current task and report its completion status", "parameters": {"type": "object", "properties": {"status": {"type": "string", "enum": ["success", "failure"], "description": "The status of the task"}}, "required": ["status"]}}
|
||||
""".strip()
|
||||
|
||||
|
||||
STEP_TEMPLATE = "# Step {step_num}:\n"
|
||||
INSTRUTION_TEMPLATE = "# Task Instruction:\n{instruction}\n\nPlease generate the next move according to the screenshot, task instruction and previous steps (if provided).\n"
|
||||
|
||||
ACTION_HISTORY_TEMPLATE = "## Action:\n{action}\n"
|
||||
THOUGHT_HISTORY_TEMPLATE = "## Thought:\n{thought}\n\n## Action:\n{action}\n"
|
||||
OBSERVATION_HISTORY_TEMPLATE = "## Observation:\n{observation}\n\n## Thought:\n{thought}\n\n## Action:\n{action}\n"
|
||||
DETAIL_HISTORY_TEMPLATE = "## Thought:\n{thought}\n\n## Action:\n{action}\n\n## Code:\n{code}\n"
|
||||
|
||||
|
||||
def encode_image(image_content):
|
||||
"""Encode the image to base64"""
|
||||
return base64.b64encode(image_content).decode('utf-8')
|
||||
|
||||
def parse_response_to_cot_and_action(input_string, screen_size, coordinate_type) -> Tuple[str, List[str], dict]:
|
||||
"""Parse response including Observation, Thought, Action and code block"""
|
||||
try:
|
||||
sections = {}
|
||||
|
||||
obs_match = re.search(r'^##\s*Observation\s*:?[\n\r]+(.*?)(?=^##\s*Thought:|^##\s*Action:|^##|\Z)', input_string, re.DOTALL | re.MULTILINE)
|
||||
if obs_match:
|
||||
sections['observation'] = obs_match.group(1).strip()
|
||||
|
||||
thought_match = re.search(r'^##\s*Thought\s*:?[\n\r]+(.*?)(?=^##\s*Action:|^##|\Z)', input_string, re.DOTALL | re.MULTILINE)
|
||||
if thought_match:
|
||||
sections['thought'] = thought_match.group(1).strip()
|
||||
|
||||
action_match = re.search(r'^##\s*Action\s*:?[\n\r]+(.*?)(?=^##|\Z)', input_string, re.DOTALL | re.MULTILINE)
|
||||
if action_match:
|
||||
action = action_match.group(1).strip()
|
||||
sections['action'] = action.strip()
|
||||
|
||||
if "computer.terminate" in input_string.lower():
|
||||
# Look for code blocks that might contain terminate command
|
||||
code_blocks = re.findall(r'```(?:code|python)?\s*(.*?)\s*```', input_string, re.DOTALL | re.IGNORECASE)
|
||||
if code_blocks:
|
||||
last_code = code_blocks[-1].strip().lower()
|
||||
if "fail" in last_code:
|
||||
sections['code'] = "FAIL"
|
||||
return "FAIL", ["FAIL"], sections
|
||||
elif "success" in last_code:
|
||||
sections['code'] = "DONE"
|
||||
return "DONE", ["DONE"], sections
|
||||
# Default to DONE if terminate is mentioned but no specific status
|
||||
sections['code'] = "DONE"
|
||||
return "DONE", ["DONE"], sections
|
||||
|
||||
code_blocks = re.findall(r'```(?:python)\s*(.*?)\s*```', input_string, re.DOTALL)
|
||||
if code_blocks:
|
||||
code = code_blocks[-1].strip()
|
||||
sections['original_code'] = transform_agnet_action_to_code_block(code)
|
||||
corrected_code = correct_pyautogui_arguments(code)
|
||||
sections['code'] = corrected_code
|
||||
sections['code'] = project_coordinate_to_absolute_scale(corrected_code, screen_width=screen_size[0], screen_height=screen_size[1], coordinate_type=coordinate_type)
|
||||
else:
|
||||
# No code blocks found
|
||||
sections['code'] = "WAIT"
|
||||
return "WAIT", ["WAIT"], sections
|
||||
|
||||
if 'code' not in sections:
|
||||
logger.error("Missing required action or code section")
|
||||
return None, None, {}
|
||||
|
||||
if 'action' not in sections:
|
||||
sections['action'] = ""
|
||||
|
||||
return sections['action'], [sections['code']], sections
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error parsing response: {str(e)}\nInput string: {input_string}")
|
||||
return None, None, {}
|
||||
|
||||
def correct_pyautogui_arguments(code: str) -> str:
|
||||
"""Correct the pyautogui arguments"""
|
||||
function_corrections = {
|
||||
'write': {
|
||||
'incorrect_args': ['text', 'content'],
|
||||
'correct_args': [],
|
||||
'keyword_arg': 'message'
|
||||
},
|
||||
'press': {
|
||||
'incorrect_args': ['key', 'button'],
|
||||
'correct_args': [],
|
||||
'keyword_arg': None
|
||||
},
|
||||
'hotkey': {
|
||||
'incorrect_args': ['key1', 'key2', 'keys'],
|
||||
'correct_args': [],
|
||||
'keyword_arg': None
|
||||
},
|
||||
}
|
||||
|
||||
lines = code.strip().split('\n')
|
||||
corrected_lines = []
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
match = re.match(r'(pyautogui\.(\w+))\((.*)\)', line)
|
||||
if match:
|
||||
full_func_call = match.group(1)
|
||||
func_name = match.group(2)
|
||||
args_str = match.group(3)
|
||||
|
||||
if func_name in function_corrections:
|
||||
func_info = function_corrections[func_name]
|
||||
args = split_args(args_str)
|
||||
corrected_args = []
|
||||
|
||||
for arg in args:
|
||||
arg = arg.strip()
|
||||
kwarg_match = re.match(r'(\w+)\s*=\s*(.*)', arg)
|
||||
if kwarg_match:
|
||||
arg_name = kwarg_match.group(1)
|
||||
arg_value = kwarg_match.group(2)
|
||||
|
||||
if arg_name in func_info['incorrect_args']:
|
||||
if func_info['keyword_arg']:
|
||||
corrected_args.append(f"{func_info['keyword_arg']}={arg_value}")
|
||||
else:
|
||||
corrected_args.append(arg_value)
|
||||
else:
|
||||
corrected_args.append(f'{arg_name}={arg_value}')
|
||||
else:
|
||||
corrected_args.append(arg)
|
||||
|
||||
corrected_args_str = ', '.join(corrected_args)
|
||||
corrected_line = f'{full_func_call}({corrected_args_str})'
|
||||
corrected_lines.append(corrected_line)
|
||||
else:
|
||||
corrected_lines.append(line)
|
||||
else:
|
||||
corrected_lines.append(line)
|
||||
|
||||
corrected_code = '\n'.join(corrected_lines)
|
||||
return corrected_code
|
||||
|
||||
def split_args(args_str: str) -> List[str]:
|
||||
"""Split the arguments string into a list of arguments"""
|
||||
args = []
|
||||
current_arg = ''
|
||||
within_string = False
|
||||
string_char = ''
|
||||
prev_char = ''
|
||||
for char in args_str:
|
||||
if char in ['"', "'"]:
|
||||
if not within_string:
|
||||
within_string = True
|
||||
string_char = char
|
||||
elif within_string and prev_char != '\\' and char == string_char:
|
||||
within_string = False
|
||||
if char == ',' and not within_string:
|
||||
args.append(current_arg)
|
||||
current_arg = ''
|
||||
else:
|
||||
current_arg += char
|
||||
prev_char = char
|
||||
if current_arg:
|
||||
args.append(current_arg)
|
||||
return args
|
||||
|
||||
def smart_resize(
|
||||
height: int,
|
||||
width: int,
|
||||
factor: int,
|
||||
min_pixels: int,
|
||||
max_pixels: int,
|
||||
max_aspect_ratio_allowed: Optional[float] = None,
|
||||
size_can_be_smaller_than_factor: bool = False,
|
||||
):
|
||||
"""
|
||||
The function is modified from https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py
|
||||
|
||||
Qwen2.5-VL based model need this function to resize screenshots.
|
||||
|
||||
Rescales the image so that the following conditions are met:
|
||||
1. Both dimensions (height and width) are divisible by 'factor'.
|
||||
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
||||
3. The aspect ratio of the image is maintained as closely as possible.
|
||||
|
||||
"""
|
||||
if not size_can_be_smaller_than_factor and (height < factor or width < factor):
|
||||
raise ValueError(
|
||||
f"height:{height} or width:{width} must be larger than factor:{factor} "
|
||||
f"(when size_can_be_smaller_than_factor is False)"
|
||||
)
|
||||
elif max_aspect_ratio_allowed is not None and max(height, width) / min(height, width) > max_aspect_ratio_allowed:
|
||||
raise ValueError(
|
||||
f"absolute aspect ratio must be smaller than {max_aspect_ratio_allowed}, "
|
||||
f"got {max(height, width) / min(height, width)}"
|
||||
f"(when max_aspect_ratio_allowed is not None)"
|
||||
)
|
||||
h_bar = max(1, round(height / factor)) * factor
|
||||
w_bar = max(1, round(width / factor)) * factor
|
||||
if h_bar * w_bar > max_pixels:
|
||||
beta = math.sqrt((height * width) / max_pixels)
|
||||
h_bar = max(1, math.floor(height / beta / factor)) * factor
|
||||
w_bar = max(1, math.floor(width / beta / factor)) * factor
|
||||
elif h_bar * w_bar < min_pixels:
|
||||
beta = math.sqrt(min_pixels / (height * width))
|
||||
h_bar = math.ceil(height * beta / factor) * factor
|
||||
w_bar = math.ceil(width * beta / factor) * factor
|
||||
return h_bar, w_bar
|
||||
|
||||
def _coordinate_projection(x, y, screen_width, screen_height, coordinate_type):
|
||||
"""Project the coordinates to the absolute scale"""
|
||||
if coordinate_type == "relative":
|
||||
return int(round(x * screen_width)), int(round(y * screen_height))
|
||||
elif coordinate_type == "absolute":
|
||||
return x, y
|
||||
elif coordinate_type == "qwen25":
|
||||
if 0 <= x <= 1 and 0 <= y <= 1:
|
||||
# If already normalized, treat like "relative"
|
||||
return int(round(x * screen_width)), int(round(y * screen_height))
|
||||
|
||||
height, width = smart_resize(
|
||||
height=screen_height,
|
||||
width=screen_width,
|
||||
factor=28,
|
||||
min_pixels=3136,
|
||||
max_pixels=12845056 # We use this max_pixels setting in our training data
|
||||
)
|
||||
return int(x / width * screen_width), int(y / height * screen_height)
|
||||
else:
|
||||
raise ValueError(f"Unsupported coordinate type: {coordinate_type}")
|
||||
|
||||
def project_coordinate_to_absolute_scale(pyautogui_code_relative_coordinates, screen_width, screen_height, coordinate_type="relative"):
|
||||
"""Convert the relative coordinates in the pyautogui code to absolute coordinates based on the logical screen size."""
|
||||
if coordinate_type not in ["relative", "relative1000", "absolute", "qwen25"]:
|
||||
raise ValueError(f"Invalid coordinate type: {coordinate_type}. Expected one of ['relative', 'relative1000', 'absolute', 'qwen25'].")
|
||||
|
||||
pattern = r'(pyautogui\.\w+\([^\)]*\))'
|
||||
matches = re.findall(pattern, pyautogui_code_relative_coordinates)
|
||||
|
||||
new_code = pyautogui_code_relative_coordinates
|
||||
|
||||
for full_call in matches:
|
||||
func_name_pattern = r'(pyautogui\.\w+)\((.*)\)'
|
||||
func_match = re.match(func_name_pattern, full_call, re.DOTALL)
|
||||
if not func_match:
|
||||
continue
|
||||
|
||||
func_name = func_match.group(1)
|
||||
args_str = func_match.group(2)
|
||||
|
||||
try:
|
||||
parsed = ast.parse(f"func({args_str})").body[0].value
|
||||
parsed_args = parsed.args
|
||||
parsed_keywords = parsed.keywords
|
||||
except SyntaxError:
|
||||
return pyautogui_code_relative_coordinates
|
||||
|
||||
function_parameters = {
|
||||
'click': ['x', 'y', 'clicks', 'interval', 'button', 'duration', 'pause'],
|
||||
'moveTo': ['x', 'y', 'duration', 'tween', 'pause'],
|
||||
'moveRel': ['xOffset', 'yOffset', 'duration', 'tween', 'pause'],
|
||||
'dragTo': ['x', 'y', 'duration', 'button', 'mouseDownUp', 'pause'],
|
||||
'dragRel': ['xOffset', 'yOffset', 'duration', 'button', 'mouseDownUp', 'pause'],
|
||||
'doubleClick': ['x', 'y', 'interval', 'button', 'duration', 'pause'],
|
||||
}
|
||||
|
||||
func_base_name = func_name.split('.')[-1]
|
||||
|
||||
param_names = function_parameters.get(func_base_name, [])
|
||||
|
||||
args = {}
|
||||
for idx, arg in enumerate(parsed_args):
|
||||
if idx < len(param_names):
|
||||
param_name = param_names[idx]
|
||||
arg_value = ast.literal_eval(arg)
|
||||
args[param_name] = arg_value
|
||||
|
||||
try:
|
||||
for kw in parsed_keywords:
|
||||
param_name = kw.arg
|
||||
arg_value = ast.literal_eval(kw.value)
|
||||
args[param_name] = arg_value
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing keyword arguments: {e}")
|
||||
return pyautogui_code_relative_coordinates
|
||||
|
||||
updated = False
|
||||
if 'x' in args and 'y' in args:
|
||||
try:
|
||||
x_rel = float(args['x'])
|
||||
y_rel = float(args['y'])
|
||||
x_abs, y_abs = _coordinate_projection(x_rel, y_rel, screen_width, screen_height, coordinate_type)
|
||||
logger.warning(f"Projecting coordinates: ({x_rel}, {y_rel}) to ({x_abs}, {y_abs}) using {coordinate_type} projection.")
|
||||
args['x'] = x_abs
|
||||
args['y'] = y_abs
|
||||
updated = True
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if 'xOffset' in args and 'yOffset' in args:
|
||||
try:
|
||||
x_rel = float(args['xOffset'])
|
||||
y_rel = float(args['yOffset'])
|
||||
x_abs, y_abs = _coordinate_projection(x_rel, y_rel, screen_width, screen_height, coordinate_type)
|
||||
args['xOffset'] = x_abs
|
||||
args['yOffset'] = y_abs
|
||||
updated = True
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if updated:
|
||||
reconstructed_args = []
|
||||
for idx, param_name in enumerate(param_names):
|
||||
if param_name in args:
|
||||
arg_value = args[param_name]
|
||||
if isinstance(arg_value, str):
|
||||
arg_repr = f"'{arg_value}'"
|
||||
else:
|
||||
arg_repr = str(arg_value)
|
||||
reconstructed_args.append(arg_repr)
|
||||
else:
|
||||
break
|
||||
|
||||
used_params = set(param_names[:len(reconstructed_args)])
|
||||
for kw in parsed_keywords:
|
||||
if kw.arg not in used_params:
|
||||
arg_value = args[kw.arg]
|
||||
if isinstance(arg_value, str):
|
||||
arg_repr = f"{kw.arg}='{arg_value}'"
|
||||
else:
|
||||
arg_repr = f"{kw.arg}={arg_value}"
|
||||
reconstructed_args.append(arg_repr)
|
||||
|
||||
new_args_str = ', '.join(reconstructed_args)
|
||||
new_full_call = f"{func_name}({new_args_str})"
|
||||
new_code = new_code.replace(full_call, new_full_call)
|
||||
|
||||
return new_code
|
||||
|
||||
def extract_positions_and_instructions(code, action) -> list[dict]:
|
||||
"""
|
||||
Extracts all `(x, y)` coordinates (both positional and keyword arguments)
|
||||
and their associated preceding comments as instructions from Python code.
|
||||
If there are no comments, use the corresponding action instead.
|
||||
|
||||
Args:
|
||||
code (str): The Python code as a string.
|
||||
action (str): The low-level action as a string.
|
||||
|
||||
Returns:
|
||||
list[dict]: A list of dictionaries with extracted positions and instructions.
|
||||
- function (str): The pyautogui function name.
|
||||
- x (int or float): The x-coordinate.
|
||||
- y (int or float): The y-coordinate.
|
||||
- instruction (str): The preceding comment as an instruction.
|
||||
"""
|
||||
lines = code.splitlines()
|
||||
extracted = []
|
||||
preceding_comment = action # To store the preceding comment
|
||||
|
||||
for line in lines:
|
||||
preceding_comment = action
|
||||
# Check if the line is a comment and store it
|
||||
if line.strip().startswith("#"):
|
||||
preceding_comment = line.strip().lstrip("#").strip() # Clean the comment
|
||||
|
||||
# Match pyautogui functions with positional arguments
|
||||
match_positional = re.match(r"(pyautogui\.\w+)\((\d+(\.\d+)?),\s*(\d+(\.\d+)?).*?\)", line)
|
||||
if match_positional:
|
||||
extracted.append({
|
||||
"function": match_positional.group(1), # pyautogui function name
|
||||
"x": float(match_positional.group(2)) if '.' in match_positional.group(2)\
|
||||
else int(match_positional.group(2)), # x-coordinate
|
||||
"y": float(match_positional.group(4)) if '.' in match_positional.group(4)\
|
||||
else int(match_positional.group(3)), # y-coordinate
|
||||
"instruction": preceding_comment, # Use the preceding comment
|
||||
})
|
||||
preceding_comment = None # Reset after associating it with a line
|
||||
continue
|
||||
|
||||
# Match pyautogui functions with keyword arguments
|
||||
match_keyword = re.match(r"(pyautogui\.\w+)\(.*?x=(\d+(\.\d+)?),\s*y=(\d+(\.\d+)?).*?\)", line)
|
||||
if match_keyword:
|
||||
extracted.append({
|
||||
"function": match_keyword.group(1), # pyautogui function name
|
||||
"x": float(match_keyword.group(2)) if '.' in match_keyword.group(2)\
|
||||
else int(match_keyword.group(2)), # x-coordinate
|
||||
"y": float(match_keyword.group(4)) if '.' in match_keyword.group(4)\
|
||||
else int(match_keyword.group(3)), # y-coordinate
|
||||
"instruction": preceding_comment, # Use the preceding comment
|
||||
})
|
||||
preceding_comment = None # Reset after associating it with a line
|
||||
|
||||
logger.info(f"Grounding extracted:\n{extracted}")
|
||||
return extracted
|
||||
|
||||
def update_code_with_new_coordinates(code, updated_positions):
|
||||
"""
|
||||
Replaces old `(x, y)` coordinates (both positional and keyword arguments)
|
||||
with updated ones in the code, handling multiple occurrences correctly.
|
||||
|
||||
Args:
|
||||
code (str): The original Python code as a string.
|
||||
updated_positions (list): A list of dictionaries with updated positions.
|
||||
|
||||
Returns:
|
||||
str: The updated Python code.
|
||||
"""
|
||||
|
||||
lines = code.splitlines()
|
||||
updated_code_lines = []
|
||||
position_index = 0 # Tracks which position update to use
|
||||
|
||||
for line in lines:
|
||||
if position_index < len(updated_positions):
|
||||
# Get the next update position
|
||||
update = updated_positions[position_index]
|
||||
function_pattern_positional = rf"{update['function']}\(\d+(\.\d+)?, \d+(\.\d+)?"
|
||||
function_pattern_keyword = rf"{update['function']}\(.*?x=\d+(\.\d+)?, y=\d+(\.\d+)?"
|
||||
|
||||
if re.search(function_pattern_positional, line):
|
||||
# Replace positional arguments
|
||||
line = re.sub(
|
||||
function_pattern_positional,
|
||||
f"{update['function']}({update['x']}, {update['y']}",
|
||||
line,
|
||||
count=1
|
||||
)
|
||||
position_index += 1 # Move to the next update
|
||||
elif re.search(function_pattern_keyword, line):
|
||||
# Replace keyword arguments
|
||||
line = re.sub(
|
||||
function_pattern_keyword,
|
||||
f"{update['function']}(x={update['x']}, y={update['y']}",
|
||||
line,
|
||||
count=1
|
||||
)
|
||||
position_index += 1 # Move to the next update
|
||||
|
||||
updated_code_lines.append(line)
|
||||
|
||||
return "\n".join(updated_code_lines)
|
||||
|
||||
def transform_agnet_action_to_code_block(action):
|
||||
"""Transform the agent action to a code block: not used in agent, for logging only"""
|
||||
if "computer.terminate" in action or "browser.select_option" in action or "browser.clear" in action:
|
||||
return f"```code\n{action}\n```"
|
||||
else:
|
||||
return f"```python\n{action}\n```"
|
||||
|
||||
class OpenCUAAgent:
|
||||
"""
|
||||
OpenCUA Agent for desktop automation tasks.
|
||||
|
||||
This class implements a OpenCUA Model based agent that can observe
|
||||
desktop environments through screenshots and execute mouse/keyboard actions
|
||||
via PyAutoGUI to complete automation tasks.
|
||||
|
||||
Attributes:
|
||||
model (str): Name of the language model being used
|
||||
history_type (str): Type of history recording mechanism
|
||||
actions (list): History of executed actions
|
||||
observations (list): History of environment observations
|
||||
cots (list): Chain of thought reasoning records
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model: str, # OpenCUA model name
|
||||
history_type: str, # History step type: action_history, thought_history, observation_history
|
||||
max_image_history_length: int = 3, # The max number of images in the history
|
||||
platform: str = "ubuntu", # The platform of the computer
|
||||
max_tokens: int = 1500, # The max number of tokens in the response
|
||||
top_p: float = 0.9, # The top p value in the response
|
||||
temperature: float = 0, # The temperature value in the response
|
||||
action_space: str = "pyautogui", # The action space: pyautogui
|
||||
observation_type: str = "screenshot", # The observation type: screenshot
|
||||
cot_level: str = "l2", # The CoT level: l1, l2, l3
|
||||
screen_size: Tuple[int, int] = (1920, 1080), # The screen size
|
||||
coordinate_type: str = "relative", # The coordinate type: relative, absolute, qwen25
|
||||
**kwargs
|
||||
):
|
||||
assert coordinate_type in ["relative", "absolute", "qwen25"]
|
||||
assert action_space in ["pyautogui"], "Invalid action space"
|
||||
assert observation_type in ["screenshot"], "Invalid observation type"
|
||||
assert history_type in ["action_history", "thought_history", "observation_history"]
|
||||
assert model is not None, "Model cannot be None"
|
||||
|
||||
self.model = model
|
||||
self.platform = platform
|
||||
self.max_tokens = max_tokens
|
||||
self.top_p = top_p
|
||||
self.temperature = temperature
|
||||
self.action_space = action_space
|
||||
self.observation_type = observation_type
|
||||
self.history_type = history_type
|
||||
self.coordinate_type = coordinate_type
|
||||
self.cot_level = cot_level
|
||||
self.screen_size = screen_size
|
||||
self.max_image_history_length = max_image_history_length
|
||||
|
||||
if history_type == "action_history":
|
||||
self.HISTORY_TEMPLATE = ACTION_HISTORY_TEMPLATE
|
||||
elif history_type == "thought_history":
|
||||
self.HISTORY_TEMPLATE = THOUGHT_HISTORY_TEMPLATE
|
||||
elif history_type == "observation_history":
|
||||
self.HISTORY_TEMPLATE = OBSERVATION_HISTORY_TEMPLATE
|
||||
else:
|
||||
raise ValueError(f"Invalid history type: {history_type}")
|
||||
|
||||
if cot_level == "l3":
|
||||
self.SYSTEM_PROMPT = AGNET_SYS_PROMPT_L3
|
||||
elif cot_level == "l2":
|
||||
self.SYSTEM_PROMPT = AGNET_SYS_PROMPT_L2
|
||||
elif cot_level == "l1":
|
||||
self.SYSTEM_PROMPT = AGNET_SYS_PROMPT_L1
|
||||
else:
|
||||
raise ValueError(f"Invalid COT level: {cot_level}")
|
||||
|
||||
self.actions = []
|
||||
self.observations = []
|
||||
self.cots = []
|
||||
|
||||
def reset(self, _logger=None):
|
||||
global logger
|
||||
logger = _logger if _logger is not None else logging.getLogger("desktopenv.agent")
|
||||
|
||||
self.observations = []
|
||||
self.cots = []
|
||||
self.actions = []
|
||||
|
||||
def _scale_scroll_for_windows(self, code: str, factor: int = 50) -> str:
|
||||
""" pyautogui.scroll has a different scale on Ubuntu and Windows, multiple 'factor' when scrolling on Windows system"""
|
||||
if self.platform.lower() != "windows":
|
||||
return code
|
||||
|
||||
pattern_pos = re.compile(r'(pyautogui\.scroll\()\s*([-+]?\d+)\s*\)')
|
||||
code = pattern_pos.sub(lambda m: f"{m.group(1)}{int(m.group(2))*factor})", code)
|
||||
return code
|
||||
|
||||
def predict(self, instruction: str, obs: Dict, **kwargs) -> Tuple[str, List[str], Dict]:
|
||||
"""
|
||||
Predict the next action(s) based on the current observation.
|
||||
"""
|
||||
if "step_idx" in kwargs:
|
||||
logger.info(f"========= {self.model} Step {kwargs['step_idx']} =======")
|
||||
else:
|
||||
logger.info(f"========================== {self.model} ===================================")
|
||||
logger.info(f"Instruction: \n{instruction}")
|
||||
|
||||
messages = []
|
||||
messages.append({
|
||||
"role": "system",
|
||||
"content": self.SYSTEM_PROMPT
|
||||
})
|
||||
|
||||
history_step_texts = []
|
||||
for i in range(len(self.actions)):
|
||||
if i > len(self.actions) - self.max_image_history_length:
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{encode_image(self.observations[i]['screenshot'])}"}
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
history_content = STEP_TEMPLATE.format(step_num=i+1) + self.HISTORY_TEMPLATE.format(
|
||||
observation=self.cots[i].get('observation'),
|
||||
thought=self.cots[i].get('thought'),
|
||||
action=self.cots[i].get('action')
|
||||
)
|
||||
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": history_content
|
||||
})
|
||||
else:
|
||||
history_content = STEP_TEMPLATE.format(step_num=i+1) + self.HISTORY_TEMPLATE.format(
|
||||
observation=self.cots[i].get('observation'),
|
||||
thought=self.cots[i].get('thought'),
|
||||
action=self.cots[i].get('action')
|
||||
)
|
||||
history_step_texts.append(history_content)
|
||||
if i == len(self.actions) - self.max_image_history_length:
|
||||
messages.append({
|
||||
"role":"assistant",
|
||||
"content": "\n".join(history_step_texts)
|
||||
})
|
||||
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{encode_image(obs['screenshot'])}"}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": INSTRUTION_TEMPLATE.format(instruction=instruction)
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
response = self.call_llm({
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": self.max_tokens,
|
||||
"top_p": self.top_p,
|
||||
"temperature": self.temperature
|
||||
}, self.model)
|
||||
|
||||
logger.info(f"Model Output: \n{response}")
|
||||
if not response:
|
||||
logger.error("No response found in the response.")
|
||||
return "ERROR", ["DONE"], {}
|
||||
|
||||
low_level_instruction, pyautogui_actions, other_cot = parse_response_to_cot_and_action(response, self.screen_size, self.coordinate_type)
|
||||
if not pyautogui_actions or len(pyautogui_actions) == 0:
|
||||
logger.error("No pyautogui actions found in the response.")
|
||||
return response, ["FAIL"], {}
|
||||
|
||||
pyautogui_actions = [
|
||||
self._scale_scroll_for_windows(code) for code in pyautogui_actions
|
||||
]
|
||||
|
||||
self.observations.append(obs)
|
||||
logger.info(f"Parsed Low-level Action: \n{low_level_instruction}")
|
||||
logger.info(f"Parsed pyautogui Action: \n{pyautogui_actions}")
|
||||
|
||||
self.actions.append(low_level_instruction)
|
||||
if 'action' not in other_cot or not other_cot['action'] or 'thought' not in other_cot or not other_cot['thought']:
|
||||
logger.error("Error! no action/thought in cot")
|
||||
logger.error(f"response: {response}")
|
||||
logger.error(f"cot: {other_cot}")
|
||||
self.cots.append(other_cot)
|
||||
|
||||
# Print message structure if needed
|
||||
# messages_to_print = []
|
||||
# current_image = 1
|
||||
# for msg in messages:
|
||||
# msg_copy = copy.deepcopy(msg)
|
||||
# if isinstance(msg_copy['content'], list):
|
||||
# for content in msg_copy['content']:
|
||||
# if content['type'] == 'image_url':
|
||||
# content['image_url']['url'] = f'Image {current_image}'
|
||||
# current_image += 1
|
||||
# messages_to_print.append(msg_copy)
|
||||
|
||||
# messages_to_print.append({
|
||||
# "new_step_cot": other_cot,
|
||||
# "response": response
|
||||
# })
|
||||
# logger.info(json.dumps(messages_to_print, indent=2))
|
||||
logger.info(f"New step cot: {other_cot}")
|
||||
|
||||
return response, pyautogui_actions, {}
|
||||
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.constant,
|
||||
# here you should add more model exceptions as you want,
|
||||
# but you are forbidden to add "Exception", that is, a common type of exception
|
||||
# because we want to catch this kind of Exception in the outside to ensure
|
||||
# each example won't exceed the time limit
|
||||
(
|
||||
Exception
|
||||
),
|
||||
interval=30,
|
||||
max_tries=10
|
||||
)
|
||||
def call_llm(self, payload, model):
|
||||
"""Call the LLM API"""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {os.environ['OPENCUA_API_KEY']}"
|
||||
}
|
||||
|
||||
for _ in range(30):
|
||||
response = httpx.post(
|
||||
os.environ['OPENCUA_URL'],
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=500,
|
||||
verify=False
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error("Failed to call LLM: " + response.text)
|
||||
logger.error("Retrying...")
|
||||
time.sleep(5)
|
||||
else:
|
||||
response = response.json()
|
||||
finish_reason = response["choices"][0].get("finish_reason")
|
||||
if finish_reason is not None and finish_reason == "stop": # for most of the time, length will not exceed max_tokens
|
||||
return response['choices'][0]['message']['content']
|
||||
else:
|
||||
logger.error("LLM did not finish properly, retrying...")
|
||||
time.sleep(5)
|
||||
|
|
@ -0,0 +1,350 @@
|
|||
import logging
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
|
||||
from mm_agents.os_symphony.memory.procedural_memory import PROCEDURAL_MEMORY
|
||||
from mm_agents.os_symphony.utils.common_utils import call_llm_safe, parse_code_from_string
|
||||
from mm_agents.os_symphony.core.mllm import LMMAgent
|
||||
|
||||
logger = logging.getLogger("desktopenv.coder_agent")
|
||||
|
||||
|
||||
def extract_code_block(action: str) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""Extract code and determine type from action string."""
|
||||
if "```python" in action:
|
||||
code_type = "python"
|
||||
code = action.split("```python")[1].split("```")[0].strip()
|
||||
elif "```bash" in action:
|
||||
code_type = "bash"
|
||||
code = action.split("```bash")[1].split("```")[0].strip()
|
||||
elif "```" in action:
|
||||
code_type = None
|
||||
code = action.split("```")[1].split("```")[0].strip()
|
||||
else:
|
||||
code_type = None
|
||||
code = None
|
||||
|
||||
logger.debug(
|
||||
f"Extracted code block: type={code_type}, length={len(code) if code else 0}"
|
||||
)
|
||||
return code_type, code
|
||||
|
||||
|
||||
def execute_code(code_type: str, code: str, env_controller) -> Dict:
|
||||
"""Execute code based on its type."""
|
||||
# Log the full code being executed (untruncated)
|
||||
logger.info(f"CODING_AGENT_CODE_EXECUTION - Type: {code_type}\nCode:\n{code}")
|
||||
|
||||
try:
|
||||
if code_type == "bash":
|
||||
result = env_controller.run_bash_script(code, timeout=30)
|
||||
elif code_type == "python":
|
||||
result = env_controller.run_python_script(code)
|
||||
else:
|
||||
result = {"status": "error", "error": f"Unknown code type: {code_type}"}
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing {code_type} code: {e}")
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def format_result(result: Dict, step_count: int) -> str:
|
||||
"""Format execution result into context string."""
|
||||
if not result:
|
||||
logger.warning(f"Step {step_count + 1}: No result returned from execution")
|
||||
return f"""
|
||||
Step {step_count + 1} Error:
|
||||
Error: No result returned from execution
|
||||
"""
|
||||
|
||||
status = result.get("status", "unknown")
|
||||
return_code = result.get("returncode", result.get("return_code", -1))
|
||||
|
||||
# Handle different response structures for bash vs python
|
||||
if "returncode" in result:
|
||||
# Bash script response
|
||||
output = result.get("output", "") # Contains both stdout and stderr merged
|
||||
error = result.get("error", "") # Always empty for bash
|
||||
else:
|
||||
# Python script response
|
||||
output = result.get("output", "") # stdout only
|
||||
error = result.get("error", "") # stderr only
|
||||
|
||||
logger.debug(f"Step {step_count + 1}: Status={status}, Return Code={return_code}")
|
||||
|
||||
# Format with better structure for multi-line outputs
|
||||
result_text = f"Step {step_count + 1} Result:\n"
|
||||
result_text += f"Status: {status}\n"
|
||||
result_text += f"Return Code: {return_code}\n"
|
||||
|
||||
if output:
|
||||
result_text += f"Output:\n{output}\n"
|
||||
|
||||
if error:
|
||||
result_text += f"Error:\n{error}\n"
|
||||
|
||||
return result_text
|
||||
|
||||
|
||||
class CoderAgent:
|
||||
"""A dedicated agent for executing code with a budget of steps."""
|
||||
|
||||
def __init__(self, engine_params: Dict, client_password: str, platform: str = "linux"):
|
||||
"""Initialize the CodeAgent."""
|
||||
if not engine_params:
|
||||
raise ValueError("engine_params cannot be None or empty")
|
||||
|
||||
self.engine_params = engine_params
|
||||
self.budget = engine_params.get("budget", 20)
|
||||
self.temperature = engine_params.get("temperature", 0.1)
|
||||
self.agent = None
|
||||
self.platform = platform
|
||||
self.client_password = client_password
|
||||
|
||||
logger.info(f"CodeAgent initialized with budget={self.budget} and platform={self.platform}")
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""Reset the code agent state."""
|
||||
logger.debug("Resetting CodeAgent state")
|
||||
self.agent = LMMAgent(
|
||||
engine_params=self.engine_params,
|
||||
system_prompt=PROCEDURAL_MEMORY.construct_coder_procedural_memory(platform=self.platform, client_password=self.client_password)
|
||||
)
|
||||
|
||||
def execute(self, task_instruction: str, screenshot: str, env_controller) -> Dict:
|
||||
"""Execute code for the given task with a budget of steps."""
|
||||
if env_controller is None:
|
||||
raise ValueError("env_controller is required for code execution")
|
||||
|
||||
print(f"\n🚀 STARTING CODE EXECUTION")
|
||||
print("=" * 60)
|
||||
print(f"Task: {task_instruction}")
|
||||
print(f"Budget: {self.budget} steps")
|
||||
print("=" * 60)
|
||||
|
||||
logger.info(f"Starting code execution for task: {task_instruction}")
|
||||
logger.info(f"Budget: {self.budget} steps")
|
||||
|
||||
self.reset()
|
||||
|
||||
|
||||
# Add initial task instruction and screenshot context as user message
|
||||
context = (
|
||||
f"Task: {task_instruction}\n\nCurrent screenshot is provided for context."
|
||||
)
|
||||
self.agent.add_message(context, image_content=screenshot, role="user")
|
||||
|
||||
step_count = 0
|
||||
execution_history = []
|
||||
execution_result_history = []
|
||||
while step_count < self.budget:
|
||||
logger.info(f"Step {step_count + 1}/{self.budget}")
|
||||
|
||||
# Get assistant response (thoughts and code)
|
||||
response = call_llm_safe(self.agent, temperature=self.temperature)
|
||||
|
||||
# Print to terminal for immediate visibility
|
||||
# print(f"\n🤖 CODING AGENT RESPONSE - Step {step_count + 1}/{self.budget}")
|
||||
# print("=" * 60)
|
||||
# print(response)
|
||||
# print("=" * 60)
|
||||
|
||||
# Log the latest message from the coding agent (untruncated)
|
||||
logger.info(
|
||||
f"CODING_AGENT_LATEST_MESSAGE - Step {step_count + 1}:\n{response}"
|
||||
)
|
||||
|
||||
# Check if response is None or empty
|
||||
if not response or response.strip() == "":
|
||||
error_msg = f"Step {step_count + 1}: LLM returned empty response"
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
# Parse the response to extract action
|
||||
action = parse_code_from_string(response)
|
||||
thoughts = response
|
||||
|
||||
execution_history.append(
|
||||
{"step": step_count + 1, "action": action, "thoughts": thoughts}
|
||||
)
|
||||
|
||||
# Check for completion signals
|
||||
action_upper = action.upper().strip()
|
||||
if action_upper == "DONE":
|
||||
print(f"\n✅ TASK COMPLETED - Step {step_count + 1}")
|
||||
print("=" * 60)
|
||||
print("Agent signaled task completion")
|
||||
print("=" * 60)
|
||||
logger.info(f"Step {step_count + 1}: Task completed successfully")
|
||||
completion_reason = "DONE"
|
||||
break
|
||||
elif action_upper == "FAIL":
|
||||
print(f"\n❌ TASK FAILED - Step {step_count + 1}")
|
||||
print("=" * 60)
|
||||
print("Agent signaled task failure")
|
||||
print("=" * 60)
|
||||
logger.info(f"Step {step_count + 1}: Task failed by agent request")
|
||||
completion_reason = "FAIL"
|
||||
break
|
||||
elif action_upper == 'INFEASIBLE':
|
||||
print(f"\n❌ TASK INFEASIBLE - Step {step_count + 1}")
|
||||
print("=" * 60)
|
||||
print("Agent signaled task infeasible")
|
||||
print("=" * 60)
|
||||
logger.info(f"Step {step_count + 1}: Task infeasible by agent request")
|
||||
completion_reason = "INFEASIBLE"
|
||||
break
|
||||
|
||||
# Extract and execute code
|
||||
code_type, code = extract_code_block(response.split("(Answer)")[-1])
|
||||
|
||||
if code:
|
||||
result = execute_code(code_type, code, env_controller)
|
||||
execution_result_history.append(
|
||||
{"step": step_count + 1, "result": result}
|
||||
)
|
||||
# Prepare formatted output and error for logging
|
||||
output = result.get("output", "")
|
||||
error = result.get("error", "")
|
||||
message = result.get("message", "")
|
||||
status = result.get("status", "")
|
||||
|
||||
# Print execution result to terminal for immediate visibility
|
||||
print(f"\n⚡ CODE EXECUTION RESULT - Step {step_count + 1}")
|
||||
print("-" * 50)
|
||||
print(f"Status: {status}")
|
||||
if output:
|
||||
print(f"Output:\n{output}")
|
||||
if error:
|
||||
print(f"Error:\n{error}")
|
||||
if message and not output and not error:
|
||||
print(f"Message:\n{message}")
|
||||
print("-" * 50)
|
||||
|
||||
log_lines = [
|
||||
f"CODING_AGENT_EXECUTION_RESULT - Step {step_count + 1}:",
|
||||
f"Status: {status}" if status else None,
|
||||
]
|
||||
|
||||
if output:
|
||||
log_lines.append(
|
||||
"Output:\n" + ("-" * 40) + f"\n{output}\n" + ("-" * 40)
|
||||
)
|
||||
if error:
|
||||
log_lines.append(
|
||||
"Error:\n" + ("!" * 40) + f"\n{error}\n" + ("!" * 40)
|
||||
)
|
||||
if message and not output and not error:
|
||||
log_lines.append(
|
||||
"Message:\n" + ("-" * 40) + f"\n{message}\n" + ("-" * 40)
|
||||
)
|
||||
|
||||
# Remove None entries and join
|
||||
formatted_log = "\n".join([line for line in log_lines if line])
|
||||
logger.info(formatted_log)
|
||||
else:
|
||||
print(f"\n⚠️ NO CODE BLOCK FOUND - Step {step_count + 1}")
|
||||
print("-" * 50)
|
||||
print("Action did not contain executable code")
|
||||
print("-" * 50)
|
||||
|
||||
logger.warning(f"Step {step_count + 1}: No code block found in action")
|
||||
result = {"status": "skipped", "message": "No code block found"}
|
||||
logger.info(
|
||||
f"CODING_AGENT_EXECUTION_RESULT - Step {step_count + 1}:\n"
|
||||
f"Status: skipped\n"
|
||||
f"Message:\n{'-' * 40}\n{result['message']}\n{'-' * 40}"
|
||||
)
|
||||
# Add assistant's thoughts and code to message history
|
||||
self.agent.add_message(response, role="assistant")
|
||||
|
||||
# Process result and add formatted environment results as user message
|
||||
result_context = format_result(result, step_count)
|
||||
self.agent.add_message(result_context, role="user")
|
||||
|
||||
step_count += 1
|
||||
|
||||
# Handle budget exhaustion
|
||||
if "completion_reason" not in locals():
|
||||
print(f"\n⏰ BUDGET EXHAUSTED - {step_count} steps completed")
|
||||
print("=" * 60)
|
||||
print(f"Maximum budget of {self.budget} steps reached")
|
||||
print("=" * 60)
|
||||
logger.info(f"Budget exhausted after {step_count} steps")
|
||||
completion_reason = f"BUDGET_EXHAUSTED_AFTER_{step_count}_STEPS"
|
||||
|
||||
# Generate final summary
|
||||
logger.info("Generating execution summary")
|
||||
summary = self._generate_summary(execution_history, task_instruction)
|
||||
|
||||
result = {
|
||||
"task_instruction": task_instruction,
|
||||
"completion_reason": completion_reason,
|
||||
"summary": summary,
|
||||
"execution_history": execution_history,
|
||||
"execution_result_history": execution_result_history,
|
||||
"steps_executed": step_count,
|
||||
"budget": self.budget
|
||||
}
|
||||
|
||||
logger.info(f"Code execution completed: steps={step_count}")
|
||||
return result
|
||||
|
||||
def _generate_summary(
|
||||
self, execution_history: List[Dict], task_instruction: str
|
||||
) -> str:
|
||||
"""Generate summary of code execution session."""
|
||||
if not execution_history:
|
||||
logger.info("No execution history to summarize")
|
||||
return "No actions were executed."
|
||||
|
||||
logger.info(f"Generated summary for {len(execution_history)} steps")
|
||||
|
||||
# Build detailed execution context for summary agent
|
||||
execution_context = f"Task: {task_instruction}\n\nExecution Steps:\n"
|
||||
|
||||
for step in execution_history:
|
||||
step_num = step["step"]
|
||||
thoughts = step.get("thoughts", "")
|
||||
action = step.get("action", "")
|
||||
|
||||
execution_context += f"\nStep {step_num}:\n"
|
||||
if thoughts:
|
||||
execution_context += f"Thoughts: {thoughts}\n"
|
||||
execution_context += f"Code: {action}\n"
|
||||
|
||||
# Create summary prompt with same context as coding agent
|
||||
summary_prompt = f"""
|
||||
{execution_context}
|
||||
|
||||
Please provide a concise summary of the code execution session. Focus on:
|
||||
|
||||
1. The code logic implemented at each step
|
||||
2. The outputs and results produced by each code execution
|
||||
3. The progression of the solution approach
|
||||
|
||||
Do not make judgments about success or failure. Simply describe what was attempted and what resulted.
|
||||
|
||||
Keep the summary under 150 words and use clear, factual language.
|
||||
"""
|
||||
|
||||
# Generate summary using LLM with dedicated summary system prompt
|
||||
try:
|
||||
summary_agent = LMMAgent(
|
||||
engine_params=self.engine_params,
|
||||
system_prompt=PROCEDURAL_MEMORY.CODE_SUMMARY_AGENT_PROMPT,
|
||||
)
|
||||
summary_agent.add_message(summary_prompt, role="user")
|
||||
summary = call_llm_safe(summary_agent, temperature=self.temperature)
|
||||
|
||||
if not summary or summary.strip() == "":
|
||||
summary = "Summary generation failed - no response from LLM"
|
||||
logger.warning("Summary generation failed - empty response from LLM")
|
||||
|
||||
except Exception as e:
|
||||
summary = f"Summary generation failed: {str(e)}"
|
||||
logger.error(f"Error generating summary: {e}")
|
||||
|
||||
return summary
|
||||
|
|
@ -0,0 +1,109 @@
|
|||
import re
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytesseract
|
||||
from PIL import Image
|
||||
import io
|
||||
from mm_agents.os_symphony.core.mllm import LMMAgent
|
||||
from mm_agents.os_symphony.utils.common_utils import call_llm_safe, smart_resize
|
||||
from mm_agents.os_symphony.memory.procedural_memory import PROCEDURAL_MEMORY
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("desktopenv.agent")
|
||||
|
||||
class GrounderAgent:
|
||||
"""
|
||||
Class designed for interacting with GUI, serving for Grounding Agent and VLMSearcher
|
||||
"""
|
||||
def __init__(self, engine_params: Dict, screen_width: int, screen_height: int):
|
||||
self.engine_params_for_grounder = engine_params # grounder_params
|
||||
system_prompt, self.user_message = PROCEDURAL_MEMORY.construct_grounder_procedural_memory(model_name=engine_params["model"])
|
||||
self.grounding_model = LMMAgent(engine_params, system_prompt=system_prompt)
|
||||
# Width and height for Grounding Agent!
|
||||
self.width = engine_params['grounding_width']
|
||||
self.height = engine_params['grounding_height']
|
||||
print(f"[Grounder]: initialized width is {self.width}, height is {self.height}")
|
||||
# Width and height for actual screen!
|
||||
self.screen_width = screen_width
|
||||
self.screen_height = screen_height
|
||||
|
||||
# Given the state and worker's referring expression, use the grounding model to generate (x,y)
|
||||
def generate_coords(self, ref_expr: str, obs: Dict, detail=False, expansion_pixels=400, **kwargs) -> List:
|
||||
cur_screenshot = obs["screenshot"]
|
||||
|
||||
# store global offset
|
||||
global_offset_x = 0
|
||||
global_offset_y = 0
|
||||
|
||||
# final coordinates for output
|
||||
final_global_x = 0
|
||||
final_global_y = 0
|
||||
|
||||
cur_width, cur_height = self.screen_width, self.screen_height
|
||||
|
||||
print(f"[Grounder] start to ground!")
|
||||
self.grounding_model.reset()
|
||||
|
||||
# Configure the context
|
||||
prompt = self.user_message.replace("REF_EXPR", ref_expr)
|
||||
|
||||
# cosistent with the system prompt presented in the paper of GTA-1
|
||||
if 'gta' in self.engine_params_for_grounder['model']:
|
||||
self.grounding_model.add_system_prompt("You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task.")
|
||||
|
||||
self.grounding_model.add_message(
|
||||
text_content=prompt, image_content=cur_screenshot, put_text_last=True, role="user"
|
||||
)
|
||||
|
||||
# Generate and parse coordinates
|
||||
response = call_llm_safe(self.grounding_model, temperature=0.05, **kwargs)
|
||||
print(f"[Grounder] prompt: {prompt}\nmodel: {self.engine_params_for_grounder['model']}, \nresponse: {response}")
|
||||
|
||||
|
||||
# 1. highest priority: (x1="...", y1="...", x="...", y="...")
|
||||
numericals = re.findall(r'(?:x1|y1|x|y)=["\']?(\d+)["\']?', response)
|
||||
# 2. second highest priority: just like <points>653 42</points> or [653, 42]
|
||||
if len(numericals) < 2:
|
||||
clean_response = re.sub(r'[xXyY]\d', '', response)
|
||||
numericals = re.findall(r'\d+', clean_response)
|
||||
assert len(numericals) >= 2
|
||||
|
||||
print(f"[Grounder] the parsed coordinates: {numericals}")
|
||||
|
||||
local_x, local_y = self._resize_coordinates([int(numericals[0]), int(numericals[1])], width=cur_width, height=cur_height)
|
||||
|
||||
# current global coordinates = local ordinates + global offset
|
||||
final_global_x = local_x + global_offset_x
|
||||
final_global_y = local_y + global_offset_y
|
||||
|
||||
if detail:
|
||||
return [cur_screenshot, global_offset_x, global_offset_y]
|
||||
else:
|
||||
return [final_global_x, final_global_y]
|
||||
|
||||
def dynamic_set_width_height(self, width: int, height: int):
|
||||
self.width = width
|
||||
self.height = height
|
||||
|
||||
# Resize from grounding model dim into OSWorld dim (1920 * 1080)
|
||||
def _resize_coordinates(self, coordinates: List[int], width:int, height:int) -> List[int]:
|
||||
"""
|
||||
width, height: for current observation
|
||||
grounding_width, grounding_height: width and height for Grounding model 1000x1000 or 1280x800)
|
||||
"""
|
||||
grounding_width = self.engine_params_for_grounder["grounding_width"]
|
||||
grounding_height = self.engine_params_for_grounder["grounding_height"]
|
||||
grounding_smart_resize = self.engine_params_for_grounder["grounding_smart_resize"]
|
||||
|
||||
|
||||
if not grounding_smart_resize:
|
||||
return [
|
||||
round(coordinates[0] * width / grounding_width),
|
||||
round(coordinates[1] * height / grounding_height),
|
||||
]
|
||||
else:
|
||||
smart_height, smart_width = smart_resize(height, width)
|
||||
return [
|
||||
round(coordinates[0] * width / smart_width),
|
||||
round(coordinates[1] * height / smart_height)
|
||||
]
|
||||
|
|
@ -0,0 +1,428 @@
|
|||
from ast import parse
|
||||
import logging
|
||||
import json
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from mm_agents.os_symphony.utils.common_utils import (
|
||||
call_llm_formatted,
|
||||
enhance_observation,
|
||||
parse_code_from_string
|
||||
)
|
||||
from functools import partial
|
||||
from mm_agents.os_symphony.utils.formatters import JSON_ANSWER_FORMATTER
|
||||
from mm_agents.os_symphony.core.mllm import LMMAgent
|
||||
from mm_agents.os_symphony.memory.procedural_memory import PROCEDURAL_MEMORY
|
||||
import imagehash
|
||||
import io
|
||||
import os
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from skimage.metrics import structural_similarity as ssim
|
||||
|
||||
logger = logging.getLogger("desktopenv.agent")
|
||||
|
||||
|
||||
class StepBehavior:
|
||||
"""
|
||||
Narrative Step Behavior.
|
||||
Description of each step, cosists of generative agent (main agent)'s output, screenshot (if this step is milestone), and textual description.
|
||||
The textual description shows that how the agent thought and did, and how the state changes.
|
||||
"""
|
||||
def __init__(self, is_milestone: bool, gen_output: str, summary: str, obs: Dict, action_dict: Dict):
|
||||
self.is_milestone = is_milestone
|
||||
self.gen_output = gen_output
|
||||
self.obs = obs
|
||||
self.summary = summary
|
||||
self.action_dict = action_dict
|
||||
# Variants for opyimizing the time complexity of loop detection
|
||||
# --- 1. pHash ---
|
||||
self.phash = None
|
||||
# --- 2. SSIM ---
|
||||
self.ssim_list = []
|
||||
|
||||
def _update_phash_ssim(self, history: List):
|
||||
# Calculate the ssim_list of current obs
|
||||
# Update pHash
|
||||
cur_img = Image.open(io.BytesIO(self.obs["screenshot"]))
|
||||
cur_img_gray = cur_img.convert('L')
|
||||
cur_img_np = np.array(cur_img_gray)
|
||||
self.phash = imagehash.phash(cur_img)
|
||||
# Update ssim_list
|
||||
for hs in history:
|
||||
compare_img = Image.open(io.BytesIO(hs.obs["screenshot"]))
|
||||
compare_img_gray = compare_img.convert('L')
|
||||
compare_img_np = np.array(compare_img_gray)
|
||||
self.ssim_list.append(ssim(cur_img_np, compare_img_np, data_range=cur_img_np.max() - compare_img_np.min()))
|
||||
|
||||
class ReflectionMemoryAgent:
|
||||
"""
|
||||
Reflection Memory Agent (RMA).
|
||||
Responsible for maintaining long-term memory, extracting narratives from trajectories,
|
||||
providing reflections to the Main Agent, and validating task completion status.
|
||||
"""
|
||||
def __init__(self, engine_params: Dict):
|
||||
"""
|
||||
Initialize the RMA.
|
||||
|
||||
Args:
|
||||
- engine_params:
|
||||
{
|
||||
"engine_type": args.provider,
|
||||
"model": args.model,
|
||||
"base_url": args.model_url,
|
||||
"api_key": args.model_api_key,
|
||||
"temperature": getattr(args, "model_temperature", None),
|
||||
}
|
||||
- max_img_len: max image number to use in reflection process
|
||||
"""
|
||||
|
||||
self.engine_params = engine_params
|
||||
|
||||
self.max_images = engine_params.get('max_images', 8)
|
||||
|
||||
self.memoryer_level = engine_params.get('memoryer_level', 3)
|
||||
|
||||
self.reset()
|
||||
|
||||
logger.info(f"ReflectionMemoryAgent initialized with:\n {self.engine_params}")
|
||||
|
||||
|
||||
def reset(self):
|
||||
"""Reset the code agent state."""
|
||||
logger.debug("Resetting RMA state")
|
||||
|
||||
self.instruction = None
|
||||
|
||||
self.trajectory: List[StepBehavior] = []
|
||||
|
||||
self.knowledge_base: List[str] = []
|
||||
|
||||
self.last_code_step_idx = -1
|
||||
|
||||
'''
|
||||
Control the count of images, we only use the maximum number of max_img_len images.
|
||||
The update logic: the 0-th screenshot is always retained. If the total number of screenshots is less than max_img_len, all are kept; otherwise, starting from index 1, milestone screenshots are managed via FIFO.
|
||||
'''
|
||||
self.active_img_idx = []
|
||||
|
||||
self.reflection_agent = LMMAgent(
|
||||
engine_params=self.engine_params,
|
||||
system_prompt=PROCEDURAL_MEMORY.REFLECTION_SYSTEM_PROMPT,
|
||||
)
|
||||
self.behavior_agent = LMMAgent(
|
||||
engine_params=self.engine_params,
|
||||
system_prompt=PROCEDURAL_MEMORY.SUMMARIZE_STEP_SYSTEM_PROMPT
|
||||
)
|
||||
|
||||
def add_instruction(self, instruction):
|
||||
"""
|
||||
[Interface] Main -> RMA
|
||||
Main agent set the instruction to RMA.
|
||||
"""
|
||||
self.instruction = instruction
|
||||
|
||||
def _update_trajectory(self, step_behavior):
|
||||
self.trajectory.append(step_behavior)
|
||||
if len(self.active_img_idx) >= self.max_images:
|
||||
if step_behavior.is_milestone:
|
||||
self.active_img_idx.append(len(self.trajectory) - 1) # over max_img_len,only milestone image
|
||||
del self.active_img_idx[1] # FIFO starts from index 1
|
||||
else:
|
||||
self.active_img_idx.append(len(self.trajectory) - 1) # less than max_img_len, feed all images
|
||||
|
||||
assert len(self.active_img_idx) <= self.max_images, "[RMA] Errors in updating StepBehavior!!"
|
||||
|
||||
def _summarize_step_behavior(
|
||||
self,
|
||||
generator_output: str,
|
||||
cur_obs: Dict,
|
||||
enhanced_obs: bytes | None,
|
||||
is_milestone: bool,
|
||||
mode: str = "gui",
|
||||
code_exec_summary: str = "",
|
||||
action_dict: Dict = {}
|
||||
) -> Tuple[StepBehavior, bool]:
|
||||
"""
|
||||
[Interface] Main -> RMA
|
||||
The Main Agent (MA) calls this method to "feed" the information of the just-completed step to the RMA.
|
||||
RMA will internally process and store this step.
|
||||
"""
|
||||
|
||||
if mode == "search":
|
||||
is_success = "successful"
|
||||
# summary is fixed
|
||||
step_behavior = StepBehavior(
|
||||
False,
|
||||
generator_output,
|
||||
"Search Agent was called last step, and a tutorial has been generated.",
|
||||
cur_obs,
|
||||
action_dict
|
||||
)
|
||||
elif mode == "code":
|
||||
self.last_code_step_idx = len(self.trajectory)
|
||||
|
||||
is_success = "successful"
|
||||
# the summary returned by the code agent
|
||||
step_behavior = StepBehavior(
|
||||
False,
|
||||
generator_output,
|
||||
f"Code Agent was called last step, and the summary of its trajectory is: \n---\n{code_exec_summary}\n---",
|
||||
cur_obs,
|
||||
action_dict
|
||||
)
|
||||
else: # common gui operation, use LLM to summarize
|
||||
prev_obs = self.trajectory[-1].obs
|
||||
|
||||
text_content = f"""Computer Use Agent's Output: \n{generator_output}"""
|
||||
|
||||
|
||||
self.behavior_agent.reset() # don't need history messages
|
||||
|
||||
updated_sys_prompt = (
|
||||
self.behavior_agent.system_prompt + "\n" + text_content
|
||||
)
|
||||
self.behavior_agent.add_system_prompt(updated_sys_prompt)
|
||||
|
||||
self.behavior_agent.add_message(
|
||||
text_content="This is the observation before executing action (attached below).",
|
||||
image_content=prev_obs['screenshot'],
|
||||
role="user",
|
||||
put_text_last=False
|
||||
)
|
||||
self.behavior_agent.add_message(
|
||||
text_content="This is the zoom-in view, which may help you to identify the operational region (attached below).",
|
||||
image_content=enhanced_obs,
|
||||
role="user",
|
||||
put_text_last=False
|
||||
)
|
||||
self.behavior_agent.add_message(
|
||||
text_content="This is the observation after executing action (attached below).",
|
||||
image_content=cur_obs['screenshot'],
|
||||
role="user",
|
||||
put_text_last=False
|
||||
)
|
||||
|
||||
required_fields = ["summary", "evaluation"]
|
||||
format_checkers = [
|
||||
partial(JSON_ANSWER_FORMATTER, required_fields)
|
||||
]
|
||||
|
||||
full_response = call_llm_formatted(
|
||||
self.behavior_agent,
|
||||
format_checkers,
|
||||
temperature=self.engine_params.get("temperture", 0.1),
|
||||
)
|
||||
|
||||
response = parse_code_from_string(full_response)
|
||||
|
||||
try:
|
||||
data = json.loads(response)
|
||||
behavior_summary = data['summary']
|
||||
is_success = data["evaluation"]
|
||||
except Exception as e:
|
||||
print("[RMA] Errors in generating step summary: ", e)
|
||||
logger.info("Response is not a JSON object or miss required keys!")
|
||||
behavior_summary = response
|
||||
is_success = "successful"
|
||||
|
||||
|
||||
step_behavior = StepBehavior(is_milestone, generator_output, behavior_summary, cur_obs, action_dict)
|
||||
|
||||
return step_behavior, is_success == "successful"
|
||||
|
||||
def get_reflection(
|
||||
self,
|
||||
cur_obs: Dict,
|
||||
generator_output: str,
|
||||
coordinates: List,
|
||||
mode: str="gui",
|
||||
code_exec_summary: str = "",
|
||||
action_dict: Dict = {}
|
||||
) -> Dict:
|
||||
"""
|
||||
[Interface] RMA -> Main
|
||||
The Main Agent (MA) calls this method to get RMA's reflection before deciding the next action.
|
||||
|
||||
Args:
|
||||
- cur_obs (Dict): The Main Agent's current observation (o_k).
|
||||
- generator_output (str): The thoughts, screen analysis and action of Main Agent.
|
||||
- coordinates (List): coordinates in the last operation step of Main Agent.
|
||||
- mode(str): [gui, code, search]. Indicate which agent that main agent called last step.
|
||||
- code_exec_summary: execution summary for code agent.
|
||||
- action_dict: extracted action from generator output.
|
||||
|
||||
Returns:
|
||||
- reflection_info(Dict): all the info related to reflection
|
||||
"""
|
||||
if self.memoryer_level == 0:
|
||||
return {
|
||||
"reflection": None,
|
||||
"reflection_thoughts": None,
|
||||
"existing_knowledge": None,
|
||||
"is_milestone": False,
|
||||
"new_knowledge": None,
|
||||
"step_summary": None,
|
||||
"hint": {
|
||||
"gui_operation_error": False,
|
||||
"lack_of_tutorial": False,
|
||||
"code_error": False,
|
||||
"loop_detection": None,
|
||||
}
|
||||
}
|
||||
|
||||
reflection = None
|
||||
reflection_thought = None
|
||||
if len(self.trajectory) == 0:
|
||||
step_behavior = StepBehavior(
|
||||
True,
|
||||
"The initial screen is provided. No action has been taken yet.",
|
||||
"The initial screen is provided. No action has been taken yet.",
|
||||
cur_obs,
|
||||
action_dict
|
||||
)
|
||||
step_behavior._update_phash_ssim(self.trajectory)
|
||||
self._update_trajectory(step_behavior)
|
||||
reflection_info = {
|
||||
"reflection": reflection,
|
||||
"reflection_thoughts": reflection_thought,
|
||||
"existing_knowledge": "\n".join(self.knowledge_base),
|
||||
"is_milestone": True,
|
||||
"new_knowledge": "",
|
||||
"step_summary": "",
|
||||
"loop_detection": None
|
||||
}
|
||||
else:
|
||||
### Step Summary
|
||||
prev_obs = self.trajectory[-1].obs
|
||||
enhanced_obs = None
|
||||
if coordinates:
|
||||
enhanced_obs, _, _, _, _ = enhance_observation(
|
||||
prev_obs["screenshot"],
|
||||
coordinates,
|
||||
draw=True
|
||||
)
|
||||
|
||||
# generate step behavior
|
||||
step_behavior, last_gui_check = self._summarize_step_behavior(
|
||||
generator_output,
|
||||
cur_obs,
|
||||
enhanced_obs,
|
||||
False,
|
||||
mode,
|
||||
code_exec_summary,
|
||||
action_dict
|
||||
)
|
||||
step_behavior._update_phash_ssim(self.trajectory)
|
||||
|
||||
### make additional hints
|
||||
additional_hints = []
|
||||
if not last_gui_check:
|
||||
additional_hints.append(f"\t- Warning: The last GUI operation is unsuccessful. Careful review is required to avoid GUI Operation Error.")
|
||||
|
||||
code_error_hint = False
|
||||
|
||||
if self.last_code_step_idx != -1 and len(self.trajectory) - self.last_code_step_idx < 0:
|
||||
code_error_hint = True
|
||||
additional_hints.append(f"\t- Warning: The Computer Use Agent might in the verification stage of Code Agent. Careful review is required to avoid Code Error.")
|
||||
|
||||
# loop detection
|
||||
from mm_agents.os_symphony.utils.loop_detection import detect_loop
|
||||
is_loop, loop_details = detect_loop(full_trajectory=self.trajectory, N=3)
|
||||
if is_loop and loop_details:
|
||||
match_sequence_indices = loop_details["match_sequence_indices"]
|
||||
loop_hint_message = f"\t- Warning: A potential LOOP has been detected between Step {match_sequence_indices[0]} and Step {match_sequence_indices[-1]}. Careful review is required to avoid Repetitive Behavior Error."
|
||||
additional_hints.append(loop_hint_message)
|
||||
|
||||
self.reflection_agent.reset()
|
||||
|
||||
updated_sys_prompt = (
|
||||
PROCEDURAL_MEMORY.REFLECTION_SYSTEM_PROMPT + "\n\n" +
|
||||
f"---\n- **user instruction**: {self.instruction}\n" +
|
||||
"- **existing knowledge**: \n" + "\n".join(self.knowledge_base) +
|
||||
"\n- **additional_hints**: " + "\n".join(additional_hints) + "\n---"
|
||||
)
|
||||
|
||||
# update system prompt
|
||||
self.reflection_agent.add_system_prompt(updated_sys_prompt)
|
||||
|
||||
|
||||
for i, step in enumerate(self.trajectory):
|
||||
text_content = f"""### (Step {i}) history:\nsummary: '''\n{step.summary}\n'''"""
|
||||
if i in self.active_img_idx:
|
||||
if i == 0:
|
||||
text_content += f"\ninitial screenshot:"
|
||||
else:
|
||||
text_content += f"\nscreenshot (after executing action): (attached below)"
|
||||
|
||||
self.reflection_agent.add_message(
|
||||
text_content=text_content,
|
||||
image_content=step.obs['screenshot'] if i in self.active_img_idx else None,
|
||||
role="user",
|
||||
put_text_last=False
|
||||
)
|
||||
|
||||
text_content = f"""### (Last Step) CUA's output (has been finished):\n---\n{generator_output}\n---\nStep Summary:\n---\n{step_behavior.summary}\n---\nlatest_screenshot: (attached below)"""
|
||||
self.reflection_agent.add_message(
|
||||
text_content=text_content,
|
||||
image_content=cur_obs['screenshot'],
|
||||
role="user",
|
||||
put_text_last=False
|
||||
)
|
||||
|
||||
required_fields = ["is_milestone", "reflection", "knowledge"]
|
||||
|
||||
format_checkers = [
|
||||
partial(JSON_ANSWER_FORMATTER, required_fields)
|
||||
]
|
||||
|
||||
full_response = call_llm_formatted(
|
||||
self.reflection_agent,
|
||||
format_checkers
|
||||
)
|
||||
|
||||
|
||||
reflection_thought = full_response
|
||||
|
||||
response = parse_code_from_string(full_response)
|
||||
|
||||
try:
|
||||
data = json.loads(response)
|
||||
reflection = data['reflection']
|
||||
is_milestone = data["is_milestone"]
|
||||
knowledge = data['knowledge']
|
||||
except Exception as e:
|
||||
print("[RMA] Errors in dealing with reflection: ", e)
|
||||
logger.info("Response is not a JSON object or miss required keys!")
|
||||
reflection = response
|
||||
is_milestone = False
|
||||
knowledge = ""
|
||||
|
||||
if len(knowledge) > 0:
|
||||
self.knowledge_base.append(knowledge)
|
||||
|
||||
if isinstance(is_milestone, str):
|
||||
is_milestone = True if "true" in is_milestone.lower() else False
|
||||
|
||||
# update trajectory and is_milestone
|
||||
self._update_trajectory(step_behavior)
|
||||
if mode == "gui": # only gui opration can be considered as milestone
|
||||
self.trajectory[-1].is_milestone = is_milestone
|
||||
|
||||
|
||||
reflection_info = {
|
||||
"reflection": reflection,
|
||||
"reflection_thoughts": reflection_thought,
|
||||
"existing_knowledge": "\n".join(self.knowledge_base),
|
||||
"is_milestone": is_milestone,
|
||||
"new_knowledge": knowledge,
|
||||
"step_summary": step_behavior.summary,
|
||||
"hint": {
|
||||
"gui_operation_error": not last_gui_check,
|
||||
"lack_of_tutorial": is_loop,
|
||||
"code_error": code_error_hint,
|
||||
"loop_detection": loop_details,
|
||||
}
|
||||
}
|
||||
|
||||
return reflection_info
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,178 @@
|
|||
import re
|
||||
from io import BytesIO
|
||||
from typing import Tuple, List, Dict
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
import numpy as np
|
||||
import pytesseract
|
||||
from pytesseract import Output
|
||||
import easyocr
|
||||
|
||||
|
||||
class OCRProcessor:
|
||||
"""
|
||||
OCR Processor supports Tesseract and EasyOCR
|
||||
"""
|
||||
def __init__(self, use_gpu: bool = False, languages: List[str] = ['en']):
|
||||
"""
|
||||
Initialize processor
|
||||
|
||||
Args:
|
||||
use_gpu (bool): whether EasyOCR need to use gpu
|
||||
languages (List[str]): language list that EasyOCR, e.g. ['en', 'ch_sim']。
|
||||
"""
|
||||
self.use_gpu = use_gpu
|
||||
self.languages = languages
|
||||
self.reader = None # lazy-load EasyOCR Reader
|
||||
|
||||
def _get_easyocr_reader(self):
|
||||
if self.reader is None:
|
||||
print(f"Loading EasyOCR model (GPU={self.use_gpu})...")
|
||||
self.reader = easyocr.Reader(self.languages, gpu=self.use_gpu)
|
||||
return self.reader
|
||||
|
||||
def get_ocr_elements(self, bytes_image_data: bytes, mode: str = 'tesseract') -> Tuple[str, List[Dict]]:
|
||||
"""
|
||||
Executes OCR recognization.
|
||||
|
||||
Args:
|
||||
bytes_image_data (str): image in Base64
|
||||
mode (str): 'tesseract' (faster) or 'easyocr' (more precise)。
|
||||
|
||||
Returns:
|
||||
Tuple[str, List]: (textual table string, list of element details)
|
||||
"""
|
||||
try:
|
||||
image = Image.open(BytesIO(bytes_image_data))
|
||||
except Exception as e:
|
||||
print(f"Error decoding or opening image: {e}")
|
||||
return "", []
|
||||
|
||||
if mode == 'tesseract':
|
||||
return self._process_tesseract(image)
|
||||
elif mode == 'easyocr':
|
||||
return self._process_easyocr(image)
|
||||
else:
|
||||
raise ValueError(f"Unknown mode: {mode}. Use 'tesseract' or 'easyocr'.")
|
||||
|
||||
def _process_tesseract(self, image: Image.Image) -> Tuple[str, List[Dict]]:
|
||||
"""Tesseract processing"""
|
||||
data = pytesseract.image_to_data(image, output_type=Output.DICT)
|
||||
|
||||
ocr_elements = []
|
||||
ocr_table = "Text Table (Tesseract):\nWord id\tText\n"
|
||||
ocr_id = 0
|
||||
|
||||
num_boxes = len(data['text'])
|
||||
for i in range(num_boxes):
|
||||
# filter text with low confidence
|
||||
if int(data['conf'][i]) > 0 and data['text'][i].strip():
|
||||
clean_text = re.sub(r"^[^a-zA-Z0-9\s.,!?;:\-\+]+|[^a-zA-Z0-9\s.,!?;:\-\+]+$", "", data['text'][i])
|
||||
if not clean_text: continue
|
||||
|
||||
ocr_table += f"{ocr_id}\t{clean_text}\n"
|
||||
|
||||
ocr_elements.append({
|
||||
"id": ocr_id,
|
||||
"text": clean_text,
|
||||
"mode": "tesseract",
|
||||
"left": data["left"][i],
|
||||
"top": data["top"][i],
|
||||
"width": data["width"][i],
|
||||
"height": data["height"][i],
|
||||
"conf": data["conf"][i]
|
||||
})
|
||||
ocr_id += 1
|
||||
|
||||
return ocr_table, ocr_elements
|
||||
|
||||
def _process_easyocr(self, image: Image.Image) -> Tuple[str, List[Dict]]:
|
||||
"""EasyOCR processing"""
|
||||
reader = self._get_easyocr_reader()
|
||||
|
||||
image_np = np.array(image)
|
||||
|
||||
# detail=1 means returning (bbox, text, conf)
|
||||
results = reader.readtext(image_np, detail=1, paragraph=False, width_ths=0.1)
|
||||
|
||||
ocr_elements = []
|
||||
ocr_table = "Text Table (EasyOCR):\nWord id\tText\n"
|
||||
ocr_id = 0
|
||||
|
||||
for (bbox, text, conf) in results:
|
||||
clean_text = re.sub(r"^[^a-zA-Z0-9\s.,!?;:\-\+]+|[^a-zA-Z0-9\s.,!?;:\-\+]+$", "", text)
|
||||
if not clean_text.strip(): continue
|
||||
|
||||
# EasyOCR returns [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
|
||||
# we convert them into left, top, width, height
|
||||
(tl, tr, br, bl) = bbox
|
||||
tl = [int(v) for v in tl]
|
||||
br = [int(v) for v in br]
|
||||
|
||||
left = min(tl[0], bl[0])
|
||||
top = min(tl[1], tr[1])
|
||||
right = max(tr[0], br[0])
|
||||
bottom = max(bl[1], br[1])
|
||||
|
||||
width = right - left
|
||||
height = bottom - top
|
||||
# ---------------
|
||||
|
||||
ocr_table += f"{ocr_id}\t{clean_text}\n"
|
||||
|
||||
ocr_elements.append({
|
||||
"id": ocr_id,
|
||||
"text": clean_text,
|
||||
"mode": "easyocr",
|
||||
"left": left,
|
||||
"top": top,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"conf": float(conf)
|
||||
})
|
||||
ocr_id += 1
|
||||
|
||||
return ocr_table, ocr_elements
|
||||
|
||||
@staticmethod
|
||||
def visualize_ocr_results(image_path: str, ocr_elements: List[Dict], output_path: str):
|
||||
"""
|
||||
Draw bounding boxes and IDs on the original image.
|
||||
"""
|
||||
try:
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
draw = ImageDraw.Draw(image)
|
||||
|
||||
try:
|
||||
font = ImageFont.truetype("arial.ttf", 16)
|
||||
except IOError:
|
||||
font = ImageFont.load_default()
|
||||
|
||||
for element in ocr_elements:
|
||||
left, top = element["left"], element["top"]
|
||||
width, height = element["width"], element["height"]
|
||||
|
||||
color = "green" if element.get("mode") == "easyocr" else "red"
|
||||
|
||||
draw.rectangle([(left, top), (left + width, top + height)], outline=color, width=2)
|
||||
|
||||
text_str = str(element["id"])
|
||||
|
||||
if hasattr(draw, "textbbox"):
|
||||
bbox = draw.textbbox((0, 0), text_str, font=font)
|
||||
text_w, text_h = bbox[2]-bbox[0], bbox[3]-bbox[1]
|
||||
else:
|
||||
text_w, text_h = draw.textsize(text_str, font=font)
|
||||
|
||||
label_bg = [left, top - text_h - 4, left + text_w + 4, top]
|
||||
draw.rectangle(label_bg, fill=color)
|
||||
|
||||
draw.text((left + 2, top - text_h - 4), text_str, fill="white", font=font)
|
||||
|
||||
image.save(output_path)
|
||||
print(f"Visualization saved to: {output_path}")
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"Error: Image {image_path} not found.")
|
||||
except Exception as e:
|
||||
print(f"Visualization error: {e}")
|
||||
|
||||
|
|
@ -0,0 +1,575 @@
|
|||
import re
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from mm_agents.os_symphony.memory.procedural_memory import PROCEDURAL_MEMORY
|
||||
from mm_agents.os_symphony.core.mllm import LMMAgent
|
||||
from mm_agents.os_symphony.utils.common_utils import call_llm_safe
|
||||
from mm_agents.os_symphony.agents.coder_agent import CoderAgent
|
||||
from mm_agents.os_symphony.agents.grounder_agent import GrounderAgent
|
||||
from mm_agents.os_symphony.agents.searcher_agent import SearcherAgent
|
||||
import logging
|
||||
from mm_agents.os_symphony.agents.ocr import OCRProcessor
|
||||
|
||||
|
||||
logger = logging.getLogger("desktopenv.agent")
|
||||
|
||||
# Agent action decorator
|
||||
def agent_action(func):
|
||||
func.is_agent_action = True
|
||||
return func
|
||||
|
||||
# GrounderAgent primitives are parameterized by description, and coordinate generation uses a pretrained grounding model
|
||||
class OSACI:
|
||||
def __init__(
|
||||
self,
|
||||
env,
|
||||
search_env,
|
||||
platform: str,
|
||||
client_password: str,
|
||||
engine_params_for_ocr: Dict,
|
||||
engine_params_for_grounder: Dict,
|
||||
engine_params_for_coder: Dict,
|
||||
engine_params_for_searcher: Dict,
|
||||
screen_width: int = 1920,
|
||||
screen_height: int = 1080
|
||||
):
|
||||
|
||||
self.env = env
|
||||
self.platform = platform
|
||||
self.client_password = client_password
|
||||
|
||||
self.result_dir = ""
|
||||
|
||||
self.grounder_agent = GrounderAgent(engine_params=engine_params_for_grounder, screen_width=screen_width, screen_height=screen_height)
|
||||
|
||||
# Configure text grounding agent
|
||||
self.ocr_processor = OCRProcessor()
|
||||
self.text_span_agent = LMMAgent(
|
||||
engine_params=engine_params_for_ocr,
|
||||
system_prompt=PROCEDURAL_MEMORY.PHRASE_TO_WORD_COORDS_PROMPT,
|
||||
)
|
||||
|
||||
# Configure code agent
|
||||
self.coder_agent = CoderAgent(
|
||||
engine_params=engine_params_for_coder,
|
||||
platform=self.platform,
|
||||
client_password=client_password
|
||||
)
|
||||
|
||||
# Configure search agent
|
||||
self.searcher_agent = SearcherAgent.create(
|
||||
engine_params=engine_params_for_searcher,
|
||||
search_env=search_env,
|
||||
grounder_agent=self.grounder_agent,
|
||||
platform=self.platform,
|
||||
client_password=self.client_password
|
||||
)
|
||||
|
||||
# Store task instruction for code agent
|
||||
self.current_task_instruction = None
|
||||
self.last_code_agent_result = None
|
||||
self.last_search_agent_result = None
|
||||
self.notes: List[str] = []
|
||||
# Tutorial should be a global info, not a local context, so how to add it to the global info
|
||||
self.tutorials = []
|
||||
|
||||
|
||||
def assign_screenshot(self, obs):
|
||||
self.obs = obs
|
||||
|
||||
# Given the state and worker's text phrase, generate the coords of the first/last word in the phrase
|
||||
def generate_text_coords(
|
||||
self, phrase: str, obs: Dict, alignment: str = ""
|
||||
) -> List[int]:
|
||||
|
||||
screenshot, global_offset_x, global_offset_y= obs["screenshot"], 0, 0
|
||||
|
||||
ocr_table, ocr_elements = self.ocr_processor.get_ocr_elements(screenshot, "easyocr")
|
||||
|
||||
alignment_prompt = ""
|
||||
if alignment == "start":
|
||||
alignment_prompt = "**Important**: Output the word id of the FIRST word in the provided phrase.\n"
|
||||
elif alignment == "end":
|
||||
alignment_prompt = "**Important**: Output the word id of the LAST word in the provided phrase.\n"
|
||||
|
||||
# Load LLM prompt
|
||||
self.text_span_agent.reset()
|
||||
self.text_span_agent.add_message(
|
||||
alignment_prompt + "Phrase: " + phrase + "\n" + ocr_table, role="user"
|
||||
)
|
||||
self.text_span_agent.add_message(
|
||||
"Screenshot:\n", image_content=screenshot, role="user"
|
||||
)
|
||||
|
||||
# Obtain the target element
|
||||
response = call_llm_safe(self.text_span_agent)
|
||||
print("TEXT SPAN AGENT RESPONSE:", response)
|
||||
numericals = re.findall(r"\d+", response)
|
||||
if len(numericals) > 0:
|
||||
text_id = int(numericals[-1])
|
||||
else:
|
||||
text_id = 0
|
||||
elem = ocr_elements[text_id]
|
||||
|
||||
# Compute the element coordinates
|
||||
# Note: 0.1 * elem["height"] is used to adjust coordinates to select the last character more precisely.
|
||||
if alignment == "start":
|
||||
coords = [elem["left"], elem["top"] + (elem["height"] // 2)]
|
||||
elif alignment == "end":
|
||||
coords = [elem["left"] + elem["width"] + 0.15 * elem["height"], elem["top"] + (elem["height"] // 2)]
|
||||
|
||||
print(f'[OCR] output coordinates: {[coords[0] + global_offset_x, coords[1] + global_offset_y]}')
|
||||
return [int(coords[0] + global_offset_x), int(coords[1] + global_offset_y)]
|
||||
|
||||
def set_task_instruction(self, task_instruction: str):
|
||||
"""Set the current task instruction for the code agent."""
|
||||
self.current_task_instruction = task_instruction
|
||||
|
||||
@agent_action
|
||||
def click(
|
||||
self,
|
||||
element_description: str,
|
||||
num_clicks: int = 1,
|
||||
button_type: str = "left",
|
||||
hold_keys: List = []
|
||||
):
|
||||
"""Click on the element
|
||||
Args:
|
||||
element_description:str, a detailed descriptions of which element to click on. This description needs to be VERY unambiguous. If the page contains many similar elements, ensure the description uniquely identifies the target element.
|
||||
num_clicks:int, number of times to click the element
|
||||
button_type:str, which mouse button to press can be "left", "middle", or "right"
|
||||
hold_keys:List, list of keys to hold while clicking
|
||||
"""
|
||||
x, y = self.grounder_agent.generate_coords(element_description, self.obs)
|
||||
|
||||
command = "import pyautogui; "
|
||||
|
||||
for k in hold_keys:
|
||||
command += f"pyautogui.keyDown({repr(k)}); "
|
||||
command += f"""import pyautogui; pyautogui.click({x}, {y}, clicks={num_clicks}, button={repr(button_type)}); """
|
||||
for k in hold_keys:
|
||||
command += f"pyautogui.keyUp({repr(k)}); "
|
||||
# Return pyautoguicode to click on the element
|
||||
|
||||
action = {"function": "click", "args": {"x": x, "y": y, "button": button_type, "clicks": num_clicks}}
|
||||
return (command, action)
|
||||
|
||||
@agent_action
|
||||
def open(self, app_or_filename: str):
|
||||
"""Open any application or file with name app_or_filename. Use this action to open applications or files on the desktop, do not open manually.
|
||||
Args:
|
||||
app_or_filename:str, the name of the application or filename to open
|
||||
|
||||
**Important**:
|
||||
Provide only the name of the application or file. Do not include the full path (e.g., "/home/user/Desktop/my_report.docx"). The function works by searching for the name, not by accessing a file path directly.
|
||||
"""
|
||||
action = {"function": "open", "args": {"name": app_or_filename}}
|
||||
if self.platform == "linux":
|
||||
return (f"import pyautogui; pyautogui.hotkey('win'); time.sleep(1.0); pyautogui.write({repr(app_or_filename)}); time.sleep(1.0); pyautogui.hotkey('enter'); time.sleep(1.0)", action)
|
||||
elif self.platform == "macos":
|
||||
return (f"import pyautogui; import time; pyautogui.hotkey('command', 'space', interval=0.5); pyautogui.typewrite({repr(app_or_filename)}); pyautogui.press('enter'); time.sleep(1.0)", action)
|
||||
elif self.platform == "windows":
|
||||
return (f"import pyautogui; import time; pyautogui.hotkey('win'); time.sleep(0.5); pyautogui.write({repr(app_or_filename)}); time.sleep(1.0); pyautogui.press('enter'); time.sleep(0.5)", action)
|
||||
else:
|
||||
assert (
|
||||
False
|
||||
), f"Unsupported platform: {self.platform}. Supported platforms are: darwin, linux, windows."
|
||||
|
||||
def _paste(self, is_terminal):
|
||||
if self.platform == 'macos':
|
||||
return "pyautogui.hotkey('command', 'v');"
|
||||
|
||||
elif self.platform == 'linux':
|
||||
if is_terminal:
|
||||
return "pyautogui.hotkey('ctrl', 'shift', 'v');"
|
||||
else:
|
||||
return "pyautogui.hotkey('ctrl', 'v');"
|
||||
|
||||
elif self.platform == 'windows':
|
||||
return "pyautogui.hotkey('ctrl', 'v');"
|
||||
|
||||
return ""
|
||||
|
||||
def _clear_all(self, is_terminal):
|
||||
"""
|
||||
Clean the content of current line
|
||||
"""
|
||||
# common apps in GUI
|
||||
if not is_terminal:
|
||||
if self.platform == 'macos':
|
||||
# macOS GUI: Command + A -> Backspace
|
||||
return "pyautogui.hotkey('command', 'a'); pyautogui.press('backspace');"
|
||||
else:
|
||||
# Windows/Linux GUI: Ctrl + A -> Backspace
|
||||
return "pyautogui.hotkey('ctrl', 'a'); pyautogui.press('backspace');"
|
||||
|
||||
# terminal
|
||||
else:
|
||||
if self.platform == 'windows':
|
||||
return "pyautogui.press('esc');"
|
||||
else:
|
||||
return "pyautogui.hotkey('ctrl', 'e'); pyautogui.hotkey('ctrl', 'u');"
|
||||
|
||||
def _type(
|
||||
self,
|
||||
text: str,
|
||||
is_terminal: bool
|
||||
):
|
||||
"""
|
||||
use copy and paste to input Chinese, otherwise type normally
|
||||
"""
|
||||
commands = ""
|
||||
has_unicode = any(ord(char) > 127 for char in text)
|
||||
if has_unicode and self.platform != "macos":
|
||||
commands += (
|
||||
"original_clipboard = pyperclip.paste();"
|
||||
f"pyperclip.copy({repr(text)});"
|
||||
"time.sleep(0.1);"
|
||||
)
|
||||
commands += self._paste(is_terminal=is_terminal)
|
||||
commands += "pyperclip.copy(original_clipboard);"
|
||||
else:
|
||||
commands += f"pyautogui.write({repr(text)}, interval=0.1);"
|
||||
|
||||
return commands
|
||||
|
||||
@agent_action
|
||||
def type(
|
||||
self,
|
||||
element_description: str,
|
||||
text: str = "",
|
||||
overwrite: bool = False,
|
||||
enter: bool = False,
|
||||
is_terminal = False
|
||||
):
|
||||
"""Type text/unicode into a specific element
|
||||
Args:
|
||||
element_description: str, a detailed description of which element to enter text in. If provided, the agent will click on this element before typing.
|
||||
text:str, the text to type
|
||||
overwrite:bool, Default is False, assign it to True if the text should overwrite the whole existing text. Using this argument clears all text in an element.
|
||||
enter:bool, Assign it to True if the enter key should be pressed after typing all the text, otherwise assign it to False.
|
||||
is_terminal:bool, (MANDATORY) You MUST set this to True whenever the target you will type into is a terminal.
|
||||
"""
|
||||
commands = (
|
||||
"import os;"
|
||||
"import pyautogui;"
|
||||
"import pyperclip;"
|
||||
"import subprocess;"
|
||||
"import time;"
|
||||
)
|
||||
|
||||
|
||||
if self.platform == "linux":
|
||||
commands += (
|
||||
"p_http = os.environ.get('http_proxy') or os.environ.get('HTTP_PROXY');"
|
||||
"p_https = os.environ.get('https_proxy') or os.environ.get('HTTPS_PROXY');"
|
||||
"proxy_prefix = (f'http_proxy={p_http} ' if p_http else '') + (f'https_proxy={p_https} ' if p_https else '');"
|
||||
f"subprocess.run(f'echo \"{self.client_password}\" | sudo -S {{proxy_prefix}}apt-get install -y xclip xsel', shell=True, check=True);"
|
||||
)
|
||||
|
||||
x, y = None, None
|
||||
if element_description is not None:
|
||||
x, y = self.grounder_agent.generate_coords(element_description, self.obs)
|
||||
commands += (
|
||||
f"pyautogui.click({x}, {y}, clicks=2);"
|
||||
f"time.sleep(1.0);"
|
||||
f"pyautogui.click({x}, {y});"
|
||||
)
|
||||
|
||||
if overwrite:
|
||||
commands += self._clear_all(is_terminal=is_terminal)
|
||||
|
||||
commands += self._type(text=text, is_terminal=is_terminal)
|
||||
|
||||
if enter:
|
||||
commands += "pyautogui.press('enter');"
|
||||
|
||||
if element_description is not None:
|
||||
action = {"function": "type", "args": {"x": x, "y": y, "text": text}}
|
||||
else:
|
||||
action = {"function": "type", "args": {"text": text}}
|
||||
return (commands, action)
|
||||
|
||||
@agent_action
|
||||
def drag_and_drop(
|
||||
self, starting_description: str, ending_description: str, hold_keys: List = []
|
||||
):
|
||||
"""Drag from the starting description to the ending description
|
||||
Args:
|
||||
starting_description:str, a very detailed description of where to start the drag action. This description should be at least a full sentence.
|
||||
ending_description:str, a very detailed description of where to end the drag action. This description should be at least a full sentence.
|
||||
hold_keys:List list of keys to hold while dragging
|
||||
"""
|
||||
x1, y1 = self.grounder_agent.generate_coords(starting_description, self.obs)
|
||||
x2, y2 = self.grounder_agent.generate_coords(ending_description, self.obs)
|
||||
|
||||
command = "import pyautogui; "
|
||||
|
||||
command += f"pyautogui.moveTo({x1}, {y1}); "
|
||||
# TODO: specified duration?
|
||||
for k in hold_keys:
|
||||
command += f"pyautogui.keyDown({repr(k)}); "
|
||||
command += f"pyautogui.dragTo({x2}, {y2}, duration=3., button='left'); pyautogui.mouseUp(); "
|
||||
for k in hold_keys:
|
||||
command += f"pyautogui.keyUp({repr(k)}); "
|
||||
|
||||
# Return pyautoguicode to drag and drop the elements
|
||||
action = {"function": "drag", "args": {"x1": x1, "y1": y1, "x2": x2, "y2": y2}}
|
||||
return (command, action)
|
||||
|
||||
@agent_action
|
||||
def highlight_text_span(
|
||||
self,
|
||||
starting_phrase: str,
|
||||
ending_phrase: str,
|
||||
button: str = "left",
|
||||
text: Optional[str|None] = None
|
||||
):
|
||||
"""Highlight a text span between a provided starting phrase and ending phrase. Use this to highlight words, lines, and paragraphs.
|
||||
Args:
|
||||
starting_phrase: str, the sequence of words that marks the beginning of the text span. Provide a unique sequence of 5 to 10 words.
|
||||
ending_phrase: str, the sequence of words that marks the end of the text span. Provide a unique sequence of 5 to 10 words.
|
||||
button:str, the button to use to highlight the text span. Defaults to "left". Can be "left", "right", or "middle".
|
||||
text: str | None, The text to overwrite the highlighted span with. Providing text here ensures the replacement happens immediately after selection, preventing focus loss.
|
||||
"""
|
||||
x1, y1 = self.generate_text_coords(
|
||||
starting_phrase, self.obs, alignment="start"
|
||||
)
|
||||
x2, y2 = self.generate_text_coords(
|
||||
ending_phrase, self.obs, alignment="end"
|
||||
)
|
||||
|
||||
command = "import pyautogui; import time;"
|
||||
command += f"pyautogui.moveTo({x1}, {y1}); "
|
||||
# Click in advance to simulate selecting the text box.
|
||||
command += (
|
||||
f"pyautogui.click({x1}, {y1}, clicks=2);"
|
||||
f"time.sleep(1.0); pyautogui.click({x1}, {y1}); time.sleep(1.0);"
|
||||
)
|
||||
command += f"pyautogui.dragTo({x2}, {y2}, duration=5., button='{button}'); time.sleep(0.5); pyautogui.mouseUp(); "
|
||||
|
||||
if text:
|
||||
if self.platform == "linux":
|
||||
command += "subprocess.run('echo \"password\" | sudo -S apt-get install -y xclip xsel', shell=True, check=True, env={\"http_proxy\": \"http://10.1.8.5:23128\", \"https_proxy\": \"http://10.1.8.5:23128\"});"
|
||||
|
||||
command += (
|
||||
"original_clipboard = pyperclip.paste();"
|
||||
f"pyperclip.copy({repr(text)});"
|
||||
)
|
||||
command += self._paste(is_terminal=False)
|
||||
command += "pyperclip.copy(original_clipboard);"
|
||||
|
||||
# Return pyautoguicode to drag and drop the elements
|
||||
action = {"function": "drag", "args": {"x1": x1, "y1": y1, "x2": x2, "y2": y2}}
|
||||
return (command, action)
|
||||
|
||||
@agent_action
|
||||
def locate_cursor(
|
||||
self,
|
||||
phrase: str,
|
||||
start_or_end: str="start",
|
||||
text: Optional[str|None] = None
|
||||
):
|
||||
"""Click at the beginning or end of a specific text phrase to precisely control cursor positioning. Please prefer using the "click" action in general situations, and use this action only in text-intensive software such as libreoffice_writer, impress, etc.
|
||||
|
||||
Args:
|
||||
phrase: str, The text phrase where you want to position the cursor. Provide a unique sequence of 5 to 10 words. Do NOT use single words unless the total text is extremely short.
|
||||
start_or_end: str, Whether to click at the "start" (beginning) or "end" (trailing edge) of the identified text phrase. Use "start" to position before the text, "end" to position after it.
|
||||
text: str | None, The text to enter immediately after positioning the cursor. Use this parameter instead of a separate 'type' action to ensure precise input.
|
||||
"""
|
||||
x, y = self.generate_text_coords(
|
||||
phrase, self.obs, alignment=start_or_end
|
||||
)
|
||||
command = (
|
||||
"import pyautogui;"
|
||||
"import time;"
|
||||
"import subprocess;"
|
||||
"import pyperclip;"
|
||||
f"pyautogui.click({x}, {y}, button='left', clicks=2);"
|
||||
"time.sleep(1.0);"
|
||||
f"pyautogui.click({x}, {y}, button='left');"
|
||||
)
|
||||
if text:
|
||||
if self.platform == "linux":
|
||||
command += "subprocess.run('echo \"password\" | sudo -S apt-get install -y xclip xsel', shell=True, check=True, env={\"http_proxy\": \"http://10.1.8.5:23128\", \"https_proxy\": \"http://10.1.8.5:23128\"});"
|
||||
|
||||
command += self._type(text=text, is_terminal=False)
|
||||
|
||||
if text:
|
||||
action = {"function": "type", "args": {"x": x, "y": y, "text": text}}
|
||||
else:
|
||||
action = {"function": "click", "args": {"x": x, "y": y, "clicks": 1, "button": "left"}}
|
||||
return (command, action)
|
||||
|
||||
|
||||
@agent_action
|
||||
def call_code_agent(self, task: str):
|
||||
"""Calls the code agent to execute a well-defined, self-contained goal that can be completed with code.
|
||||
|
||||
Args:
|
||||
task: str, A specific, self-contained goal that the code agent can work on until completion.
|
||||
|
||||
**🚨 CRITICAL GUIDELINES:**
|
||||
|
||||
**Decompose the Main Objective into Logical Goals:**
|
||||
- You **MUST** break down the overall mission into distinct, logical goals or stages.
|
||||
- Your role is to define *what* needs to be done for a specific stage. The code agent's role is to figure out *how* to do it with code.
|
||||
- Pass only one logical goal at a time. The `task` parameter is **REQUIRED**.
|
||||
|
||||
**Define a Self-Contained, Continuous Goal:**
|
||||
- The `task` you provide should be a single, continuous goal. The code agent is capable of handling a multi-step process internally (e.g., opening a file, processing its data, and then saving it) to achieve this one goal.
|
||||
- **Crucially, do not pass a task that combines multiple distinct objectives.** For example, instead of passing "Analyze the sales data, AND email the result," you should first pass the self-contained goal: "Analyze the sales data." After that goal is complete, you can proceed with the next logical goal (e.g., emailing the result) in a subsequent step.
|
||||
- **If unsure, err on the side of caution.** If a task feels like it has two separate parts, break it down and pass only the first part.
|
||||
- Your instruction must describe the desired end-state, NOT the recipe to get there. Do not specify any solution!
|
||||
|
||||
**Goal Purity is Essential:**
|
||||
- **NEVER** rephrase, paraphrase, or modify the subtask instruction you have decided on. Pass the exact, original wording of the subtask to prevent instruction drift and hallucination.
|
||||
|
||||
Use this for tasks that can be fully accomplished through code execution, particularly for:
|
||||
- Spreadsheet applications: data processing, filtering, sorting, calculations, formulas, data analysis
|
||||
- Document editors: text processing, content editing, formatting, document manipulation
|
||||
- Code editors: code editing, file processing, text manipulation, configuration
|
||||
- Data analysis tools: statistical analysis, data transformation, reporting
|
||||
- File management: bulk operations, file processing, content extraction
|
||||
- System utilities: configuration, setup, automation
|
||||
"""
|
||||
logger.info("=" * 50)
|
||||
logger.info("ACI: Calling Code Agent")
|
||||
logger.info("=" * 50)
|
||||
task_to_execute = task
|
||||
logger.info(f"Executing SUBTASK: {task_to_execute}")
|
||||
|
||||
print("obs keys: ", self.obs.keys())
|
||||
screenshot = self.obs.get("screenshot", "") if self.obs else ""
|
||||
logger.info(f"Screenshot available: {'Yes' if screenshot else 'No'}")
|
||||
|
||||
logger.info("Executing code agent...")
|
||||
|
||||
result = self.coder_agent.execute(
|
||||
task_to_execute, screenshot, self.env.controller
|
||||
)
|
||||
|
||||
# Store the result for the worker to access
|
||||
self.last_code_agent_result = result
|
||||
|
||||
logger.info("Code agent execution completed")
|
||||
logger.info(f"Result - Completion reason: {result['completion_reason']}")
|
||||
logger.info(f"Steps executed: {result['steps_executed']}")
|
||||
logger.info(f"Summary: {result['summary']}")
|
||||
|
||||
logger.info("=" * 50)
|
||||
logger.info("GROUNDING AGENT: Code Agent Call Finished")
|
||||
logger.info("=" * 50)
|
||||
|
||||
action = {"function": "call_code_agent", "args": {"query": task, "result": True if result["completion_reason"] == "DONE" else False}}
|
||||
# Return code to be executed in the environment
|
||||
return ("import time; time.sleep(2.222)", action)
|
||||
|
||||
@agent_action
|
||||
def scroll(self, element_description: str, clicks: int, shift: bool = False):
|
||||
"""Scroll the element in the specified direction
|
||||
Args:
|
||||
element_description:str, a very detailed description of which element to enter scroll in. This description should be at least a full sentence.
|
||||
clicks:int, the number of clicks to scroll can be positive (up) or negative (down).
|
||||
shift:bool, whether to use shift+scroll for horizontal scrolling
|
||||
"""
|
||||
x, y = self.grounder_agent.generate_coords(element_description, self.obs)
|
||||
action = {"function": "scroll", "args": {"x": x, "y": y, "clicks": clicks, "shift": shift}}
|
||||
if shift:
|
||||
return (f"import pyautogui; import time; pyautogui.moveTo({x}, {y}); time.sleep(0.5); pyautogui.hscroll({clicks})", action)
|
||||
else:
|
||||
return (f"import pyautogui; import time; pyautogui.moveTo({x}, {y}); time.sleep(0.5); pyautogui.vscroll({clicks})", action)
|
||||
|
||||
@agent_action
|
||||
def hotkey(self, keys: List):
|
||||
"""Press a hotkey combination (can press a single key as well)
|
||||
Args:
|
||||
keys:List the keys to press in combination in a list format (e.g. ['ctrl', 'c'], ['enter'])
|
||||
"""
|
||||
# add quotes around the keys
|
||||
keys = [f"'{key}'" for key in keys]
|
||||
keys_string = " ".join(keys)
|
||||
action = {"function": "key", "args": {"keys": keys_string}}
|
||||
return (f"import pyautogui; pyautogui.hotkey({', '.join(keys)});", action)
|
||||
|
||||
@agent_action
|
||||
def hold_and_press(self, hold_keys: List, press_keys: List):
|
||||
"""Hold a list of keys and press a list of keys
|
||||
Args:
|
||||
hold_keys:List, list of keys to hold
|
||||
press_keys:List, list of keys to press in a sequence
|
||||
"""
|
||||
|
||||
press_keys_str = "[" + ", ".join([f"'{key}'" for key in press_keys]) + "]"
|
||||
command = "import pyautogui; "
|
||||
for k in hold_keys:
|
||||
command += f"pyautogui.keyDown({repr(k)}); "
|
||||
command += f"pyautogui.press({press_keys_str}); "
|
||||
for k in hold_keys:
|
||||
command += f"pyautogui.keyUp({repr(k)}); "
|
||||
|
||||
hold_keys_string = " ".join(hold_keys)
|
||||
press_keys_string = " ".join(press_keys)
|
||||
action = {"function": "key", "args": {"keys": hold_keys_string + ";" + press_keys_string}}
|
||||
return (command, action)
|
||||
|
||||
@agent_action
|
||||
def wait(self, time: float):
|
||||
"""Wait for a specified amount of time
|
||||
Args:
|
||||
time:float, the amount of time to wait in seconds
|
||||
"""
|
||||
return (f"""import time; time.sleep({time});""", {"function": "wait", "args": {}})
|
||||
|
||||
@agent_action
|
||||
def done(
|
||||
self,
|
||||
):
|
||||
"""
|
||||
End the current task with a success. Use this when you believe the entire task has been fully completed. You must ensure all visual information aligns with the user's true intent.
|
||||
"""
|
||||
return ("""DONE""", {"function": "done", "args": {}})
|
||||
|
||||
@agent_action
|
||||
def fail(self):
|
||||
"""End the current task with a failure. Use this when you believe the entire task is impossible to complete."""
|
||||
return ("""FAIL""", {"function": "fail", "args": {}})
|
||||
|
||||
@agent_action
|
||||
def call_search_agent(
|
||||
self,
|
||||
query: str,
|
||||
):
|
||||
"""
|
||||
Calls a specialized 'Searcher Agent' to find a detailed, step-by-step tutorial on the internet for a specific GUI action.
|
||||
Args:
|
||||
query:str, the search phrase or question for the tutorial. The formulation of this query is critical for success and must follow the guidelines below.
|
||||
|
||||
**Query Formulation Guidelines:**
|
||||
|
||||
Your query must be a well-defined question targeting a **single, specific action** within a **specific application**. To get the best results, adhere to these rules:
|
||||
|
||||
1. **Start with "How to":** Your query must begin with the phrase "How to" to frame it as a request for instructions.
|
||||
2. **Include the Application Name:** Always specify the name of the software you are working in (e.g., "GIMP", "Google Chrome", "Libreoffice Writer").
|
||||
3. **Focus on a Single Intent:** The query should represent one clear goal. Do not combine multiple steps or tasks into one query.
|
||||
4. **Be Specific, Not Abstract:** Ask a concrete question. Avoid repeating the user's high-level or abstract instructions.
|
||||
5. **Decompose Complex Tasks:** If the user's overall instruction involves multiple actions (e.g., "download a file and then email it"), and you are stuck on one part, search *only for that specific part*.
|
||||
|
||||
**Examples:**
|
||||
|
||||
* **User's Overall Instruction:** "Please help me download my latest bank statement and then send it to my accountant."
|
||||
* **Correct Query (if stuck on downloading):** "How to download a bank statement from the Bank of America website?"
|
||||
* **Correct Query (if stuck on attaching a file):** "How to attach a file to an email in Gmail?"
|
||||
* **Incorrect Query:** "Download my bank statement and email it to my accountant" *(This query is too broad, contains multiple sub-tasks, and does not start with "How to".)*
|
||||
"""
|
||||
logger.info("=" * 50)
|
||||
logger.info(f"ACI: Calling Search Agent(query={query})")
|
||||
logger.info("=" * 50)
|
||||
self.searcher_agent.result_dir = self.result_dir
|
||||
result = self.searcher_agent.search(query=query, main_obs=self.obs)
|
||||
self.last_search_agent_result = result
|
||||
if result["completion_reason"] == "DONE":
|
||||
self.tutorials.append(result["final_answer"])
|
||||
action = {"function": "call_search_agent", "args": {"query": query, "result": True if result["completion_reason"] == "DONE" else False}}
|
||||
return ("import time; time.sleep(2.222)", action)
|
||||
|
||||
|
|
@ -0,0 +1,61 @@
|
|||
import logging
|
||||
import platform
|
||||
from typing import Dict, List, Tuple
|
||||
from mm_agents.os_symphony.agents.os_aci import OSACI
|
||||
from mm_agents.os_symphony.agents.searcher_agent import VLMSearcherAgent
|
||||
from mm_agents.os_symphony.agents.worker import Worker
|
||||
|
||||
logger = logging.getLogger("desktopenv.agent")
|
||||
|
||||
class OSSymphony:
|
||||
def __init__(
|
||||
self,
|
||||
engine_params_for_orchestrator: Dict,
|
||||
engine_params_for_memoryer: Dict,
|
||||
os_aci: OSACI,
|
||||
platform: str = platform.system().lower(),
|
||||
client_password: str = "",
|
||||
max_trajectory_length: int = 8,
|
||||
enable_reflection: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
worker_engine_params: Configuration parameters for the worker agent.
|
||||
grounding_agent: Instance of ACI class for UI interaction
|
||||
platform: Operating system platform (darwin, linux, windows)
|
||||
max_trajectory_length: Maximum number of image turns to keep
|
||||
enable_reflection: Creates a reflection agent to assist the worker agent
|
||||
"""
|
||||
|
||||
self.engine_params_for_orchestrator = engine_params_for_orchestrator
|
||||
self.engine_params_for_memoryer = engine_params_for_memoryer
|
||||
self.os_aci: OSACI = os_aci
|
||||
self.platform =platform
|
||||
self.client_password = client_password
|
||||
self.max_trajectory_length = max_trajectory_length
|
||||
self.enable_reflection = enable_reflection
|
||||
|
||||
def reset(self, result_dir) -> None:
|
||||
"""Reset agent state and initialize components"""
|
||||
# Reset the search time per task
|
||||
self.os_aci.result_dir = result_dir
|
||||
self.executor = Worker(
|
||||
engine_params_for_orchestrator=self.engine_params_for_orchestrator,
|
||||
engine_params_for_memoryer=self.engine_params_for_memoryer,
|
||||
os_aci=self.os_aci,
|
||||
platform=self.platform,
|
||||
client_password=self.client_password,
|
||||
max_trajectory_length=self.max_trajectory_length,
|
||||
enable_reflection=self.enable_reflection,
|
||||
)
|
||||
|
||||
def predict(self, instruction: str, observation: Dict, is_last_step: bool) -> Tuple[Dict, List[str]]:
|
||||
# Initialize the three info dictionaries
|
||||
executor_info, actions = self.executor.generate_next_action(
|
||||
instruction=instruction, obs=observation, is_last_step=is_last_step
|
||||
)
|
||||
|
||||
# concatenate the three info dictionaries
|
||||
info = {**{k: v for d in [executor_info or {}] for k, v in d.items()}}
|
||||
|
||||
return info, actions
|
||||
|
|
@ -0,0 +1,478 @@
|
|||
import logging
|
||||
import urllib.parse
|
||||
from typing import Any, Dict, List, Optional
|
||||
from mm_agents.os_symphony.memory.procedural_memory import PROCEDURAL_MEMORY
|
||||
from mm_agents.os_symphony.utils.common_utils import (
|
||||
draw_coordinates,
|
||||
call_llm_formatted,
|
||||
parse_code_from_string,
|
||||
create_pyautogui_code
|
||||
)
|
||||
from mm_agents.os_symphony.core.mllm import LMMAgent
|
||||
from mm_agents.os_symphony.agents.grounder_agent import GrounderAgent
|
||||
import os
|
||||
import time
|
||||
import json
|
||||
|
||||
|
||||
logger = logging.getLogger("desktopenv.searcher_agent")
|
||||
|
||||
# Agent action decorator
|
||||
def searcher_agent_action(func):
|
||||
func.is_searcher_agent_action = True
|
||||
return func
|
||||
|
||||
|
||||
# --- Abstract Base Class and Factory ---
|
||||
class SearcherAgent:
|
||||
def __init__(self, engine_params: Dict, platform: str):
|
||||
self.engine_params = engine_params
|
||||
self.result_dir = ""
|
||||
self.tutorial_or_hint = ""
|
||||
self.tutorial_notes = []
|
||||
self.max_trajectory_length = 8
|
||||
self.platform = platform
|
||||
self.budget = engine_params.get("budget", 20)
|
||||
|
||||
@staticmethod
|
||||
def create(engine_params: Dict, search_env, grounder_agent: GrounderAgent, platform: str, client_password: str="password"):
|
||||
searcher_type = engine_params.get("type", "vlm")
|
||||
if searcher_type == "vlm":
|
||||
return VLMSearcherAgent(engine_params=engine_params, search_env=search_env, grounder_agent=grounder_agent, platform=platform, client_password=client_password)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_search_time(self) -> int:
|
||||
"""for the name of result directory"""
|
||||
if not self.result_dir: return 1
|
||||
search_times: list[int] = []
|
||||
try:
|
||||
if not os.path.exists(self.result_dir): return 1
|
||||
for item_name in os.listdir(self.result_dir):
|
||||
full_path = os.path.join(self.result_dir, item_name)
|
||||
if os.path.isdir(full_path) and item_name.startswith("search_"):
|
||||
try:
|
||||
time_val = int(item_name.split('_', 1)[1])
|
||||
search_times.append(time_val)
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
except Exception:
|
||||
return 1
|
||||
if not search_times: return 1
|
||||
return max(search_times) + 1
|
||||
|
||||
def search(self, query: str, obs) -> str:
|
||||
"""
|
||||
Args:
|
||||
query: Format like "How to xxxx?", must be a detailed subtask
|
||||
obs: Current screenshot
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement the 'search' method")
|
||||
|
||||
class VLMSearcherAgent(SearcherAgent):
|
||||
"""
|
||||
Start a new, isolated vm, and open chrome in advance
|
||||
"""
|
||||
def __init__(self, engine_params: Dict, search_env, grounder_agent: GrounderAgent, platform: str, client_password: str):
|
||||
SearcherAgent.__init__(self, engine_params=engine_params, platform=platform)
|
||||
|
||||
self.grounder_agent = grounder_agent
|
||||
self.client_password = client_password
|
||||
self.env = search_env
|
||||
|
||||
self.use_thinking = engine_params.get("model", "") in [
|
||||
"claude-opus-4-20250514",
|
||||
"claude-sonnet-4-20250514",
|
||||
"claude-3-7-sonnet-20250219",
|
||||
"claude-sonnet-4-5-20250929",
|
||||
]
|
||||
|
||||
self.engine = engine_params.get("engine", "google")
|
||||
|
||||
# Reuse OSWorld's initialization script to set up Chrome, then directly perform a Google search using the query—currently, the query can be substituted by a placeholder field.
|
||||
self.task_config = {
|
||||
"id": "searcher",
|
||||
"instruction": "searcher",
|
||||
"config": [
|
||||
{
|
||||
"type": "launch",
|
||||
"parameters": {
|
||||
"command": [
|
||||
"google-chrome",
|
||||
"--remote-debugging-port=1337"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "launch",
|
||||
"parameters": {
|
||||
"command": [
|
||||
"socat",
|
||||
"tcp-listen:9222,fork",
|
||||
"tcp:localhost:1337"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "chrome_open_tabs",
|
||||
"parameters": {
|
||||
"urls_to_open": [
|
||||
"GOOGLE_SEARCH_URL"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "activate_window",
|
||||
"parameters": {
|
||||
"window_name": "Google Chrome"
|
||||
}
|
||||
}
|
||||
],
|
||||
"proxy": True
|
||||
}
|
||||
self.obs = None
|
||||
|
||||
def reset(self, query):
|
||||
# When the search function is invoked, a new agent is created; the environment is instantiated only upon the first call, but it must be reset on every invocation.
|
||||
self.tutorial_notes = []
|
||||
self.tutorial_or_hint = ""
|
||||
self.system_prompt = PROCEDURAL_MEMORY.construct_vlm_searcher_procedural_memory(
|
||||
agent_class=type(self)
|
||||
).replace("CURRENT_OS", self.platform).replace("QUERY", query)
|
||||
self.searcher_agent = LMMAgent(
|
||||
engine_params=self.engine_params,
|
||||
system_prompt=self.system_prompt
|
||||
)
|
||||
self.env.start()
|
||||
# config URL and initialize search environment (google/duckduckgo)
|
||||
search_url = f"https://www.google.com/search?q=" + urllib.parse.quote_plus(query) if self.engine == "google" else f"https://www.duckduckgo.com/?q=" + urllib.parse.quote_plus(query)
|
||||
self.task_config["config"][2]["parameters"]["urls_to_open"][0] = search_url
|
||||
|
||||
self.env.reset(task_config=self.task_config)
|
||||
print("[Searcher] sleeping...")
|
||||
time.sleep(5)
|
||||
|
||||
def flush_messages(self):
|
||||
"""Flush messages based on the model's context limits.
|
||||
|
||||
This method ensures that the agent's message history does not exceed the maximum trajectory length.
|
||||
|
||||
Side Effects:
|
||||
- Modifies the messages of generator, reflection, and bon_judge agents to fit within the context limits.
|
||||
"""
|
||||
engine_type = self.engine_params.get("engine_type", "")
|
||||
|
||||
# Flush strategy for long-context models: keep all text, only keep latest images
|
||||
if engine_type in ["anthropic", "openai", "gemini"]:
|
||||
max_images = self.max_trajectory_length
|
||||
for agent in [self.searcher_agent]:
|
||||
if agent is None:
|
||||
continue
|
||||
# keep latest k images
|
||||
# @Yang: keep the first main agent image
|
||||
img_count = 0
|
||||
for i in range(len(agent.messages) - 1, 1, -1):
|
||||
for j in range(len(agent.messages[i]["content"]) - 1, -1, -1):
|
||||
if "image" in agent.messages[i]["content"][j].get("type", ""):
|
||||
img_count += 1
|
||||
if img_count > max_images:
|
||||
del agent.messages[i]["content"][j]
|
||||
|
||||
# Flush strategy for non-long-context models: drop full turns
|
||||
else:
|
||||
# generator msgs are alternating [user, assistant], so 2 per round
|
||||
if len(self.searcher_agent.messages) > 2 * self.max_trajectory_length + 1:
|
||||
self.searcher_agent.messages.pop(1)
|
||||
self.searcher_agent.messages.pop(1)
|
||||
|
||||
def assign_screenshot(self, obs):
|
||||
self.obs = obs
|
||||
|
||||
def search(self, query: str, main_obs):
|
||||
# only create vm when search is called
|
||||
self.reset(query=query) # reset
|
||||
search_result_dir = os.path.join(self.result_dir, f"search_{self._get_search_time()}")
|
||||
os.makedirs(search_result_dir, exist_ok=True)
|
||||
|
||||
obs = self.env._get_obs() # Get the initial observation
|
||||
step_idx = 0
|
||||
initial_state_text = (
|
||||
"This screenshot shows the current visual context of the main GUI Agent you are assisting. "
|
||||
"Use this image to understand the application, the current view, and the overall environment. "
|
||||
"Your primary goal is to find a tutorial that is highly relevant and well-aligned with this specific context, "
|
||||
"ensuring the instructions you find are applicable to what the main agent is currently seeing."
|
||||
)
|
||||
self.searcher_agent.add_message(
|
||||
text_content=initial_state_text,
|
||||
image_content=main_obs["screenshot"],
|
||||
role="user"
|
||||
)
|
||||
execution_history = []
|
||||
completion_reason = ""
|
||||
final_answer = ""
|
||||
|
||||
while step_idx < self.budget:
|
||||
# update system_prompt dynamically
|
||||
tutorial_notes_str = ""
|
||||
if len(self.tutorial_notes) > 0:
|
||||
for i, note in enumerate(self.tutorial_notes, 1):
|
||||
tutorial_notes_str += f"Tutorial Note {i}: {note}\n\n"
|
||||
|
||||
if step_idx == self.budget - 1:
|
||||
# eager mode
|
||||
self.system_prompt = PROCEDURAL_MEMORY.construct_searcher_eager_mode_procedural_memory(
|
||||
agent_class=type(self)
|
||||
).replace("CURRENT_OS", self.platform).replace("QUERY", query)
|
||||
|
||||
system_prompt = self.system_prompt.replace("TUTORIAL_PLACEHOLDER", tutorial_notes_str)
|
||||
self.searcher_agent.add_system_prompt(system_prompt=system_prompt)
|
||||
|
||||
# start a new turn
|
||||
self.assign_screenshot(obs=obs)
|
||||
generator_message = ""
|
||||
|
||||
self.searcher_agent.add_message(
|
||||
generator_message, image_content=obs["screenshot"], role="user"
|
||||
)
|
||||
format_checkers = []
|
||||
|
||||
# predict action
|
||||
plan = call_llm_formatted(
|
||||
self.searcher_agent,
|
||||
format_checkers,
|
||||
temperature=self.engine_params.get("temperture", 0.1),
|
||||
use_thinking=self.use_thinking,
|
||||
)
|
||||
|
||||
self.searcher_agent.add_message(plan, role="assistant")
|
||||
execution_history.append(plan)
|
||||
logger.info("SEARCHER PLAN:\n %s", plan)
|
||||
|
||||
plan_code = parse_code_from_string(plan)
|
||||
try:
|
||||
assert plan_code, "Plan code should not be empty"
|
||||
# exec_code e.g. import pyautogui; pyautogui.click(1, 2);
|
||||
exec_code, coords = create_pyautogui_code(self, plan_code, obs)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Could not evaluate the following plan code:\n{plan_code}\nError: {e}"
|
||||
)
|
||||
exec_code = self.wait(
|
||||
1.333
|
||||
) # Skip a turn if the code cannot be evaluated
|
||||
|
||||
self.flush_messages()
|
||||
|
||||
# execute action
|
||||
action = exec_code
|
||||
logger.info("Step %d: %s", step_idx + 1, action)
|
||||
|
||||
# Save screenshot and trajectory information
|
||||
with open(os.path.join(search_result_dir, f"step_{step_idx + 1}.png"),
|
||||
"wb") as _f:
|
||||
_f.write(obs['screenshot'])
|
||||
|
||||
if coords is not None and isinstance(coords, list):
|
||||
draw_coordinates(
|
||||
image_bytes=obs['screenshot'],
|
||||
coordinates=coords,
|
||||
save_path=os.path.join(search_result_dir, f"step_{step_idx + 1}_draw.png")
|
||||
)
|
||||
|
||||
with open(os.path.join(search_result_dir, "traj.jsonl"), "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps({
|
||||
"query": query,
|
||||
"step_num": step_idx + 1,
|
||||
"action": action,
|
||||
"response": {
|
||||
"plan": plan,
|
||||
"plan_code": plan_code,
|
||||
"coordinates": coords
|
||||
},
|
||||
"screenshot_file": f"step_{step_idx + 1}.png"
|
||||
}, ensure_ascii=False))
|
||||
f.write("\n")
|
||||
|
||||
with open(os.path.join(search_result_dir, f"traj_{step_idx+1}.json"), "w", encoding="utf-8") as f:
|
||||
json.dump({
|
||||
"query": query,
|
||||
"step_num": step_idx + 1,
|
||||
"action": action,
|
||||
"response": {
|
||||
"plan": plan,
|
||||
"plan_code": plan_code,
|
||||
"coordinates": coords
|
||||
},
|
||||
"screenshot_file": f"step_{step_idx + 1}.png"
|
||||
}, f, indent=4, ensure_ascii=False)
|
||||
|
||||
if exec_code in ["DONE", "FAIL"]:
|
||||
# terminate loop
|
||||
completion_reason = exec_code
|
||||
final_answer = self.tutorial_or_hint
|
||||
break
|
||||
else:
|
||||
obs, _, _, _ = self.env.step(action, 5)
|
||||
|
||||
step_idx += 1
|
||||
|
||||
if completion_reason == "":
|
||||
completion_reason = "BUDGET_EXHAUSTED"
|
||||
final_answer = "Sorry, can't get the useful tutorial about the GUI task you provided."
|
||||
|
||||
return {
|
||||
"query": query,
|
||||
"completion_reason": completion_reason,
|
||||
"tutorial_notes": self.tutorial_notes,
|
||||
"execution_history": execution_history,
|
||||
"steps_executed": step_idx,
|
||||
"budget": self.budget,
|
||||
"final_answer": final_answer,
|
||||
}
|
||||
|
||||
@searcher_agent_action
|
||||
def click(
|
||||
self,
|
||||
element_description: str,
|
||||
num_clicks: int = 1,
|
||||
button_type: str = "left",
|
||||
):
|
||||
"""Click on the element
|
||||
Args:
|
||||
element_description:str, a detailed descriptions of which element to click on. This description should be at least a full sentence.
|
||||
num_clicks:int, number of times to click the element
|
||||
button_type:str, which mouse button to press can be "left", "middle", or "right"
|
||||
"""
|
||||
x, y = self.grounder_agent.generate_coords(element_description, self.obs)
|
||||
command = "import pyautogui; "
|
||||
command += f"""import pyautogui; pyautogui.click({x}, {y}, clicks={num_clicks}, button={repr(button_type)}); """
|
||||
|
||||
# Return pyautoguicode to click on the element
|
||||
return (command, [x, y])
|
||||
|
||||
@searcher_agent_action
|
||||
def type(
|
||||
self,
|
||||
element_description: Optional[str] = None,
|
||||
text: str = "",
|
||||
overwrite: bool = True,
|
||||
enter: bool = False
|
||||
):
|
||||
"""Type text/unicode into a specific element
|
||||
Args:
|
||||
element_description:str, a detailed description of which element to enter text in. This description should be at least a full sentence.
|
||||
text:str, the text to type
|
||||
overwrite:bool, Default is True, assign it to False if the text should not overwrite the existing text. Using this argument clears all text in an element.
|
||||
enter:bool, Assign it to True if the enter key should be pressed after typing the text, otherwise assign it to False.
|
||||
"""
|
||||
commands = (
|
||||
"import os;"
|
||||
"import pyautogui;"
|
||||
"import pyperclip;"
|
||||
"import subprocess;"
|
||||
"import time;"
|
||||
"p_http = os.environ.get('http_proxy') or os.environ.get('HTTP_PROXY');"
|
||||
"p_https = os.environ.get('https_proxy') or os.environ.get('HTTPS_PROXY');"
|
||||
"proxy_prefix = (f'http_proxy={p_http} ' if p_http else '') + (f'https_proxy={p_https} ' if p_https else '');"
|
||||
f"subprocess.run(f'echo \"{self.client_password}\" | sudo -S {{proxy_prefix}}apt-get install -y xclip xsel', shell=True, check=True);"
|
||||
)
|
||||
|
||||
|
||||
|
||||
click_coords = None
|
||||
if element_description is not None:
|
||||
x, y = self.grounder_agent.generate_coords(element_description, self.obs)
|
||||
click_coords = [x, y]
|
||||
|
||||
commands += f"pyautogui.click({x}, {y});"
|
||||
|
||||
if overwrite:
|
||||
commands += (
|
||||
f"pyautogui.hotkey('ctrl', 'a');"
|
||||
"pyautogui.press('backspace');"
|
||||
)
|
||||
|
||||
# use paste to input
|
||||
commands += (
|
||||
"original_clipboard = pyperclip.paste();"
|
||||
f"pyperclip.copy({repr(text)});"
|
||||
"pyautogui.hotkey('ctrl', 'v');"
|
||||
"pyperclip.copy(original_clipboard);"
|
||||
)
|
||||
|
||||
if enter:
|
||||
commands += "pyautogui.press('enter');"
|
||||
|
||||
if click_coords is not None:
|
||||
return (commands, click_coords)
|
||||
else:
|
||||
return commands
|
||||
|
||||
@searcher_agent_action
|
||||
def scroll(self, element_description: str, clicks: int, shift: bool = False):
|
||||
"""Scroll the element in the specified direction
|
||||
Args:
|
||||
element_description:str, a very detailed description of which element to enter scroll in. This description should be at least a full sentence.
|
||||
clicks:int, the number of clicks to scroll can be positive (up) or negative (down).
|
||||
shift:bool, whether to use shift+scroll for horizontal scrolling
|
||||
"""
|
||||
x, y = self.grounder_agent.generate_coords(element_description, self.obs)
|
||||
|
||||
if shift:
|
||||
return (f"import pyautogui; import time; pyautogui.moveTo({x}, {y}); time.sleep(0.5); pyautogui.hscroll({clicks})", [x, y])
|
||||
else:
|
||||
return (f"import pyautogui; import time; pyautogui.moveTo({x}, {y}); time.sleep(0.5); pyautogui.vscroll({clicks})", [x, y])
|
||||
|
||||
@searcher_agent_action
|
||||
def hotkey(self, keys: List):
|
||||
"""Press a hotkey combination (can press a single key as well)
|
||||
Args:
|
||||
keys: List the keys to press in combination in a list format (e.g. ['ctrl', 'c'], ['enter'])
|
||||
"""
|
||||
# add quotes around the keys
|
||||
keys = [f"'{key}'" for key in keys]
|
||||
return f"import pyautogui; pyautogui.hotkey({', '.join(keys)})"
|
||||
|
||||
@searcher_agent_action
|
||||
def save_to_tutorial_notes(self, text: str):
|
||||
"""Save high quality and useful information to a long-term knowledge bank for reuse during this search task.
|
||||
Args:
|
||||
text:str, the text to save to the tutorial notes
|
||||
"""
|
||||
self.tutorial_notes.append(text)
|
||||
return """WAIT"""
|
||||
|
||||
@searcher_agent_action
|
||||
def wait(self, time: float):
|
||||
"""Wait for a specified amount of time
|
||||
Args:
|
||||
time:float the amount of time to wait in seconds
|
||||
"""
|
||||
return f"""import time; time.sleep({time})"""
|
||||
|
||||
@searcher_agent_action
|
||||
def done(
|
||||
self,
|
||||
tutorial: str
|
||||
):
|
||||
"""End the current task with a success. Use this when you believe the entire task has been fully completed.
|
||||
Args:
|
||||
tutorial:str, A detailed, step-by-step tutorial compiled from the search results to be passed to the main agent.
|
||||
"""
|
||||
self.tutorial_or_hint = tutorial
|
||||
return """DONE"""
|
||||
|
||||
@searcher_agent_action
|
||||
def fail(
|
||||
self,
|
||||
hint: str
|
||||
):
|
||||
"""End the current task with a failure. Use this when you believe the entire task is impossible to complete.
|
||||
Args:
|
||||
hint:str, A hint or reason explaining why the search failed, or what kind of information was missing.
|
||||
"""
|
||||
self.tutorial_or_hint = hint
|
||||
return """FAIL"""
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,340 @@
|
|||
from functools import partial
|
||||
import logging
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from mm_agents.os_symphony.agents.memoryer_agent import ReflectionMemoryAgent
|
||||
from mm_agents.os_symphony.agents.os_aci import OSACI
|
||||
from mm_agents.os_symphony.core.module import BaseModule
|
||||
from mm_agents.os_symphony.memory.procedural_memory import PROCEDURAL_MEMORY
|
||||
from mm_agents.os_symphony.utils.common_utils import (
|
||||
call_llm_formatted,
|
||||
extract_coords_from_action_dict,
|
||||
parse_action_from_string,
|
||||
parse_code_from_string,
|
||||
create_pyautogui_code,
|
||||
)
|
||||
from mm_agents.os_symphony.utils.formatters import (
|
||||
SINGLE_ACTION_FORMATTER,
|
||||
CODE_VALID_FORMATTER,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger("desktopenv.agent")
|
||||
|
||||
|
||||
class Worker(BaseModule):
|
||||
def __init__(
|
||||
self,
|
||||
engine_params_for_orchestrator: Dict,
|
||||
engine_params_for_memoryer: Dict,
|
||||
os_aci: OSACI,
|
||||
platform: str,
|
||||
client_password: str,
|
||||
max_trajectory_length: int = 8,
|
||||
enable_reflection: bool = True,
|
||||
):
|
||||
"""
|
||||
Worker receives the main task and generates actions, without the need of hierarchical planning
|
||||
Args:
|
||||
worker_engine_params: Dict
|
||||
Parameters for the worker agent
|
||||
os_aci: Agent
|
||||
The grounding agent to use
|
||||
platform: str
|
||||
OS platform the agent runs on (darwin, linux, windows)
|
||||
max_trajectory_length: int
|
||||
The amount of images turns to keep
|
||||
enable_reflection: bool
|
||||
Whether to enable reflection
|
||||
"""
|
||||
super().__init__(platform=platform)
|
||||
self.client_password = client_password
|
||||
|
||||
self.temperature = engine_params_for_orchestrator.get("temperature", 0.0)
|
||||
self.tool_config = engine_params_for_orchestrator.get("tool_config", "")
|
||||
self.use_thinking = engine_params_for_orchestrator.get("model", "") in [
|
||||
"claude-opus-4-20250514",
|
||||
"claude-sonnet-4-20250514",
|
||||
"claude-3-7-sonnet-20250219",
|
||||
"claude-sonnet-4-5-20250929",
|
||||
]
|
||||
self.engine_params_for_orchestrator = engine_params_for_orchestrator
|
||||
self.engine_params_for_memoryer = engine_params_for_memoryer
|
||||
self.os_aci: OSACI = os_aci
|
||||
|
||||
self.max_trajectory_length = max_trajectory_length if not self.engine_params_for_orchestrator.get("keep_first_image", False) else max_trajectory_length - 1
|
||||
self.enable_reflection = enable_reflection
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
# set_cell_values only occurs in linux; meanwhile there is no fail option in the other benchmarks
|
||||
if self.platform in ["windows", "macos"]:
|
||||
skipped_actions = ["set_cell_values", "fail"]
|
||||
else:
|
||||
skipped_actions = []
|
||||
|
||||
# Hide code agent action entirely if no env/controller is available
|
||||
if not getattr(self.os_aci, "env", None) or not getattr(
|
||||
getattr(self.os_aci, "env", None), "controller", None
|
||||
):
|
||||
skipped_actions.append("call_code_agent")
|
||||
|
||||
self.orchestrator_sys_prompt = PROCEDURAL_MEMORY.construct_simple_worker_procedural_memory(
|
||||
agent_class=type(self.os_aci),
|
||||
skipped_actions=skipped_actions,
|
||||
tool_config=self.tool_config,
|
||||
platform=self.platform
|
||||
).replace("CURRENT_OS", self.platform).replace("CLIENT_PASSWORD", self.client_password)
|
||||
|
||||
# Worker contains orchestrator and reflection agent
|
||||
self.orchestrator_agent = self._create_agent(
|
||||
engine_params=self.engine_params_for_orchestrator,
|
||||
system_prompt=self.orchestrator_sys_prompt
|
||||
|
||||
)
|
||||
self.memoryer_agent = ReflectionMemoryAgent(self.engine_params_for_memoryer)
|
||||
|
||||
self.instruction = None
|
||||
self.turn_count = 0
|
||||
self.worker_history = []
|
||||
self.coords_history = []
|
||||
|
||||
# For loop detection
|
||||
self.action_dict_history = []
|
||||
|
||||
def flush_messages(self):
|
||||
"""Flush messages based on the model's context limits.
|
||||
|
||||
This method ensures that the agent's message history does not exceed the maximum trajectory length.
|
||||
|
||||
Side Effects:
|
||||
- Modifies the messages of generator, reflection, and bon_judge agents to fit within the context limits.
|
||||
"""
|
||||
engine_type = self.engine_params_for_orchestrator.get("engine_type", "")
|
||||
|
||||
# Flush strategy for long-context models: keep all text, only keep latest images
|
||||
if engine_type in ["anthropic", "openai", "gemini", "vllm"]:
|
||||
max_images = self.max_trajectory_length
|
||||
# for agent in [self.generator_agent, self.reflection_agent]:
|
||||
for agent in [self.orchestrator_agent]:
|
||||
if agent is None:
|
||||
continue
|
||||
# keep latest k images
|
||||
img_count = 0
|
||||
stop_idx = 1 if self.engine_params_for_orchestrator.get("keep_first_image", False) else -1
|
||||
for i in range(len(agent.messages) - 1, stop_idx, -1):
|
||||
# for j in range(len(agent.messages[i]["content"])):
|
||||
for j in range(len(agent.messages[i]["content"]) - 1, -1, -1):
|
||||
if "image" in agent.messages[i]["content"][j].get("type", ""):
|
||||
img_count += 1
|
||||
if img_count > max_images:
|
||||
del agent.messages[i]["content"][j]
|
||||
|
||||
# Flush strategy for non-long-context models: drop full turns
|
||||
else:
|
||||
# generator msgs are alternating [user, assistant], so 2 per round
|
||||
if len(self.orchestrator_agent.messages) > 2 * self.max_trajectory_length + 1:
|
||||
self.orchestrator_agent.messages.pop(1)
|
||||
self.orchestrator_agent.messages.pop(1)
|
||||
|
||||
|
||||
def generate_next_action(self, instruction: str, obs: Dict, is_last_step: bool) -> Tuple[Dict, List]:
|
||||
"""
|
||||
Predict the next action(s) based on the current observation.
|
||||
"""
|
||||
print("=" * 30, f"Turn {self.turn_count + 1}", "=" * 30)
|
||||
|
||||
print("=" * 10)
|
||||
print(instruction)
|
||||
print("=" * 10)
|
||||
|
||||
self.os_aci.assign_screenshot(obs)
|
||||
self.os_aci.set_task_instruction(instruction)
|
||||
|
||||
|
||||
generator_message = (
|
||||
""
|
||||
if self.turn_count > 0
|
||||
else "The initial screen is provided. No action has been taken yet."
|
||||
)
|
||||
|
||||
|
||||
# Load the task into the system prompt
|
||||
if is_last_step:
|
||||
# Eager mode: must decide done / fail
|
||||
prompt_with_instructions = PROCEDURAL_MEMORY.construct_eager_mode_procedural_memory(agent_class=type(self.os_aci)).replace(
|
||||
"TASK_DESCRIPTION", instruction
|
||||
).replace(
|
||||
"CURRENT_OS", self.platform
|
||||
)
|
||||
print(f'Eager Mode Started, Instruction: {prompt_with_instructions}')
|
||||
self.orchestrator_agent.add_system_prompt(prompt_with_instructions)
|
||||
generator_message += "Note: 'EAGER MODE' is enabled. You must determine whether the task is done or fail in this step!!!"
|
||||
else:
|
||||
tutorials = ""
|
||||
for idx, t in enumerate(self.os_aci.tutorials, start=1):
|
||||
tutorials += f"### Tutorial {idx}:\n {t}\n"
|
||||
|
||||
prompt_with_instructions = self.orchestrator_sys_prompt.replace(
|
||||
"TASK_DESCRIPTION", instruction
|
||||
).replace(
|
||||
"TUTORIAL_PLACEHOLDER", tutorials
|
||||
)
|
||||
|
||||
self.orchestrator_agent.add_system_prompt(prompt_with_instructions)
|
||||
|
||||
# print(self.orchestrator_agent.system_prompt)
|
||||
|
||||
### Reflection Part
|
||||
reflection_info = {}
|
||||
if self.enable_reflection:
|
||||
# set instruction to memory agent
|
||||
self.memoryer_agent.add_instruction(instruction)
|
||||
reflection = None
|
||||
# Differentiate the operation mode of last step
|
||||
last_code_summary = ""
|
||||
mode = "gui"
|
||||
if (
|
||||
hasattr(self.os_aci, "last_code_agent_result")
|
||||
and self.os_aci.last_code_agent_result is not None
|
||||
):
|
||||
# If code agent is called last step, we use its execution result as step behavior.
|
||||
code_result = self.os_aci.last_code_agent_result
|
||||
mode = "code"
|
||||
last_code_summary += f"Subtask Instruction: {code_result['task_instruction']}\nSteps Completed: {code_result['steps_executed']}\nCompletion Reason: {code_result['completion_reason']}\nExec Summary: {code_result['summary']}\n"
|
||||
|
||||
if (
|
||||
hasattr(self.os_aci, "last_search_agent_result")
|
||||
and self.os_aci.last_search_agent_result is not None
|
||||
):
|
||||
mode = "search"
|
||||
# retrieve reflection!!!
|
||||
reflection_info = self.memoryer_agent.get_reflection(
|
||||
cur_obs=obs,
|
||||
# only use the string after "(next action)" in orchestrator's output
|
||||
generator_output=parse_action_from_string(self.worker_history[-1]) if self.turn_count != 0 else "",
|
||||
coordinates=self.coords_history[-1] if self.turn_count != 0 else [],
|
||||
mode=mode,
|
||||
code_exec_summary=last_code_summary,
|
||||
action_dict=self.action_dict_history[-1] if self.turn_count != 0 else {}
|
||||
)
|
||||
reflection = reflection_info['reflection']
|
||||
logger.info(f'[Reflection]: {reflection}')
|
||||
if reflection:
|
||||
generator_message += f"REFLECTION: You MUST use this reflection on the latest action:\n{reflection}\n"
|
||||
else:
|
||||
generator_message += "You should go on with your plan.\n"
|
||||
else:
|
||||
generator_message += "You should go on with your plan.\n"
|
||||
|
||||
|
||||
# Add code agent result from previous step if available (from full task or subtask execution)
|
||||
if (
|
||||
hasattr(self.os_aci, "last_code_agent_result")
|
||||
and self.os_aci.last_code_agent_result is not None
|
||||
):
|
||||
code_result = self.os_aci.last_code_agent_result
|
||||
generator_message += f"\nCODE AGENT RESULT:\n"
|
||||
generator_message += (
|
||||
f"Task/Subtask Instruction: {code_result['task_instruction']}\n"
|
||||
)
|
||||
generator_message += f"Steps Completed: {code_result['steps_executed']}\n"
|
||||
generator_message += f"Max Steps: {code_result['budget']}\n"
|
||||
generator_message += (
|
||||
f"Completion Reason: {code_result['completion_reason']}\n"
|
||||
)
|
||||
generator_message += f"Summary: {code_result['summary']}\n"
|
||||
generator_message += "\n"
|
||||
# Reset the code agent result after adding it to context
|
||||
self.os_aci.last_code_agent_result = None
|
||||
|
||||
if (
|
||||
hasattr(self.os_aci, "last_search_agent_result")
|
||||
and self.os_aci.last_search_agent_result is not None
|
||||
):
|
||||
# Retrieve the result dictionary
|
||||
search_result = self.os_aci.last_search_agent_result
|
||||
|
||||
# Add a clear, distinct header for this section in the prompt
|
||||
generator_message += f"\nSEARCH AGENT RESULT:\n"
|
||||
|
||||
# Add contextual metadata from the search task
|
||||
generator_message += f"Search Query: {search_result['query']}\n"
|
||||
generator_message += f"Search Completion Reason: {search_result['completion_reason']}\n"
|
||||
generator_message += "Search Result: "
|
||||
# Add the most important part: the tutorial found by the agent.
|
||||
# This is given a prominent sub-header so the LLM knows to pay close attention.
|
||||
if search_result["completion_reason"] == "DONE":
|
||||
generator_message += f'Search is completed, the tutorial it found has been already added to your system prompt.\n'
|
||||
elif search_result["completion_reason"] == "FAIL":
|
||||
generator_message += f"Search is fail, the failure reason or the hint is as follow: {search_result['final_answer']}\n"
|
||||
|
||||
|
||||
# CRITICAL: Reset the search agent result after adding it to the context.
|
||||
# This prevents it from being added to the prompt again in the next turn.
|
||||
self.os_aci.last_search_agent_result = None
|
||||
|
||||
|
||||
# Finalize the generator message
|
||||
self.orchestrator_agent.add_message(
|
||||
generator_message, image_content=obs["screenshot"], role="user", put_text_last=True
|
||||
)
|
||||
|
||||
# Generate the plan and next action
|
||||
format_checkers = [
|
||||
SINGLE_ACTION_FORMATTER,
|
||||
partial(CODE_VALID_FORMATTER, self.tool_config),
|
||||
]
|
||||
plan = call_llm_formatted(
|
||||
self.orchestrator_agent,
|
||||
format_checkers,
|
||||
temperature=self.engine_params_for_orchestrator.get("temperture", 0.1),
|
||||
use_thinking=self.use_thinking,
|
||||
)
|
||||
self.worker_history.append(plan)
|
||||
self.orchestrator_agent.add_message(plan, role="assistant")
|
||||
logger.info("PLAN:\n %s", plan)
|
||||
|
||||
# Extract the next action from the plan
|
||||
# 此时的plan code e.g. agent.click('xxxxx', 1)
|
||||
plan_code = parse_code_from_string(plan)
|
||||
action_dict, coordinates = None, None
|
||||
try:
|
||||
assert plan_code, "Plan code should not be empty"
|
||||
# exec_code e.g. import pyautogui; pyautogui.click(1, 2);
|
||||
exec_code, action_dict = create_pyautogui_code(self.os_aci, plan_code, obs)
|
||||
coordinates = extract_coords_from_action_dict(action_dict)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Could not evaluate the following plan code:\n{plan_code}\nError: {e}"
|
||||
)
|
||||
exec_code, action_dict = self.os_aci.wait(
|
||||
1.333
|
||||
) # Skip a turn if the code cannot be evaluated
|
||||
|
||||
self.action_dict_history.append(action_dict)
|
||||
|
||||
executor_info = {
|
||||
"refined_instruction": self.instruction,
|
||||
"plan": plan,
|
||||
"plan_code": plan_code,
|
||||
"exec_code": exec_code,
|
||||
"coordinates": coordinates,
|
||||
"reflection": reflection_info,
|
||||
"code_agent_output": (
|
||||
self.os_aci.last_code_agent_result
|
||||
if hasattr(self.os_aci, "last_code_agent_result")
|
||||
and self.os_aci.last_code_agent_result is not None
|
||||
else None
|
||||
),
|
||||
"search_agent_output": (
|
||||
self.os_aci.last_search_agent_result
|
||||
if hasattr(self.os_aci, "last_search_agent_result")
|
||||
and self.os_aci.last_search_agent_result is not None
|
||||
else None
|
||||
)
|
||||
}
|
||||
self.turn_count += 1
|
||||
self.coords_history.append(coordinates)
|
||||
self.flush_messages()
|
||||
return executor_info, [exec_code]
|
||||
|
|
@ -0,0 +1,480 @@
|
|||
import copy
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import base64
|
||||
import backoff
|
||||
from anthropic import Anthropic
|
||||
from openai import (
|
||||
AzureOpenAI,
|
||||
APIConnectionError,
|
||||
APIError,
|
||||
AzureOpenAI,
|
||||
OpenAI,
|
||||
RateLimitError,
|
||||
)
|
||||
logger = logging.getLogger("desktopenv.agents.engine")
|
||||
|
||||
logger = logging.getLogger("desktopenv.agents.engine")
|
||||
|
||||
|
||||
class LMMEngine:
|
||||
pass
|
||||
|
||||
|
||||
class LMMEngineOpenAI(LMMEngine):
|
||||
def __init__(
|
||||
self,
|
||||
base_url=None,
|
||||
api_key=None,
|
||||
model=None,
|
||||
rate_limit=-1,
|
||||
temperature=None,
|
||||
organization=None,
|
||||
**kwargs,
|
||||
):
|
||||
assert model is not None, "model must be provided"
|
||||
self.model = model
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key
|
||||
self.organization = organization
|
||||
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
|
||||
self.llm_client = None
|
||||
self.temperature = temperature # Can force temperature to be the same (in the case of o3 requiring temperature to be 1)
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
||||
)
|
||||
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
|
||||
api_key = self.api_key or os.getenv("OPENAI_API_KEY")
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"An API Key needs to be provided in either the api_key parameter or as an environment variable named OPENAI_API_KEY"
|
||||
)
|
||||
organization = self.organization or os.getenv("OPENAI_ORG_ID")
|
||||
|
||||
# H 集群认证 最后再删!!!!!!
|
||||
if self.model.lower().startswith("ui") or self.model.lower().startswith("qwen") or self.model.lower().startswith("scale") or self.model.lower().startswith("holo"):
|
||||
custom_headers = {
|
||||
"Authorization": "Basic NWFkMzQxMDBlZTA1NWE0YmFlNjYzNzBhNWU2ODNiYWM6NjA3ZGU4MjQ5NjU3YTNiM2JkMDM2ZGM5NmQ0YzBiMmY="
|
||||
}
|
||||
else:
|
||||
custom_headers = {}
|
||||
if not self.llm_client:
|
||||
if not self.base_url:
|
||||
self.llm_client = OpenAI(
|
||||
api_key=api_key,
|
||||
organization=organization,
|
||||
default_headers=custom_headers
|
||||
)
|
||||
else:
|
||||
self.llm_client = OpenAI(
|
||||
base_url=self.base_url,
|
||||
api_key=api_key,
|
||||
organization=organization,
|
||||
default_headers=custom_headers
|
||||
)
|
||||
|
||||
# print(**kwargs)
|
||||
payload_size = len(json.dumps(messages)) / 1024 / 1024
|
||||
logger.info(f"Payload size: {len(json.dumps(messages)) / 1024 / 1024:.2f} MB")
|
||||
if payload_size > 30:
|
||||
logger.info("Payload size exceeds 30MB!!!")
|
||||
|
||||
result = self.llm_client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
# max_completion_tokens=max_new_tokens if max_new_tokens else 4096,
|
||||
temperature=(
|
||||
temperature if self.temperature is None else self.temperature
|
||||
),
|
||||
**kwargs,
|
||||
)
|
||||
usage = result.usage
|
||||
response = result.choices[0].message.content
|
||||
return (response, usage)
|
||||
|
||||
|
||||
class LMMEngineAnthropic(LMMEngine):
|
||||
def __init__(
|
||||
self,
|
||||
base_url=None,
|
||||
api_key=None,
|
||||
model=None,
|
||||
thinking=False,
|
||||
temperature=None,
|
||||
**kwargs,
|
||||
):
|
||||
assert model is not None, "model must be provided"
|
||||
self.model = model
|
||||
self.thinking = thinking
|
||||
self.api_key = api_key
|
||||
self.llm_client = None
|
||||
self.temperature = temperature
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
||||
)
|
||||
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
|
||||
api_key = self.api_key or os.getenv("ANTHROPIC_API_KEY")
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"An API Key needs to be provided in either the api_key parameter or as an environment variable named ANTHROPIC_API_KEY"
|
||||
)
|
||||
self.llm_client = Anthropic(api_key=api_key)
|
||||
# Use the instance temperature if not specified in the call
|
||||
temp = self.temperature if temperature is None else temperature
|
||||
if self.thinking:
|
||||
full_response = self.llm_client.messages.create(
|
||||
system=messages[0]["content"][0]["text"],
|
||||
model=self.model,
|
||||
messages=messages[1:],
|
||||
max_tokens=8192,
|
||||
thinking={"type": "enabled", "budget_tokens": 4096},
|
||||
**kwargs,
|
||||
)
|
||||
thoughts = full_response.content[0].thinking
|
||||
return full_response.content[1].text
|
||||
return (
|
||||
self.llm_client.messages.create(
|
||||
system=messages[0]["content"][0]["text"],
|
||||
model=self.model,
|
||||
messages=messages[1:],
|
||||
max_tokens=max_new_tokens if max_new_tokens else 4096,
|
||||
temperature=temp,
|
||||
**kwargs,
|
||||
)
|
||||
.content[0]
|
||||
.text
|
||||
)
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
||||
)
|
||||
# Compatible with Claude-3.7 Sonnet thinking mode
|
||||
def generate_with_thinking(
|
||||
self, messages, temperature=0.0, max_new_tokens=None, **kwargs
|
||||
):
|
||||
"""Generate the next message based on previous messages, and keeps the thinking tokens"""
|
||||
api_key = self.api_key or os.getenv("ANTHROPIC_API_KEY")
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"An API Key needs to be provided in either the api_key parameter or as an environment variable named ANTHROPIC_API_KEY"
|
||||
)
|
||||
self.llm_client = Anthropic(api_key=api_key)
|
||||
full_response = self.llm_client.messages.create(
|
||||
system=messages[0]["content"][0]["text"],
|
||||
model=self.model,
|
||||
messages=messages[1:],
|
||||
max_tokens=8192,
|
||||
thinking={"type": "enabled", "budget_tokens": 4096},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
thoughts = full_response.content[0].thinking
|
||||
answer = full_response.content[1].text
|
||||
full_response = (
|
||||
f"<thoughts>\n{thoughts}\n</thoughts>\n\n<answer>\n{answer}\n</answer>\n"
|
||||
)
|
||||
return full_response
|
||||
|
||||
|
||||
class LMMEngineGemini(LMMEngine):
|
||||
def __init__(
|
||||
self,
|
||||
base_url=None,
|
||||
api_key=None,
|
||||
model=None,
|
||||
rate_limit=-1,
|
||||
temperature=None,
|
||||
**kwargs,
|
||||
):
|
||||
assert model is not None, "model must be provided"
|
||||
self.model = model
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key
|
||||
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
|
||||
self.llm_client = None
|
||||
self.temperature = temperature
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
||||
)
|
||||
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
|
||||
api_key = self.api_key or os.getenv("GEMINI_API_KEY")
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"An API Key needs to be provided in either the api_key parameter or as an environment variable named GEMINI_API_KEY"
|
||||
)
|
||||
base_url = self.base_url or os.getenv("GEMINI_ENDPOINT_URL")
|
||||
if base_url is None:
|
||||
raise ValueError(
|
||||
"An endpoint URL needs to be provided in either the endpoint_url parameter or as an environment variable named GEMINI_ENDPOINT_URL"
|
||||
)
|
||||
if not self.llm_client:
|
||||
self.llm_client = OpenAI(base_url=base_url, api_key=api_key)
|
||||
# Use the temperature passed to generate, otherwise use the instance's temperature, otherwise default to 0.0
|
||||
temp = self.temperature if temperature is None else temperature
|
||||
return (
|
||||
self.llm_client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
max_tokens=max_new_tokens if max_new_tokens else 4096,
|
||||
temperature=temp,
|
||||
**kwargs,
|
||||
)
|
||||
.choices[0]
|
||||
.message.content
|
||||
)
|
||||
|
||||
|
||||
class LMMEngineOpenRouter(LMMEngine):
|
||||
def __init__(
|
||||
self,
|
||||
base_url=None,
|
||||
api_key=None,
|
||||
model=None,
|
||||
rate_limit=-1,
|
||||
temperature=None,
|
||||
**kwargs,
|
||||
):
|
||||
assert model is not None, "model must be provided"
|
||||
self.model = model
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key
|
||||
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
|
||||
self.llm_client = None
|
||||
self.temperature = temperature
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
||||
)
|
||||
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
|
||||
api_key = self.api_key or os.getenv("OPENROUTER_API_KEY")
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"An API Key needs to be provided in either the api_key parameter or as an environment variable named OPENROUTER_API_KEY"
|
||||
)
|
||||
base_url = self.base_url or os.getenv("OPEN_ROUTER_ENDPOINT_URL")
|
||||
if base_url is None:
|
||||
raise ValueError(
|
||||
"An endpoint URL needs to be provided in either the endpoint_url parameter or as an environment variable named OPEN_ROUTER_ENDPOINT_URL"
|
||||
)
|
||||
if not self.llm_client:
|
||||
self.llm_client = OpenAI(base_url=base_url, api_key=api_key)
|
||||
# Use self.temperature if set, otherwise use the temperature argument
|
||||
temp = self.temperature if self.temperature is not None else temperature
|
||||
return (
|
||||
self.llm_client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
max_tokens=max_new_tokens if max_new_tokens else 4096,
|
||||
temperature=temp,
|
||||
**kwargs,
|
||||
)
|
||||
.choices[0]
|
||||
.message.content
|
||||
)
|
||||
|
||||
|
||||
class LMMEngineAzureOpenAI(LMMEngine):
|
||||
def __init__(
|
||||
self,
|
||||
base_url=None,
|
||||
api_key=None,
|
||||
azure_endpoint=None,
|
||||
model=None,
|
||||
api_version=None,
|
||||
rate_limit=-1,
|
||||
temperature=None,
|
||||
**kwargs,
|
||||
):
|
||||
assert model is not None, "model must be provided"
|
||||
self.model = model
|
||||
self.api_version = api_version
|
||||
self.api_key = api_key
|
||||
self.azure_endpoint = azure_endpoint
|
||||
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
|
||||
self.llm_client = None
|
||||
self.cost = 0.0
|
||||
self.temperature = temperature
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
||||
)
|
||||
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
|
||||
api_key = self.api_key or os.getenv("AZURE_OPENAI_API_KEY")
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"An API Key needs to be provided in either the api_key parameter or as an environment variable named AZURE_OPENAI_API_KEY"
|
||||
)
|
||||
api_version = self.api_version or os.getenv("OPENAI_API_VERSION")
|
||||
if api_version is None:
|
||||
raise ValueError(
|
||||
"api_version must be provided either as a parameter or as an environment variable named OPENAI_API_VERSION"
|
||||
)
|
||||
azure_endpoint = self.azure_endpoint or os.getenv("AZURE_OPENAI_ENDPOINT")
|
||||
if azure_endpoint is None:
|
||||
raise ValueError(
|
||||
"An Azure API endpoint needs to be provided in either the azure_endpoint parameter or as an environment variable named AZURE_OPENAI_ENDPOINT"
|
||||
)
|
||||
if not self.llm_client:
|
||||
self.llm_client = AzureOpenAI(
|
||||
azure_endpoint=azure_endpoint,
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
)
|
||||
# Use self.temperature if set, otherwise use the temperature argument
|
||||
temp = self.temperature if self.temperature is not None else temperature
|
||||
completion = self.llm_client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
max_tokens=max_new_tokens if max_new_tokens else 4096,
|
||||
temperature=temp,
|
||||
**kwargs,
|
||||
)
|
||||
total_tokens = completion.usage.total_tokens
|
||||
self.cost += 0.02 * ((total_tokens + 500) / 1000)
|
||||
return completion.choices[0].message.content
|
||||
|
||||
|
||||
class LMMEnginevLLM(LMMEngine):
|
||||
def __init__(
|
||||
self,
|
||||
base_url=None,
|
||||
api_key=None,
|
||||
model=None,
|
||||
rate_limit=-1,
|
||||
temperature=None,
|
||||
**kwargs,
|
||||
):
|
||||
assert model is not None, "model must be provided"
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
|
||||
self.llm_client = None
|
||||
self.temperature = temperature
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
||||
)
|
||||
def generate(
|
||||
self,
|
||||
messages,
|
||||
temperature=0.0,
|
||||
top_p=0.8,
|
||||
repetition_penalty=1.05,
|
||||
max_new_tokens=4096,
|
||||
**kwargs,
|
||||
):
|
||||
api_key = self.api_key or os.getenv("vLLM_API_KEY")
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"A vLLM API key needs to be provided in either the api_key parameter or as an environment variable named vLLM_API_KEY"
|
||||
)
|
||||
base_url = self.base_url or os.getenv("vLLM_ENDPOINT_URL")
|
||||
if base_url is None:
|
||||
raise ValueError(
|
||||
"An endpoint URL needs to be provided in either the endpoint_url parameter or as an environment variable named vLLM_ENDPOINT_URL"
|
||||
)
|
||||
if not self.llm_client:
|
||||
USERNAME = "5ad34100ee055a4bae66370a5e683bac"
|
||||
PASSWORD = "607de8249657a3b3bd036dc96d4c0b2f"
|
||||
auth_string = f"{USERNAME}:{PASSWORD}".encode("utf-8")
|
||||
basic_auth_encoded = base64.b64encode(auth_string).decode("utf-8")
|
||||
basic_auth_header = f"Basic {basic_auth_encoded}"
|
||||
self.llm_client = OpenAI(base_url=base_url, api_key=api_key, default_headers={"Authorization": basic_auth_header},)
|
||||
# Use self.temperature if set, otherwise use the temperature argument
|
||||
temp = self.temperature if self.temperature is not None else temperature
|
||||
completion = self.llm_client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
max_tokens=max_new_tokens if max_new_tokens else 4096,
|
||||
temperature=temp,
|
||||
top_p=top_p,
|
||||
extra_body={"repetition_penalty": repetition_penalty},
|
||||
)
|
||||
|
||||
usage = completion.usage
|
||||
response = completion.choices[0].message.content
|
||||
return (response, usage)
|
||||
|
||||
|
||||
class LMMEngineHuggingFace(LMMEngine):
|
||||
def __init__(self, base_url=None, api_key=None, rate_limit=-1, **kwargs):
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key
|
||||
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
|
||||
self.llm_client = None
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
||||
)
|
||||
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
|
||||
api_key = self.api_key or os.getenv("HF_TOKEN")
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"A HuggingFace token needs to be provided in either the api_key parameter or as an environment variable named HF_TOKEN"
|
||||
)
|
||||
base_url = self.base_url or os.getenv("HF_ENDPOINT_URL")
|
||||
if base_url is None:
|
||||
raise ValueError(
|
||||
"HuggingFace endpoint must be provided as base_url parameter or as an environment variable named HF_ENDPOINT_URL."
|
||||
)
|
||||
if not self.llm_client:
|
||||
self.llm_client = OpenAI(base_url=base_url, api_key=api_key)
|
||||
return (
|
||||
self.llm_client.chat.completions.create(
|
||||
model="tgi",
|
||||
messages=messages,
|
||||
max_tokens=max_new_tokens if max_new_tokens else 4096,
|
||||
temperature=temperature,
|
||||
**kwargs,
|
||||
)
|
||||
.choices[0]
|
||||
.message.content
|
||||
)
|
||||
|
||||
|
||||
class LMMEngineParasail(LMMEngine):
|
||||
def __init__(
|
||||
self, base_url=None, api_key=None, model=None, rate_limit=-1, **kwargs
|
||||
):
|
||||
assert model is not None, "Parasail model id must be provided"
|
||||
self.base_url = base_url
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
|
||||
self.llm_client = None
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
||||
)
|
||||
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
|
||||
api_key = self.api_key or os.getenv("PARASAIL_API_KEY")
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"A Parasail API key needs to be provided in either the api_key parameter or as an environment variable named PARASAIL_API_KEY"
|
||||
)
|
||||
base_url = self.base_url
|
||||
if base_url is None:
|
||||
raise ValueError(
|
||||
"Parasail endpoint must be provided as base_url parameter or as an environment variable named PARASAIL_ENDPOINT_URL"
|
||||
)
|
||||
if not self.llm_client:
|
||||
self.llm_client = OpenAI(
|
||||
base_url=base_url if base_url else "https://api.parasail.io/v1",
|
||||
api_key=api_key,
|
||||
)
|
||||
return (
|
||||
self.llm_client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
max_tokens=max_new_tokens if max_new_tokens else 4096,
|
||||
temperature=temperature,
|
||||
**kwargs,
|
||||
)
|
||||
.choices[0]
|
||||
.message.content
|
||||
)
|
||||
|
|
@ -0,0 +1,308 @@
|
|||
import base64
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mm_agents.os_symphony.core.engine import (
|
||||
LMMEngineAnthropic,
|
||||
LMMEngineAzureOpenAI,
|
||||
LMMEngineHuggingFace,
|
||||
LMMEngineOpenAI,
|
||||
LMMEngineOpenRouter,
|
||||
LMMEngineParasail,
|
||||
LMMEnginevLLM,
|
||||
LMMEngineGemini,
|
||||
)
|
||||
|
||||
|
||||
class LMMAgent:
|
||||
def __init__(self, engine_params: dict, system_prompt=None, engine=None):
|
||||
if engine is None:
|
||||
if engine_params is not None:
|
||||
engine_type = engine_params.get("engine_type")
|
||||
if engine_type == "openai":
|
||||
self.engine = LMMEngineOpenAI(**engine_params)
|
||||
elif engine_type == "anthropic":
|
||||
self.engine = LMMEngineAnthropic(**engine_params)
|
||||
elif engine_type == "azure":
|
||||
self.engine = LMMEngineAzureOpenAI(**engine_params)
|
||||
elif engine_type == "vllm":
|
||||
self.engine = LMMEnginevLLM(**engine_params)
|
||||
elif engine_type == "huggingface":
|
||||
self.engine = LMMEngineHuggingFace(**engine_params)
|
||||
elif engine_type == "gemini":
|
||||
self.engine = LMMEngineGemini(**engine_params)
|
||||
elif engine_type == "open_router":
|
||||
self.engine = LMMEngineOpenRouter(**engine_params)
|
||||
elif engine_type == "parasail":
|
||||
self.engine = LMMEngineParasail(**engine_params)
|
||||
else:
|
||||
raise ValueError(f"engine_type '{engine_type}' is not supported")
|
||||
else:
|
||||
raise ValueError("engine_params must be provided")
|
||||
else:
|
||||
self.engine = engine
|
||||
|
||||
self.messages = [] # Empty messages
|
||||
self.agent_name = engine_params.get("agent_name")
|
||||
if system_prompt:
|
||||
self.add_system_prompt(system_prompt)
|
||||
else:
|
||||
self.add_system_prompt("You are a helpful assistant.")
|
||||
|
||||
def encode_image(self, image_content):
|
||||
# if image_content is a path to an image file, check type of the image_content to verify
|
||||
if isinstance(image_content, str):
|
||||
with open(image_content, "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||
else:
|
||||
return base64.b64encode(image_content).decode("utf-8")
|
||||
|
||||
def reset(
|
||||
self,
|
||||
):
|
||||
|
||||
self.messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": self.system_prompt}],
|
||||
}
|
||||
]
|
||||
|
||||
def add_system_prompt(self, system_prompt):
|
||||
self.system_prompt = system_prompt
|
||||
if len(self.messages) > 0:
|
||||
self.messages[0] = {
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": self.system_prompt}],
|
||||
}
|
||||
else:
|
||||
self.messages.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": self.system_prompt}],
|
||||
}
|
||||
)
|
||||
|
||||
def remove_message_at(self, index):
|
||||
"""Remove a message at a given index"""
|
||||
if index < len(self.messages):
|
||||
self.messages.pop(index)
|
||||
|
||||
def replace_message_at(
|
||||
self, index, text_content, image_content=None, image_detail="high"
|
||||
):
|
||||
"""Replace a message at a given index"""
|
||||
if index < len(self.messages):
|
||||
self.messages[index] = {
|
||||
"role": self.messages[index]["role"],
|
||||
"content": [{"type": "text", "text": text_content}],
|
||||
}
|
||||
if image_content:
|
||||
base64_image = self.encode_image(image_content)
|
||||
self.messages[index]["content"].append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{base64_image}",
|
||||
"detail": image_detail,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
def add_message(
|
||||
self,
|
||||
text_content,
|
||||
image_content=None,
|
||||
role=None,
|
||||
image_detail="high",
|
||||
put_text_last=True,
|
||||
):
|
||||
"""Add a new message to the list of messages"""
|
||||
|
||||
# API-style inference from OpenAI and AzureOpenAI
|
||||
if isinstance(
|
||||
self.engine,
|
||||
(
|
||||
LMMEngineOpenAI,
|
||||
LMMEngineAzureOpenAI,
|
||||
LMMEngineHuggingFace,
|
||||
LMMEngineGemini,
|
||||
LMMEngineOpenRouter,
|
||||
LMMEngineParasail,
|
||||
),
|
||||
):
|
||||
# infer role from previous message
|
||||
if role != "user":
|
||||
if self.messages[-1]["role"] == "system":
|
||||
role = "user"
|
||||
elif self.messages[-1]["role"] == "user":
|
||||
role = "assistant"
|
||||
elif self.messages[-1]["role"] == "assistant":
|
||||
role = "user"
|
||||
|
||||
message = {
|
||||
"role": role,
|
||||
"content": [{"type": "text", "text": text_content}],
|
||||
}
|
||||
|
||||
if isinstance(image_content, np.ndarray) or image_content:
|
||||
# Check if image_content is a list or a single image
|
||||
if isinstance(image_content, list):
|
||||
# If image_content is a list of images, loop through each image
|
||||
for image in image_content:
|
||||
base64_image = self.encode_image(image)
|
||||
message["content"].append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{base64_image}",
|
||||
"detail": image_detail,
|
||||
},
|
||||
}
|
||||
)
|
||||
else:
|
||||
# If image_content is a single image, handle it directly
|
||||
base64_image = self.encode_image(image_content)
|
||||
message["content"].append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{base64_image}",
|
||||
"detail": image_detail,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Rotate text to be the last message if desired
|
||||
if put_text_last:
|
||||
text_content = message["content"].pop(0)
|
||||
message["content"].append(text_content)
|
||||
|
||||
self.messages.append(message)
|
||||
|
||||
# For API-style inference from Anthropic
|
||||
elif isinstance(self.engine, LMMEngineAnthropic):
|
||||
# infer role from previous message
|
||||
if role != "user":
|
||||
if self.messages[-1]["role"] == "system":
|
||||
role = "user"
|
||||
elif self.messages[-1]["role"] == "user":
|
||||
role = "assistant"
|
||||
elif self.messages[-1]["role"] == "assistant":
|
||||
role = "user"
|
||||
|
||||
message = {
|
||||
"role": role,
|
||||
"content": [{"type": "text", "text": text_content}],
|
||||
}
|
||||
|
||||
if image_content:
|
||||
# Check if image_content is a list or a single image
|
||||
if isinstance(image_content, list):
|
||||
# If image_content is a list of images, loop through each image
|
||||
for image in image_content:
|
||||
base64_image = self.encode_image(image)
|
||||
message["content"].append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": base64_image,
|
||||
},
|
||||
}
|
||||
)
|
||||
else:
|
||||
# If image_content is a single image, handle it directly
|
||||
base64_image = self.encode_image(image_content)
|
||||
message["content"].append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": base64_image,
|
||||
},
|
||||
}
|
||||
)
|
||||
self.messages.append(message)
|
||||
|
||||
# Locally hosted vLLM model inference
|
||||
elif isinstance(self.engine, LMMEnginevLLM):
|
||||
# infer role from previous message
|
||||
if role != "user":
|
||||
if self.messages[-1]["role"] == "system":
|
||||
role = "user"
|
||||
elif self.messages[-1]["role"] == "user":
|
||||
role = "assistant"
|
||||
elif self.messages[-1]["role"] == "assistant":
|
||||
role = "user"
|
||||
|
||||
message = {
|
||||
"role": role,
|
||||
"content": [{"type": "text", "text": text_content}],
|
||||
}
|
||||
|
||||
if image_content:
|
||||
# Check if image_content is a list or a single image
|
||||
if isinstance(image_content, list):
|
||||
# If image_content is a list of images, loop through each image
|
||||
for image in image_content:
|
||||
base64_image = self.encode_image(image)
|
||||
message["content"].append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image;base64,{base64_image}"
|
||||
},
|
||||
}
|
||||
)
|
||||
else:
|
||||
# If image_content is a single image, handle it directly
|
||||
base64_image = self.encode_image(image_content)
|
||||
message["content"].append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image;base64,{base64_image}"},
|
||||
}
|
||||
)
|
||||
|
||||
if put_text_last:
|
||||
text_content = message["content"].pop(0)
|
||||
message["content"].append(text_content)
|
||||
self.messages.append(message)
|
||||
else:
|
||||
raise ValueError("engine_type is not supported")
|
||||
|
||||
def get_response(
|
||||
self,
|
||||
user_message=None,
|
||||
messages=None,
|
||||
temperature=0.0,
|
||||
max_new_tokens=32168,
|
||||
use_thinking=False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Generate the next response based on previous messages"""
|
||||
if messages is None:
|
||||
messages = self.messages
|
||||
if user_message:
|
||||
messages.append(
|
||||
{"role": "user", "content": [{"type": "text", "text": user_message}]}
|
||||
)
|
||||
|
||||
# Regular generation
|
||||
# if use_thinking:
|
||||
# return self.engine.generate_with_thinking(
|
||||
# messages,
|
||||
# temperature=temperature,
|
||||
# max_new_tokens=max_new_tokens,
|
||||
# **kwargs,
|
||||
# )
|
||||
|
||||
return self.engine.generate(
|
||||
messages,
|
||||
temperature=temperature,
|
||||
max_new_tokens=max_new_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
from typing import Dict, Optional
|
||||
from mm_agents.os_symphony.core.mllm import LMMAgent
|
||||
|
||||
|
||||
class BaseModule:
|
||||
def __init__(self, engine_params: Dict = None, platform: str = "Linux"):
|
||||
self.engine_params = engine_params
|
||||
self.platform = platform
|
||||
|
||||
def _create_agent(
|
||||
self, system_prompt: str = None, engine_params: Optional[Dict] = None
|
||||
) -> LMMAgent:
|
||||
"""Create a new LMMAgent instance"""
|
||||
agent = LMMAgent(engine_params or self.engine_params)
|
||||
if system_prompt:
|
||||
agent.add_system_prompt(system_prompt)
|
||||
return agent
|
||||
|
|
@ -0,0 +1,995 @@
|
|||
import inspect
|
||||
import textwrap
|
||||
import yaml
|
||||
|
||||
class PROCEDURAL_MEMORY:
|
||||
|
||||
FORMATTING_FEEDBACK_PROMPT = textwrap.dedent(
|
||||
"""
|
||||
Your previous response was not formatted correctly. You must respond again to replace your previous response. Do not make reference to this message while fixing the response. Please address the following issues below to improve the previous response:
|
||||
FORMATTING_FEEDBACK
|
||||
"""
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def construct_eager_mode_procedural_memory(
|
||||
agent_class
|
||||
):
|
||||
|
||||
procedural_memory = textwrap.dedent(
|
||||
f"""
|
||||
You are an expert in graphical user interfaces. Your budget for this task is now EXHAUSTED.
|
||||
This is your FINAL opportunity to act. You must make a definitive judgment.
|
||||
|
||||
You are responsible for executing the task: `TASK_DESCRIPTION`.
|
||||
You are working in CURRENT_OS.
|
||||
|
||||
|
||||
# GUIDELINES
|
||||
|
||||
## Final Judgment Mode
|
||||
1. **Analyze the final state**: Carefully examine the current screenshot and your action history.
|
||||
2. **Make a decision**: Determine if the task has been successfully and fully completed.
|
||||
3. **Choose one of two actions**: You can ONLY use `agent.done()` or `agent.fail()`. No other actions are permitted.
|
||||
|
||||
### END OF GUIDELINES
|
||||
|
||||
You are provided with:
|
||||
1. The final screenshot of the UI.
|
||||
2. The complete history of your previous interactions.
|
||||
3. Access to ONLY the following two methods for your final decision:
|
||||
class Agent:
|
||||
"""
|
||||
)
|
||||
|
||||
eager_tools = ["done", "fail"]
|
||||
for tool_name in eager_tools:
|
||||
attr = getattr(agent_class, tool_name, None)
|
||||
|
||||
if not (attr and callable(attr) and hasattr(attr, "is_agent_action")):
|
||||
raise AttributeError(f"Eager mode requires the method '{tool_name}' to be defined in '{agent_class.__name__}' and decorated with @agent_action.")
|
||||
|
||||
signature = inspect.signature(attr)
|
||||
procedural_memory += textwrap.dedent(f"""
|
||||
def {tool_name}{signature}:
|
||||
'''{attr.__doc__}'''
|
||||
""")
|
||||
|
||||
|
||||
procedural_memory += textwrap.dedent(
|
||||
"""
|
||||
Your response must be formatted like this:
|
||||
|
||||
(Final State Analysis)
|
||||
Closely examine the screenshot and your history. Describe whether the final state of the UI confirms that the task `TASK_DESCRIPTION` is complete. Provide your reasoning.
|
||||
|
||||
(Final Judgment)
|
||||
State your final decision in natural language. For example: "The task is complete because the file has been saved and closed." or "The task has failed because the required text is not present."
|
||||
|
||||
(Grounded Action)
|
||||
Translate your final judgment into ONE of the two available commands.
|
||||
|
||||
**CRITICAL**: You MUST choose one of the following two actions. No other actions are allowed.
|
||||
- If the task is fully completed, use `agent.done()`.
|
||||
- If the task is not completed or has failed, use `agent.fail()`.
|
||||
|
||||
Example for success:
|
||||
```python
|
||||
agent.done()
|
||||
```
|
||||
|
||||
Example for failure:
|
||||
```python
|
||||
agent.fail()
|
||||
```
|
||||
"""
|
||||
)
|
||||
|
||||
return procedural_memory.strip()
|
||||
|
||||
@staticmethod
|
||||
def construct_simple_worker_procedural_memory(
|
||||
agent_class,
|
||||
skipped_actions,
|
||||
tool_config,
|
||||
platform = "linux"
|
||||
):
|
||||
|
||||
procedural_memory = textwrap.dedent(
|
||||
f"""\
|
||||
You are an expert in graphical user interfaces, web search and Python code. You are responsible for executing the task using the provided actions.
|
||||
The TASK DESCRIPTION: `TASK_DESCRIPTION`.
|
||||
The OS you are working in: CURRENT_OS.
|
||||
# 1. **AGENT WORKFLOW & TOOLS**
|
||||
You have most three tool agents: GUI, Code and Search. You must choose the correct one for the job. You also have a reflection agent to provide useful feedback at each step, please follow its feedback and adjust your plan.
|
||||
|
||||
---
|
||||
"""
|
||||
)
|
||||
|
||||
# Load tool yaml config
|
||||
try:
|
||||
with open(tool_config, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
except Exception as e:
|
||||
raise Exception(f"Tool config isn't loaded successfully, error: {e}")
|
||||
|
||||
# has_code_agent = "call_code_agent" in config.get("tools", {}).keys()
|
||||
# if has_code_agent:
|
||||
has_search_agent = "call_search_agent" in config.get("tools", {}).keys() and config["tools"]["call_search_agent"].get("enabled", False)
|
||||
has_code_agent = "call_code_agent" in config.get("tools", {}).keys() and config["tools"]["call_code_agent"].get("enabled", False)
|
||||
|
||||
gui_section = textwrap.dedent(
|
||||
f"""
|
||||
## 1.1 GUI Agent
|
||||
* **Use for**: All direct UI interactions (clicking, typing, dragging). Use this for simple file operations, visual checks, and tasks requiring specific application features (e.g., charts, pivot tables, print settings, and **other visual elements**).
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
search_section = textwrap.dedent(
|
||||
f"""
|
||||
## 1.2 Search Agent
|
||||
You have access to a search agent that can browse the web to find tutorials.
|
||||
* **Use for**: Use the Search Agent **when you are unsure how to perform a GUI-based task**. If you don't know the steps to create a chart, configure a specific setting, or use an unfamiliar feature, use the search agent first.
|
||||
* **Usage Strategy**:
|
||||
* **CRITICAL**: Call the search agent with a clear, concise "how-to" query. For example: `agent.call_search_agent("How to create a pivot table in LibreOffice Calc?")`.
|
||||
* **CRITICAL**: Before searching, evaluate if a tutorial is likely to exist. Well-documented software features always have tutorials. In contrast, tasks with a specific website's unique design (e.g., booking a flight, purchasing an item) typically do not have formal, universal tutorials.
|
||||
* **Result Interpretation**:
|
||||
* **DONE**: The Search Agent finds a step-by-step and **complete** tutorial, often starting from the very beginning. This means the returned guide may contain steps you have already completed. It is **your responsibility** to analyze the tutorial in conjunction with your current screen context to determine the correct step to begin with. **Do not blindly follow the tutorial from step 1.**
|
||||
* **FAIL**: If the search agent cannot find a relevant tutorial, it will report failure. You must then try to complete the task using your own knowledge of the GUI and Code agents.
|
||||
* **Search Agent Verification**: If the result is DONE, it is highly recommended to follow the tutorial with **GUI operations** in the next several steps to verify the tutorial's validation.
|
||||
|
||||
"""
|
||||
) if has_search_agent else ""
|
||||
|
||||
code_section = textwrap.dedent(
|
||||
f"""
|
||||
## 1.3 Code Agent
|
||||
You have access to a code agent that can execute python/bash code in the task environment.
|
||||
* **Use for**: Complex, non-UI tasks. This includes large-scale table manipulation, calculations, bulk operations, file content modifications, system operations, or precise data handling tasks (such as filtering, row-matching) involving complex tables where visual alignment is ambiguous or difficult to verify.
|
||||
* **Usage Strategy**:
|
||||
* **Subtask**: Use `agent.call_code_agent("specific subtask")` for focused data tasks. Please refer to the args explaination of function `call_code_agent`.
|
||||
* **When To Use**:
|
||||
* **Spreadsheet Automation (Strongly Recommended)**: For LibreOffice Calc or Excel tasks, specifically when filling entire rows/columns, performing batch data entry, or running calculations.
|
||||
* **Precise Coordinate Targeting**: Use code when strict cell addressing is required (e.g., writing specifically to cell D2). The GUI agent often struggles to visually distinguish between adjacent cells or columns in dense grids. Code actions ensure 100% address accuracy.
|
||||
* **When NOT to Use**: NEVER use the code agent for charts, graphs, **pivot tables**, or visual elements. Always use the GUI for those.
|
||||
|
||||
* **Code Agent Verification (MANDATORY)**
|
||||
* The code agent works in the background. You CANNOT trust its output report alone. Your job is to verify its work via the GUI.
|
||||
* **Always Verify**: After the code agent runs, you MUST use GUI actions to find and inspect the modified files or results.
|
||||
* **MANDATORY RESTART**: Files modified by the code agent will not show changes in already-open applications. You **MUST close and reopen the entire application** to verify changes. Reloading the file or page is NOT sufficient.
|
||||
* **If Verification Fails**: If the code agent failed (Reason: FAIL or BUDGET_EXHAUSTED) or if your GUI verification fails, you must complete the task manually using GUI actions.
|
||||
* **Infeasible Tasks**: Sometimes the code agent will report the task is impossible to solve. Under this case, if you have verified it's correct, just call `agent.fail()`!
|
||||
|
||||
"""
|
||||
) if has_code_agent else ""
|
||||
|
||||
reflection_section = textwrap.dedent(
|
||||
f"""
|
||||
## 1.4 Reflection Agent (Handling Feedback)
|
||||
* **Use for**: The `Reflection` input is your primary source for error correction and guidance. You **MUST** read it first at every step and adjust your plan accordingly.
|
||||
* **Usage Strategy**:
|
||||
* **If `Off-Track` (GUI Operation Error)**: The reflection indicates your last action failed (e.g., a bad click or type). Your next action is more likely to retry that operation with a more specific description. (e.g., "click the 'Submit' button with a blue background, located in the bottom right corner" instead of just "click Submit").
|
||||
* **If `Off-Track` (Lack of Tutorial)**: The reflection indicates you are stuck, looping, or don't know the steps. You are missing information. You'd better call the search agent.
|
||||
* **If `Off-Track` (Code Error)**: It indicates the code agent fails to finish the task, so you need to recover from potential errors or side effects caused by the failed code execution and continue doing the task by GUI operations.
|
||||
* **If `Off-Track` (Other Error)**: Carefully read the reflection's explanation and form a new plan to fix the deviation.
|
||||
* **If `On-Track`**: Continue with your original plan.
|
||||
* **If `Task Completed` / `Task Infeasible`**: Maybe you need to call `agent.done()` or `agent.fail()`.
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
first_section = gui_section + search_section + code_section + reflection_section
|
||||
procedural_memory += first_section
|
||||
|
||||
if platform == "linux":
|
||||
procedural_memory += textwrap.dedent(
|
||||
f"""\
|
||||
---
|
||||
# 2. ACTION RULES
|
||||
## 2.1 Core Execution Constraints
|
||||
- **Use One Provided Action at a Time**: Execute only one grounded action per turn. Only use the methods provided in the Agent class. Do not invent new methods.
|
||||
- **No Interaction with User**: You MUST complete the task individually. There is **NO** additional input from someone else.
|
||||
- **Password**: Your sudo password is "CLIENT_PASSWORD".
|
||||
- **User**: Your username is "user".
|
||||
- **Home**: Your home path is "/home/user".
|
||||
|
||||
## 2.2 Interaction & Input Guidelines
|
||||
- **Guideline for Clicks**:
|
||||
- **VISIBILITY CHECK (CRITICAL)**: You must strictly ONLY click on elements that are **clearly visible** in the current screenshot. Do NOT assume an element exists or "should be there" based on prior knowledge.
|
||||
- The `element_description` for `agent.click()` must be unambiguous. If similar elements exist, be specific to avoid confusion. Describe the target using its appearance, position, and your purpose.
|
||||
- **Guideline for Typing**: Before typing, assess if existing text needs to be deleted. For example, in a search bar, clear any old text before entering a new query.
|
||||
- **Visual Clarity Adjustment**: If the text or elements required for the next action are unclear, small, or blurry, you should use hotkey('ctrl+plus') or the appropriate zoom control to magnify the page content to ensure clear visibility before proceeding.
|
||||
|
||||
## 2.3 Efficiency & Tool Usage
|
||||
- **Efficiency is Key**:
|
||||
- Prefer `agent.hotkey()` over mouse clicks for shortcuts.
|
||||
- Prefer the software(libreoffice, etc.)'s built-in FEATURES over executing a series of complex steps.
|
||||
- **Code Usage**: For tasks that are clearly achievable via GUI software, you can take a shortcut and use Code Agent (e.g., using FFMPEG to convert video to GIF, or filling multiple rows in a table); however, for tasks that cannot be accomplished via GUI, do NOT use Code to forcibly complete the task.
|
||||
- You MUST use Code agent when filling table (LibreOffice Calc), instead of manual click-and-type in spreadsheets.
|
||||
- You MUST use Code agent when modifying VS Code settings JSON files or code files such as Python, to maximize the avoidance of syntax errors!
|
||||
"""
|
||||
)
|
||||
|
||||
elif platform == "windows":
|
||||
procedural_memory += textwrap.dedent(
|
||||
f"""\
|
||||
---
|
||||
# 2. ACTION RULES
|
||||
## 2.1 Core Execution Constraints
|
||||
- **Use One Provided Action at a Time**: Execute only one grounded action per turn. Only use the methods provided in the Agent class. Do not invent new methods.
|
||||
- **No Interaction with User**: You MUST complete the task individually. There is **NO** additional input from someone else.
|
||||
- **User**: Your username is "Docker".
|
||||
- **Home**: Your home path is "C:\\Users\\Docker"
|
||||
|
||||
## 2.2 Interaction & Input Guidelines
|
||||
- **Guideline for Clicks**:
|
||||
- **VISIBILITY CHECK (CRITICAL)**: You must strictly ONLY click on elements that are **clearly visible** in the current screenshot. Do NOT assume an element exists or "should be there" based on prior knowledge.
|
||||
- The `element_description` for `agent.click()` must be unambiguous. If similar elements exist, be specific to avoid confusion. Describe the target using its appearance, position, and your purpose.
|
||||
- **Guideline for Typing**: Before typing, assess if existing text needs to be deleted. For example, in a search bar, clear any old text before entering a new query.
|
||||
- **Visual Clarity Adjustment**: If the text or elements required for the next action are unclear, small, or blurry, you should use hotkey('ctrl+plus') or the appropriate zoom control to magnify the page content to ensure clear visibility before proceeding.
|
||||
|
||||
## 2.3 Efficiency & Tool Usage
|
||||
- **Efficiency is Key**:
|
||||
- Prefer `agent.hotkey()` over mouse clicks for shortcuts.
|
||||
- Prefer the software(libreoffice, etc.)'s built-in FEATURES over executing a series of complex steps.
|
||||
- **Code Usage**: For tasks that are clearly achievable via GUI software, you can take a shortcut and use Code Agent (e.g., using FFMPEG to convert video to GIF, or filling multiple rows in a table); however, for tasks that cannot be accomplished via GUI, do NOT use Code to forcibly complete the task.
|
||||
- You MUST use Code agent when filling table (LibreOffice Calc), instead of manual click-and-type in spreadsheets.
|
||||
- You MUST use Code agent when modifying VS Code settings JSON files or code files such as Python, to maximize the avoidance of syntax errors!
|
||||
"""
|
||||
)
|
||||
elif platform == "macos":
|
||||
procedural_memory += textwrap.dedent(
|
||||
f"""\
|
||||
---
|
||||
# 2. ACTION RULES
|
||||
## 2.1 Core Execution Constraints
|
||||
- **Use One Provided Action at a Time**: Execute only one grounded action per turn. Only use the methods provided in the Agent class. Do not invent new methods.
|
||||
- **No Interaction with User**: You MUST complete the task individually. There is **NO** additional input from someone else.
|
||||
- **User**: Your username is "pipiwu".
|
||||
- **Password**: Your password is "1234".
|
||||
- **Home**: Your home path is "/Users/pipiwu"
|
||||
|
||||
## 2.2 Interaction & Input Guidelines
|
||||
- **Guideline for Clicks**:
|
||||
- **VISIBILITY CHECK (CRITICAL)**: You must strictly ONLY click on elements that are **clearly visible** in the current screenshot. Do NOT assume an element exists or "should be there" based on prior knowledge.
|
||||
- The `element_description` for `agent.click()` must be unambiguous. If similar elements exist, be specific to avoid confusion. Describe the target using its appearance, position, and your purpose.
|
||||
- **Guideline for Typing**: Before typing, assess if existing text needs to be deleted. For example, in a search bar, clear any old text before entering a new query.
|
||||
- **Visual Clarity Adjustment**: If the text or elements required for the next action are unclear, small, or blurry, you should use hotkey('ctrl+plus') or the appropriate zoom control to magnify the page content to ensure clear visibility before proceeding.
|
||||
|
||||
## 2.3 Efficiency & Tool Usage
|
||||
- **Efficiency is Key**:
|
||||
- Prefer `agent.hotkey()` over mouse clicks for shortcuts.
|
||||
- Prefer the software(libreoffice, etc.)'s built-in FEATURES over executing a series of complex steps.
|
||||
- You MUST use Code agent when filling table (LibreOffice Calc), instead of manual click-and-type in spreadsheets.
|
||||
- **Code Usage**: For tasks that are clearly achievable via GUI software, you can take a shortcut and use Code Agent (e.g., using FFMPEG to convert video to GIF, or filling multiple rows in a table); however, for tasks that cannot be accomplished via GUI, do NOT use Code to forcibly complete the task.
|
||||
"""
|
||||
)
|
||||
else:
|
||||
pass
|
||||
|
||||
procedural_memory += textwrap.dedent(
|
||||
"""
|
||||
- **Search Usage**: When the overall execution logic appears flawed, or if you are unable to accomplish the task after multiple attempts (indicating a lack of specific know-how), or if the Reflection Agent reports a "Lack of Tutorial" error, invoke the Search Agent to retrieve detailed online tutorials for further guidance.
|
||||
"""
|
||||
) if has_search_agent else ""
|
||||
|
||||
procedural_memory += textwrap.dedent(
|
||||
"""
|
||||
|
||||
## 2.4 Task Flow & Verification
|
||||
- **Task Initial State**: The file you need to operate on is usually already open. Please align the screenshot with task description. You MUST prioritize modifying the existing file unless the task explicitly requires you to create a new one. Avoid creating new files unnecessarily.
|
||||
- **Default Sheet Names**: If creating a new sheet and no name is specified, use default names (e.g., "Sheet1", "Sheet2").
|
||||
- **Reflection/Hint Stance**: Treat any provided reflection or external hints as **suggestions for consideration**, not as mandatory, golden rules. Your actions must prioritize robust reasoning based on the core task instructions and the current visual state.
|
||||
- **Infeasible**: Use `agent.fail()` if the task is infeasible (e.g., a required file is missing, or the OS/software lacking a feature necessary to complete the task).
|
||||
- **Completion**: Only use `agent.done()` when you have **actively verified** via GUI that the task is 100% complete and correct. **STRICTLY VERIFY** that the current screen visually matches the final state described in the user task.
|
||||
- **Error Recovery (Application Missteps)**: If a misoperation occurs in file editing software (e.g., LibreOffice), first attempt recovery using **hotkey('ctrl+z')**. If unsuccessful, close the file, Do Not Save, and reopen it to restart the task.
|
||||
- You should proactively save the file after completing file modification tasks and verify that the save was successful.
|
||||
|
||||
---
|
||||
# 3. INPUT & OUTPUT FORMAT
|
||||
You are provided with:
|
||||
1. A screenshot of the current time step.
|
||||
2. The history of your previous interactions with the UI.
|
||||
3. A text reflection generated by a Reflection Agent.
|
||||
4. Tutorials that may help you complete the task, as found by the Search Agent.
|
||||
--- TUTORIALS START ---
|
||||
TUTORIAL_PLACEHOLDER
|
||||
--- TUTORIALS END ---
|
||||
5. Access to the following class and methods to interact with the UI. You MUST select only one action to execute at a time.
|
||||
class Agent:
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
for tool_name, tool_config in config.get('tools', {}).items():
|
||||
# 如果工具被显式禁用,则跳过
|
||||
if tool_config and tool_config.get('enabled') is False:
|
||||
continue
|
||||
if tool_name in skipped_actions:
|
||||
continue
|
||||
attr = getattr(agent_class, tool_name, None)
|
||||
|
||||
if callable(attr) and hasattr(attr, "is_agent_action"):
|
||||
# Use inspect to get the full function signature
|
||||
signature = inspect.signature(attr)
|
||||
procedural_memory += textwrap.dedent(f"""
|
||||
def {tool_name}{signature}:
|
||||
'''{attr.__doc__}'''
|
||||
""")
|
||||
|
||||
procedural_memory += textwrap.dedent(
|
||||
"""
|
||||
**Your response should be formatted like this**:
|
||||
(Previous action verification)
|
||||
Carefully analyze based on the screenshot if the previous action was successful. If the previous action was not successful, provide a reason for the failure.
|
||||
|
||||
(Screenshot Analysis)
|
||||
Closely examine and describe the current state of the desktop along with the currently open applications.
|
||||
|
||||
(Next Action)
|
||||
Based on the current screenshot and the history of your previous interaction with the UI, decide on the next action in natural language to accomplish the given task.
|
||||
|
||||
(Grounded Action)
|
||||
Translate the next action into code using the provided API methods. Format the code like this:
|
||||
```python
|
||||
agent.click("The menu button at the top right of the window", 1, "left")
|
||||
```
|
||||
"""
|
||||
)
|
||||
|
||||
return procedural_memory.strip()
|
||||
|
||||
REWRITE_GUI_INSTRUCTION = textwrap.dedent(
|
||||
"""
|
||||
You are an expert instruction refiner. Your task is to transform verbose, conversational user requests for GUI tasks into clear, direct, and unambiguous **high-level commands** that capture the user's ultimate goal.
|
||||
|
||||
You will be given both the user's text request and a screenshot of the application's initial state. Your primary goal is to synthesize information from both sources to produce a command that states the final objective with as much specificity and context as possible.
|
||||
|
||||
### The Core Distinction: Goal vs. Procedure
|
||||
|
||||
This is the most important rule. The rewritten command must describe the **WHAT** (the user's final objective). It must **NOT** describe the **HOW** (the specific sequence of clicks, menu openings, or keyboard shortcuts to achieve that objective).
|
||||
|
||||
* **User's Goal:** "I want to change the font for all text boxes to 'Liberation Sans Narrow'."
|
||||
* **Correct (Goal-Oriented) Command:** "For the presentation `note-taking-strategies.pptx` in LibreOffice Impress, change the font for all text boxes to 'Liberation Sans Narrow'."
|
||||
* **Incorrect (Procedural) Command:** "Open the Master Slide view, go to Styles, right-click 'Default', select 'Modify', go to the Font tab, choose 'Liberation Sans Narrow', and click OK."
|
||||
|
||||
Your output should always be the **Correct (Goal-Oriented) Command**.
|
||||
|
||||
### Core Principles:
|
||||
|
||||
1. **Focus on the Objective:** The final command must be a statement of the end goal. Eliminate all procedural steps.
|
||||
|
||||
2. **Eliminate Conversational Filler:** Remove all polite expressions, greetings, questions, and personal anecdotes (e.g., "Please," "Could you," "I need to," "Thank you").
|
||||
|
||||
3. **Enrich with Visual Context:** Analyze the screenshot to add critical context to the goal, making it specific and unambiguous.
|
||||
* **Identify the Operating Context:** State the application name (`LibreOffice Impress`), file name (`document.docx`), or website (`github.com`) visible in the screenshot.
|
||||
* **Specify the Target:** If the user says "delete it" and the screenshot shows a file named `report_v2.pdf` is selected, the command should be "Delete the selected file, `report_v2.pdf`."
|
||||
* **Clarify Ambiguous Parameters:** Use the screenshot to translate vague user intent into specific parameters available in the UI. If the user says "make it cheap" and the UI has a "Sort by: Price - Low to High" option, the command is "Sort the results by 'Price: Low to High'."
|
||||
|
||||
4. **Preserve All Essential Details:** Extract and retain every specific detail related to the *goal* itself from the user's text (e.g., file names like `export.jpg`, values like `512 pixels`, font names like `'Liberation Sans Narrow'`).
|
||||
|
||||
5. **Use Imperative (Command) Language:** Start the command with a direct action verb that describes the overall goal (e.g., "Change," "Sort," "Search," "Export").
|
||||
|
||||
6. **Do Not Invent Unjustified Information:** Do not add details or parameters that cannot be inferred from either the user's text or the screenshot.
|
||||
|
||||
### Examples
|
||||
|
||||
**Example 1:**
|
||||
* **Original Request:** "On next Monday, look up a flight from Mumbai to Stockholm."
|
||||
* **Provided Context:** A screenshot of an airline website showing "Round-trip" selected by default.
|
||||
* **Rewritten Command:** "Search for a one-way flight from Mumbai to Stockholm for next Monday."
|
||||
* **Reasoning:** The user's request implies a "one-way" trip. The rewritten command states this as a parameter of the search goal, rather than instructing the AI to "click the one-way button."
|
||||
|
||||
**Example 2:**
|
||||
* **Original Request:** "Help me update my profile."
|
||||
* **Provided Context:** A screenshot of a user's profile page on `github.com`.
|
||||
* **Rewritten Command:** "On `github.com`, update the user profile."
|
||||
* **Reasoning:** The command states the high-level goal and adds the application context from the screenshot. It does not say "Click the 'Edit Profile' button."
|
||||
|
||||
**Example 3:**
|
||||
* **Original Request:** "Find me some cheap headphones."
|
||||
* **Provided Context:** A screenshot of an e-commerce site's search results page with a "Sort by" dropdown.
|
||||
* **Rewritten Command:** "Sort the search results by 'Price: Low to High'."
|
||||
* **Reasoning:** The user's vague intent ("cheap") is translated into a specific, high-level command using the explicit option visible in the UI.
|
||||
|
||||
Now, apply these principles to the user requests and screenshots I provide. Your output should **only** be the final, goal-oriented command.
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
##### Reflection Memory Agent Part!!!!!
|
||||
REFLECTION_SYSTEM_PROMPT = textwrap.dedent(
|
||||
"""
|
||||
You are an expert "Memory & Reflection Agent." Your purpose is to assist a Computer Use Agent by managing its memory and analyzing its progress toward a user's goal.
|
||||
You will perform three tasks:
|
||||
1. **Extract Knowledge**: Identify and save new, useful information.
|
||||
1. **Reflect & Recall**: Provide trajectory feedback and recall saved knowledge when needed.
|
||||
2. **Evaluate Milestone**: Determine if the most recent action was a significant "milestone."
|
||||
|
||||
**Inputs**:
|
||||
- user_instruction (Text): The high-level, ultimate goal the agent is trying to achieve (e.g., "Find the phone number and address for 'The French Laundry' and put it in 'contacts.xlsx'").
|
||||
- history (List of Objects): A sequence of past steps. Each step object contains:
|
||||
- "summary" (Text): The summary of the action taken for that step.
|
||||
- "screenshot" (Image, Optional): The screenshot *after* the action. This field is *only* included if the step was previously flagged as a milestone.
|
||||
- latest_agent_output: (Text) The output from the Computer Use Agent on the last step, containing the Agent's screen analysis, thought process, and action.
|
||||
- IMPORTANT: This action has been DONE!
|
||||
- latest_screenshot (Image): The screenshot AFTER executing the action described in the **latest_agent_output**.
|
||||
- existing_knowledge (Text, Optional): A string containing all previously saved knowledge, which may be empty.
|
||||
- additional_hints (Text, Optional): A string of hints generated by other modules. **Treat these as strong indicators!**.
|
||||
|
||||
---
|
||||
**Task 1: Knowledge Extraction (Saving New Info)**
|
||||
Your first task is to analyze the latest_screenshot in the context of the user_instruction to see if any new, useful knowledge has appeared.
|
||||
- **Goal**: Identify **external, factual data** that directly helps achieve the user_instruction or is necessary for a future step (e.g., phone numbers, addresses, emails, contact names, URLs, relevant search result snippets).
|
||||
- **Crucial Rules**: What NOT to Extract. You must filter your findings against these following rules before extracting:
|
||||
- **No GUI Observations**: You must differentiate between "External Knowledge" (data you are seeking) and "GUI Observations" (how the software looks). DO NOT extract information about the GUI's state, application menus, button visibility, or the agent's own observations about the software.
|
||||
- **No Duplicates**: Check the existing_knowledge input. DO NOT extract any information that is already present. Your goal is to find new information only.
|
||||
- **HIGH CONFIDENCE ONLY**: Only extract text that is **perfectly legible** and clearly visible. **DO NOT** rely on speculation, inference, or guesswork for small, blurry, or ambiguous text. If you lack complete certainty, you must omit the information.
|
||||
- Action: If you find **new**, relevant, **external** knowledge, you will prepare it for the knowledge output field.
|
||||
- Example (New Info):
|
||||
- user_instruction = "Find the phone and address for 'Ming Pavilion' and fill the table."
|
||||
- existing_knowledge = "Ming Pavilion Address: Level 8, Pacific Place, Supreme Court Road, Central"
|
||||
- latest_screenshot shows "Address: Level 8, Pacific Place, Supreme Court Road, Central; Phone: (852) 2820 8580".
|
||||
- Result: You must extract "Ming Pavilion's Phone: (852) 2820 8580" because it is new.
|
||||
- Example (Duplicate Info):
|
||||
- user_instruction = "Find the email of 'Tao Yu'."
|
||||
- existing_knowledge = "Tao Yu's email: tao.yu.nlp@gmail.com"
|
||||
- latest_screenshot shows "Contact me: tao.yu.nlp [AT] gmail.com".
|
||||
- Result: You must extract nothing because it is NOT new.
|
||||
|
||||
---
|
||||
**Task 2: Reflection & Knowledge Recall**
|
||||
Then, you must generate a reflection on the **entire history and current state (last_agent_output and last_screenshot)** in the context of the user_instruction. Your reflection must be one of the four cases below.
|
||||
|
||||
You must check the cases in this order: 1, 2, 3, then 4.
|
||||
- Case 1. **Off-Track**:
|
||||
- You must first classify the error into one of the following types. Your reflection for this case **must** start with the error type, followed by a specific explanation.
|
||||
- **Format**: `The trajectory is not going according to plan. [Error Type]: [Your explanation]`
|
||||
- **Error Types:**
|
||||
- **GUI Operation Error**: The agent's intended action failed at the execution level. It usually occurs when `additional_hints` contain "Warning: The last GUI operation is unsuccessful".
|
||||
- *Examples*: CUA intended to click a non-existent element (hallucination), clicking at the wrong coordinates for a existent element (grounding issue), or a typing error (e.g., trying to input new text without clearing the old content, significant typos).
|
||||
- *Tip*: Do NOT check the action `agent.locate_cursor()`, since it must be correct.
|
||||
- **Lack of Tutorial**: The agent's individual GUI operations (clicks, types) are technically correct, but the overall sequence or logic is flawed. The agent seems not to know *how* to accomplish the task.
|
||||
- *Examples*: The agent is clicking randomly, or appears "stuck" and is stubbornly repeating a fixed set of actions *without* making progress (loop detected).
|
||||
- **Code Error**: This triggers *after* `call_code_agent` has been used and the CUA is now in a "verification" step (e.g., has opened the file that the Code Agent was supposed to modify). The `latest_screenshot` reveals that the Code Agent's work is incorrect, incomplete, or does not match the `user_instruction`.
|
||||
- *Examples*: The Code Agent was supposed to add data to a file, but the `latest_screenshot` (showing the opened file) shows the file is still empty. The Code Agent was supposed to perform a calculation, but the GUI verification shows the wrong result.
|
||||
- **Other Error**: The trajectory is off-track for a reason not covered above. Here are some examples:
|
||||
- CUA is deviating from the goal,
|
||||
- CUA is filling in wrong information that conflicts with knowledge,
|
||||
- Screenshot shows an obvious bug or error (pay attention when editing code or json file)...
|
||||
- **Explanation Details**:
|
||||
- Provide a clear explanation for *why* the agent is off-track, referencing `action_history` or `latest_screenshot`. But DON'T give any advice!
|
||||
- **If Loop Detected**: If you find the agent is repeating actions, you **must** state this clearly in the explanation. (e.g., "...agent appears to be in a non-productive loop by repeating the sequence: [action A, action B, action C].")
|
||||
- **Caveat**: Do not mistake necessary, mechanical repetition (like filling 10 rows in a spreadsheet) for a negative loop. A loop is repetitive action *without progress*.
|
||||
- Case 2. **Task Completed**: **You must have high confidence and sufficient evidence that the high-level `user_instruction` has been successfully and completely fulfilled.** You must verify task completion based on the following:
|
||||
- **Visual Alignment Verification**: **Always verify that the `latest_screenshot` visually and explicitly demonstrates the expected final successful state**. If the action summary suggests the goal is achieved but the **expected screen change is not observed**, the task is **NOT** finished.
|
||||
- **"Outcome over Action" Rule**: You must strictly distinguish between **Action Execution** (e.g., clicking 'Submit', typing text) and **State Change** (e.g., a 'Success' banner appears, page redirects, file's format changes).
|
||||
- **CRITICAL**: The agent clicking a correct button is **NOT** evidence of completion. Buttons can fail, be unresponsive, or trigger errors.
|
||||
- **Requirement**: You must observe the **consequence** of the click in the `latest_screenshot`. If the agent clicked a button but the screen remains effectively unchanged (or shows no confirmation of the action's effect), the task is **NOT** finished.
|
||||
- Case 3. **Task Infeasible**: You are **highly certain** the task cannot be completed. In this case, tell the agent to choose "fail" action. This may be due to:
|
||||
- **Factual Errors**: Such as requesting to install a non-existent software version, or the OS/software lacking a feature necessary to complete the task.
|
||||
- **Missing Prerequisites**: Such as attempting to edit a file that does not exist and cannot be found.
|
||||
- Case 4. **On-Track**: (If Cases 1, 2, and 3 do not apply) The CUA is going according to plan. Now, you must perform a sub-check to see if Knowledge Recall is needed.
|
||||
- **Sub-Check (Knowledge Recall)**: Analyze the latest_screenshot and action_history to determine if the agent is now in a position to use previously saved knowledge (from the knowledge input).
|
||||
- **Triggers for Recall**: The agent has opened the target Excel/spreadsheet, a browser with a search bar, or the action_history clearly shows an intent to "write down" or "fill in" the info.
|
||||
- **Format**: "You are on track. [Summary of past actions]. [ (Optional) Content from existing_knowledge input]"
|
||||
|
||||
Rules for Feedback (Cases 1-4):
|
||||
- **Your output MUST be based on one of the case options above**.
|
||||
- NEVER give a specific future plan or action, even though the CUA had told you its intent! Your job is NOT to give suggestions!
|
||||
- Be very certain for Case 4 (DANGEROUS case).
|
||||
- Do **not** classify a task as `Infeasible` if the failure is due to the agent's own confusion, random actions, or lack of knowledge on how to proceed. That is **`Case 1 (Lack of Tutorial)`**. `Infeasible` means the task is *externally* impossible (e.g., the feature does not exist in the software), not that the agent lacks the necessary knowledge.
|
||||
- Pay attention to the latest summary, especially the **screenshot change** part. It may help you analyze the screen.
|
||||
- When CUA has just used the `call_search_agent` or `call_code_agent`, just simply consider it's on-track.
|
||||
- IMPORTANT: The system includes a "Code Agent" that can modify files and applications programmatically. When you see:
|
||||
- Files with different content than expected.
|
||||
- Applications being closed and reopened.
|
||||
- Documents with fewer lines or modified content.
|
||||
...these are likely LEGITIMATE results of those agents' work, not errors. Do not classify the trajectory as "off-plan" just because of these programmatic changes.
|
||||
|
||||
---
|
||||
**Task 3: Milestone Evaluation**
|
||||
After formulating your reflection, you must determine if the latest step qualifies as a "milestone."
|
||||
1. **What IS a "Milestone"?** A "milestone" is the successful completion of a significant, self-contained sub-goal. It represents a major step forward.
|
||||
- Examples of Milestones:
|
||||
- Successfully landing on a key page.
|
||||
- Successfully completing a multi-step form (e.g., submitting the flight search, adding an item to the cart).
|
||||
- Successfully downloading a required file.
|
||||
- Successfully arriving at the final piece of information requested (e.g., the screen now shows the weather in London).
|
||||
|
||||
2. **What is NOT a "Milestone"?** Most successful actions are not milestones. They are just small, incremental steps towards a milestone.
|
||||
- Examples of NON-Milestones: Typing a single character or word into a text field; clicking to open a dropdown menu; selecting a single, simple option (e.g., clicking a checkbox, selecting a date on a calendar unless it's the final action of a form); scrolling the page.
|
||||
|
||||
---
|
||||
**Output Format**: Please format your response as follows below. On (Answer) part, you must output a valid JSON object wrapped by ```json and ```.
|
||||
(Thought)
|
||||
[
|
||||
Your detailed reasoning.
|
||||
Screenshot Analysis: I will first examine and analyze the whole screen VERY carefully.
|
||||
Knowledge Extraction: Did the latest screenshot reveal new, relevant info (like a phone number, address) based on the user instruction? Is thats info really new? Check the existing knowledge and determine! If so, what is it?
|
||||
Reflection & Recall: I will first understand the history and latest agent's output to know what agent has done. I will then formulate my reflection based on the rules mentioned in **"Task 2" part**. But I should NOT give any advice about next step.
|
||||
Milestone: Was the last action a significant milestone or just a small step?
|
||||
]
|
||||
|
||||
(Answer)
|
||||
```json
|
||||
{
|
||||
"is_milestone": true / false,
|
||||
"reflection": "(Fill in the reflection here)",
|
||||
"knowledge": "(Fill in any newly extracted knowledge from Task 1. If no new knowledge was found in this step, this MUST be an empty string)"
|
||||
}
|
||||
```
|
||||
|
||||
Here's your input:
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
SUMMARIZE_STEP_SYSTEM_PROMPT = textwrap.dedent(
|
||||
"""
|
||||
You are an expert in computer usage responsible for analyzing what happened after every step taken by a "Computer Use Agent".
|
||||
|
||||
**Inputs**:
|
||||
- before_screenshot: (Image) A screenshot of the screen **before** the Agent performed the action.
|
||||
- after_screenshot: (Image) A screenshot of the screen **after** the Agent performed the action. This is your ONLY source for judging the outcome.
|
||||
- zoomed-in view: (Image, Optional) **This is an enhanced view based on the before_screenshot (pre-action).**
|
||||
* **Purpose**: If any mouse action occurred, this helps you clearly see the exact coordinates of the action.
|
||||
* **CRITICAL WARNING**: This image reflects the state **before** the action. **NEVER** mistake it for the result of the action. Ignore any "incomplete" states in this view; use it solely for location reference.
|
||||
- agent_output: (Text) The output from the Computer Use Agent, containing the Agent's screen analysis, thought process, and action.
|
||||
|
||||
**Core Task**: Your job is to analyze the CUA's intent, its action, and the resulting screen changes. Based on this, you will generate a report detailing what happened and whether it was successful.
|
||||
|
||||
**Reasoning Guidelines:**
|
||||
1. **Analyze Intent vs. Outcome**: First, understand the CUA's thought process and `Grounded Action` from the agent_output. Next, Analyze what agent intended to do for this SINGLE-STEP. (Be careful not to confuse the intention of a single step with the overall intention). Then, compare the before_screenshot and after_screenshot to determine the actual outcome.
|
||||
2. **Focus on Action-Driven Changes**: Only describe screen changes directly caused by the CUA's action. Ignore irrelevant changes (e.g., the system clock).
|
||||
3. **Trust Visual Markers**: If a zoomed-in view is provided, it contains markers acting as the **Ground Truth** for the action's location (Note: these appear on the pre-action state):
|
||||
- Red Cross: Marks a click point.
|
||||
- Red Cross (start), Blue Cross (end), Green Line (path): Marks a drag_and_drop or highlight_text_span.
|
||||
4. **Verify Success (Strict Criteria)**: **You must apply strict success criteria to check if there is any GUI operation error.** You must examine the `after_screenshot` very carefully.
|
||||
* **Check Single-Step**: Your duty is just to give a feedback based on the LATEST step of CUA. NOT the whole task or process.
|
||||
* **Substantial Expectation**: **Always verify that the `latest_screenshot` visually. The screen state in the after_screenshot must match the **expected outcome** of the operation, not just the physical feedback of the action.
|
||||
|
||||
**Output Fields**:
|
||||
1. Summary: You need to output a comprehensive summary of the CUA's step. It must include:
|
||||
- CUA's Thought: What did the agent think?
|
||||
- CUA's Action: What action did it perform?
|
||||
- Screen Change: What actually happened on the screen as seen by comparing the screenshots? What didn't change?
|
||||
2. Evaluation: An assessment of whether the step was successful. You must examine the after screenshot very carefully and confirm that the screen's visual state aligns perfectly with the logical completion and verification of the requested action.
|
||||
|
||||
**Additional Tips**:
|
||||
- Your role is to record history, not to guide the future. Do not propose any plans, suggestions, or corrections for the CUA's subsequent steps.
|
||||
- **Ambiguity Handling**: For actions such as `highlight_text_span`, `locate_cursor`, or operations involving "Select All", "Underline", etc., where visual changes are subtle or not obvious: if you cannot make a clear visual judgment, **default to evaluating them as 'successful'!!**.
|
||||
|
||||
**Output Format**: Please format your response as follows below. On (Answer) part, you must output a valid JSON object wrapped by ```json and ```.
|
||||
|
||||
(Thoughts)
|
||||
[Your detailed reasoning. First, state the CUA's thought process and intended action. Second, analyze the screenshots (using the zoomed-in view to confirm the action **location**, and the after_screenshot to confirm the **result**) to identify all visual changes and what remains the same. Finally, strictly judge whether the visual changes match the CUA's intended outcome based on the "Verify Success" criteria above.]
|
||||
|
||||
(Answer)
|
||||
```json
|
||||
{
|
||||
"summary": "A summary of the CUA's step. See the rules above.",
|
||||
"evaluation": "fail / successful"
|
||||
}
|
||||
```
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
PHRASE_TO_WORD_COORDS_PROMPT = textwrap.dedent(
|
||||
"""
|
||||
You are an expert in graphical user interfaces. Your task is to process a phrase of text, and identify the most relevant word on the computer screen.
|
||||
You are provided with a phrase, a table with alxl the text on the screen, and a screenshot of the computer screen. You will identify the single word id that is best associated with the provided phrase.
|
||||
This single word must be displayed on the computer screenshot, and its location on the screen should align with the provided phrase.
|
||||
Each row in the text table provides 2 pieces of data in the following order. 1st is the unique word id. 2nd is the corresponding word.
|
||||
|
||||
To be successful, it is very important to follow all these rules:
|
||||
1. First, think step by step and generate your reasoning about which word id to click on.
|
||||
2. Then, output the unique word id. Remember, the word id is the 1st number in each row of the text table.
|
||||
3. If there are multiple occurrences of the same word, use the surrounding context in the phrase to choose the correct one. Pay very close attention to punctuation and capitalization.
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def construct_coder_procedural_memory(platform: str = "linux", client_password: str = ""):
|
||||
# 1. Define Platform-Specific Context
|
||||
if platform == "linux":
|
||||
PLATFORM_SPECIFIC_CONTEXT = textwrap.dedent(
|
||||
"""\
|
||||
# 2. Environment & Execution
|
||||
* **Platform:** Linux
|
||||
* **User:** "user"
|
||||
* **Home:** "/home/user"
|
||||
* **Shell:** Bash
|
||||
* **Sudo:** Use `echo '{client_password}' | sudo -S [COMMAND]`
|
||||
* **Packages:** Install missing packages as needed.
|
||||
* **Ignored Errors:** Ignore "sudo: /etc/sudoers.d is world writable".
|
||||
* **Note:** Code execution might not be visible on screen immediately. GUI actions (like reopening files) may be needed to see changes.
|
||||
"""
|
||||
)
|
||||
PLATFORM_SPECIFIC_CONTEXT = PLATFORM_SPECIFIC_CONTEXT.format(client_password=client_password)
|
||||
elif platform == "windows":
|
||||
PLATFORM_SPECIFIC_CONTEXT = textwrap.dedent(
|
||||
"""\
|
||||
# 2. Environment & Execution
|
||||
* **Platform:** Windows
|
||||
* **User:** "Docker"
|
||||
* **Home:** "C:\\Users\\Docker"
|
||||
* **Shell:** PowerShell
|
||||
* **Packages:** Install missing packages as needed.
|
||||
* **Path Separators:** Use backslashes `\\` for file paths.
|
||||
* **Note:** Code execution might not be visible on screen immediately. GUI actions (like reopening files) may be needed to see changes.
|
||||
"""
|
||||
)
|
||||
elif platform == "macos":
|
||||
# Placeholder for macOS (Darwin) specific instructions
|
||||
PLATFORM_SPECIFIC_CONTEXT = textwrap.dedent(
|
||||
"""\
|
||||
# 2. Environment & Execution
|
||||
* **Platform:** MacOS(Darwin)
|
||||
* **User:** "pipiwu"
|
||||
* **Password:** "1234"
|
||||
* **Home:** "/Users/pipiwu"
|
||||
* **Shell:** Bash
|
||||
* **Packages:** Install missing packages as needed.
|
||||
* **Note:** Code execution might not be visible on screen immediately. GUI actions (like reopening files) may be needed to see changes.
|
||||
* **Note:** You have sudo privileges. It is recommended to use sudo when performing Bash actions.
|
||||
"""
|
||||
)
|
||||
|
||||
# 2. Define Common Instructions (Universal)
|
||||
COMMON_INSTRUCTIONS = textwrap.dedent(
|
||||
"""\
|
||||
You are a code execution agent. Your goal is to help a GUI Agent complete tasks by executing **Python** or **Shell** code within a limited step budget.
|
||||
|
||||
# 1. Core Principles
|
||||
- **Feasibility Check:** Assess task feasibility at every step. Do not attempt impossible tasks.
|
||||
- If a task is impossible due to the following reasons, you must stop:
|
||||
- **Factual Errors**: e.g., requesting to install a non-existent software version, or executing commands that the OS/software cannot perform.
|
||||
- **Missing Critical Prerequisites**: e.g., attempting to edit a file that does not exist and cannot be found. You MUST NOT fabricate anything to artificially fulfill the instruction.
|
||||
- In your (Thought) block, **clearly explain WHY** the task is infeasible.
|
||||
- In your (Answer) block, return FAIL.
|
||||
- **Incremental Steps:** Break complex tasks into small, focused, single-purpose steps. Do not write large, multi-step scripts in one block. Code **does not persist** between steps. Each code block you write MUST be a complete, standalone snippet.
|
||||
|
||||
{platform_context}
|
||||
|
||||
# 3. Core Workflow:
|
||||
1. **Find:** Locate the target file. The screenshot context may show which file is currently open and should be modified.
|
||||
2. **Inspect:** **ALWAYS** read and inspect file contents, data types, and formatting *before* modifying.
|
||||
3. **Modify:**
|
||||
* **Priority:** Modify existing open files IN-PLACE (use screenshot context). Only create new files when explicitly required by the task.
|
||||
* **Strategy:** Perform **COMPLETE OVERWRITES**, not appends. For text files, write the full new content. For .docx/.xlsx, replace all paragraphs/sheets with new content.
|
||||
* **Libraries:** Use appropriate libraries (e.g. `python-docx`, `openpyxl` and so on).
|
||||
* **Preservation:** **PRESERVE** all original formatting, headers (column headers and row headers), styles, file names and directory structure unless explicitly told to change them. The document's visual presentation should remain the same.
|
||||
4. **Verify:** After modifying, inspect the file again to confirm the changes were applied correctly. If verification fails, return to Step 3 and retry the modification.
|
||||
5. **Result Visualization**: At the final step before completing the task (the step before you return DONE), you MUST print out the contents of any files you modified. Use appropriate commands to display the final state of modified files:
|
||||
* For text files (Linux/Mac): `cat filename` or `head -n 50 filename`
|
||||
* For text files (Windows): `Get-Content filename -TotalCount 50` or `type filename`
|
||||
* For Python files: `cat filename.py` (Linux/Mac) or `type filename.py` (Windows)
|
||||
* For any other file type: use appropriate viewing commands.
|
||||
6. **Verification Instructions**: When you complete a task that modifies files, you MUST provide clear verification instructions including specific details about what the GUI agent should check:
|
||||
* Which files were modified and their expected final state (number of lines, key data points, etc.).
|
||||
* How to verify the changes are correct.
|
||||
* Whether the task is complete or if additional GUI actions are needed.
|
||||
|
||||
# 4. Response Format:
|
||||
You MUST respond using exactly this format:
|
||||
|
||||
(Thought)
|
||||
Your step-by-step reasoning about what needs to be done and how to approach the current step. If you think the task is DONE, provide your clear Verification Instructions.
|
||||
|
||||
(Answer)
|
||||
Return EXACTLY ONE of the following options. For all the options, you MUST wrap your answer by ```. The Options are:
|
||||
|
||||
For Python code:
|
||||
```python
|
||||
your_python_code_here
|
||||
```
|
||||
|
||||
For Bash/PowerShell commands:
|
||||
```bash
|
||||
your_shell_commands_here
|
||||
```
|
||||
|
||||
For task completion:
|
||||
```
|
||||
DONE
|
||||
```
|
||||
|
||||
For task failure:
|
||||
```
|
||||
FAIL
|
||||
```
|
||||
|
||||
For impossible tasks (factual errors or missing prerequisites):
|
||||
```
|
||||
INFEASIBLE
|
||||
```
|
||||
"""
|
||||
)
|
||||
|
||||
# 3. Combine and Return
|
||||
CODE_AGENT_PROMPT = COMMON_INSTRUCTIONS.format(platform_context=PLATFORM_SPECIFIC_CONTEXT)
|
||||
|
||||
return CODE_AGENT_PROMPT
|
||||
|
||||
CODE_SUMMARY_AGENT_PROMPT = textwrap.dedent(
|
||||
"""\
|
||||
You are a code execution summarizer. Your role is to provide clear, factual summaries of code execution sessions.
|
||||
|
||||
Key responsibilities:
|
||||
- Summarize the code logic and approach used at each step
|
||||
- Describe the outputs and results produced by code execution
|
||||
- Explain the progression of the solution approach
|
||||
- Use neutral, objective language without making judgments about success or failure
|
||||
- Focus on what was attempted and what resulted
|
||||
- Keep summaries concise and well-structured
|
||||
|
||||
CRITICAL: Include verification instructions for the GUI agent
|
||||
- If files were modified, provide specific verification guidance:
|
||||
* What files were changed and their expected final state
|
||||
* What the GUI agent should look for when verifying
|
||||
* How to verify the changes are correct
|
||||
* Whether the task appears complete or if additional GUI actions are needed
|
||||
- This helps the GUI agent understand what to expect and verify your work properly
|
||||
|
||||
Always maintain a factual, non-judgmental tone.
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def construct_vlm_searcher_procedural_memory(
|
||||
agent_class: type
|
||||
) -> str:
|
||||
"""
|
||||
Dynamically constructs the procedural memory (prompt) for the Searcher Agent.
|
||||
"""
|
||||
# The prompt is updated to focus on contextual alignment.
|
||||
procedural_memory = textwrap.dedent(
|
||||
f"""
|
||||
You are a Searcher Agent, a specialized expert in graphical user interfaces. Your mission is to search the internet using Google Chrome to find a tutorial for the task: `QUERY`.
|
||||
You are working in CURRENT_OS. Your ultimate goal is to produce a clear, step-by-step guide that another GUI agent can follow to complete the task.
|
||||
|
||||
# GUIDELINES
|
||||
|
||||
## Your Role and Goal
|
||||
You are a research assistant. You will be given a "how to" query and an initial screenshot showing the current screen of the main agent you are assisting. Your job is to use the Chrome browser to find the best possible tutorial that is well-aligned with the provided visual context.
|
||||
|
||||
## Leveraging Initial Context
|
||||
1. **Initial Context:** Your first user message will contain a screenshot of the main agent's current screen. This is a key piece of information.
|
||||
2. **Contextual Understanding:** Use this screenshot to understand the main agent's environment (e.g., which application is open, what menu is visible).
|
||||
3. **Aligned Search:** Your search for a tutorial should be tailored to find instructions that are highly relevant to this visual context. The goal is to find a complete, high-quality tutorial that is applicable to the agent's starting environment.
|
||||
|
||||
## Constraints
|
||||
1. **Strictly use Google Chrome.** You must perform all your actions within the Chrome browser window.
|
||||
2. **Be Thorough.** Explore different websites and articles to find the most accurate and comprehensive instructions.
|
||||
3. **Be Cautious.** The information you provide will directly guide another agent. If you are not confident in the accuracy of a step, do not include it.
|
||||
4. **Always rely on verified tutorials.** Use only tutorials that you have personally found and reviewed, rather than relying solely on your internal knowledge.
|
||||
|
||||
## Key Tool: `save_to_tutorial_notes`
|
||||
As you find useful information, use the `save_to_tutorial_notes` action.
|
||||
1. **Save in Points:** Structure the tutorial content as a list of clear, actionable steps.
|
||||
2. **Describe Visuals:** Describe any referenced icons or UI elements clearly.
|
||||
3. **Record URLs:** Always save the URL of the source page.
|
||||
|
||||
## Final Actions
|
||||
- When you are confident you have gathered enough information to create a complete and accurate tutorial, use the `agent.done()` action. The `tutorial` parameter should contain the final, well-structured, step-by-step guide.
|
||||
- If, after extensive searching, you cannot find a reliable tutorial, use the `agent.fail()` action. Provide a hint explaining why the search was unsuccessful.
|
||||
|
||||
**You are provided with**:
|
||||
1. A screenshot of the current time step.
|
||||
2. The history of your previous interactions with the UI.
|
||||
3. Tutorials notes you have already found.
|
||||
--- TUTORIAL NOTES START ---
|
||||
TUTORIAL_PLACEHOLDER
|
||||
--- TUTORIAL NOTES END ---
|
||||
4. Access to the following class and methods to interact with the UI. You must only use these actions.
|
||||
class Agent:
|
||||
"""
|
||||
)
|
||||
|
||||
for tool_name in dir(agent_class):
|
||||
if tool_name.startswith("_"):
|
||||
continue
|
||||
|
||||
attr = getattr(agent_class, tool_name)
|
||||
|
||||
if callable(attr) and hasattr(attr, "is_searcher_agent_action"):
|
||||
signature = inspect.signature(attr)
|
||||
docstring = inspect.getdoc(attr) or "No description available."
|
||||
|
||||
procedural_memory += textwrap.dedent(f"""
|
||||
def {tool_name}{signature}:
|
||||
'''{docstring}'''
|
||||
""")
|
||||
|
||||
procedural_memory += textwrap.dedent(
|
||||
"""
|
||||
# RESPONSE FORMAT
|
||||
Your response must follow this exact format:
|
||||
|
||||
(Previous action verification)
|
||||
Carefully analyze the screenshot to verify if your last action was successful. If it failed, explain why.
|
||||
|
||||
(Screenshot Analysis)
|
||||
Examine the current state of the Chrome browser. Describe the current webpage, any open tabs, and visible UI elements relevant to your search.
|
||||
|
||||
(Next Action)
|
||||
In natural language, decide the next logical step to find the tutorial. This could be refining your search query, clicking a link, scrolling down, or saving a note.
|
||||
|
||||
(Grounded Action)
|
||||
Translate your "Next Action" into a single line of Python code using the `agent` methods provided above.
|
||||
```python
|
||||
agent.type(element_description="the search bar at the top of the Google page", text="how to create a pivot table in excel", enter=True)
|
||||
```
|
||||
|
||||
Note for the grounded action:
|
||||
1. Only perform one action at a time.
|
||||
2. You must use only the available methods provided above. Do not invent new methods.
|
||||
3. Return with `agent.done()` immediately after you have compiled the complete tutorial, or `agent.fail()` if it cannot be completed.
|
||||
4. Prefer hotkeys (`agent.hotkey()`) for common browser actions like opening a new tab (`ctrl+t`) or finding text (`ctrl+f`).
|
||||
5. Generate `agent.fail()` if you are exhaustively stuck and believe the task is impossible.
|
||||
6. Generate `agent.done()` when you believe the task is fully complete and you have a high-quality tutorial.
|
||||
"""
|
||||
)
|
||||
|
||||
return procedural_memory
|
||||
|
||||
@staticmethod
|
||||
def construct_searcher_eager_mode_procedural_memory(
|
||||
agent_class: type
|
||||
):
|
||||
"""
|
||||
Constructs the procedural memory for a Searcher Agent in "Eager Mode" (final attempt).
|
||||
|
||||
This prompt is designed for the scenario where the agent has exhausted its step budget.
|
||||
It restricts the agent to only two possible actions: `done()` or `fail()`, forcing a final,
|
||||
decisive judgment based on the information gathered so far.
|
||||
"""
|
||||
# 1. Set the specific "last chance" introductory text.
|
||||
# This combines the urgency of the planner's eager mode with the Searcher's specific mission.
|
||||
procedural_memory = textwrap.dedent(
|
||||
f"""
|
||||
You are a Searcher Agent, a specialized expert in graphical user interfaces. Your operational budget is now EXHAUSTED.
|
||||
This is your FINAL opportunity to act. You must make a definitive judgment on the task: `QUERY`.
|
||||
You are working in CURRENT_OS.
|
||||
|
||||
# GUIDELINES
|
||||
|
||||
## Final Judgment Mode
|
||||
1. **Analyze Your Notes:** Carefully review all the information you have gathered using `save_to_tutorial_notes`.
|
||||
2. **Make a Final Decision:** Based on your notes, decide if you have enough high-quality information to construct a complete and reliable step-by-step tutorial.
|
||||
3. **Choose One of Two Actions:** You can ONLY use `agent.done()` or `agent.fail()`. No other actions are permitted.
|
||||
|
||||
- **If you choose `agent.done()`:** You MUST provide the complete, well-structured tutorial in the `tutorial` parameter. Compile all your useful notes into a final guide. Do NOT use `done` unless you are highly confident in the tutorial's accuracy and completeness.
|
||||
- **If you choose `agent.fail()`:** Use this if you could not find enough information, or if the information you found is contradictory, unreliable, or incomplete. Provide a reason in the `hint` parameter.
|
||||
|
||||
**You are provided with**:
|
||||
1. A screenshot of the current time step.
|
||||
2. The history of your previous interactions with the UI.
|
||||
3. Tutorials notes you have already found.
|
||||
--- TUTORIAL NOTES START ---
|
||||
TUTORIAL_PLACEHOLDER
|
||||
--- TUTORIAL NOTES END ---
|
||||
4. Access to the following class and methods to interact with the UI. You must only use these two actions.
|
||||
class Agent:
|
||||
"""
|
||||
)
|
||||
|
||||
# 2. Strictly inject only the 'done' and 'fail' methods.
|
||||
# This logic is adapted from the planner's eager mode constructor.
|
||||
eager_tools = ["done", "fail"]
|
||||
for tool_name in eager_tools:
|
||||
attr = getattr(agent_class, tool_name, None)
|
||||
|
||||
# We check for 'is_searcher_agent_action' to be consistent with the SearcherAgent's decorators.
|
||||
if attr and callable(attr) and hasattr(attr, "is_searcher_agent_action"):
|
||||
signature = inspect.signature(attr)
|
||||
docstring = inspect.getdoc(attr) or "No description available."
|
||||
procedural_memory += textwrap.dedent(f"""
|
||||
def {tool_name}{signature}:
|
||||
'''{docstring}'''
|
||||
""")
|
||||
|
||||
# 3. Provide the specific response format for this final decision.
|
||||
procedural_memory += textwrap.dedent(
|
||||
"""
|
||||
# RESPONSE FORMAT
|
||||
Your response must follow this exact format:
|
||||
|
||||
(Final Analysis and Tutorial Compilation)
|
||||
Review your collected notes and the final screenshot. State whether you have sufficient information to create a definitive tutorial. Summarize your reasoning.
|
||||
|
||||
(Final Decision)
|
||||
In natural language, declare your final choice. For example: "The search is successful, and I have compiled a complete tutorial." or "The search has failed because no reliable sources were found for this specific software version."
|
||||
|
||||
(Grounded Action)
|
||||
Translate your final decision into a single line of Python code using the `agent` methods provided above.
|
||||
**Example**:
|
||||
```python
|
||||
agent.done(tutorial="xxxx")
|
||||
```
|
||||
```python
|
||||
agent.fail(hint="xxxx")
|
||||
```
|
||||
**CRITICAL**: You MUST choose one of the following two actions. No other actions are allowed.
|
||||
"""
|
||||
)
|
||||
|
||||
return procedural_memory.strip()
|
||||
|
||||
|
||||
@staticmethod
|
||||
def construct_grounder_procedural_memory(model_name: str):
|
||||
system_prompt, user_message = None, f"Query:REF_EXPR\nOutput only the coordinate of one point in your response.\n"
|
||||
if "scalecua" in model_name.lower():
|
||||
user_message = "REF_EXPR"
|
||||
system_prompt = textwrap.dedent(
|
||||
'''
|
||||
You are an autonomous GUI agent capable of operating on desktops, mobile devices, and web browsers. Your primary function is to analyze screen captures and perform appropriate UI actions to complete assigned tasks.
|
||||
|
||||
## Action Space
|
||||
def click(
|
||||
x: float | None = None,
|
||||
y: float | None = None,
|
||||
clicks: int = 1,
|
||||
button: str = "left",
|
||||
) -> None:
|
||||
"""Clicks on the screen at the specified coordinates. The `x` and `y` parameter specify where the mouse event occurs. If not provided, the current mouse position is used. The `clicks` parameter specifies how many times to click, and the `button` parameter specifies which mouse button to use ('left', 'right', or 'middle')."""
|
||||
pass
|
||||
|
||||
def doubleClick(
|
||||
x: float | None = None,
|
||||
y: float | None = None,
|
||||
button: str = "left",
|
||||
) -> None:
|
||||
"""Performs a double click. This is a wrapper function for click(x, y, 2, 'left')."""
|
||||
pass
|
||||
|
||||
def rightClick(x: float | None = None, y: float | None = None) -> None:
|
||||
"""Performs a right mouse button click. This is a wrapper function for click(x, y, 1, 'right')."""
|
||||
pass
|
||||
|
||||
def moveTo(x: float, y: float) -> None:
|
||||
"""Move the mouse to the specified coordinates."""
|
||||
pass
|
||||
|
||||
def dragTo(
|
||||
x: float | None = None, y: float | None = None, button: str = "left"
|
||||
) -> None:
|
||||
"""Performs a drag-to action with optional `x` and `y` coordinates and button."""
|
||||
pass
|
||||
|
||||
def swipe(
|
||||
from_coord: tuple[float, float] | None = None,
|
||||
to_coord: tuple[float, float] | None = None,
|
||||
direction: str = "up",
|
||||
amount: float = 0.5,
|
||||
) -> None:
|
||||
"""Performs a swipe action on the screen. The `from_coord` and `to_coord` specify the starting and ending coordinates of the swipe. If `to_coord` is not provided, the `direction` and `amount` parameters are used to determine the swipe direction and distance. The `direction` can be 'up', 'down', 'left', or 'right', and the `amount` specifies how far to swipe relative to the screen size (0 to 1)."""
|
||||
pass
|
||||
|
||||
def long_press(x: float, y: float, duration: int = 1) -> None:
|
||||
"""Long press on the screen at the specified coordinates. The `duration` specifies how long to hold the press in seconds."""
|
||||
pass
|
||||
|
||||
## Input Specification
|
||||
- Screenshot of the current screen + task description
|
||||
|
||||
## Output Format
|
||||
<action>
|
||||
[A set of executable action command]
|
||||
</action>
|
||||
|
||||
## Note
|
||||
- Avoid action(s) that would lead to invalid states.
|
||||
- The generated action(s) must exist within the defined action space.
|
||||
- The generated action(s) should be enclosed within <action></action> tags.'''
|
||||
)
|
||||
return system_prompt, user_message
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
description: "The config of all tools"
|
||||
|
||||
tools:
|
||||
click:
|
||||
enabled: true
|
||||
|
||||
type:
|
||||
enabled: true
|
||||
|
||||
scroll:
|
||||
enabled: true
|
||||
|
||||
drag_and_drop:
|
||||
enabled: true
|
||||
|
||||
highlight_text_span:
|
||||
enabled: true
|
||||
|
||||
locate_cursor:
|
||||
enabled: true
|
||||
|
||||
call_code_agent:
|
||||
enabled: true
|
||||
|
||||
call_search_agent:
|
||||
enabled: true
|
||||
|
||||
scroll:
|
||||
enabled: true
|
||||
|
||||
hotkey:
|
||||
enabled: true
|
||||
|
||||
hold_and_press:
|
||||
enabled: true
|
||||
|
||||
wait:
|
||||
enabled: true
|
||||
|
||||
done:
|
||||
enabled: true
|
||||
|
||||
fail:
|
||||
enabled: true
|
||||
|
||||
open:
|
||||
enabled: true
|
||||
|
||||
|
|
@ -0,0 +1,448 @@
|
|||
import json
|
||||
import re
|
||||
import time
|
||||
from io import BytesIO
|
||||
from typing import Tuple, Dict, List, Union
|
||||
import io
|
||||
import os
|
||||
from PIL import Image, ImageDraw
|
||||
from mm_agents.os_symphony.memory.procedural_memory import PROCEDURAL_MEMORY
|
||||
from mm_agents.os_symphony.utils.process_context import get_current_result_dir
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("desktopenv.agent")
|
||||
|
||||
|
||||
def create_pyautogui_code(agent, code: str, obs: Dict) -> Tuple[str, dict | None]:
|
||||
"""
|
||||
Attempts to evaluate the code into a pyautogui code snippet with grounded actions using the observation screenshot.
|
||||
|
||||
Args:
|
||||
agent (ACI): The grounding agent to use for evaluation.
|
||||
code (str): The code string to evaluate.
|
||||
obs (Dict): The current observation containing the screenshot.
|
||||
|
||||
Returns:
|
||||
exec_code (str): The pyautogui code to execute the grounded action.
|
||||
coordinate (List): The coordinate of the action, a list such as [x1, y1, x2, y2, x3, y3...]. Because may appear more than one coordinate in one action.
|
||||
Modified by Yang.
|
||||
Raises:
|
||||
Exception: If there is an error in evaluating the code.
|
||||
"""
|
||||
agent.assign_screenshot(obs) # Necessary for grounding
|
||||
response = eval(code)
|
||||
if isinstance(response, Tuple):
|
||||
return response
|
||||
elif isinstance(response, str):
|
||||
return response, None
|
||||
else:
|
||||
return "", None
|
||||
|
||||
|
||||
def draw_coordinates(image_bytes: bytes, coordinates: List[Union[int, float]], save_path: str):
|
||||
"""
|
||||
Draw coordinates on the given image and save it to a new file.
|
||||
|
||||
This function receives an image as a byte stream, a list of coordinates in the format [x1, y1, x2, y2, ...],
|
||||
and draws a red 'X' at each (x, y) coordinate point. The resulting image is then saved to the specified path.
|
||||
|
||||
Args:
|
||||
- image_bytes (bytes): The raw byte data of the image (e.g., read from a PNG or JPEG file).
|
||||
- coordinates (List[Union[int, float]]): A flattened list of coordinates, must contain an even number of elements. For example: [x1, y1, x2, y2].
|
||||
- save_path (str): The path where the new image with markings will be saved.
|
||||
"""
|
||||
try:
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
image = image.convert("RGB")
|
||||
except Exception as e:
|
||||
return
|
||||
|
||||
draw = ImageDraw.Draw(image)
|
||||
|
||||
cross_size = 15
|
||||
cross_color = "red"
|
||||
cross_width = 3
|
||||
|
||||
for i in range(0, len(coordinates) - 1, 2):
|
||||
x, y = coordinates[i], coordinates[i+1]
|
||||
|
||||
line1_start = (x - cross_size, y - cross_size)
|
||||
line1_end = (x + cross_size, y + cross_size)
|
||||
|
||||
line2_start = (x + cross_size, y - cross_size)
|
||||
line2_end = (x - cross_size, y + cross_size)
|
||||
|
||||
draw.line([line1_start, line1_end], fill=cross_color, width=cross_width)
|
||||
draw.line([line2_start, line2_end], fill=cross_color, width=cross_width)
|
||||
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
image.save(save_path)
|
||||
|
||||
|
||||
def parse_action_from_string(string):
|
||||
'''
|
||||
Parse all strings following "(next action)", including the phrase "next action" itself. If parsing is not possible, return everything.
|
||||
'''
|
||||
marker = "(Next Action)"
|
||||
|
||||
start_index = string.find(marker)
|
||||
|
||||
if start_index != -1:
|
||||
return string[start_index:]
|
||||
else:
|
||||
return string
|
||||
|
||||
|
||||
def call_llm_safe(
|
||||
agent, temperature: float = 0.0, use_thinking: bool = False, **kwargs
|
||||
) -> str:
|
||||
|
||||
try:
|
||||
example_result_dir = get_current_result_dir()
|
||||
except Exception:
|
||||
example_result_dir = "logs/tokens"
|
||||
# Retry if fails
|
||||
max_retries = 3 # Set the maximum number of retries
|
||||
attempt = 0
|
||||
response = ""
|
||||
while attempt < max_retries:
|
||||
try:
|
||||
response = agent.get_response(
|
||||
temperature=temperature, use_thinking=use_thinking, **kwargs
|
||||
)
|
||||
assert response is not None, "Response from agent should not be None"
|
||||
# print("Response success!")
|
||||
break # If successful, break out of the loop
|
||||
except Exception as e:
|
||||
attempt += 1
|
||||
print(f"{agent.engine} Attempt {attempt} failed: {e}")
|
||||
if attempt == max_retries:
|
||||
print("Max retries reached. Handling failure.")
|
||||
time.sleep(1.0)
|
||||
# record token cost
|
||||
if isinstance(response, tuple):
|
||||
response, usage = response
|
||||
agent_name = agent.agent_name
|
||||
with open(os.path.join(example_result_dir, "token.jsonl"), "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps({
|
||||
"agent_name": agent_name,
|
||||
"completion_tokens": usage.completion_tokens,
|
||||
"prompt_tokens": usage.prompt_tokens,
|
||||
"total_tokens": usage.total_tokens
|
||||
}))
|
||||
f.write("\n")
|
||||
|
||||
return response if response is not None else ""
|
||||
|
||||
|
||||
def call_func_safe(
|
||||
func, **kwargs
|
||||
) -> str:
|
||||
# Retry if fails
|
||||
max_retries = 3 # Set the maximum number of retries
|
||||
attempt = 0
|
||||
response = ""
|
||||
while attempt < max_retries:
|
||||
try:
|
||||
response = func(**kwargs)
|
||||
break
|
||||
except Exception as e:
|
||||
attempt += 1
|
||||
print(f"Attempt {attempt} failed: {e}")
|
||||
if attempt == max_retries:
|
||||
print("Max retries reached. Handling failure.")
|
||||
time.sleep(1.0)
|
||||
|
||||
return response if response is not None else ""
|
||||
|
||||
|
||||
def extract_coords_from_action_dict(action_dict: Dict | None) -> List:
|
||||
coords = []
|
||||
coords_num = 0
|
||||
if action_dict:
|
||||
for k, v in action_dict["args"].items():
|
||||
if (k == "x" and v) or (k == "y" and v) or (k == "x1" and v) or (k == "x2" and v) or (k == "y1" and v) or (k == "y2" and v):
|
||||
coords_num += 1
|
||||
if coords_num == 2:
|
||||
coords.append(action_dict["args"]["x"])
|
||||
coords.append(action_dict["args"]["y"])
|
||||
if coords_num == 4:
|
||||
coords.append(action_dict["args"]["x1"])
|
||||
coords.append(action_dict["args"]["y1"])
|
||||
coords.append(action_dict["args"]["x2"])
|
||||
coords.append(action_dict["args"]["y2"])
|
||||
return coords
|
||||
|
||||
|
||||
def call_llm_formatted(generator, format_checkers, **kwargs):
|
||||
"""
|
||||
Calls the generator agent's LLM and ensures correct formatting.
|
||||
|
||||
Args:
|
||||
generator (ACI): The generator agent to call.
|
||||
obs (Dict): The current observation containing the screenshot.
|
||||
format_checkers (Callable): Functions that take the response and return a tuple of (success, feedback).
|
||||
**kwargs: Additional keyword arguments for the LLM call.
|
||||
|
||||
Returns:
|
||||
response (str): The formatted response from the generator agent.
|
||||
"""
|
||||
max_retries = 3 # Set the maximum number of retries
|
||||
attempt = 0
|
||||
response = ""
|
||||
if kwargs.get("messages") is None:
|
||||
messages = (
|
||||
generator.messages.copy()
|
||||
) # Copy messages to avoid modifying the original
|
||||
else:
|
||||
messages = kwargs["messages"]
|
||||
del kwargs["messages"] # Remove messages from kwargs to avoid passing it twice
|
||||
while attempt < max_retries:
|
||||
response = call_llm_safe(generator, messages=messages, **kwargs)
|
||||
# Prepare feedback messages for incorrect formatting
|
||||
feedback_msgs = []
|
||||
for format_checker in format_checkers:
|
||||
success, feedback = format_checker(response)
|
||||
if not success:
|
||||
feedback_msgs.append(feedback)
|
||||
if not feedback_msgs:
|
||||
# logger.info(f"Response formatted correctly on attempt {attempt} for {generator.engine.model}")
|
||||
break
|
||||
logger.error(
|
||||
f"Response formatting error on attempt {attempt} for {generator.engine.model}. Response: {response} {', '.join(feedback_msgs)}"
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": response}],
|
||||
}
|
||||
)
|
||||
logger.info(f"Bad response: {response}")
|
||||
delimiter = "\n- "
|
||||
formatting_feedback = f"- {delimiter.join(feedback_msgs)}"
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": PROCEDURAL_MEMORY.FORMATTING_FEEDBACK_PROMPT.replace(
|
||||
"FORMATTING_FEEDBACK", formatting_feedback
|
||||
),
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
logger.info("Feedback:\n%s", formatting_feedback)
|
||||
|
||||
attempt += 1
|
||||
if attempt == max_retries:
|
||||
logger.error(
|
||||
"Max retries reached when formatting response. Handling failure."
|
||||
)
|
||||
time.sleep(1.0)
|
||||
return response
|
||||
|
||||
|
||||
def split_thinking_response(full_response: str) -> Tuple[str, str]:
|
||||
try:
|
||||
# Extract thoughts section
|
||||
thoughts = full_response.split("<thoughts>")[-1].split("</thoughts>")[0].strip()
|
||||
|
||||
# Extract answer section
|
||||
answer = full_response.split("<answer>")[-1].split("</answer>")[0].strip()
|
||||
|
||||
return answer, thoughts
|
||||
except Exception as e:
|
||||
return full_response, ""
|
||||
|
||||
|
||||
def parse_code_from_string(input_string):
|
||||
"""Parses a string to extract each line of code enclosed in triple backticks (```)
|
||||
|
||||
Args:
|
||||
input_string (str): The input string containing code snippets.
|
||||
|
||||
Returns:
|
||||
str: The last code snippet found in the input string, or an empty string if no code is found.
|
||||
"""
|
||||
input_string = input_string.strip()
|
||||
|
||||
# This regular expression will match both ```code``` and ```python code```
|
||||
# and capture the `code` part. It uses a non-greedy match for the content inside.
|
||||
pattern = r"```(?:\w+\s+)?(.*?)```"
|
||||
# print(f'[parse_code_from_string].input_string: {input_string}')
|
||||
# Find all non-overlapping matches in the string
|
||||
matches = re.findall(pattern, input_string, re.DOTALL)
|
||||
if len(matches) == 0:
|
||||
# return []
|
||||
return ""
|
||||
relevant_code = matches[
|
||||
-1
|
||||
] # We only care about the last match given it is the grounded action
|
||||
# print(f'[parse_code_from_string].relevant_code: {relevant_code}')
|
||||
return relevant_code
|
||||
|
||||
|
||||
def extract_agent_functions(code):
|
||||
"""
|
||||
Extracts all agent function names from the given code.
|
||||
|
||||
Args:
|
||||
code (str): The code string to search.
|
||||
|
||||
Returns:
|
||||
list: A list of strings like ['agent.click', 'agent.type'].
|
||||
"""
|
||||
pattern = r"agent\.\w+"
|
||||
|
||||
return re.findall(pattern, code)
|
||||
|
||||
|
||||
def compress_image(image_bytes: bytes = None, image: Image = None) -> bytes:
|
||||
"""Compresses an image represented as bytes.
|
||||
|
||||
Compression involves resizing image into half its original size and saving to webp format.
|
||||
|
||||
Args:
|
||||
image_bytes (bytes): The image data to compress.
|
||||
|
||||
Returns:
|
||||
bytes: The compressed image data.
|
||||
"""
|
||||
if not image:
|
||||
image = Image.open(BytesIO(image_bytes))
|
||||
output = BytesIO()
|
||||
image.save(output, format="WEBP")
|
||||
compressed_image_bytes = output.getvalue()
|
||||
return compressed_image_bytes
|
||||
|
||||
import math
|
||||
|
||||
IMAGE_FACTOR = 28
|
||||
MIN_PIXELS = 100 * 28 * 28
|
||||
MAX_PIXELS = 16384 * 28 * 28
|
||||
MAX_RATIO = 200
|
||||
|
||||
def round_by_factor(number: int, factor: int) -> int:
|
||||
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
||||
return round(number / factor) * factor
|
||||
|
||||
|
||||
def ceil_by_factor(number: int, factor: int) -> int:
|
||||
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.ceil(number / factor) * factor
|
||||
|
||||
|
||||
def floor_by_factor(number: int, factor: int) -> int:
|
||||
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.floor(number / factor) * factor
|
||||
|
||||
|
||||
def smart_resize(
|
||||
height: int,
|
||||
width: int,
|
||||
factor: int = IMAGE_FACTOR,
|
||||
min_pixels: int = MIN_PIXELS,
|
||||
max_pixels: int = MAX_PIXELS,
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Rescales the image so that the following conditions are met:
|
||||
|
||||
1. Both dimensions (height and width) are divisible by 'factor'.
|
||||
|
||||
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
||||
|
||||
3. The aspect ratio of the image is maintained as closely as possible.
|
||||
"""
|
||||
min_pixels = MIN_PIXELS if not min_pixels else min_pixels
|
||||
max_pixels = MAX_PIXELS if not max_pixels else max_pixels
|
||||
if max(height, width) / min(height, width) > MAX_RATIO:
|
||||
raise ValueError(
|
||||
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
|
||||
)
|
||||
h_bar = max(factor, round_by_factor(height, factor))
|
||||
w_bar = max(factor, round_by_factor(width, factor))
|
||||
if h_bar * w_bar > max_pixels:
|
||||
beta = math.sqrt((height * width) / max_pixels)
|
||||
h_bar = floor_by_factor(height / beta, factor)
|
||||
w_bar = floor_by_factor(width / beta, factor)
|
||||
elif h_bar * w_bar < min_pixels:
|
||||
beta = math.sqrt(min_pixels / (height * width))
|
||||
h_bar = ceil_by_factor(height * beta, factor)
|
||||
w_bar = ceil_by_factor(width * beta, factor)
|
||||
return h_bar, w_bar
|
||||
|
||||
|
||||
def enhance_observation(image_data: bytes, coordinates: List, expansion_pixels: int = 400, draw=True) -> Tuple[bytes, int, int, int, int]:
|
||||
"""
|
||||
According to the given coordinates, draw markers on the screenshot and crop a "focused" area.
|
||||
|
||||
Returns:
|
||||
Tuple[bytes, int, int, int, int]:
|
||||
- new_image_data (bytes): Data of the cropped image
|
||||
- crop_left (int): X-axis offset
|
||||
- crop_top (int): Y-axis offset
|
||||
- new_width (int): Width of the cropped image
|
||||
- new_height (int): Height of the cropped image
|
||||
"""
|
||||
image = Image.open(io.BytesIO(image_data)).convert("RGBA")
|
||||
draw_ctx = ImageDraw.Draw(image)
|
||||
|
||||
img_width, img_height = image.size
|
||||
|
||||
X_MARKER_SIZE = 40
|
||||
X_MARKER_WIDTH = 5
|
||||
|
||||
def _draw_x(draw_context, center_x, center_y, size=X_MARKER_SIZE, color="red", width=X_MARKER_WIDTH):
|
||||
half_size = size // 2
|
||||
draw_context.line((center_x - half_size, center_y - half_size, center_x + half_size, center_y + half_size), fill=color, width=width)
|
||||
draw_context.line((center_x - half_size, center_y + half_size, center_x + half_size, center_y - half_size), fill=color, width=width)
|
||||
|
||||
crop_left, crop_top, crop_right, crop_bottom = 0, 0, img_width, img_height
|
||||
|
||||
if len(coordinates) == 2:
|
||||
x, y = coordinates[0], coordinates[1]
|
||||
if draw:
|
||||
_draw_x(draw_ctx, x, y)
|
||||
|
||||
crop_left = x - expansion_pixels
|
||||
crop_top = y - expansion_pixels
|
||||
crop_right = x + expansion_pixels
|
||||
crop_bottom = y + expansion_pixels
|
||||
|
||||
elif len(coordinates) >= 4:
|
||||
x1, y1 = coordinates[0], coordinates[1]
|
||||
x2, y2 = coordinates[2], coordinates[3]
|
||||
|
||||
if draw:
|
||||
_draw_x(draw_ctx, x1, y1, color="red")
|
||||
_draw_x(draw_ctx, x2, y2, color="blue")
|
||||
draw_ctx.line((x1, y1, x2, y2), fill="green", width=5)
|
||||
|
||||
box_left = min(x1, x2)
|
||||
box_top = min(y1, y2)
|
||||
box_right = max(x1, x2)
|
||||
box_bottom = max(y1, y2)
|
||||
|
||||
crop_left = box_left - expansion_pixels
|
||||
crop_top = box_top - expansion_pixels
|
||||
crop_right = box_right + expansion_pixels
|
||||
crop_bottom = box_bottom + expansion_pixels
|
||||
|
||||
# check boundary
|
||||
crop_left = max(0, int(crop_left))
|
||||
crop_top = max(0, int(crop_top))
|
||||
crop_right = min(img_width, int(crop_right))
|
||||
crop_bottom = min(img_height, int(crop_bottom))
|
||||
|
||||
crop_box = (crop_left, crop_top, crop_right, crop_bottom)
|
||||
cropped_image = image.crop(crop_box)
|
||||
|
||||
new_width, new_height = cropped_image.size
|
||||
|
||||
buffered = io.BytesIO()
|
||||
cropped_image.save(buffered, format="PNG")
|
||||
|
||||
return buffered.getvalue(), crop_left, crop_top, new_width, new_height
|
||||
|
||||
|
|
@ -0,0 +1,106 @@
|
|||
"""This file contains various formatting checks used to reprompt an agent for correctly formatted responses."""
|
||||
from typing import List
|
||||
import json
|
||||
import yaml
|
||||
import re
|
||||
from mm_agents.os_symphony.utils.common_utils import (
|
||||
extract_agent_functions,
|
||||
parse_code_from_string,
|
||||
split_thinking_response,
|
||||
)
|
||||
|
||||
|
||||
single_action_check = (
|
||||
lambda response: len(extract_agent_functions(parse_code_from_string(response))) == 1
|
||||
)
|
||||
single_action_error_msg = (
|
||||
"Incorrect code: There must be a single agent action in the code response."
|
||||
)
|
||||
SINGLE_ACTION_FORMATTER = lambda response: (
|
||||
single_action_check(response),
|
||||
single_action_error_msg,
|
||||
)
|
||||
|
||||
|
||||
def code_valid_check(tool_config, response):
|
||||
code = parse_code_from_string(response)
|
||||
print(f'[code_valid_check] parsed code is: {code}')
|
||||
|
||||
# check if the action is pre-defined
|
||||
with open(tool_config, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
valid_methods = set(config['tools'].keys())
|
||||
|
||||
pattern = r"^agent\.(\w+)\(.*\)$"
|
||||
|
||||
match = re.match(pattern, code.strip(), re.DOTALL)
|
||||
|
||||
if match:
|
||||
method_name = match.group(1)
|
||||
print(f'[code_valid_check]: method is {method_name}')
|
||||
if method_name in valid_methods:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
code_valid_error_msg = "Incorrect code: The agent action must be a SINGLE and VALID function and use valid parameters from the docstring list."
|
||||
CODE_VALID_FORMATTER = lambda tool_config, response: (
|
||||
code_valid_check(tool_config, response),
|
||||
code_valid_error_msg,
|
||||
)
|
||||
|
||||
thoughts_answer_tag_check = lambda response: split_thinking_response(response)[1] != ""
|
||||
thoughts_answer_tag_error_msg = "Incorrect response: The response must contain both <thoughts>...</thoughts> and <answer>...</answer> tags."
|
||||
THOUGHTS_ANSWER_TAG_FORMATTER = lambda response: (
|
||||
thoughts_answer_tag_check(response),
|
||||
thoughts_answer_tag_error_msg,
|
||||
)
|
||||
|
||||
integer_answer_check = (
|
||||
lambda response: split_thinking_response(response)[0].strip().isdigit()
|
||||
)
|
||||
integer_answer_error_msg = (
|
||||
"Incorrect response: The <answer>...</answer> tag must contain a single integer."
|
||||
)
|
||||
INTEGER_ANSWER_FORMATTER = lambda response: (
|
||||
integer_answer_check(response),
|
||||
integer_answer_error_msg,
|
||||
)
|
||||
|
||||
|
||||
def json_answer_check(response: str, required_fields: List[str]) -> bool:
|
||||
"""
|
||||
一个只返回 True/False 的检查函数。
|
||||
"""
|
||||
try:
|
||||
answer_str = parse_code_from_string(response)
|
||||
|
||||
if len(answer_str) == 0:
|
||||
return False
|
||||
|
||||
data = json.loads(answer_str)
|
||||
|
||||
if not isinstance(data, dict):
|
||||
return False
|
||||
|
||||
if set(required_fields) - set(data.keys()):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
json_answer_error_msg = (
|
||||
"Incorrect response: The (Answer) part must contain a valid JSON object that includes ALL required keys and need to be wrapped by ```json and ```"
|
||||
)
|
||||
|
||||
|
||||
JSON_ANSWER_FORMATTER = lambda response, required_fields: (
|
||||
json_answer_check(required_fields, response),
|
||||
json_answer_error_msg,
|
||||
)
|
||||
|
|
@ -0,0 +1,216 @@
|
|||
import io
|
||||
import math
|
||||
import numpy as np
|
||||
from typing import List, Tuple, Dict, Any, Optional
|
||||
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from rapidfuzz import fuzz
|
||||
import logging
|
||||
from mm_agents.os_symphony.agents.memoryer_agent import StepBehavior
|
||||
|
||||
logger = logging.getLogger("desktopenv.loop_detection")
|
||||
|
||||
def _are_actions_similar(
|
||||
action1: Dict[str, Any],
|
||||
action2: Dict[str, Any],
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
relative_coord_threshold: float,
|
||||
fuzzy_text_threshold: float,
|
||||
) -> bool:
|
||||
"""
|
||||
[Internal Auxiliary] Determine if two actions are similar based on detailed rules.
|
||||
|
||||
Args:
|
||||
action1: The first action.
|
||||
action2: The second action.
|
||||
image_width: The width of the screenshot.
|
||||
image_height: The height of the screenshot.
|
||||
relative_coord_threshold: A relative distance threshold for coordinate comparison.
|
||||
fuzzy_text_threshold: A similarity threshold (0-100) for fuzzy text matching.
|
||||
|
||||
Returns:
|
||||
Return True if the actions are similar, otherwise return False.
|
||||
"""
|
||||
# ensure same action
|
||||
if action1.get("function") != action2.get("function"):
|
||||
return False
|
||||
|
||||
func = action1.get("function")
|
||||
args1 = action1.get("args", {})
|
||||
args2 = action2.get("args", {})
|
||||
|
||||
diagonal = math.sqrt(image_width**2 + image_height**2)
|
||||
abs_coord_thresh = relative_coord_threshold * diagonal
|
||||
|
||||
def are_coords_close(x1, y1, x2, y2):
|
||||
if None in [x1, y1, x2, y2]: return False
|
||||
distance = math.sqrt((x1 - x2)**2 + (y1 - y2)**2)
|
||||
return distance < abs_coord_thresh
|
||||
|
||||
if func == "click":
|
||||
return (
|
||||
are_coords_close(args1.get("x"), args1.get("y"), args2.get("x"), args2.get("y")) and
|
||||
args1.get("button") == args2.get("button") and
|
||||
args1.get("clicks") == args2.get("clicks")
|
||||
)
|
||||
|
||||
elif func == "open":
|
||||
return args1.get("name") == args2.get("name")
|
||||
|
||||
elif func == "type":
|
||||
if args1.get("x") and args1.get("y") and args2.get("x") and args2.get("y"):
|
||||
return (
|
||||
are_coords_close(args1.get("x"), args1.get("y"), args2.get("x"), args2.get("y")) and
|
||||
args1.get("text") == args2.get("text")
|
||||
)
|
||||
else:
|
||||
return args1.get("text") == args2.get("text")
|
||||
|
||||
elif func == "drag":
|
||||
return (
|
||||
are_coords_close(args1.get("x1"), args1.get("y1"), args2.get("x1"), args2.get("y1")) and
|
||||
are_coords_close(args1.get("x2"), args1.get("y2"), args2.get("x2"), args2.get("y2"))
|
||||
)
|
||||
|
||||
elif func == "set_cell_values":
|
||||
return args1.get("text") == args2.get("text")
|
||||
|
||||
elif func == "scroll":
|
||||
clicks1 = args1.get("clicks", 0)
|
||||
clicks2 = args2.get("clicks", 0)
|
||||
if (clicks1 == 0 and clicks2 != 0) or (clicks1 != 0 and clicks2 == 0):
|
||||
same_direction = False
|
||||
else:
|
||||
same_direction = math.copysign(1, clicks1) == math.copysign(1, clicks2)
|
||||
|
||||
return (
|
||||
are_coords_close(args1.get("x"), args1.get("y"), args2.get("x"), args2.get("y")) and
|
||||
same_direction and
|
||||
args1.get("shift") == args2.get("shift")
|
||||
)
|
||||
|
||||
elif func == "key":
|
||||
return args1.get("keys") == args2.get("keys")
|
||||
|
||||
elif func == "wait":
|
||||
return True
|
||||
|
||||
elif func in ["call_code_agent", "call_search_agent"]:
|
||||
query1 = args1.get("query", "")
|
||||
query2 = args2.get("query", "")
|
||||
# use Levenshtein distance to calculate fuzzy similarity
|
||||
query_similarity = fuzz.token_set_ratio(query1, query2)
|
||||
# print(f'query_sim: {query_similarity}')
|
||||
return (
|
||||
query_similarity >= fuzzy_text_threshold and
|
||||
args1.get("result") == args2.get("result")
|
||||
)
|
||||
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def _are_steps_similar_optimized(
|
||||
step1: StepBehavior,
|
||||
step2: StepBehavior,
|
||||
idx1: int,
|
||||
idx2: int,
|
||||
full_trajectory: List[StepBehavior],
|
||||
phash_threshold: int,
|
||||
ssim_threshold: float,
|
||||
# 动作比较所需的参数
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
relative_coord_threshold: float,
|
||||
fuzzy_text_threshold: float,
|
||||
) -> bool:
|
||||
"""
|
||||
[Internal Auxiliary] use pre-calculated data to quickly determine if the two actions are similar/
|
||||
"""
|
||||
|
||||
if step1.phash is None or step2.phash is None:
|
||||
return False
|
||||
|
||||
if (step1.phash - step2.phash) > phash_threshold:
|
||||
return False
|
||||
|
||||
|
||||
later_step_idx = max(idx1, idx2)
|
||||
earlier_step_idx = min(idx1, idx2)
|
||||
|
||||
ssim_score = full_trajectory[later_step_idx].ssim_list[earlier_step_idx]
|
||||
|
||||
if ssim_score < ssim_threshold:
|
||||
return False
|
||||
|
||||
if not _are_actions_similar(
|
||||
step1.action_dict, step2.action_dict,
|
||||
image_width, image_height, relative_coord_threshold, fuzzy_text_threshold
|
||||
):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def detect_loop(
|
||||
full_trajectory: List[StepBehavior],
|
||||
image_width: int = 1920,
|
||||
image_height: int = 1080,
|
||||
N: int = 3,
|
||||
phash_threshold: int = 1,
|
||||
ssim_threshold: float = 0.99,
|
||||
relative_coord_threshold: float = 0.02,
|
||||
fuzzy_text_threshold: float = 85.0,
|
||||
) -> Tuple[bool, Optional[Dict[str, List[int]]]]:
|
||||
"""
|
||||
Efficiently detect the presence of looping patterns based on precomputed data.
|
||||
|
||||
Args:
|
||||
full_trajectory (List[StepBehavior]): Full history including the current step.
|
||||
image_width (int): Width of the screenshot.
|
||||
image_height (int): Height of the screenshot.
|
||||
N (int): Number of steps in the candidate loop (sequence length).
|
||||
phash_threshold (int): Hamming distance threshold for pHash similarity. Recommended: 0–2.
|
||||
ssim_threshold (float): SSIM similarity threshold for image comparison. Recommended: 0.95–0.99.
|
||||
relative_coord_threshold (float): Relative threshold for coordinate similarity. Recommended: 0.01–0.05.
|
||||
fuzzy_text_threshold (float): Fuzzy text matching similarity threshold (0–100) for agent queries.
|
||||
|
||||
Returns:
|
||||
A tuple (is_loop_detected, loop_info):
|
||||
- is_loop_detected (bool): Whether a loop is detected.
|
||||
- loop_info (Dict | None): If a loop is detected, contains the indices of the two matching sequences.
|
||||
"""
|
||||
L = len(full_trajectory)
|
||||
|
||||
if not isinstance(N, int) or N <= 0 or L < 2 * N:
|
||||
return False, None
|
||||
|
||||
max_start_index = L - 2 * N
|
||||
for i in range(max_start_index, -1, -1):
|
||||
is_potential_match = True
|
||||
|
||||
for j in range(N):
|
||||
idx_prev = i + j
|
||||
idx_curr = (L - N) + j
|
||||
|
||||
step_prev = full_trajectory[idx_prev]
|
||||
step_curr = full_trajectory[idx_curr]
|
||||
|
||||
if not _are_steps_similar_optimized(
|
||||
step_prev, step_curr, idx_prev, idx_curr, full_trajectory,
|
||||
phash_threshold, ssim_threshold,
|
||||
image_width, image_height, relative_coord_threshold, fuzzy_text_threshold
|
||||
):
|
||||
is_potential_match = False
|
||||
break
|
||||
|
||||
if is_potential_match:
|
||||
previous_sequence_indices = list(range(i, i + N))
|
||||
loop_info = {
|
||||
"match_sequence_indices": previous_sequence_indices
|
||||
}
|
||||
return True, loop_info
|
||||
|
||||
return False, None
|
||||
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
# process_context.py
|
||||
# This module provides an independent context storage for each process.
|
||||
|
||||
from multiprocessing import current_process
|
||||
|
||||
# We will store process-specific contexts here.
|
||||
# Since each process has its own separate memory space, when accessing this variable,
|
||||
# each process accesses its own copy, without conflicting with others.
|
||||
_context_storage = {}
|
||||
|
||||
def set_context(key, value):
|
||||
"""Set a value in the context of the current process."""
|
||||
_context_storage[key] = value
|
||||
# print(f"[{current_process().name}] Set context: {key} = {value}") # For debugging
|
||||
|
||||
def get_context(key, default=None):
|
||||
"""Retrieve a value from the context of the current process."""
|
||||
value = _context_storage.get(key, default)
|
||||
# print(f"[{current_process().name}] Get context: {key} -> {value}") # For debugging
|
||||
if value is None and default is None:
|
||||
raise NameError(f"'{key}' not found in the current process context. Ensure it is set at the process entry point.")
|
||||
return value
|
||||
|
||||
# For convenience, we can create a specialized getter for result_dir
|
||||
def get_current_result_dir():
|
||||
"""Get the result_dir specific to the current process."""
|
||||
return get_context('current_result_dir')
|
||||
|
||||
def set_current_result_dir(example_result_dir):
|
||||
set_context("current_result_dir", example_result_dir)
|
||||
|
|
@ -6,6 +6,9 @@ import os
|
|||
from io import BytesIO
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from http import HTTPStatus
|
||||
import dashscope
|
||||
from dashscope import MultiModalConversation
|
||||
import backoff
|
||||
import openai
|
||||
from PIL import Image
|
||||
|
|
@ -40,7 +43,7 @@ def process_image(image_bytes):
|
|||
height=height,
|
||||
width=width,
|
||||
factor=32,
|
||||
max_pixels=16 * 16 * 4 * 1280,
|
||||
max_pixels=16 * 16 * 4 * 12800,
|
||||
)
|
||||
|
||||
image = image.resize((resized_width, resized_height))
|
||||
|
|
@ -58,7 +61,7 @@ class Qwen3VLAgent:
|
|||
self,
|
||||
platform: str = "ubuntu",
|
||||
model: str = "qwen3-vl",
|
||||
max_tokens: int = 1500,
|
||||
max_tokens: int = 32768,
|
||||
top_p: float = 0.9,
|
||||
temperature: float = 0.0,
|
||||
action_space: str = "pyautogui",
|
||||
|
|
@ -66,6 +69,9 @@ class Qwen3VLAgent:
|
|||
history_n: int = 4,
|
||||
add_thought_prefix: bool = False,
|
||||
coordinate_type: str = "relative",
|
||||
api_backend: str = "dashscope", # "openai" or "dashscope"
|
||||
enable_thinking: bool = False, # Enable thinking mode for DashScope
|
||||
thinking_budget: int = 32768, # Token budget for reasoning
|
||||
):
|
||||
self.platform = platform
|
||||
self.model = model
|
||||
|
|
@ -77,9 +83,13 @@ class Qwen3VLAgent:
|
|||
self.history_n = history_n
|
||||
self.add_thought_prefix = add_thought_prefix
|
||||
self.coordinate_type = coordinate_type
|
||||
self.api_backend = api_backend
|
||||
self.enable_thinking = enable_thinking
|
||||
self.thinking_budget = thinking_budget
|
||||
|
||||
assert action_space in ["pyautogui"], "Invalid action space"
|
||||
assert observation_type in ["screenshot"], "Invalid observation type"
|
||||
assert api_backend in ["openai", "dashscope"], "Invalid API backend, must be 'openai' or 'dashscope'"
|
||||
|
||||
self.thoughts = []
|
||||
self.actions = []
|
||||
|
|
@ -527,6 +537,70 @@ Previous actions:
|
|||
|
||||
return low_level_instruction, pyautogui_code
|
||||
|
||||
@staticmethod
|
||||
def _to_dashscope_messages(messages):
|
||||
"""
|
||||
Convert messages built for OpenAI compat into DashScope MultiModalConversation format.
|
||||
- "text" part -> {"text": "..."}
|
||||
- "image_url" -> {"image": "<url-or-data-uri>"}
|
||||
- "video_url" -> {"video": "<url-or-data-uri>"}
|
||||
"""
|
||||
ds_msgs = []
|
||||
for m in messages:
|
||||
role = m.get("role", "")
|
||||
parts = m.get("content", [])
|
||||
ds_content = []
|
||||
for p in parts:
|
||||
ptype = p.get("type")
|
||||
if ptype == "text":
|
||||
ds_content.append({"text": p.get("text", "")})
|
||||
elif ptype == "image_url":
|
||||
url = (p.get("image_url") or {}).get("url", "")
|
||||
# DashScope accepts http(s), file://, or data:image/*; keep as-is
|
||||
ds_content.append({"image": url})
|
||||
elif ptype == "video_url":
|
||||
url = (p.get("video_url") or {}).get("url", "")
|
||||
ds_content.append({"video": url})
|
||||
else:
|
||||
# If you ever pass raw assistant strings (no parts), tolerate it
|
||||
if isinstance(p, str):
|
||||
ds_content.append({"text": p})
|
||||
# Also tolerate plain-string content (rare)
|
||||
if not ds_content and isinstance(m.get("content"), str):
|
||||
ds_content = [{"text": m["content"]}]
|
||||
ds_msgs.append({"role": role, "content": ds_content})
|
||||
return ds_msgs
|
||||
|
||||
@staticmethod
|
||||
def _extract_text_from_dashscope_response(resp):
|
||||
"""Join all 'text' parts from the first choice, including reasoning if present."""
|
||||
if hasattr(resp, "output"):
|
||||
out = resp.output
|
||||
else:
|
||||
out = resp.get("output") if isinstance(resp, dict) else None
|
||||
if not out:
|
||||
return None
|
||||
choices = getattr(out, "choices", None) if not isinstance(out, dict) else out.get("choices")
|
||||
if not choices:
|
||||
return None
|
||||
msg = getattr(choices[0], "message", None) if not isinstance(choices[0], dict) else choices[0].get("message")
|
||||
if not msg:
|
||||
return None
|
||||
content = getattr(msg, "content", None) if not isinstance(msg, dict) else msg.get("content", [])
|
||||
if not content:
|
||||
return None
|
||||
|
||||
# Extract reasoning content if present (for thinking models)
|
||||
reasoning_content = getattr(msg, "reasoning_content", None) if not isinstance(msg, dict) else msg.get("reasoning_content", None)
|
||||
|
||||
content_text = "".join(part.get("text", "") for part in content if isinstance(part, dict) and "text" in part)
|
||||
|
||||
# Format with thinking tags if reasoning exists
|
||||
if reasoning_content is not None:
|
||||
return f"<think>\n{reasoning_content}\n</think>\n\n{content_text}"
|
||||
else:
|
||||
return content_text
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.constant,
|
||||
(
|
||||
|
|
@ -545,25 +619,93 @@ Previous actions:
|
|||
def call_llm(self, payload, model):
|
||||
messages = payload["messages"]
|
||||
|
||||
base_url = "https://poc-dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
api_key = "sk-123"
|
||||
if self.api_backend == "openai":
|
||||
return self._call_llm_openai(messages, model)
|
||||
elif self.api_backend == "dashscope":
|
||||
return self._call_llm_dashscope(messages, model)
|
||||
else:
|
||||
raise ValueError(f"Unknown API backend: {self.api_backend}")
|
||||
|
||||
def _call_llm_openai(self, messages, model):
|
||||
"""Call LLM using OpenAI SDK (compatible with OpenAI-compatible endpoints)."""
|
||||
base_url = os.environ.get("OPENAI_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1")
|
||||
api_key = os.environ.get("OPENAI_API_KEY", "sk-123")
|
||||
client = openai.OpenAI(base_url=base_url, api_key=api_key)
|
||||
|
||||
for _ in range(MAX_RETRY_TIMES):
|
||||
logger.info("Generating content with Qwen model: %s", model)
|
||||
for attempt in range(1, MAX_RETRY_TIMES + 1):
|
||||
logger.info(f"[OpenAI] Generating content with model: {model} (attempt {attempt}/{MAX_RETRY_TIMES})")
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
max_tokens=self.max_tokens,
|
||||
temperature=self.temperature,
|
||||
top_p=self.top_p,
|
||||
# temperature=self.temperature,
|
||||
# top_p=self.top_p,
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling Qwen model: {e}")
|
||||
time.sleep(5)
|
||||
continue
|
||||
logger.error(f"[OpenAI] Error calling model: {e}")
|
||||
if attempt < MAX_RETRY_TIMES:
|
||||
time.sleep(5)
|
||||
continue
|
||||
break
|
||||
return ""
|
||||
|
||||
def _call_llm_dashscope(self, messages, model):
|
||||
"""Call LLM using DashScope SDK."""
|
||||
dashscope.base_http_api_url = os.environ.get("DASHSCOPE_BASE_URL", "https://dashscope.aliyuncs.com/api/v1")
|
||||
dashscope.api_key = os.environ.get("DASHSCOPE_API_KEY", "sk-123")
|
||||
|
||||
# Convert message schema
|
||||
ds_messages = self._to_dashscope_messages(messages)
|
||||
|
||||
# Retry loop
|
||||
last_err = None
|
||||
for attempt in range(1, MAX_RETRY_TIMES + 1):
|
||||
thinking_status = f" (thinking={self.enable_thinking})" if self.enable_thinking else ""
|
||||
logger.info(f"[DashScope] Generating content with model: {model}, thinking_status: {thinking_status} (attempt {attempt}/{MAX_RETRY_TIMES})")
|
||||
try:
|
||||
# Build API call parameters
|
||||
call_params = {
|
||||
"model": model,
|
||||
"messages": ds_messages,
|
||||
"max_tokens": self.max_tokens,
|
||||
# "temperature": self.temperature,
|
||||
# "top_p": self.top_p,
|
||||
"vl_high_resolution_images": True,
|
||||
}
|
||||
|
||||
# Add thinking parameters if enabled
|
||||
if self.enable_thinking:
|
||||
call_params["enable_thinking"] = True
|
||||
call_params["thinking_budget"] = self.thinking_budget
|
||||
|
||||
resp = MultiModalConversation.call(**call_params)
|
||||
|
||||
if getattr(resp, "status_code", None) not in (None, HTTPStatus.OK):
|
||||
code = getattr(resp, "code", "")
|
||||
msg = getattr(resp, "message", "")
|
||||
reqid = getattr(resp, "request_id", "")
|
||||
logger.warning(f"[DashScope] non-OK response (id={reqid}): {code} {msg}")
|
||||
last_err = RuntimeError(f"DashScope status {resp.status_code}: {code} {msg}")
|
||||
time.sleep(1.5 * attempt)
|
||||
continue
|
||||
|
||||
text = self._extract_text_from_dashscope_response(resp)
|
||||
if not text:
|
||||
raise ValueError("DashScope response has no text content")
|
||||
return text
|
||||
|
||||
except Exception as e:
|
||||
last_err = e
|
||||
logger.error(f"[DashScope] call failed: {e}")
|
||||
if attempt < MAX_RETRY_TIMES:
|
||||
time.sleep(1.5 * attempt)
|
||||
continue
|
||||
break
|
||||
|
||||
if last_err:
|
||||
raise last_err
|
||||
return ""
|
||||
|
||||
def reset(self, _logger=None):
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
14
monitor/.env
14
monitor/.env
|
|
@ -2,13 +2,13 @@
|
|||
# Do not write any secret keys or sensitive information here.
|
||||
|
||||
# Monitor configuration
|
||||
TASK_CONFIG_PATH=../evaluation_examples/test_all.json
|
||||
TASK_CONFIG_PATH=../evaluation_examples/test_50_random_proportional.json
|
||||
EXAMPLES_BASE_PATH=../evaluation_examples/examples
|
||||
RESULTS_BASE_PATH=../results
|
||||
# ACTION_SPACE=pyautogui
|
||||
# OBSERVATION_TYPE=screenshot
|
||||
# MODEL_NAME=computer-use-preview
|
||||
# MAX_STEPS=150
|
||||
FLASK_PORT=80
|
||||
RESULTS_BASE_PATH=../results_hosted_gbox_50
|
||||
ACTION_SPACE=pyautogui
|
||||
OBSERVATION_TYPE=screenshot
|
||||
MODEL_NAME=us.anthropic.claude-sonnet-4-5-20250929-v1:0
|
||||
MAX_STEPS=15
|
||||
FLASK_PORT=8080
|
||||
FLASK_HOST=0.0.0.0
|
||||
FLASK_DEBUG=false
|
||||
|
|
|
|||
|
|
@ -69,4 +69,6 @@ alibabacloud_ecs20140526
|
|||
alibabacloud_tea_openapi
|
||||
alibabacloud_tea_util
|
||||
json_minify
|
||||
json_repair
|
||||
json_repair
|
||||
volcengine-python-sdk[ark]
|
||||
ui-tars>=0.4.2.2
|
||||
2
run.py
2
run.py
|
|
@ -218,7 +218,7 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
|||
f.write("\n")
|
||||
|
||||
env.close()
|
||||
logger.info(f"Average score: {sum(scores) / len(scores)}")
|
||||
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
|
||||
|
||||
|
||||
def get_unfinished(
|
||||
|
|
|
|||
|
|
@ -457,7 +457,7 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
|||
f.write("\n")
|
||||
|
||||
env.close()
|
||||
logger.info(f"Average score: {sum(scores) / len(scores)}")
|
||||
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
|
||||
|
||||
|
||||
def get_unfinished(action_space, use_model, observation_type, result_dir, total_file_json):
|
||||
|
|
|
|||
|
|
@ -485,7 +485,7 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
|||
f.write("\n")
|
||||
|
||||
env.close()
|
||||
logger.info(f"Average score: {sum(scores) / len(scores)}")
|
||||
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
|
||||
|
||||
|
||||
def get_unfinished(action_space, use_model, observation_type, result_dir, total_file_json):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,18 @@
|
|||
# export HF_ENDPOINT=https://hf-mirror.com
|
||||
python run_multienv_dart_gui.py \
|
||||
--dart_base_url http://0.0.0.0:6006/v1 \
|
||||
--provider_name docker \
|
||||
--test_all_meta_path evaluation_examples/test_nogdrive.json \
|
||||
--path_to_vm docker_vm_data/Ubuntu.qcow2 \
|
||||
--headless \
|
||||
--max_steps 30 \
|
||||
--domain all \
|
||||
--num_envs 2 \
|
||||
--log_level INFO \
|
||||
--temperature 1.0 \
|
||||
--save_complete_trajectory \
|
||||
--use_enhanced_runner \
|
||||
--model dart-gui \
|
||||
--model_type qwen25vl \
|
||||
--infer_mode dart_mode \
|
||||
--result_dir ./result_multi_apps_pengxiang_transformers12 | tee run_20251103_multi_apps_pengxiang_transformers12.log
|
||||
|
|
@ -0,0 +1,542 @@
|
|||
from __future__ import annotations
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import signal
|
||||
import time
|
||||
from typing import List, Dict
|
||||
import math
|
||||
from tqdm import tqdm
|
||||
from multiprocessing import Process, Manager
|
||||
from multiprocessing import current_process
|
||||
import lib_run_single
|
||||
from desktop_env.desktop_env import DesktopEnv
|
||||
from mm_agents.agi_agent import AGIAgent
|
||||
|
||||
# Global variables for signal handling
|
||||
active_environments = []
|
||||
processes = []
|
||||
is_terminating = False
|
||||
|
||||
# import wandb
|
||||
|
||||
# load the environment variables from .env file
|
||||
if os.path.exists(".env"):
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
# Logger Configs {{{ #
|
||||
def config() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run end-to-end evaluation on the benchmark"
|
||||
)
|
||||
|
||||
# environment config
|
||||
parser.add_argument("--path_to_vm", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--headless", action="store_true", help="Run in headless machine"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--action_space", type=str, default="pyautogui", help="Action type"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--observation_type",
|
||||
choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
default="screenshot",
|
||||
help="Observation type",
|
||||
)
|
||||
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
|
||||
parser.add_argument("--max_steps", type=int, default=15)
|
||||
|
||||
# agent config
|
||||
parser.add_argument("--max_trajectory_length", type=int, default=3)
|
||||
parser.add_argument(
|
||||
"--test_config_base_dir", type=str, default="evaluation_examples"
|
||||
)
|
||||
|
||||
# example config
|
||||
parser.add_argument("--domain", type=str, nargs='+', default=["all"],
|
||||
help="Domain(s) to run. Use 'all' for all domains, or specify one or more domain names")
|
||||
parser.add_argument(
|
||||
"--test_all_meta_path", type=str, default="evaluation_examples/test_all.json"
|
||||
)
|
||||
|
||||
# logging related
|
||||
parser.add_argument("--result_dir", type=str, default="./results")
|
||||
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel")
|
||||
parser.add_argument("--log_level", type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
|
||||
default='INFO', help="Set the logging level")
|
||||
# aws config
|
||||
parser.add_argument(
|
||||
"--region", type=str, default="us-east-1", help="AWS region for the VM"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--provider_name", type=str, default="aws", choices=["aws", "virtualbox", "vmware", "docker", "azure"], help="Provider name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--client_password", type=str, default="osworld-public-evaluation", help="Client password"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--screen_width", type=int, default=1920, help="Screen width"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--screen_height", type=int, default=1080, help="Screen height"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
args = config() # Get command line arguments first
|
||||
|
||||
logger = logging.getLogger()
|
||||
log_level = getattr(logging, args.log_level.upper())
|
||||
logger.setLevel(log_level)
|
||||
|
||||
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
|
||||
file_handler = logging.FileHandler(
|
||||
os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||
)
|
||||
debug_handler = logging.FileHandler(
|
||||
os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||
)
|
||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||
|
||||
file_handler.setLevel(logging.INFO)
|
||||
debug_handler.setLevel(logging.DEBUG)
|
||||
stdout_handler.setLevel(log_level)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
|
||||
)
|
||||
file_handler.setFormatter(formatter)
|
||||
debug_handler.setFormatter(formatter)
|
||||
stdout_handler.setFormatter(formatter)
|
||||
|
||||
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
||||
|
||||
logger.addHandler(file_handler)
|
||||
logger.addHandler(debug_handler)
|
||||
logger.addHandler(stdout_handler)
|
||||
# }}} Logger Configs #
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
|
||||
def distribute_tasks(test_all_meta: dict) -> List[tuple]:
|
||||
all_tasks = []
|
||||
for domain, examples in test_all_meta.items():
|
||||
for example_id in examples:
|
||||
all_tasks.append((domain, example_id))
|
||||
return all_tasks
|
||||
|
||||
|
||||
def process_signal_handler(signum, frame, env_idx):
|
||||
"""Signal handler for child processes to gracefully shut down their environments."""
|
||||
logger.info(f"Process {env_idx + 1} received signal {signum}. Shutting down...")
|
||||
|
||||
# Get the active_environments from the caller's frame
|
||||
local_vars = frame.f_locals
|
||||
active_environments = local_vars.get('active_environments', [])
|
||||
|
||||
# Close environment in the current process context
|
||||
for env in active_environments:
|
||||
if env is not None:
|
||||
try:
|
||||
logger.info(f"Process {env_idx + 1} closing environment...")
|
||||
env.close()
|
||||
logger.info(f"Process {env_idx + 1} environment closed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Process {env_idx + 1} error closing environment: {e}")
|
||||
|
||||
logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: list):
|
||||
active_environments = []
|
||||
env = None
|
||||
try:
|
||||
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
|
||||
REGION = args.region
|
||||
screen_size = (args.screen_width, args.screen_height)
|
||||
ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)])
|
||||
env = DesktopEnv(
|
||||
path_to_vm=args.path_to_vm,
|
||||
action_space=args.action_space,
|
||||
provider_name=args.provider_name,
|
||||
region=REGION,
|
||||
snapshot_name=ami_id,
|
||||
screen_size=screen_size,
|
||||
headless=args.headless,
|
||||
os_type="Ubuntu",
|
||||
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
enable_proxy=True,
|
||||
client_password=args.client_password
|
||||
)
|
||||
active_environments.append(env)
|
||||
agent = AGIAgent(
|
||||
env=env,
|
||||
# Contact the authors for access to a private deployment endpoint.
|
||||
server_url="https://your-private-agi-endpoint",
|
||||
action_space=args.action_space,
|
||||
observation_type=args.observation_type,
|
||||
max_trajectory_length=args.max_trajectory_length,
|
||||
client_password=args.client_password,
|
||||
provider_name=args.provider_name,
|
||||
screen_width=args.screen_width,
|
||||
screen_height=args.screen_height
|
||||
)
|
||||
logger.info(f"Process {current_process().name} started.")
|
||||
while True:
|
||||
try:
|
||||
item = task_queue.get(timeout=5)
|
||||
except Exception:
|
||||
break
|
||||
domain, example_id = item
|
||||
try:
|
||||
config_file = os.path.join(
|
||||
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
|
||||
)
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
example = json.load(f)
|
||||
logger.info(f"[{current_process().name}][Domain]: {domain}")
|
||||
logger.info(f"[{current_process().name}][Example ID]: {example_id}")
|
||||
logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}")
|
||||
example_result_dir = os.path.join(
|
||||
args.result_dir,
|
||||
args.action_space,
|
||||
args.observation_type,
|
||||
"agi-0",
|
||||
domain,
|
||||
example_id,
|
||||
)
|
||||
os.makedirs(example_result_dir, exist_ok=True)
|
||||
try:
|
||||
lib_run_single.run_single_example_agi(
|
||||
agent,
|
||||
env,
|
||||
example,
|
||||
args.max_steps,
|
||||
example["instruction"],
|
||||
args,
|
||||
example_result_dir,
|
||||
shared_scores,
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
try:
|
||||
env.controller.end_recording(
|
||||
os.path.join(example_result_dir, "recording.mp4")
|
||||
)
|
||||
except Exception as rec_e:
|
||||
logger.error(f"Failed to end recording: {rec_e}")
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||
f.write(
|
||||
json.dumps(
|
||||
{"Error": f"{domain}/{example_id} - {e}"}
|
||||
)
|
||||
)
|
||||
f.write("\n")
|
||||
except Exception as e:
|
||||
logger.error(f"Task-level error in {current_process().name}: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
except Exception as e:
|
||||
logger.error(f"Process-level error in {current_process().name}: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
finally:
|
||||
logger.info(f"{current_process().name} cleaning up environment...")
|
||||
try:
|
||||
if env:
|
||||
env.close()
|
||||
logger.info(f"{current_process().name} environment closed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"{current_process().name} error during environment cleanup: {e}")
|
||||
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
"""Handle termination signals (SIGINT, SIGTERM) to gracefully shutdown environments."""
|
||||
global is_terminating, active_environments, processes
|
||||
|
||||
# Avoid duplicate handling
|
||||
if is_terminating:
|
||||
return
|
||||
|
||||
is_terminating = True
|
||||
logger.info(f"Received signal {signum}. Gracefully shutting down...")
|
||||
|
||||
# Close all registered environments in the main process
|
||||
for env in active_environments:
|
||||
try:
|
||||
logger.info(f"Closing environment...")
|
||||
env.close()
|
||||
logger.info(f"Environment closed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing environment: {e}")
|
||||
|
||||
# Send termination signal to all child processes first
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
try:
|
||||
logger.info(f"Sending termination signal to process {p.name}...")
|
||||
p.terminate()
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending termination signal to process: {e}")
|
||||
|
||||
# Allow a short time for processes to handle their own cleanup
|
||||
time.sleep(1)
|
||||
|
||||
# Forcefully terminate any processes that didn't exit
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
try:
|
||||
logger.info(f"Forcefully terminating process {p.name}...")
|
||||
import signal as sig
|
||||
os.kill(p.pid, sig.SIGKILL)
|
||||
except Exception as e:
|
||||
logger.error(f"Error forcefully terminating process: {e}")
|
||||
|
||||
logger.info("Shutdown complete. Exiting.")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||
global processes
|
||||
logger.info("Args: %s", args)
|
||||
all_tasks = distribute_tasks(test_all_meta)
|
||||
logger.info(f"Total tasks: {len(all_tasks)}")
|
||||
with Manager() as manager:
|
||||
shared_scores = manager.list()
|
||||
task_queue = manager.Queue()
|
||||
for item in all_tasks:
|
||||
task_queue.put(item)
|
||||
num_envs = args.num_envs
|
||||
processes = []
|
||||
for i in range(num_envs):
|
||||
p = Process(
|
||||
target=run_env_tasks,
|
||||
args=(task_queue, args, shared_scores),
|
||||
name=f"EnvProcess-{i+1}"
|
||||
)
|
||||
p.daemon = True
|
||||
p.start()
|
||||
processes.append(p)
|
||||
logger.info(f"Started process {p.name} with PID {p.pid}")
|
||||
try:
|
||||
while True:
|
||||
alive_count = 0
|
||||
for idx, p in enumerate(processes):
|
||||
if not p.is_alive():
|
||||
logger.warning(f"Process {p.name} died, restarting...")
|
||||
new_p = Process(
|
||||
target=run_env_tasks,
|
||||
args=(task_queue, args, shared_scores),
|
||||
name=f"EnvProcess-Restart-{idx+1}"
|
||||
)
|
||||
new_p.daemon = True
|
||||
new_p.start()
|
||||
processes[idx] = new_p
|
||||
logger.info(f"Restarted process {new_p.name} with PID {new_p.pid}")
|
||||
else:
|
||||
alive_count += 1
|
||||
if task_queue.empty():
|
||||
logger.info("All tasks finished.")
|
||||
break
|
||||
if alive_count == 0:
|
||||
logger.error("All processes died, exiting.")
|
||||
break
|
||||
time.sleep(5)
|
||||
for p in processes:
|
||||
p.join()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True)
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
try:
|
||||
logger.info(f"Terminating process {p.name} due to error...")
|
||||
p.terminate()
|
||||
except Exception as term_e:
|
||||
logger.error(f"Error terminating process {p.name}: {term_e}")
|
||||
raise
|
||||
scores = list(shared_scores)
|
||||
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
|
||||
|
||||
|
||||
def get_unfinished(
|
||||
action_space, use_model, observation_type, result_dir, total_file_json
|
||||
):
|
||||
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||
|
||||
if not os.path.exists(target_dir):
|
||||
return total_file_json
|
||||
|
||||
finished = {}
|
||||
for domain in os.listdir(target_dir):
|
||||
finished[domain] = []
|
||||
domain_path = os.path.join(target_dir, domain)
|
||||
if os.path.isdir(domain_path):
|
||||
for example_id in os.listdir(domain_path):
|
||||
if example_id == "onboard":
|
||||
continue
|
||||
example_path = os.path.join(domain_path, example_id)
|
||||
if os.path.isdir(example_path):
|
||||
if "result.txt" not in os.listdir(example_path):
|
||||
# empty all files under example_id
|
||||
for file in os.listdir(example_path):
|
||||
os.remove(os.path.join(example_path, file))
|
||||
else:
|
||||
finished[domain].append(example_id)
|
||||
|
||||
if not finished:
|
||||
return total_file_json
|
||||
|
||||
for domain, examples in finished.items():
|
||||
if domain in total_file_json:
|
||||
total_file_json[domain] = [
|
||||
x for x in total_file_json[domain] if x not in examples
|
||||
]
|
||||
|
||||
return total_file_json
|
||||
|
||||
|
||||
def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
|
||||
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||
if not os.path.exists(target_dir):
|
||||
print("New experiment, no result yet.")
|
||||
return None
|
||||
|
||||
all_result = []
|
||||
|
||||
for domain in os.listdir(target_dir):
|
||||
domain_path = os.path.join(target_dir, domain)
|
||||
if os.path.isdir(domain_path):
|
||||
for example_id in os.listdir(domain_path):
|
||||
example_path = os.path.join(domain_path, example_id)
|
||||
if os.path.isdir(example_path):
|
||||
if "result.txt" in os.listdir(example_path):
|
||||
# empty all files under example_id
|
||||
try:
|
||||
all_result.append(
|
||||
float(
|
||||
open(
|
||||
os.path.join(example_path, "result.txt"), "r"
|
||||
).read()
|
||||
)
|
||||
)
|
||||
except:
|
||||
all_result.append(0.0)
|
||||
|
||||
if not all_result:
|
||||
print("New experiment, no result yet.")
|
||||
return None
|
||||
else:
|
||||
print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%")
|
||||
return all_result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
####### The complete version of the list of examples #######
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
# Register signal handlers for graceful termination
|
||||
signal.signal(signal.SIGINT, signal_handler) # Handle Ctrl+C
|
||||
signal.signal(signal.SIGTERM, signal_handler) # Handle termination signal
|
||||
|
||||
try:
|
||||
args = config()
|
||||
|
||||
# save args to json in result_dir/action_space/observation_type/model/args.json
|
||||
path_to_args = os.path.join(
|
||||
args.result_dir,
|
||||
args.action_space,
|
||||
args.observation_type,
|
||||
"agi-0",
|
||||
"args.json",
|
||||
)
|
||||
os.makedirs(os.path.dirname(path_to_args), exist_ok=True)
|
||||
with open(path_to_args, "w", encoding="utf-8") as f:
|
||||
json.dump(vars(args), f, indent=4)
|
||||
|
||||
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
|
||||
test_all_meta = json.load(f)
|
||||
|
||||
# Handle multiple domains
|
||||
if "all" not in args.domain:
|
||||
# Filter test_all_meta to only include specified domains
|
||||
filtered_meta = {}
|
||||
for domain in args.domain:
|
||||
if domain in test_all_meta:
|
||||
filtered_meta[domain] = test_all_meta[domain]
|
||||
else:
|
||||
logger.warning(f"Domain '{domain}' not found in test_all_meta")
|
||||
test_all_meta = filtered_meta
|
||||
|
||||
test_file_list = get_unfinished(
|
||||
args.action_space,
|
||||
"agi-0",
|
||||
args.observation_type,
|
||||
args.result_dir,
|
||||
test_all_meta,
|
||||
)
|
||||
left_info = ""
|
||||
for domain in test_file_list:
|
||||
left_info += f"{domain}: {len(test_file_list[domain])}\n"
|
||||
logger.info(f"Left tasks:\n{left_info}")
|
||||
|
||||
get_result(
|
||||
args.action_space,
|
||||
"agi-0",
|
||||
args.observation_type,
|
||||
args.result_dir,
|
||||
test_all_meta,
|
||||
)
|
||||
test(args, test_file_list)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Main process received KeyboardInterrupt.")
|
||||
# Signal handler will take care of cleanup
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in main process: {e}", exc_info=True)
|
||||
# Also trigger cleanup for unhandled exceptions
|
||||
signal_handler(signal.SIGTERM, None)
|
||||
finally:
|
||||
# Final cleanup in case any environments or processes remain
|
||||
logger.info("Main process final cleanup...")
|
||||
for env in active_environments:
|
||||
if env is not None:
|
||||
try:
|
||||
logger.info(f"Closing environment in final cleanup...")
|
||||
env.close()
|
||||
logger.info(f"Environment closed successfully in final cleanup")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during final environment cleanup: {e}")
|
||||
|
||||
# First try gentle termination
|
||||
for p in processes:
|
||||
if p is not None and p.is_alive():
|
||||
try:
|
||||
logger.info(f"Terminating process {p.name}...")
|
||||
p.terminate()
|
||||
except Exception as e:
|
||||
logger.error(f"Error terminating process: {e}")
|
||||
|
||||
# Wait a moment for processes to terminate
|
||||
time.sleep(1)
|
||||
|
||||
# Then force kill if needed
|
||||
for p in processes:
|
||||
if p is not None and p.is_alive():
|
||||
try:
|
||||
logger.info(f"Force killing process {p.name}...")
|
||||
os.kill(p.pid, signal.SIGKILL)
|
||||
logger.info(f"Process {p.name} force killed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error force killing process: {e}")
|
||||
|
|
@ -13,6 +13,7 @@ import time
|
|||
from typing import List
|
||||
from multiprocessing import Process, Manager, current_process
|
||||
import lib_run_single
|
||||
from lib_results_logger import log_task_error
|
||||
from desktop_env.desktop_env import DesktopEnv
|
||||
from mm_agents.anthropic import AnthropicAgent
|
||||
|
||||
|
|
@ -67,17 +68,27 @@ def config() -> argparse.Namespace:
|
|||
)
|
||||
|
||||
# lm config
|
||||
parser.add_argument("--model", type=str, default="claude-4-sonnet-20250514")
|
||||
parser.add_argument("--temperature", type=float, default=1.0)
|
||||
parser.add_argument("--top_p", type=float, default=0.9)
|
||||
parser.add_argument("--max_tokens", type=int, default=1500)
|
||||
parser.add_argument("--model", type=str, default="")
|
||||
parser.add_argument("--temperature", type=float, default=None)
|
||||
parser.add_argument("--top_p", type=float, default=None)
|
||||
parser.add_argument("--max_tokens", type=int, default=3000)
|
||||
parser.add_argument("--stop_token", type=str, default=None)
|
||||
|
||||
# thinking mode config
|
||||
parser.add_argument("--no-thinking", action="store_true",
|
||||
help="Disable thinking mode (no scratchpad)")
|
||||
parser.add_argument("--use-isp", action="store_true",
|
||||
help="Use interleaved scratchpad (ISP) mode")
|
||||
|
||||
# example config
|
||||
parser.add_argument("--domain", type=str, default="all")
|
||||
parser.add_argument(
|
||||
"--test_all_meta_path", type=str, default="evaluation_examples/test_all.json"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--specific_task_id", type=str, default=None,
|
||||
help="Run only a specific task ID (overrides domain filtering)"
|
||||
)
|
||||
|
||||
# logging related
|
||||
parser.add_argument("--result_dir", type=str, default="./results")
|
||||
|
|
@ -95,6 +106,37 @@ def config() -> argparse.Namespace:
|
|||
|
||||
args = config() # Get command line arguments first
|
||||
|
||||
# Validate that model is specified to prevent accidental usage with empty model
|
||||
if not args.model or args.model.strip() == "":
|
||||
print("ERROR: Model must be specified. Use --model <model_name>")
|
||||
print("Example: --model claude-sonnet-4-5-20250929")
|
||||
sys.exit(1)
|
||||
|
||||
# Validate model support before proceeding
|
||||
from mm_agents.anthropic.utils import validate_model_support
|
||||
|
||||
# Pass same temperature/top_p and thinking parameters as will be used by the agent
|
||||
validation_kwargs = {}
|
||||
if args.temperature is not None:
|
||||
validation_kwargs['temperature'] = args.temperature
|
||||
if args.top_p is not None:
|
||||
validation_kwargs['top_p'] = args.top_p
|
||||
validation_kwargs['no_thinking'] = args.no_thinking
|
||||
validation_kwargs['use_isp'] = args.use_isp
|
||||
|
||||
if not validate_model_support(args.model, **validation_kwargs):
|
||||
print(f"\n💥 Model '{args.model}' api sample failed")
|
||||
sys.exit(1)
|
||||
|
||||
# Validate thinking mode options are mutually exclusive
|
||||
if args.no_thinking and args.use_isp:
|
||||
print("ERROR: --no-thinking and --use-isp are mutually exclusive")
|
||||
print("Choose one of:")
|
||||
print(" (default): Regular scratchpad mode")
|
||||
print(" --no-thinking: Disable thinking/scratchpad")
|
||||
print(" --use-isp: Use interleaved scratchpad (ISP)")
|
||||
sys.exit(1)
|
||||
|
||||
logger = logging.getLogger()
|
||||
log_level = getattr(logging, args.log_level.upper())
|
||||
logger.setLevel(log_level)
|
||||
|
|
@ -182,7 +224,7 @@ def run_env_tasks(task_queue, args, shared_scores):
|
|||
headless=args.headless,
|
||||
os_type="Ubuntu",
|
||||
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
enable_proxy=True,
|
||||
enable_proxy=False,
|
||||
client_password=args.client_password
|
||||
)
|
||||
active_environments.append(env)
|
||||
|
|
@ -196,8 +238,9 @@ def run_env_tasks(task_queue, args, shared_scores):
|
|||
observation_type=args.observation_type,
|
||||
max_trajectory_length=args.max_trajectory_length,
|
||||
provider_name=args.provider_name,
|
||||
screen_width=args.screen_width,
|
||||
screen_height=args.screen_height,
|
||||
screen_size=(args.screen_width, args.screen_height),
|
||||
no_thinking=getattr(args, 'no_thinking', False),
|
||||
use_isp=getattr(args, 'use_isp', False),
|
||||
)
|
||||
logger.info(f"Process {current_process().name} started.")
|
||||
while True:
|
||||
|
|
@ -239,6 +282,14 @@ def run_env_tasks(task_queue, args, shared_scores):
|
|||
import traceback
|
||||
logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# Log error to results.json
|
||||
try:
|
||||
example = {"id": example_id} # Create minimal example dict for error logging
|
||||
log_task_error(example, str(e), example_result_dir, args)
|
||||
except Exception as log_e:
|
||||
logger.error(f"Failed to log error to results.json: {log_e}")
|
||||
|
||||
try:
|
||||
env.controller.end_recording(
|
||||
os.path.join(example_result_dir, "recording.mp4")
|
||||
|
|
@ -479,7 +530,28 @@ if __name__ == "__main__":
|
|||
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
|
||||
test_all_meta = json.load(f)
|
||||
|
||||
if args.domain != "all":
|
||||
# Filter for specific task ID if provided
|
||||
if args.specific_task_id:
|
||||
logger.info(f"Filtering for specific task ID: {args.specific_task_id}")
|
||||
filtered_meta = {}
|
||||
task_found = False
|
||||
|
||||
for domain, task_ids in test_all_meta.items():
|
||||
for task_id in task_ids:
|
||||
if task_id == args.specific_task_id:
|
||||
filtered_meta[domain] = [task_id]
|
||||
task_found = True
|
||||
logger.info(f"Found task {args.specific_task_id} in domain: {domain}")
|
||||
break
|
||||
if task_found:
|
||||
break
|
||||
|
||||
if not task_found:
|
||||
logger.error(f"Task ID {args.specific_task_id} not found in test file!")
|
||||
sys.exit(1)
|
||||
|
||||
test_all_meta = filtered_meta
|
||||
elif args.domain != "all":
|
||||
test_all_meta = {args.domain: test_all_meta[args.domain]}
|
||||
|
||||
test_file_list = get_unfinished(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,916 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import annotations
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import signal
|
||||
import time
|
||||
from typing import List
|
||||
from multiprocessing import Process, Manager, Queue
|
||||
from multiprocessing import current_process
|
||||
|
||||
from numpy import True_
|
||||
import lib_run_single
|
||||
from desktop_env.desktop_env import DesktopEnv
|
||||
from mm_agents.dart_gui_agent import DartAgent
|
||||
import os
|
||||
|
||||
# Global variables for signal handling
|
||||
active_environments = []
|
||||
processes = []
|
||||
is_terminating = False
|
||||
|
||||
# load the environment variables from .env file
|
||||
if os.path.exists(".env"):
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
# Logger Configs {{{ #
|
||||
def config() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run end-to-end evaluation on the benchmark - Dart Version"
|
||||
)
|
||||
|
||||
# environment config
|
||||
parser.add_argument("--path_to_vm", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--headless", action="store_true", help="Run in headless machine"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--action_space", type=str, default="pyautogui", help="Action type"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--observation_type",
|
||||
choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
default="screenshot",
|
||||
help="Observation type",
|
||||
)
|
||||
parser.add_argument("--sleep_after_execution", type=float, default=5.0)
|
||||
parser.add_argument("--max_steps", type=int, default=15)
|
||||
|
||||
# evaluation config
|
||||
parser.add_argument(
|
||||
"--test_config_base_dir", type=str, default="evaluation_examples"
|
||||
)
|
||||
|
||||
# lm config - Dart specific configurations
|
||||
parser.add_argument("--model", type=str, default="dart-uitars", help="Model name for Dart")
|
||||
parser.add_argument("--model_type", type=str, default="qwen25vl", choices=["qwen25vl", "qwen2vl"])
|
||||
parser.add_argument("--infer_mode", type=str, default="dart_mode", choices=["dart_mode", "qwen2vl_user"])
|
||||
parser.add_argument("--prompt_style", type=str, default="dart_style")
|
||||
parser.add_argument("--input_swap", action="store_true", help="Use copy and paste to type content")
|
||||
parser.add_argument("--language", type=str, default="English")
|
||||
parser.add_argument("--max_pixels", type=float, default=16384*28*28)
|
||||
parser.add_argument("--min_pixels", type=float, default=100*28*28)
|
||||
parser.add_argument("--temperature", type=float, default=0.0)
|
||||
parser.add_argument("--top_p", type=float, default=1.0)
|
||||
parser.add_argument("--top_k", type=int, default=-1)
|
||||
parser.add_argument("--history_n", type=int, default=5)
|
||||
parser.add_argument("--max_tokens", type=int, default=500)
|
||||
parser.add_argument("--stop_token", type=str, default=None)
|
||||
|
||||
parser.add_argument("--max_trajectory_length", type=int, default=None, help="The max number of trajectory steps.")
|
||||
parser.add_argument("--max_image_history_length", type=int, default=5, help="The max number of images in the history.")
|
||||
|
||||
# example config
|
||||
parser.add_argument("--domain", type=str, default="all")
|
||||
parser.add_argument(
|
||||
"--test_all_meta_path", type=str, default="evaluation_examples/test_all.json"
|
||||
)
|
||||
|
||||
# logging related
|
||||
parser.add_argument("--result_dir", type=str, default="./results")
|
||||
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel")
|
||||
parser.add_argument("--log_level", type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
|
||||
default='INFO', help="Set the logging level")
|
||||
# aws config
|
||||
parser.add_argument(
|
||||
"--region", type=str, default="us-east-1", help="AWS region for the VM"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--provider_name", type=str, default="aws", choices=["aws", "virtualbox", "vmware", "docker", "azure"], help="Provider name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--client_password", type=str, default="password", help="Client password"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--screen_width", type=int, default=1920, help="Screen width"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--screen_height", type=int, default=1080, help="Screen height"
|
||||
)
|
||||
|
||||
# Dart specific parameters
|
||||
parser.add_argument("--dart_api_key", type=str, default="", help="Dart API key")
|
||||
parser.add_argument("--dart_base_url", type=str, default="", help="Dart base URL")
|
||||
parser.add_argument("--max_images", type=int, default=5, help="Maximum number of images in prompt history")
|
||||
parser.add_argument("--max_texts", type=int, default=35, help="Maximum number of text responses in prompt history")
|
||||
|
||||
# Enhanced trajectory saving
|
||||
parser.add_argument("--save_complete_trajectory", action="store_true", help="Save complete trajectory with images and detailed information")
|
||||
parser.add_argument("--use_enhanced_runner", action="store_true", help="Use enhanced Dart runner with complete trajectory saving")
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
args = config() # Get command line arguments first
|
||||
|
||||
logger = logging.getLogger()
|
||||
log_level = getattr(logging, args.log_level.upper())
|
||||
logger.setLevel(log_level)
|
||||
|
||||
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
|
||||
file_handler = logging.FileHandler(
|
||||
os.path.join("logs", "dart-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||
)
|
||||
debug_handler = logging.FileHandler(
|
||||
os.path.join("logs", "dart-debug-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||
)
|
||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||
|
||||
file_handler.setLevel(logging.INFO)
|
||||
debug_handler.setLevel(logging.DEBUG)
|
||||
stdout_handler.setLevel(log_level)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
|
||||
)
|
||||
file_handler.setFormatter(formatter)
|
||||
debug_handler.setFormatter(formatter)
|
||||
stdout_handler.setFormatter(formatter)
|
||||
|
||||
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
||||
|
||||
logger.addHandler(file_handler)
|
||||
logger.addHandler(debug_handler)
|
||||
logger.addHandler(stdout_handler)
|
||||
# }}} Logger Configs #
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
|
||||
def distribute_tasks(test_all_meta: dict) -> List[tuple]:
|
||||
all_tasks = []
|
||||
for domain, examples in test_all_meta.items():
|
||||
for example_id in examples:
|
||||
all_tasks.append((domain, example_id))
|
||||
return all_tasks
|
||||
|
||||
|
||||
def process_signal_handler(signum, frame, env_idx):
|
||||
"""Signal handler for child processes to gracefully shut down their environments."""
|
||||
logger.info(f"Process {env_idx + 1} received signal {signum}. Shutting down...")
|
||||
|
||||
# Get the active_environments from the caller's frame
|
||||
local_vars = frame.f_locals
|
||||
active_environments = local_vars.get('active_environments', [])
|
||||
|
||||
# Close environment in the current process context
|
||||
for env in active_environments:
|
||||
if env is not None:
|
||||
try:
|
||||
logger.info(f"Process {env_idx + 1} closing environment...")
|
||||
env.close()
|
||||
logger.info(f"Process {env_idx + 1} environment closed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Process {env_idx + 1} error closing environment: {e}")
|
||||
|
||||
logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.")
|
||||
sys.exit(0)
|
||||
|
||||
def save_complete_trajectory_with_images(example_result_dir: str, task_info: dict, reward: float,
|
||||
messages: list, all_images: list = None):
|
||||
"""
|
||||
保存完整的轨迹信息,包括图片路径
|
||||
|
||||
Args:
|
||||
example_result_dir: 结果保存目录
|
||||
task_info: 任务信息
|
||||
reward: 最终奖励分数
|
||||
messages: 完整的对话消息
|
||||
all_images: 所有图片数据列表(可选)
|
||||
"""
|
||||
import datetime
|
||||
|
||||
# 构建完整轨迹数据
|
||||
complete_trajectory = {
|
||||
"task_info": {
|
||||
"domain": task_info.get("domain", "unknown"),
|
||||
"example_id": task_info.get("example_id", "unknown"),
|
||||
"instruction": task_info.get("instruction", ""),
|
||||
"timestamp": datetime.datetime.now().isoformat()
|
||||
},
|
||||
"evaluation": {
|
||||
"reward": reward,
|
||||
"success": reward > 0
|
||||
},
|
||||
"trajectory": {
|
||||
"messages": [],
|
||||
"image_paths": [],
|
||||
"step_count": 0
|
||||
}
|
||||
}
|
||||
|
||||
# 处理消息和图片路径
|
||||
image_counter = 0
|
||||
step_counter = 0
|
||||
|
||||
for msg_idx, message in enumerate(messages):
|
||||
processed_message = {
|
||||
"step": step_counter,
|
||||
"role": message.get("role", "unknown"),
|
||||
"content": message.get("content", []),
|
||||
"timestamp": message.get("timestamp", ""),
|
||||
"image_files": []
|
||||
}
|
||||
|
||||
# 检查消息中的图片内容
|
||||
if isinstance(message.get("content"), list):
|
||||
for content_item in message["content"]:
|
||||
if content_item.get("type") == "image_url":
|
||||
# 如果有对应的图片数据,保存图片文件
|
||||
if all_images and image_counter < len(all_images):
|
||||
image_filename = f"step_{step_counter}_image_{image_counter}.png"
|
||||
image_path = os.path.join(example_result_dir, image_filename)
|
||||
|
||||
try:
|
||||
# 保存图片
|
||||
if hasattr(all_images[image_counter], 'save'):
|
||||
# PIL Image对象
|
||||
all_images[image_counter].save(image_path)
|
||||
elif isinstance(all_images[image_counter], bytes):
|
||||
# 二进制数据
|
||||
with open(image_path, 'wb') as f:
|
||||
f.write(all_images[image_counter])
|
||||
else:
|
||||
logger.warning(f"Unknown image format for image {image_counter}")
|
||||
continue
|
||||
|
||||
processed_message["image_files"].append(image_filename)
|
||||
complete_trajectory["trajectory"]["image_paths"].append(image_path)
|
||||
logger.info(f"Saved image: {image_filename}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save image {image_counter}: {e}")
|
||||
|
||||
image_counter += 1
|
||||
|
||||
# 更新content中的图片引用为本地路径
|
||||
if processed_message["image_files"]:
|
||||
content_item["local_path"] = processed_message["image_files"][-1]
|
||||
|
||||
complete_trajectory["trajectory"]["messages"].append(processed_message)
|
||||
|
||||
# 如果是assistant的回复,增加步数
|
||||
if message.get("role") == "assistant":
|
||||
step_counter += 1
|
||||
|
||||
complete_trajectory["trajectory"]["step_count"] = step_counter
|
||||
|
||||
# 保存完整轨迹JSON文件
|
||||
trajectory_file = os.path.join(example_result_dir, "complete_trajectory.json")
|
||||
try:
|
||||
with open(trajectory_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(complete_trajectory, f, indent=2, ensure_ascii=False)
|
||||
logger.info(f"Complete trajectory saved to: {trajectory_file}")
|
||||
|
||||
# 同时保存一个简化版本用于快速查看
|
||||
summary_file = os.path.join(example_result_dir, "trajectory_summary.json")
|
||||
summary = {
|
||||
"task_id": task_info.get("example_id", "unknown"),
|
||||
"domain": task_info.get("domain", "unknown"),
|
||||
"instruction": task_info.get("instruction", ""),
|
||||
"reward": reward,
|
||||
"success": reward > 0,
|
||||
"total_steps": step_counter,
|
||||
"total_images": len(complete_trajectory["trajectory"]["image_paths"]),
|
||||
"image_files": [os.path.basename(path) for path in complete_trajectory["trajectory"]["image_paths"]],
|
||||
"timestamp": complete_trajectory["task_info"]["timestamp"]
|
||||
}
|
||||
|
||||
with open(summary_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(summary, f, indent=2, ensure_ascii=False)
|
||||
logger.info(f"Trajectory summary saved to: {summary_file}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save complete trajectory: {e}")
|
||||
|
||||
|
||||
def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: list):
|
||||
active_environments = []
|
||||
env = None
|
||||
try:
|
||||
# Initialize proxy configuration if enabled
|
||||
# if hasattr(args, 'proxy_host') and args.proxy_host and args.proxy_port:
|
||||
# from desktop_env.providers.aws.proxy_pool import get_global_proxy_pool
|
||||
# proxy_pool = get_global_proxy_pool()
|
||||
# proxy_pool.add_proxy(
|
||||
# host=args.proxy_host,
|
||||
# port=args.proxy_port,
|
||||
# protocol=args.proxy_protocol
|
||||
# )
|
||||
# logger.info(f"Added proxy: {args.proxy_host}:{args.proxy_port} ({args.proxy_protocol})")
|
||||
# elif hasattr(args, 'proxy_config') and args.proxy_config and os.path.exists(args.proxy_config):
|
||||
# from desktop_env.providers.aws.proxy_pool import init_proxy_pool
|
||||
# init_proxy_pool(args.proxy_config)
|
||||
# logger.info(f"Initialized proxy pool from {args.proxy_config}")
|
||||
|
||||
# Configure environment based on provider
|
||||
if args.provider_name == "aws":
|
||||
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
|
||||
REGION = args.region
|
||||
screen_size = (args.screen_width, args.screen_height)
|
||||
ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)])
|
||||
env = DesktopEnv(
|
||||
path_to_vm=args.path_to_vm,
|
||||
action_space=args.action_space,
|
||||
provider_name=args.provider_name,
|
||||
region=REGION,
|
||||
snapshot_name=ami_id,
|
||||
screen_size=screen_size,
|
||||
headless=args.headless,
|
||||
os_type="Ubuntu",
|
||||
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"]
|
||||
)
|
||||
else:
|
||||
# For non-AWS providers (docker, virtualbox, etc.)
|
||||
env = DesktopEnv(
|
||||
path_to_vm=args.path_to_vm,
|
||||
action_space=args.action_space,
|
||||
provider_name=args.provider_name,
|
||||
headless=args.headless,
|
||||
os_type="Ubuntu",
|
||||
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"]
|
||||
)
|
||||
active_environments.append(env)
|
||||
args.max_trajectory_length = args.max_steps
|
||||
|
||||
# Dart specific runtime configuration
|
||||
if args.infer_mode == "dart_mode":
|
||||
runtime_conf: dict = {
|
||||
"infer_mode": args.infer_mode,
|
||||
"prompt_style": args.prompt_style,
|
||||
"input_swap": args.input_swap,
|
||||
"language": args.language,
|
||||
"history_n": args.history_n,
|
||||
"max_pixels": args.max_pixels,
|
||||
"min_pixels": args.min_pixels,
|
||||
"temperature": args.temperature,
|
||||
"top_k": args.top_k,
|
||||
"top_p": args.top_p,
|
||||
"max_tokens": args.max_tokens,
|
||||
"max_images": args.max_images,
|
||||
"max_texts": args.max_texts,
|
||||
"dart_api_key": args.dart_api_key,
|
||||
"dart_base_url": args.dart_base_url
|
||||
}
|
||||
elif args.infer_mode == "qwen2vl_user":
|
||||
runtime_conf: dict = {
|
||||
"infer_mode": "qwen2vl_user",
|
||||
"prompt_style": "qwen2vl_user",
|
||||
"input_swap": args.input_swap,
|
||||
"language": args.language,
|
||||
"history_n": 5,
|
||||
"max_pixels": 2116800,
|
||||
"min_pixels": 3136,
|
||||
"temperature": 0.0,
|
||||
"top_k": -1,
|
||||
"top_p": 0.9,
|
||||
"max_tokens": 1000
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unknown infer_mode: {args.infer_mode}")
|
||||
|
||||
agent = DartAgent(
|
||||
model=args.model,
|
||||
action_space=args.action_space,
|
||||
observation_type=args.observation_type,
|
||||
max_trajectory_length=args.max_trajectory_length,
|
||||
model_type=args.model_type,
|
||||
runtime_conf=runtime_conf
|
||||
)
|
||||
|
||||
logger.info(f"Process {current_process().name} started with Dart configuration.")
|
||||
while True:
|
||||
try:
|
||||
item = task_queue.get(timeout=5)
|
||||
except Exception:
|
||||
break
|
||||
domain, example_id = item
|
||||
try:
|
||||
config_file = os.path.join(
|
||||
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
|
||||
)
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
example = json.load(f)
|
||||
logger.info(f"[{current_process().name}][Domain]: {domain}")
|
||||
logger.info(f"[{current_process().name}][Example ID]: {example_id}")
|
||||
logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}")
|
||||
example_result_dir = os.path.join(
|
||||
args.result_dir,
|
||||
args.action_space,
|
||||
args.observation_type,
|
||||
args.model,
|
||||
domain,
|
||||
example_id,
|
||||
)
|
||||
os.makedirs(example_result_dir, exist_ok=True)
|
||||
try:
|
||||
# Create a temporary list to capture the score
|
||||
temp_scores = []
|
||||
|
||||
# 根据参数选择使用哪个运行函数
|
||||
if args.use_enhanced_runner or args.save_complete_trajectory:
|
||||
# 使用九章专用的运行函数,支持完整轨迹保存
|
||||
logger.info(f"Using enhanced Dart runner for {domain}/{example_id}")
|
||||
lib_run_single.run_single_example(
|
||||
agent,
|
||||
env,
|
||||
example,
|
||||
args.max_steps,
|
||||
example["instruction"],
|
||||
args,
|
||||
example_result_dir,
|
||||
temp_scores,
|
||||
)
|
||||
else:
|
||||
# 使用标准运行函数
|
||||
lib_run_single.run_single_example(
|
||||
agent,
|
||||
env,
|
||||
example,
|
||||
args.max_steps,
|
||||
example["instruction"],
|
||||
args,
|
||||
example_result_dir,
|
||||
temp_scores,
|
||||
)
|
||||
# Add domain info to the score
|
||||
if temp_scores:
|
||||
shared_scores.append({
|
||||
'domain': domain,
|
||||
'example_id': example_id,
|
||||
'score': temp_scores[-1]
|
||||
})
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
try:
|
||||
env.controller.end_recording(
|
||||
os.path.join(example_result_dir, "recording.mp4")
|
||||
)
|
||||
except Exception as rec_e:
|
||||
logger.error(f"Failed to end recording: {rec_e}")
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||
f.write(
|
||||
json.dumps(
|
||||
{"Error": f"{domain}/{example_id} - {e}"}
|
||||
)
|
||||
)
|
||||
f.write("\n")
|
||||
except Exception as e:
|
||||
logger.error(f"Task-level error in {current_process().name}: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
except Exception as e:
|
||||
logger.error(f"Process-level error in {current_process().name}: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
finally:
|
||||
logger.info(f"{current_process().name} cleaning up environment...")
|
||||
try:
|
||||
if env:
|
||||
env.close()
|
||||
logger.info(f"{current_process().name} environment closed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"{current_process().name} error during environment cleanup: {e}")
|
||||
|
||||
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
"""Handle termination signals (SIGINT, SIGTERM) to gracefully shutdown environments."""
|
||||
global is_terminating, active_environments, processes
|
||||
|
||||
# Avoid duplicate handling
|
||||
if is_terminating:
|
||||
return
|
||||
|
||||
is_terminating = True
|
||||
logger.info(f"Received signal {signum}. Gracefully shutting down...")
|
||||
|
||||
# Close all registered environments in the main process
|
||||
for env in active_environments:
|
||||
try:
|
||||
logger.info(f"Closing environment...")
|
||||
env.close()
|
||||
logger.info(f"Environment closed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing environment: {e}")
|
||||
|
||||
# Send termination signal to all child processes first
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
try:
|
||||
logger.info(f"Sending termination signal to process {p.name}...")
|
||||
p.terminate()
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending termination signal to process: {e}")
|
||||
|
||||
# Allow a short time for processes to handle their own cleanup
|
||||
time.sleep(1)
|
||||
|
||||
# Forcefully terminate any processes that didn't exit
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
try:
|
||||
logger.info(f"Forcefully terminating process {p.name}...")
|
||||
import signal as sig
|
||||
os.kill(p.pid, sig.SIGKILL)
|
||||
except Exception as e:
|
||||
logger.error(f"Error forcefully terminating process: {e}")
|
||||
|
||||
logger.info("Shutdown complete. Exiting.")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||
global processes
|
||||
logger.info("Args: %s", args)
|
||||
all_tasks = distribute_tasks(test_all_meta)
|
||||
logger.info(f"Total tasks: {len(all_tasks)}")
|
||||
with Manager() as manager:
|
||||
shared_scores = manager.list()
|
||||
task_queue = manager.Queue()
|
||||
for item in all_tasks:
|
||||
task_queue.put(item)
|
||||
num_envs = args.num_envs
|
||||
processes = []
|
||||
for i in range(num_envs):
|
||||
p = Process(
|
||||
target=run_env_tasks,
|
||||
args=(task_queue, args, shared_scores),
|
||||
name=f"DartEnvProcess-{i+1}"
|
||||
)
|
||||
p.daemon = True
|
||||
p.start()
|
||||
processes.append(p)
|
||||
logger.info(f"Started Dart process {p.name} with PID {p.pid}")
|
||||
try:
|
||||
while True:
|
||||
alive_count = 0
|
||||
for idx, p in enumerate(processes):
|
||||
if not p.is_alive():
|
||||
logger.warning(f"Process {p.name} died, restarting...")
|
||||
new_p = Process(
|
||||
target=run_env_tasks,
|
||||
args=(task_queue, args, shared_scores),
|
||||
name=f"DartEnvProcess-Restart-{idx+1}"
|
||||
)
|
||||
new_p.daemon = True
|
||||
new_p.start()
|
||||
processes[idx] = new_p
|
||||
logger.info(f"Restarted process {new_p.name} with PID {new_p.pid}")
|
||||
else:
|
||||
alive_count += 1
|
||||
if task_queue.empty():
|
||||
logger.info("All tasks finished.")
|
||||
break
|
||||
if alive_count == 0:
|
||||
logger.error("All processes died, exiting.")
|
||||
break
|
||||
time.sleep(5)
|
||||
for p in processes:
|
||||
p.join()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True)
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
try:
|
||||
logger.info(f"Terminating process {p.name} due to error...")
|
||||
p.terminate()
|
||||
except Exception as term_e:
|
||||
logger.error(f"Error terminating process {p.name}: {term_e}")
|
||||
raise
|
||||
scores = list(shared_scores)
|
||||
|
||||
# Detailed statistics reporting
|
||||
if scores:
|
||||
# Extract numeric scores for overall statistics
|
||||
numeric_scores = []
|
||||
domain_stats = {}
|
||||
|
||||
for score_entry in scores:
|
||||
if isinstance(score_entry, dict):
|
||||
domain = score_entry.get('domain', 'unknown')
|
||||
example_id = score_entry.get('example_id', 'unknown')
|
||||
score = score_entry.get('score', 0)
|
||||
else:
|
||||
# Handle legacy numeric scores
|
||||
domain = 'unknown'
|
||||
example_id = 'unknown'
|
||||
score = score_entry
|
||||
|
||||
numeric_scores.append(score)
|
||||
|
||||
# Domain statistics
|
||||
if domain not in domain_stats:
|
||||
domain_stats[domain] = {'total': 0, 'success': 0, 'scores': []}
|
||||
|
||||
domain_stats[domain]['total'] += 1
|
||||
domain_stats[domain]['scores'].append(score)
|
||||
if score > 0:
|
||||
domain_stats[domain]['success'] += 1
|
||||
|
||||
# Overall statistics
|
||||
total_tasks = len(numeric_scores)
|
||||
successful_tasks = sum(1 for score in numeric_scores if score > 0)
|
||||
average_score = sum(numeric_scores) / total_tasks
|
||||
success_rate = (successful_tasks / total_tasks) * 100
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("📊 DART EVALUATION RESULTS SUMMARY")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"📈 Overall Statistics:")
|
||||
logger.info(f" • Total tasks executed: {total_tasks}")
|
||||
logger.info(f" • Successful tasks (score > 0): {successful_tasks}")
|
||||
logger.info(f" • Success rate: {success_rate:.1f}%")
|
||||
logger.info(f" • Average score: {average_score:.3f}")
|
||||
|
||||
# Domain-specific statistics
|
||||
if domain_stats and len(domain_stats) > 1: # Only show domain breakdown if multiple domains
|
||||
logger.info(f"\n🏷️ Domain-specific Results:")
|
||||
for domain, stats in sorted(domain_stats.items()):
|
||||
domain_success_rate = (stats['success'] / stats['total']) * 100 if stats['total'] > 0 else 0
|
||||
domain_avg_score = sum(stats['scores']) / len(stats['scores']) if stats['scores'] else 0
|
||||
logger.info(f" • {domain}:")
|
||||
logger.info(f" - Tasks: {stats['total']}")
|
||||
logger.info(f" - Successful: {stats['success']}")
|
||||
logger.info(f" - Success rate: {domain_success_rate:.1f}%")
|
||||
logger.info(f" - Average score: {domain_avg_score:.3f}")
|
||||
|
||||
# Score distribution
|
||||
score_ranges = {
|
||||
'Perfect (1.0)': sum(1 for s in numeric_scores if s == 1.0),
|
||||
'High (0.8-0.99)': sum(1 for s in numeric_scores if 0.8 <= s < 1.0),
|
||||
'Medium (0.5-0.79)': sum(1 for s in numeric_scores if 0.5 <= s < 0.8),
|
||||
'Low (0.1-0.49)': sum(1 for s in numeric_scores if 0.1 <= s < 0.5),
|
||||
'Failed (0.0)': sum(1 for s in numeric_scores if s == 0.0)
|
||||
}
|
||||
|
||||
logger.info(f"\n📊 Score Distribution:")
|
||||
for range_name, count in score_ranges.items():
|
||||
if count > 0:
|
||||
percentage = (count / total_tasks) * 100
|
||||
logger.info(f" • {range_name}: {count} tasks ({percentage:.1f}%)")
|
||||
|
||||
logger.info("=" * 60)
|
||||
else:
|
||||
logger.warning("⚠️ No scores collected during evaluation!")
|
||||
|
||||
|
||||
def get_unfinished(
|
||||
action_space, use_model, observation_type, result_dir, total_file_json
|
||||
):
|
||||
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||
|
||||
if not os.path.exists(target_dir):
|
||||
return total_file_json
|
||||
|
||||
finished = {}
|
||||
for domain in os.listdir(target_dir):
|
||||
finished[domain] = []
|
||||
domain_path = os.path.join(target_dir, domain)
|
||||
if os.path.isdir(domain_path):
|
||||
for example_id in os.listdir(domain_path):
|
||||
if example_id == "onboard":
|
||||
continue
|
||||
example_path = os.path.join(domain_path, example_id)
|
||||
if os.path.isdir(example_path):
|
||||
if "result.txt" not in os.listdir(example_path):
|
||||
# empty all files under example_id
|
||||
for file in os.listdir(example_path):
|
||||
os.remove(os.path.join(example_path, file))
|
||||
else:
|
||||
finished[domain].append(example_id)
|
||||
|
||||
if not finished:
|
||||
return total_file_json
|
||||
|
||||
for domain, examples in finished.items():
|
||||
if domain in total_file_json:
|
||||
total_file_json[domain] = [
|
||||
x for x in total_file_json[domain] if x not in examples
|
||||
]
|
||||
|
||||
return total_file_json
|
||||
|
||||
|
||||
def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
|
||||
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||
if not os.path.exists(target_dir):
|
||||
print("New experiment, no result yet.")
|
||||
return None
|
||||
|
||||
all_result = []
|
||||
|
||||
for domain in os.listdir(target_dir):
|
||||
domain_path = os.path.join(target_dir, domain)
|
||||
if os.path.isdir(domain_path):
|
||||
for example_id in os.listdir(domain_path):
|
||||
example_path = os.path.join(domain_path, example_id)
|
||||
if os.path.isdir(example_path):
|
||||
if "result.txt" in os.listdir(example_path):
|
||||
# empty all files under example_id
|
||||
try:
|
||||
all_result.append(
|
||||
float(
|
||||
open(
|
||||
os.path.join(example_path, "result.txt"), "r"
|
||||
).read()
|
||||
)
|
||||
)
|
||||
except:
|
||||
all_result.append(0.0)
|
||||
|
||||
if not all_result:
|
||||
print("New experiment, no result yet.")
|
||||
return None
|
||||
else:
|
||||
print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%")
|
||||
return all_result
|
||||
|
||||
|
||||
def clear_cache_directory():
|
||||
"""清空cache目录中的所有内容"""
|
||||
cache_dir = "cache"
|
||||
if os.path.exists(cache_dir):
|
||||
logger.info(f"Clearing cache directory: {cache_dir}")
|
||||
try:
|
||||
import shutil
|
||||
# 删除整个cache目录
|
||||
shutil.rmtree(cache_dir)
|
||||
# 重新创建空的cache目录
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
logger.info("Cache directory cleared successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to clear cache directory: {e}")
|
||||
else:
|
||||
logger.info("Cache directory does not exist, creating it")
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
|
||||
def cleanup_docker_containers():
|
||||
"""清理Docker容器,保留monitor容器"""
|
||||
logger.info("Cleaning up Docker containers...")
|
||||
try:
|
||||
import subprocess
|
||||
|
||||
# 获取所有容器ID,排除monitor-monitor-1
|
||||
cmd = 'docker ps --format "{{.ID}} {{.Names}}" | grep -v "monitor-monitor-1" | awk \'{print $1}\''
|
||||
result = subprocess.run(cmd, shell=True, capture_output=True, text=True, timeout=30)
|
||||
|
||||
if result.returncode == 0 and result.stdout.strip():
|
||||
container_ids = result.stdout.strip().split('\n')
|
||||
container_ids = [cid for cid in container_ids if cid.strip()]
|
||||
|
||||
if container_ids:
|
||||
logger.info(f"Found {len(container_ids)} containers to remove: {container_ids}")
|
||||
|
||||
# 强制删除容器
|
||||
for container_id in container_ids:
|
||||
try:
|
||||
rm_result = subprocess.run(
|
||||
f"docker rm -f {container_id}",
|
||||
shell=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10
|
||||
)
|
||||
if rm_result.returncode == 0:
|
||||
logger.info(f"Successfully removed container: {container_id}")
|
||||
else:
|
||||
logger.warning(f"Failed to remove container {container_id}: {rm_result.stderr}")
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning(f"Timeout removing container: {container_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing container {container_id}: {e}")
|
||||
|
||||
logger.info("Docker container cleanup completed")
|
||||
else:
|
||||
logger.info("No containers found to remove")
|
||||
else:
|
||||
logger.info("No containers found or error getting container list")
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error("Timeout during Docker container cleanup")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup Docker containers: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
####### Dart Version - Complete evaluation runner #######
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
# Register signal handlers for graceful termination
|
||||
signal.signal(signal.SIGINT, signal_handler) # Handle Ctrl+C
|
||||
signal.signal(signal.SIGTERM, signal_handler) # Handle termination signal
|
||||
|
||||
try:
|
||||
args = config()
|
||||
|
||||
# 清理Docker容器
|
||||
# 清除上一次存留的docker 容器 自己跑的时候要留着
|
||||
cleanup_docker_containers()
|
||||
|
||||
# 清空cache目录 清除上一次下载的文件
|
||||
clear_cache_directory()
|
||||
|
||||
logger.info("Starting Dart evaluation runner...")
|
||||
|
||||
# save args to json in result_dir/action_space/observation_type/model/args.json
|
||||
path_to_args = os.path.join(
|
||||
args.result_dir,
|
||||
args.action_space,
|
||||
args.observation_type,
|
||||
args.model,
|
||||
"args.json",
|
||||
)
|
||||
os.makedirs(os.path.dirname(path_to_args), exist_ok=True)
|
||||
with open(path_to_args, "w", encoding="utf-8") as f:
|
||||
json.dump(vars(args), f, indent=4)
|
||||
|
||||
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
|
||||
test_all_meta = json.load(f)
|
||||
|
||||
if args.domain != "all":
|
||||
test_all_meta = {args.domain: test_all_meta[args.domain]}
|
||||
|
||||
test_file_list = get_unfinished(
|
||||
args.action_space,
|
||||
args.model,
|
||||
args.observation_type,
|
||||
args.result_dir,
|
||||
test_all_meta,
|
||||
)
|
||||
left_info = ""
|
||||
for domain in test_file_list:
|
||||
left_info += f"{domain}: {len(test_file_list[domain])}\n"
|
||||
logger.info(f"Left tasks:\n{left_info}")
|
||||
|
||||
get_result(
|
||||
args.action_space,
|
||||
args.model,
|
||||
args.observation_type,
|
||||
args.result_dir,
|
||||
test_all_meta,
|
||||
)
|
||||
test(args, test_file_list)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Main process received KeyboardInterrupt.")
|
||||
# Signal handler will take care of cleanup
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in main process: {e}", exc_info=True)
|
||||
# Also trigger cleanup for unhandled exceptions
|
||||
signal_handler(signal.SIGTERM, None)
|
||||
finally:
|
||||
# Final cleanup in case any environments or processes remain
|
||||
logger.info("Main process final cleanup...")
|
||||
for env in active_environments:
|
||||
if env is not None:
|
||||
try:
|
||||
logger.info(f"Closing environment in final cleanup...")
|
||||
env.close()
|
||||
logger.info(f"Environment closed successfully in final cleanup")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during final environment cleanup: {e}")
|
||||
|
||||
# First try gentle termination
|
||||
for p in processes:
|
||||
if p is not None and p.is_alive():
|
||||
try:
|
||||
logger.info(f"Terminating process {p.name}...")
|
||||
p.terminate()
|
||||
except Exception as e:
|
||||
logger.error(f"Error terminating process: {e}")
|
||||
|
||||
# Wait a moment for processes to terminate
|
||||
time.sleep(1)
|
||||
|
||||
# Then force kill if needed
|
||||
for p in processes:
|
||||
if p is not None and p.is_alive():
|
||||
try:
|
||||
logger.info(f"Force killing process {p.name}...")
|
||||
os.kill(p.pid, signal.SIGKILL)
|
||||
logger.info(f"Process {p.name} force killed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error force killing process: {e}")
|
||||
|
|
@ -0,0 +1,603 @@
|
|||
"""
|
||||
Script to run EvoCUA native agent model on OSWorld tasks.
|
||||
|
||||
export AWS_ACCESS_KEY_ID="xx"
|
||||
export AWS_SECRET_ACCESS_KEY="xx"
|
||||
export AWS_REGION="xx"
|
||||
export AWS_SECURITY_GROUP_ID="xx"
|
||||
export AWS_SUBNET_ID="xx"
|
||||
export OPENAI_API_KEY="xxxx"
|
||||
export OPENAI_BASE_URL="xxxx"
|
||||
|
||||
Example Usage (S2):
|
||||
python3 run_multienv_evocua.py \
|
||||
--headless \
|
||||
--provider_name aws \
|
||||
--observation_type screenshot \
|
||||
--model EvoCUA-S2 \
|
||||
--result_dir ./evocua_s2 \
|
||||
--test_all_meta_path evaluation_examples/test_nogdrive.json \
|
||||
--max_steps 50 \
|
||||
--num_envs 30 \
|
||||
--temperature 0.01 \
|
||||
--max_history_turns 4 \
|
||||
--coordinate_type relative \
|
||||
--resize_factor 32 \
|
||||
--prompt_style S2
|
||||
|
||||
|
||||
Example Usage (S1):
|
||||
python3 run_multienv_evocua.py \
|
||||
--headless \
|
||||
--provider_name aws \
|
||||
--observation_type screenshot \
|
||||
--model EvoCUA-S1 \
|
||||
--result_dir ./evocua_s1 \
|
||||
--test_all_meta_path evaluation_examples/test_nogdrive.json \
|
||||
--max_steps 50 \
|
||||
--num_envs 30 \
|
||||
--max_history_turns 3 \
|
||||
--coordinate_type qwen25 \
|
||||
--max_tokens 10240 \
|
||||
--resize_factor 28 \
|
||||
--prompt_style S1
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import signal
|
||||
import time
|
||||
from typing import List
|
||||
from multiprocessing import Process, Manager, Queue
|
||||
from multiprocessing import current_process
|
||||
import lib_run_single
|
||||
from desktop_env.desktop_env import DesktopEnv
|
||||
from mm_agents.evocua.evocua_agent import EvoCUAAgent
|
||||
|
||||
# Global variables for signal handling
|
||||
active_environments = []
|
||||
processes = []
|
||||
is_terminating = False
|
||||
|
||||
|
||||
# Thread-local storage for task context (works per-process in multiprocessing)
|
||||
import threading
|
||||
_task_context = threading.local()
|
||||
|
||||
def get_task_context():
|
||||
"""Get current task context from thread-local storage."""
|
||||
return getattr(_task_context, 'context', {'domain': None, 'example_id': None})
|
||||
|
||||
def set_task_context(domain: str, example_id: str):
|
||||
"""Set current task context in thread-local storage."""
|
||||
_task_context.context = {'domain': domain, 'example_id': example_id}
|
||||
|
||||
def clear_task_context():
|
||||
"""Clear current task context."""
|
||||
if hasattr(_task_context, 'context'):
|
||||
delattr(_task_context, 'context')
|
||||
|
||||
class TaskContextFilter(logging.Filter):
|
||||
"""Filter to add domain and example_id to log records."""
|
||||
def filter(self, record):
|
||||
ctx = get_task_context()
|
||||
domain = ctx.get('domain')
|
||||
example_id = ctx.get('example_id')
|
||||
if domain and example_id:
|
||||
record.domain = domain
|
||||
record.example_id = example_id
|
||||
# Add prefix to message
|
||||
if hasattr(record, 'msg') and isinstance(record.msg, str):
|
||||
if not record.msg.startswith(f"[{domain}/{example_id}]"):
|
||||
record.msg = f"[{domain}/{example_id}] {record.msg}"
|
||||
else:
|
||||
record.domain = domain or "N/A"
|
||||
record.example_id = example_id or "N/A"
|
||||
return True
|
||||
|
||||
# load the environment variables from .env file
|
||||
if os.path.exists(".env"):
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
def config() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run end-to-end evaluation with EvoCUAAgent"
|
||||
)
|
||||
|
||||
# environment config
|
||||
parser.add_argument("--path_to_vm", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--headless", action="store_true", help="Run in headless machine"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--action_space", type=str, default="pyautogui", help="Action type"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--observation_type",
|
||||
choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
default="screenshot",
|
||||
help="Observation type",
|
||||
)
|
||||
parser.add_argument("--sleep_after_execution", type=float, default=5.0)
|
||||
parser.add_argument("--max_steps", type=int, default=50)
|
||||
|
||||
# evaluation config
|
||||
parser.add_argument(
|
||||
"--test_config_base_dir", type=str, default="evaluation_examples"
|
||||
)
|
||||
|
||||
# lm config
|
||||
parser.add_argument("--model", type=str, default="evocua", help="Model name.")
|
||||
parser.add_argument("--temperature", type=float, default=0.0)
|
||||
parser.add_argument("--top_p", type=float, default=0.9)
|
||||
parser.add_argument("--max_tokens", type=int, default=32768)
|
||||
parser.add_argument("--stop_token", type=str, default=None)
|
||||
|
||||
parser.add_argument("--prompt_style", type=str, default="S2", choices=["S1", "S2"], help="Prompt style: 'S1' (structured reasoning) or 'S2' (tool calling)")
|
||||
parser.add_argument("--history_type", type=str, default="action_history", help="[S1] History type")
|
||||
parser.add_argument("--coordinate_type", type=str, default="relative", help="Coordinate type: relative, absolute, qwen25")
|
||||
parser.add_argument("--password", type=str, default="osworld-public-evaluation", help="VM Password")
|
||||
|
||||
# Unified History Parameter
|
||||
parser.add_argument("--max_history_turns", type=int, default=3, help="Number of history turns to include")
|
||||
parser.add_argument("--resize_factor", type=int, default=32, help="Image resize factor (S1: 28, S2: 32)")
|
||||
|
||||
# example config
|
||||
parser.add_argument("--domain", type=str, default="all")
|
||||
parser.add_argument(
|
||||
"--test_all_meta_path", type=str, default="evaluation_examples/test_nogdrive.json"
|
||||
)
|
||||
|
||||
# logging related
|
||||
parser.add_argument("--result_dir", type=str, default="./results")
|
||||
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel")
|
||||
parser.add_argument("--log_level", type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
|
||||
default='INFO', help="Set the logging level")
|
||||
# aws config
|
||||
parser.add_argument(
|
||||
"--region", type=str, default="us-east-1", help="AWS region for the VM"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--provider_name", type=str, default="aws", choices=["aws", "virtualbox", "vmware", "docker", "azure"], help="Provider name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--client_password", type=str, default="", help="Client password"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--screen_width", type=int, default=1920, help="Screen width"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--screen_height", type=int, default=1080, help="Screen height"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
args = config()
|
||||
|
||||
logger = logging.getLogger()
|
||||
log_level = getattr(logging, args.log_level.upper())
|
||||
logger.setLevel(log_level)
|
||||
|
||||
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
|
||||
file_handler = logging.FileHandler(
|
||||
os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||
)
|
||||
debug_handler = logging.FileHandler(
|
||||
os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||
)
|
||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||
|
||||
file_handler.setLevel(logging.INFO)
|
||||
debug_handler.setLevel(logging.DEBUG)
|
||||
stdout_handler.setLevel(log_level)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
|
||||
)
|
||||
file_handler.setFormatter(formatter)
|
||||
debug_handler.setFormatter(formatter)
|
||||
stdout_handler.setFormatter(formatter)
|
||||
|
||||
# Add task context filter to all handlers
|
||||
task_filter = TaskContextFilter()
|
||||
file_handler.addFilter(task_filter)
|
||||
debug_handler.addFilter(task_filter)
|
||||
stdout_handler.addFilter(task_filter)
|
||||
|
||||
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
||||
|
||||
logger.addHandler(file_handler)
|
||||
logger.addHandler(debug_handler)
|
||||
logger.addHandler(stdout_handler)
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
|
||||
def distribute_tasks(test_all_meta: dict) -> List[tuple]:
|
||||
all_tasks = []
|
||||
for domain, examples in test_all_meta.items():
|
||||
for example_id in examples:
|
||||
all_tasks.append((domain, example_id))
|
||||
return all_tasks
|
||||
|
||||
|
||||
def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: list):
|
||||
active_environments = []
|
||||
env = None
|
||||
try:
|
||||
REGION = args.region
|
||||
screen_size = (args.screen_width, args.screen_height)
|
||||
|
||||
# Determine snapshot based on provider
|
||||
snapshot_name = "init_state"
|
||||
if args.provider_name == "aws":
|
||||
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
|
||||
ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION].get((1920, 1080)))
|
||||
snapshot_name = ami_id
|
||||
|
||||
env = DesktopEnv(
|
||||
path_to_vm=args.path_to_vm,
|
||||
action_space=args.action_space,
|
||||
provider_name=args.provider_name,
|
||||
region=REGION,
|
||||
snapshot_name=snapshot_name,
|
||||
screen_size=screen_size,
|
||||
headless=args.headless,
|
||||
os_type="Ubuntu",
|
||||
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
enable_proxy=True,
|
||||
client_password=args.client_password
|
||||
)
|
||||
|
||||
active_environments.append(env)
|
||||
|
||||
logger.info(f"Process {current_process().name} started.")
|
||||
while True:
|
||||
try:
|
||||
item = task_queue.get(timeout=5)
|
||||
except Exception:
|
||||
break
|
||||
domain, example_id = item
|
||||
set_task_context(domain, example_id)
|
||||
try:
|
||||
config_file = os.path.join(
|
||||
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
|
||||
)
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
example = json.load(f)
|
||||
logger.info(f"[{current_process().name}][Domain]: {domain}")
|
||||
logger.info(f"[{current_process().name}][Example ID]: {example_id}")
|
||||
logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}")
|
||||
example_result_dir = os.path.join(
|
||||
args.result_dir,
|
||||
args.action_space,
|
||||
args.observation_type,
|
||||
args.model,
|
||||
domain,
|
||||
example_id,
|
||||
)
|
||||
os.makedirs(example_result_dir, exist_ok=True)
|
||||
|
||||
# Initialize EvoCUAAgent
|
||||
agent = EvoCUAAgent(
|
||||
model=args.model,
|
||||
max_tokens=args.max_tokens,
|
||||
top_p=args.top_p,
|
||||
temperature=args.temperature,
|
||||
action_space=args.action_space,
|
||||
observation_type=args.observation_type,
|
||||
max_steps=args.max_steps,
|
||||
prompt_style=args.prompt_style,
|
||||
max_history_turns=args.max_history_turns,
|
||||
screen_size=(args.screen_width, args.screen_height),
|
||||
coordinate_type=args.coordinate_type,
|
||||
password=args.password,
|
||||
resize_factor=args.resize_factor,
|
||||
)
|
||||
|
||||
try:
|
||||
lib_run_single.run_single_example_evocua(
|
||||
agent,
|
||||
env,
|
||||
example,
|
||||
args.max_steps,
|
||||
example["instruction"],
|
||||
args,
|
||||
example_result_dir,
|
||||
shared_scores,
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
try:
|
||||
env.controller.end_recording(
|
||||
os.path.join(example_result_dir, "recording.mp4")
|
||||
)
|
||||
except Exception as rec_e:
|
||||
logger.error(f"Failed to end recording: {rec_e}")
|
||||
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||
f.write(json.dumps({"Error": f"{domain}/{example_id} - {e}"}))
|
||||
f.write("\n")
|
||||
except Exception as e:
|
||||
logger.error(f"Task-level error in {current_process().name}: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
finally:
|
||||
clear_task_context()
|
||||
except Exception as e:
|
||||
logger.error(f"Process-level error in {current_process().name}: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
finally:
|
||||
logger.info(f"{current_process().name} cleaning up environment...")
|
||||
try:
|
||||
if env:
|
||||
env.close()
|
||||
logger.info(f"{current_process().name} environment closed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"{current_process().name} error during environment cleanup: {e}")
|
||||
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
"""Handle termination signals (SIGINT, SIGTERM) to gracefully shutdown environments."""
|
||||
global is_terminating, active_environments, processes
|
||||
|
||||
if is_terminating:
|
||||
return
|
||||
|
||||
is_terminating = True
|
||||
logger.info(f"Received signal {signum}. Gracefully shutting down...")
|
||||
|
||||
for env in active_environments:
|
||||
try:
|
||||
logger.info(f"Closing environment...")
|
||||
env.close()
|
||||
logger.info(f"Environment closed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing environment: {e}")
|
||||
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
try:
|
||||
logger.info(f"Sending termination signal to process {p.name}...")
|
||||
p.terminate()
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending termination signal to process: {e}")
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
try:
|
||||
logger.info(f"Forcefully terminating process {p.name}...")
|
||||
import signal as sig
|
||||
os.kill(p.pid, sig.SIGKILL)
|
||||
except Exception as e:
|
||||
logger.error(f"Error forcefully terminating process: {e}")
|
||||
|
||||
logger.info("Shutdown complete. Exiting.")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||
global processes
|
||||
logger.info("Args: %s", args)
|
||||
all_tasks = distribute_tasks(test_all_meta)
|
||||
logger.info(f"Total tasks: {len(all_tasks)}")
|
||||
with Manager() as manager:
|
||||
shared_scores = manager.list()
|
||||
task_queue = manager.Queue()
|
||||
for item in all_tasks:
|
||||
task_queue.put(item)
|
||||
num_envs = args.num_envs
|
||||
processes = []
|
||||
for i in range(num_envs):
|
||||
p = Process(
|
||||
target=run_env_tasks,
|
||||
args=(task_queue, args, shared_scores),
|
||||
name=f"EnvProcess-{i+1}"
|
||||
)
|
||||
p.daemon = True
|
||||
p.start()
|
||||
processes.append(p)
|
||||
logger.info(f"Started process {p.name} with PID {p.pid}")
|
||||
try:
|
||||
while True:
|
||||
alive_count = 0
|
||||
for idx, p in enumerate(processes):
|
||||
if not p.is_alive():
|
||||
logger.warning(f"Process {p.name} died, restarting...")
|
||||
new_p = Process(
|
||||
target=run_env_tasks,
|
||||
args=(task_queue, args, shared_scores),
|
||||
name=f"EnvProcess-Restart-{idx+1}"
|
||||
)
|
||||
new_p.daemon = True
|
||||
new_p.start()
|
||||
processes[idx] = new_p
|
||||
logger.info(f"Restarted process {new_p.name} with PID {new_p.pid}")
|
||||
else:
|
||||
alive_count += 1
|
||||
if task_queue.empty():
|
||||
logger.info("All tasks finished.")
|
||||
break
|
||||
if alive_count == 0:
|
||||
logger.error("All processes died, exiting.")
|
||||
break
|
||||
time.sleep(5)
|
||||
for p in processes:
|
||||
p.join()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True)
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
try:
|
||||
logger.info(f"Terminating process {p.name} due to error...")
|
||||
p.terminate()
|
||||
except Exception as term_e:
|
||||
logger.error(f"Error terminating process {p.name}: {term_e}")
|
||||
raise
|
||||
scores = list(shared_scores)
|
||||
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
|
||||
|
||||
|
||||
def get_unfinished(
|
||||
action_space, use_model, observation_type, result_dir, total_file_json
|
||||
):
|
||||
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||
|
||||
if not os.path.exists(target_dir):
|
||||
return total_file_json
|
||||
|
||||
finished = {}
|
||||
for domain in os.listdir(target_dir):
|
||||
finished[domain] = []
|
||||
domain_path = os.path.join(target_dir, domain)
|
||||
if os.path.isdir(domain_path):
|
||||
for example_id in os.listdir(domain_path):
|
||||
if example_id == "onboard":
|
||||
continue
|
||||
example_path = os.path.join(domain_path, example_id)
|
||||
if os.path.isdir(example_path):
|
||||
if "result.txt" not in os.listdir(example_path):
|
||||
for file in os.listdir(example_path):
|
||||
os.remove(os.path.join(example_path, file))
|
||||
else:
|
||||
finished[domain].append(example_id)
|
||||
|
||||
if not finished:
|
||||
return total_file_json
|
||||
|
||||
for domain, examples in finished.items():
|
||||
if domain in total_file_json:
|
||||
total_file_json[domain] = [
|
||||
x for x in total_file_json[domain] if x not in examples
|
||||
]
|
||||
|
||||
return total_file_json
|
||||
|
||||
|
||||
def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
|
||||
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||
if not os.path.exists(target_dir):
|
||||
print("New experiment, no result yet.")
|
||||
return None
|
||||
|
||||
all_result = []
|
||||
|
||||
for domain in os.listdir(target_dir):
|
||||
domain_path = os.path.join(target_dir, domain)
|
||||
if os.path.isdir(domain_path):
|
||||
for example_id in os.listdir(domain_path):
|
||||
example_path = os.path.join(domain_path, example_id)
|
||||
if os.path.isdir(example_path):
|
||||
if "result.txt" in os.listdir(example_path):
|
||||
try:
|
||||
all_result.append(
|
||||
float(
|
||||
open(
|
||||
os.path.join(example_path, "result.txt"), "r"
|
||||
).read()
|
||||
)
|
||||
)
|
||||
except:
|
||||
all_result.append(0.0)
|
||||
|
||||
if not all_result:
|
||||
print("New experiment, no result yet.")
|
||||
return None
|
||||
else:
|
||||
print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%")
|
||||
return all_result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
# Register signal handlers
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
try:
|
||||
args = config()
|
||||
|
||||
path_to_args = os.path.join(
|
||||
args.result_dir,
|
||||
args.action_space,
|
||||
args.observation_type,
|
||||
args.model,
|
||||
"args.json",
|
||||
)
|
||||
os.makedirs(os.path.dirname(path_to_args), exist_ok=True)
|
||||
with open(path_to_args, "w", encoding="utf-8") as f:
|
||||
json.dump(vars(args), f, indent=4)
|
||||
|
||||
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
|
||||
test_all_meta = json.load(f)
|
||||
|
||||
if args.domain != "all":
|
||||
test_all_meta = {args.domain: test_all_meta[args.domain]}
|
||||
|
||||
test_file_list = get_unfinished(
|
||||
args.action_space,
|
||||
args.model,
|
||||
args.observation_type,
|
||||
args.result_dir,
|
||||
test_all_meta,
|
||||
)
|
||||
left_info = ""
|
||||
for domain in test_file_list:
|
||||
left_info += f"{domain}: {len(test_file_list[domain])}\n"
|
||||
logger.info(f"Left tasks:\n{left_info}")
|
||||
|
||||
get_result(
|
||||
args.action_space,
|
||||
args.model,
|
||||
args.observation_type,
|
||||
args.result_dir,
|
||||
test_all_meta,
|
||||
)
|
||||
test(args, test_file_list)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Main process received KeyboardInterrupt.")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in main process: {e}", exc_info=True)
|
||||
signal_handler(signal.SIGTERM, None)
|
||||
finally:
|
||||
logger.info("Main process final cleanup...")
|
||||
for env in active_environments:
|
||||
if env is not None:
|
||||
try:
|
||||
logger.info("Closing environment in final cleanup...")
|
||||
env.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error during final environment cleanup: {e}")
|
||||
|
||||
for p in processes:
|
||||
if p is not None and p.is_alive():
|
||||
try:
|
||||
p.terminate()
|
||||
except Exception as e:
|
||||
logger.error(f"Error terminating process: {e}")
|
||||
|
||||
time.sleep(1)
|
||||
for p in processes:
|
||||
if p is not None and p.is_alive():
|
||||
try:
|
||||
os.kill(p.pid, signal.SIGKILL)
|
||||
except Exception as e:
|
||||
logger.error(f"Error force killing process: {e}")
|
||||
|
|
@ -0,0 +1,525 @@
|
|||
"""Run OSWorld evaluation using hosted GBOX service"""
|
||||
from __future__ import annotations
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import signal
|
||||
import time
|
||||
from typing import List
|
||||
from multiprocessing import Process, Manager
|
||||
from multiprocessing import current_process
|
||||
import lib_run_single
|
||||
from desktop_env.desktop_env import DesktopEnv
|
||||
from mm_agents.hosted_gbox_agent import HostedGboxAgent
|
||||
|
||||
# Global variables for signal handling
|
||||
active_environments = []
|
||||
processes = []
|
||||
is_terminating = False
|
||||
|
||||
# load the environment variables from .env file
|
||||
if os.path.exists(".env"):
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
# Logger Configs {{{ #
|
||||
def config() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run OSWorld evaluation with hosted GBOX service"
|
||||
)
|
||||
|
||||
# environment config
|
||||
parser.add_argument("--path_to_vm", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--headless", action="store_true", help="Run in headless machine"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--action_space", type=str, default="pyautogui", help="Action type"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--observation_type",
|
||||
choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
default="screenshot",
|
||||
help="Observation type",
|
||||
)
|
||||
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
|
||||
parser.add_argument("--max_steps", type=int, default=15)
|
||||
|
||||
# agent config
|
||||
parser.add_argument("--max_trajectory_length", type=int, default=3)
|
||||
parser.add_argument(
|
||||
"--test_config_base_dir", type=str, default="evaluation_examples"
|
||||
)
|
||||
|
||||
# Hosted GBOX service config
|
||||
parser.add_argument(
|
||||
"--gbox_service_url",
|
||||
type=str,
|
||||
default=os.getenv("GBOX_SERVICE_URL", "http://44.201.221.203:8000"),
|
||||
help="URL of hosted GBOX service"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gbox_service_api_key",
|
||||
type=str,
|
||||
default=os.getenv("GBOX_SERVICE_API_KEY"),
|
||||
help="API key for hosted GBOX service"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="us.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||
help="Claude model to use (default: Bedrock Sonnet 4.5)"
|
||||
)
|
||||
parser.add_argument("--max_tokens", type=int, default=1500)
|
||||
|
||||
# example config
|
||||
parser.add_argument("--domain", type=str, default="all")
|
||||
parser.add_argument(
|
||||
"--test_all_meta_path", type=str, default="evaluation_examples/test_all.json"
|
||||
)
|
||||
|
||||
# logging related
|
||||
parser.add_argument("--result_dir", type=str, default="./results_hosted_gbox")
|
||||
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel")
|
||||
parser.add_argument("--log_level", type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
|
||||
default='INFO', help="Set the logging level")
|
||||
|
||||
# aws config
|
||||
parser.add_argument(
|
||||
"--region", type=str, default="us-east-1", help="AWS region for the VM"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--provider_name", type=str, default="aws", help="Cloud provider name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--screen_width", type=int, default=1920, help="Screen width"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--screen_height", type=int, default=1080, help="Screen height"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--client_password",
|
||||
type=str,
|
||||
default=os.getenv("CLIENT_PASSWORD", "osworld-public-evaluation"),
|
||||
help="Client password (default: osworld-public-evaluation)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
# }}} Logger Configs #
|
||||
|
||||
def setup_logger(env_idx: int = None, result_dir: str = "./results_gbox", level: str = 'INFO') -> logging.Logger:
|
||||
"""Set up a logger for the current process.
|
||||
|
||||
Args:
|
||||
env_idx: Environment index for naming (None for main process)
|
||||
result_dir: Directory to store logs
|
||||
level: Logging level
|
||||
|
||||
Returns:
|
||||
Configured logger instance
|
||||
"""
|
||||
# Set log level
|
||||
numeric_level = getattr(logging, level.upper(), None)
|
||||
if not isinstance(numeric_level, int):
|
||||
raise ValueError(f'Invalid log level: {level}')
|
||||
|
||||
# Create logger
|
||||
if env_idx is not None:
|
||||
logger_name = f"osworld-worker-{env_idx}"
|
||||
else:
|
||||
logger_name = "osworld-main"
|
||||
|
||||
logger = logging.getLogger(logger_name)
|
||||
logger.setLevel(numeric_level)
|
||||
|
||||
# Remove existing handlers
|
||||
logger.handlers.clear()
|
||||
|
||||
# Create formatters and handlers
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
# Console handler
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(numeric_level)
|
||||
console_handler.setFormatter(formatter)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
# File handler
|
||||
os.makedirs(result_dir, exist_ok=True)
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
if env_idx is not None:
|
||||
log_file = os.path.join(result_dir, f"worker_{env_idx}_{timestamp}.log")
|
||||
else:
|
||||
log_file = os.path.join(result_dir, f"main_{timestamp}.log")
|
||||
|
||||
file_handler = logging.FileHandler(log_file)
|
||||
file_handler.setLevel(numeric_level)
|
||||
file_handler.setFormatter(formatter)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
logger = logging.getLogger("osworld-main")
|
||||
|
||||
|
||||
def check_completed_tasks(result_dir: str, test_all_meta: dict) -> List[str]:
|
||||
"""Check which tasks have already been completed.
|
||||
|
||||
Args:
|
||||
result_dir: Directory containing results
|
||||
test_all_meta: Dictionary of domain -> list of task IDs
|
||||
|
||||
Returns:
|
||||
List of completed task IDs (format: "domain/task_id")
|
||||
"""
|
||||
completed = []
|
||||
for domain, examples in test_all_meta.items():
|
||||
for example_id in examples:
|
||||
result_path = os.path.join(
|
||||
result_dir,
|
||||
"pyautogui",
|
||||
"screenshot",
|
||||
"claude-sonnet-4-5", # Model name from args
|
||||
domain,
|
||||
example_id,
|
||||
"result.txt"
|
||||
)
|
||||
if os.path.exists(result_path):
|
||||
completed.append(f"{domain}/{example_id}")
|
||||
logger.info(f"Task {domain}/{example_id} already completed (result found)")
|
||||
|
||||
return completed
|
||||
|
||||
|
||||
def report_current_results(target_dir: str) -> List[float]:
|
||||
"""Report current results from completed tasks.
|
||||
|
||||
Args:
|
||||
target_dir: Directory containing results
|
||||
|
||||
Returns:
|
||||
List of scores (0.0 or 1.0)
|
||||
"""
|
||||
all_result = []
|
||||
|
||||
for domain in os.listdir(target_dir):
|
||||
domain_path = os.path.join(target_dir, domain)
|
||||
if os.path.isdir(domain_path):
|
||||
for example_id in os.listdir(domain_path):
|
||||
example_path = os.path.join(domain_path, example_id)
|
||||
if os.path.isdir(example_path):
|
||||
if "result.txt" in os.listdir(example_path):
|
||||
try:
|
||||
with open(os.path.join(example_path, "result.txt"), "r") as f:
|
||||
all_result.append(float(f.read()))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read result for {domain}/{example_id}: {e}")
|
||||
all_result.append(0.0)
|
||||
|
||||
if not all_result:
|
||||
logger.info("New experiment, no results yet.")
|
||||
return None
|
||||
else:
|
||||
success_rate = sum(all_result) / len(all_result) * 100
|
||||
logger.info(f"Current Success Rate: {success_rate:.2f}% ({len(all_result)} tasks)")
|
||||
return all_result
|
||||
|
||||
|
||||
def distribute_tasks(test_all_meta: dict) -> List[tuple]:
|
||||
all_tasks = []
|
||||
for domain, examples in test_all_meta.items():
|
||||
for example_id in examples:
|
||||
all_tasks.append((domain, example_id))
|
||||
return all_tasks
|
||||
|
||||
|
||||
def process_signal_handler(signum, frame, env_idx):
|
||||
"""Signal handler for child processes to gracefully shut down their environments."""
|
||||
logger.info(f"Process {env_idx + 1} received signal {signum}. Shutting down...")
|
||||
|
||||
# Get the active_environments from the caller's frame
|
||||
local_vars = frame.f_locals
|
||||
active_environments = local_vars.get('active_environments', [])
|
||||
|
||||
# Close environment in the current process context
|
||||
for env in active_environments:
|
||||
if env is not None:
|
||||
try:
|
||||
logger.info(f"Process {env_idx + 1} closing environment...")
|
||||
env.close()
|
||||
logger.info(f"Process {env_idx + 1} environment closed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Process {env_idx + 1} error closing environment: {e}")
|
||||
|
||||
logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def run_env_tasks(task_queue, args: argparse.Namespace, shared_scores: list):
|
||||
"""Worker process that runs tasks from the queue using hosted GBOX service."""
|
||||
active_environments = []
|
||||
env = None
|
||||
try:
|
||||
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
|
||||
REGION = args.region
|
||||
screen_size = (args.screen_width, args.screen_height)
|
||||
ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)])
|
||||
|
||||
# Create environment
|
||||
env = DesktopEnv(
|
||||
path_to_vm=args.path_to_vm,
|
||||
action_space=args.action_space,
|
||||
provider_name=args.provider_name,
|
||||
region=REGION,
|
||||
snapshot_name=ami_id,
|
||||
screen_size=screen_size,
|
||||
headless=args.headless,
|
||||
os_type="Ubuntu",
|
||||
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
enable_proxy=True,
|
||||
client_password=args.client_password
|
||||
)
|
||||
active_environments.append(env)
|
||||
|
||||
# Get VM IP address - MCP server will handle public IP lookup if needed
|
||||
vm_ip = env.vm_ip
|
||||
logger.info(f"VM IP: {vm_ip}")
|
||||
|
||||
# Create hosted GBOX agent
|
||||
agent = HostedGboxAgent(
|
||||
server_url=args.gbox_service_url,
|
||||
api_key=args.gbox_service_api_key,
|
||||
vm_ip=vm_ip,
|
||||
platform="ubuntu",
|
||||
model=args.model,
|
||||
max_steps=args.max_steps,
|
||||
)
|
||||
|
||||
# Process tasks from queue
|
||||
while True:
|
||||
try:
|
||||
item = task_queue.get(timeout=5)
|
||||
except Exception:
|
||||
break
|
||||
|
||||
domain, example_id = item
|
||||
try:
|
||||
config_file = os.path.join(
|
||||
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
|
||||
)
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
example = json.load(f)
|
||||
|
||||
logger.info(f"[Domain]: {domain}")
|
||||
logger.info(f"[Example ID]: {example_id}")
|
||||
logger.info(f"[Instruction]: {example['instruction']}")
|
||||
|
||||
example_result_dir = os.path.join(
|
||||
args.result_dir,
|
||||
args.action_space,
|
||||
args.observation_type,
|
||||
args.model,
|
||||
domain,
|
||||
example_id,
|
||||
)
|
||||
os.makedirs(example_result_dir, exist_ok=True)
|
||||
|
||||
try:
|
||||
lib_run_single.run_single_example(
|
||||
agent,
|
||||
env,
|
||||
example,
|
||||
args.max_steps,
|
||||
example["instruction"],
|
||||
args,
|
||||
example_result_dir,
|
||||
shared_scores,
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"Exception {domain}/{example_id}: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
try:
|
||||
env.controller.end_recording(
|
||||
os.path.join(example_result_dir, "recording.mp4")
|
||||
)
|
||||
except Exception as rec_e:
|
||||
logger.error(f"Failed to end recording: {rec_e}")
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||
f.write(
|
||||
json.dumps(
|
||||
{"Error": f"{domain}/{example_id} - {e}"}
|
||||
)
|
||||
)
|
||||
f.write("\n")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing task: {e}", exc_info=True)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Worker received interrupt signal")
|
||||
except Exception as e:
|
||||
logger.error(f"Worker error: {e}", exc_info=True)
|
||||
finally:
|
||||
# Cleanup
|
||||
if env is not None:
|
||||
try:
|
||||
logger.info("Closing environment...")
|
||||
env.close()
|
||||
logger.info("Environment closed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing environment: {e}")
|
||||
|
||||
|
||||
def main_signal_handler(signum, frame):
|
||||
"""Signal handler for main process to gracefully shut down all child processes."""
|
||||
global is_terminating
|
||||
if is_terminating:
|
||||
logger.info("Already terminating, please wait...")
|
||||
return
|
||||
|
||||
is_terminating = True
|
||||
logger.info(f"Main process received signal {signum}. Shutting down all workers...")
|
||||
|
||||
# Terminate all child processes
|
||||
for idx, proc in enumerate(processes):
|
||||
if proc.is_alive():
|
||||
logger.info(f"Terminating worker process {idx + 1}...")
|
||||
proc.terminate()
|
||||
|
||||
# Wait for processes to finish with timeout
|
||||
timeout = 30
|
||||
start_time = time.time()
|
||||
for idx, proc in enumerate(processes):
|
||||
remaining_time = max(0, timeout - (time.time() - start_time))
|
||||
proc.join(timeout=remaining_time)
|
||||
if proc.is_alive():
|
||||
logger.warning(f"Worker {idx + 1} did not terminate gracefully, forcing...")
|
||||
proc.kill()
|
||||
proc.join()
|
||||
|
||||
logger.info("All workers terminated. Exiting.")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = config()
|
||||
|
||||
# Setup main logger
|
||||
logger = setup_logger(env_idx=None, result_dir=args.result_dir, level=args.log_level)
|
||||
|
||||
# Validate hosted service configuration
|
||||
if not args.gbox_service_url:
|
||||
logger.error("GBOX_SERVICE_URL not set (use --gbox_service_url or GBOX_SERVICE_URL env var)")
|
||||
sys.exit(1)
|
||||
|
||||
if not args.gbox_service_api_key:
|
||||
logger.error("GBOX_SERVICE_API_KEY not set (use --gbox_service_api_key or GBOX_SERVICE_API_KEY env var)")
|
||||
sys.exit(1)
|
||||
|
||||
logger.info(f"Using hosted GBOX service at: {args.gbox_service_url}")
|
||||
logger.info(f"Model: {args.model}")
|
||||
logger.info(f"Max steps: {args.max_steps}")
|
||||
logger.info(f"Number of parallel environments: {args.num_envs}")
|
||||
|
||||
# Setup signal handlers
|
||||
signal.signal(signal.SIGINT, main_signal_handler)
|
||||
signal.signal(signal.SIGTERM, main_signal_handler)
|
||||
|
||||
# Load test configuration
|
||||
logger.info(f"Loading test configuration from: {args.test_all_meta_path}")
|
||||
with open(args.test_all_meta_path, "r") as f:
|
||||
test_all_meta = json.load(f)
|
||||
|
||||
# Filter by domain if specified
|
||||
if args.domain != "all":
|
||||
if args.domain in test_all_meta:
|
||||
test_all_meta = {args.domain: test_all_meta[args.domain]}
|
||||
logger.info(f"Filtering to domain: {args.domain}")
|
||||
else:
|
||||
logger.error(f"Domain '{args.domain}' not found in test configuration")
|
||||
sys.exit(1)
|
||||
|
||||
# Check for completed tasks
|
||||
completed_tasks = check_completed_tasks(args.result_dir, test_all_meta)
|
||||
logger.info(f"Found {len(completed_tasks)} completed tasks")
|
||||
|
||||
# Distribute tasks
|
||||
all_tasks = distribute_tasks(test_all_meta)
|
||||
logger.info(f"Total tasks to run: {len(all_tasks)}")
|
||||
|
||||
# Filter out completed tasks
|
||||
all_tasks = [task for task in all_tasks if f"{task[0]}/{task[1]}" not in completed_tasks]
|
||||
logger.info(f"Tasks remaining after filtering completed: {len(all_tasks)}")
|
||||
|
||||
if not all_tasks:
|
||||
logger.info("No tasks to run. All tasks already completed.")
|
||||
|
||||
# Report current results
|
||||
target_dir = os.path.join(
|
||||
args.result_dir,
|
||||
args.action_space,
|
||||
args.observation_type,
|
||||
args.model if getattr(args, 'model_dir_name', None) is None else args.model_dir_name
|
||||
)
|
||||
if os.path.exists(target_dir):
|
||||
report_current_results(target_dir)
|
||||
|
||||
sys.exit(0)
|
||||
|
||||
# Create shared task queue
|
||||
manager = Manager()
|
||||
task_queue = manager.Queue()
|
||||
shared_scores = manager.list()
|
||||
|
||||
# Populate queue
|
||||
for task in all_tasks:
|
||||
task_queue.put(task)
|
||||
|
||||
# Start worker processes
|
||||
logger.info(f"Starting {args.num_envs} worker processes...")
|
||||
for env_idx in range(args.num_envs):
|
||||
proc = Process(
|
||||
target=run_env_tasks,
|
||||
args=(task_queue, args, shared_scores)
|
||||
)
|
||||
proc.start()
|
||||
processes.append(proc)
|
||||
logger.info(f"Started worker process {env_idx + 1} (PID: {proc.pid})")
|
||||
|
||||
# Wait for all processes to complete
|
||||
try:
|
||||
for idx, proc in enumerate(processes):
|
||||
proc.join()
|
||||
logger.info(f"Worker process {idx + 1} completed")
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Received interrupt, shutting down...")
|
||||
main_signal_handler(signal.SIGINT, None)
|
||||
|
||||
# Report final results
|
||||
logger.info("=" * 50)
|
||||
logger.info("EVALUATION COMPLETE")
|
||||
logger.info("=" * 50)
|
||||
|
||||
target_dir = os.path.join(
|
||||
args.result_dir,
|
||||
args.action_space,
|
||||
args.observation_type,
|
||||
args.model
|
||||
)
|
||||
|
||||
if os.path.exists(target_dir):
|
||||
final_results = report_current_results(target_dir)
|
||||
if final_results:
|
||||
success_rate = sum(final_results) / len(final_results) * 100
|
||||
logger.info(f"Final Success Rate: {success_rate:.2f}% ({len(final_results)} tasks)")
|
||||
|
||||
logger.info("Exiting...")
|
||||
|
|
@ -3,29 +3,34 @@
|
|||
|
||||
You should first host the OpenCUA model on your local machine or a server.
|
||||
|
||||
Command for OpenCUA-72B:
|
||||
```
|
||||
python run_multienv_opencua.py \
|
||||
--headless \
|
||||
--observation_type screenshot \
|
||||
--model OpenCUA-72B \
|
||||
--result_dir ./results\
|
||||
--test_all_meta_path evaluation_examples/test_nogdrive.json \
|
||||
--max_steps 100 \
|
||||
--num_envs 30 \
|
||||
--coordinate_type qwen25
|
||||
|
||||
```
|
||||
|
||||
|
||||
Command for OpenCUA-7B and OpenCUA-32B:
|
||||
```
|
||||
python run_multienv_opencua.py \
|
||||
--headless \
|
||||
--observation_type screenshot \
|
||||
--model OpenCUA-32B \
|
||||
--result_dir ./results --test_all_meta_path evaluation_examples/test_all_no_gdrive.json \
|
||||
--max_steps 100 \
|
||||
--num_envs 30 \
|
||||
--coordinate_type qwen25
|
||||
```
|
||||
|
||||
Command for OpenCUA-Qwen2-7B and OpenCUA-A3B:
|
||||
```
|
||||
python run_multienv_opencua.py \
|
||||
--headless \
|
||||
--observation_type screenshot \
|
||||
--model OpenCUA-A3B \
|
||||
--result_dir ./results \
|
||||
--result_dir ./results\
|
||||
--test_all_meta_path evaluation_examples/test_nogdrive.json \
|
||||
--max_steps 100 \
|
||||
--num_envs 10 \
|
||||
--coordinate_type relative
|
||||
--num_envs 30 \
|
||||
--coordinate_type qwen25 \
|
||||
--use_old_sys_prompt
|
||||
|
||||
```
|
||||
|
||||
"""
|
||||
|
|
@ -44,7 +49,7 @@ from multiprocessing import Process, Manager
|
|||
from multiprocessing import current_process
|
||||
import lib_run_single
|
||||
from desktop_env.desktop_env import DesktopEnv
|
||||
from mm_agents.opencua_agent import OpenCUAAgent
|
||||
from mm_agents.opencua import OpenCUAAgent
|
||||
|
||||
# Global variables for signal handling
|
||||
active_environments = []
|
||||
|
|
@ -76,8 +81,8 @@ def config() -> argparse.Namespace:
|
|||
default="screenshot",
|
||||
help="Observation type",
|
||||
)
|
||||
parser.add_argument("--sleep_after_execution", type=float, default=3.0)
|
||||
parser.add_argument("--max_steps", type=int, default=15)
|
||||
parser.add_argument("--sleep_after_execution", type=float, default=5.0)
|
||||
parser.add_argument("--max_steps", type=int, default=100)
|
||||
|
||||
# evaluation config
|
||||
parser.add_argument(
|
||||
|
|
@ -85,7 +90,7 @@ def config() -> argparse.Namespace:
|
|||
)
|
||||
|
||||
# lm config
|
||||
parser.add_argument("--model", type=str, default="opencua")
|
||||
parser.add_argument("--model", type=str, default=None)
|
||||
parser.add_argument("--temperature", type=float, default=0)
|
||||
parser.add_argument("--top_p", type=float, default=0.9)
|
||||
parser.add_argument("--max_tokens", type=int, default=2048)
|
||||
|
|
@ -94,13 +99,14 @@ def config() -> argparse.Namespace:
|
|||
# OpenCUAagent config
|
||||
parser.add_argument("--cot_level", type=str, default="l2", help="CoT version: l1, l2, l3. Default is l2 includes 'thought' and 'action'")
|
||||
parser.add_argument("--history_type", type=str, default="action_history", help="Use action to represent history steps", choices=["action_history", "thought_history", "observation_history"])
|
||||
parser.add_argument("--coordinate_type", type=str, default="relative", help="Type of coordinate: Qwen2-VL or Kimi-VL based models use 'relative'; Qwen2.5-VL based models use 'qwen25'", choices=["relative", "qwen25"])
|
||||
parser.add_argument("--coordinate_type", type=str, default="qwen25", help="Type of coordinate: Qwen2-VL or Kimi-VL based models use 'relative'; Qwen2.5-VL based models use 'qwen25'", choices=["relative", "qwen25"])
|
||||
parser.add_argument("--max_image_history_length", type=int, default=3, help="The max number of images in the history.")
|
||||
|
||||
parser.add_argument("--use_old_sys_prompt", action="store_true", help="Use the old system prompt for OpenCUA-7B and OpenCUA-32B")
|
||||
|
||||
# example config
|
||||
parser.add_argument("--domain", type=str, default="all")
|
||||
parser.add_argument(
|
||||
"--test_all_meta_path", type=str, default="evaluation_examples/test_all.json"
|
||||
"--test_all_meta_path", type=str, default="evaluation_examples/test_nogdrive.json"
|
||||
)
|
||||
|
||||
# logging related
|
||||
|
|
@ -124,6 +130,9 @@ def config() -> argparse.Namespace:
|
|||
parser.add_argument(
|
||||
"--screen_height", type=int, default=1080, help="Screen height"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--password", type=str, default="osworld-public-evaluation", help="The password for the computer if needed"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
|
@ -253,6 +262,9 @@ def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: li
|
|||
screen_size=(args.screen_width, args.screen_height),
|
||||
coordinate_type=args.coordinate_type,
|
||||
max_image_history_length=args.max_image_history_length,
|
||||
max_steps=args.max_steps,
|
||||
use_old_sys_prompt=args.use_old_sys_prompt,
|
||||
password=args.password,
|
||||
)
|
||||
try:
|
||||
lib_run_single.run_single_example_opencua(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,907 @@
|
|||
"""
|
||||
OS-Symphony Official Evaluation Script
|
||||
|
||||
This script serves as the official evaluation entry point for OS-Symphony.
|
||||
It handles the setup of the desktop environment, agent initialization, and
|
||||
execution of evaluation tasks.
|
||||
|
||||
For detailed evaluation metrics, configuration options, and usage instructions,
|
||||
please refer to the official repository:
|
||||
https://github.com/OS-Copilot/OS-Symphony
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import signal
|
||||
import time
|
||||
from multiprocessing import Process, Manager, current_process, Queue
|
||||
|
||||
from mm_agents.os_symphony.agents.os_symphony import OSSymphony
|
||||
from mm_agents.os_symphony.agents.os_aci import OSACI
|
||||
import lib_run_single
|
||||
# Modify desktop_env, add a new function 'start'
|
||||
from desktop_env.desktop_env_os_symphony import DesktopEnv as OSWorldDesktopEnv
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
# only for WAA
|
||||
def prepare_worker_vm_paths(base_golden_path: str, worker_idx: int):
|
||||
# remove the '/' at the end
|
||||
base_golden_path = base_golden_path.rstrip(os.sep)
|
||||
|
||||
# get parent directory (like /nvme/yangbowen/vm_stroage/waa)
|
||||
parent_dir = os.path.dirname(base_golden_path)
|
||||
|
||||
# define the path of this worker
|
||||
worker_storage_path = os.path.join(parent_dir, f"storage_{worker_idx}")
|
||||
worker_backup_path = os.path.join(parent_dir, f"storage_{worker_idx}_backup")
|
||||
|
||||
return worker_storage_path, worker_backup_path
|
||||
|
||||
|
||||
# only for WAA
|
||||
def initialize_worker_files(golden_path: str, worker_backup_path: str, worker_storage_path: str):
|
||||
"""
|
||||
Initialize worker. If backup doesn't exist, then replicate from golden path.
|
||||
"""
|
||||
if not os.path.exists(golden_path):
|
||||
raise FileNotFoundError(f"Golden VM path not found: {golden_path}")
|
||||
|
||||
if not os.path.exists(worker_backup_path):
|
||||
logger.info(f"Initializing backup for worker from {golden_path} to {worker_backup_path} ...")
|
||||
try:
|
||||
os.makedirs(os.path.dirname(worker_backup_path), exist_ok=True)
|
||||
|
||||
if os.path.isdir(golden_path):
|
||||
subprocess.check_call(['cp', '-r', '--sparse=always', golden_path, worker_backup_path])
|
||||
else:
|
||||
subprocess.check_call(['cp', '--sparse=always', golden_path, worker_backup_path])
|
||||
|
||||
logger.info(f"Backup initialization complete for {worker_backup_path}")
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"Failed to copy golden image to backup using cp: {e}")
|
||||
raise e
|
||||
else:
|
||||
logger.info(f"Worker backup already exists at {worker_backup_path}, skipping copy.")
|
||||
|
||||
if not os.path.exists(worker_storage_path):
|
||||
os.makedirs(worker_storage_path, exist_ok=True)
|
||||
|
||||
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
formatter = logging.Formatter(
|
||||
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
|
||||
)
|
||||
|
||||
# Set up Stdout handler
|
||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||
stdout_handler.setLevel(logging.INFO)
|
||||
stdout_handler.setFormatter(formatter)
|
||||
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
||||
logger.addHandler(stdout_handler)
|
||||
|
||||
# Set up File Handler
|
||||
# file_handler = logging.FileHandler(filename="log.txt")
|
||||
# file_handler.setLevel(logging.ERROR)
|
||||
# file_handler.setFormatter(formatter)
|
||||
# file_handler.addFilter(logging.Filter("desktopenv"))
|
||||
# logger.addHandler(file_handler)
|
||||
|
||||
# Logger Configs
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
|
||||
# Global variables for signal handling
|
||||
active_environments = []
|
||||
processes = []
|
||||
is_terminating = False
|
||||
|
||||
|
||||
def distribute_tasks(test_all_meta: dict) -> list:
|
||||
all_tasks = []
|
||||
for domain, examples in test_all_meta.items():
|
||||
for example_id in examples:
|
||||
all_tasks.append((domain, example_id))
|
||||
return all_tasks
|
||||
|
||||
|
||||
def process_signal_handler(signum, frame, env_idx):
|
||||
logger.info(f"Process {env_idx + 1} received signal {signum}. Shutting down...")
|
||||
local_vars = frame.f_locals
|
||||
active_environments = local_vars.get("active_environments", [])
|
||||
for env in active_environments:
|
||||
if env is not None:
|
||||
try:
|
||||
logger.info(f"Process {env_idx + 1} closing environment...")
|
||||
env.close()
|
||||
logger.info(f"Process {env_idx + 1} environment closed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Process {env_idx + 1} error closing environment: {e}")
|
||||
logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def run_env_tasks(
|
||||
task_queue: Queue,
|
||||
args: argparse.Namespace,
|
||||
shared_scores: list,
|
||||
engine_params_for_orchestrator,
|
||||
engine_params_for_grounder,
|
||||
engine_params_for_coder,
|
||||
engine_params_for_memoryer,
|
||||
engine_params_for_searcher,
|
||||
worker_id: int,
|
||||
):
|
||||
active_environments = []
|
||||
env = None
|
||||
search_env = None
|
||||
try:
|
||||
# Use IMAGE_ID_MAP for AWS provider to get snapshot_name
|
||||
snapshot_name = None
|
||||
region = getattr(args, "region", "us-east-1")
|
||||
platform = 'linux'
|
||||
screen_size = (args.screen_width, args.screen_height)
|
||||
|
||||
if "osworld" in args.benchmark:
|
||||
if args.provider_name == "aws":
|
||||
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
|
||||
ami_id = IMAGE_ID_MAP[region].get(screen_size, IMAGE_ID_MAP[region][(1920, 1080)])
|
||||
env = OSWorldDesktopEnv(
|
||||
path_to_vm=args.path_to_vm,
|
||||
action_space=args.action_space,
|
||||
provider_name=args.provider_name,
|
||||
region=region,
|
||||
snapshot_name=ami_id,
|
||||
screen_size=screen_size,
|
||||
headless=args.headless,
|
||||
os_type="Ubuntu",
|
||||
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"]
|
||||
)
|
||||
elif args.provider_name == "docker":
|
||||
env = OSWorldDesktopEnv(
|
||||
path_to_vm=args.path_to_vm,
|
||||
action_space=args.action_space,
|
||||
provider_name=args.provider_name,
|
||||
region=region,
|
||||
snapshot_name=snapshot_name,
|
||||
screen_size=screen_size,
|
||||
headless=args.headless,
|
||||
os_type="Ubuntu",
|
||||
require_a11y_tree=args.observation_type
|
||||
in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
enable_proxy=True,
|
||||
client_password=getattr(args, "client_password", "")
|
||||
)
|
||||
else:
|
||||
raise Exception("Don't support other providers!")
|
||||
|
||||
env.start()
|
||||
|
||||
if args.provider_name == "aws":
|
||||
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
|
||||
ami_id = IMAGE_ID_MAP[region].get(screen_size, IMAGE_ID_MAP[region][(1920, 1080)])
|
||||
search_env = OSWorldDesktopEnv(
|
||||
path_to_vm=args.path_to_vm,
|
||||
action_space=args.action_space,
|
||||
provider_name=args.provider_name,
|
||||
region=region,
|
||||
snapshot_name=ami_id,
|
||||
screen_size=screen_size,
|
||||
headless=args.headless,
|
||||
os_type="Ubuntu",
|
||||
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"]
|
||||
)
|
||||
elif args.provider_name == "docker":
|
||||
search_env = OSWorldDesktopEnv(
|
||||
path_to_vm=args.path_to_vm,
|
||||
action_space=args.action_space,
|
||||
provider_name=args.provider_name,
|
||||
region=region,
|
||||
snapshot_name=snapshot_name,
|
||||
screen_size=screen_size,
|
||||
headless=args.headless,
|
||||
os_type="Ubuntu",
|
||||
require_a11y_tree=args.observation_type
|
||||
in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
enable_proxy=True,
|
||||
client_password=getattr(args, "client_password", "")
|
||||
)
|
||||
else:
|
||||
raise Exception("Don't support other providers!")
|
||||
|
||||
engine_params_for_ocr = copy.deepcopy(engine_params_for_orchestrator)
|
||||
engine_params_for_ocr["agent_name"] = "ocr"
|
||||
os_aci = OSACI(
|
||||
env=env,
|
||||
search_env=search_env,
|
||||
platform=platform,
|
||||
client_password=args.client_password,
|
||||
engine_params_for_ocr=engine_params_for_ocr,
|
||||
engine_params_for_grounder=engine_params_for_grounder,
|
||||
engine_params_for_coder=engine_params_for_coder,
|
||||
engine_params_for_searcher=engine_params_for_searcher,
|
||||
screen_width=args.screen_width,
|
||||
screen_height=args.screen_height,
|
||||
)
|
||||
agent = OSSymphony(
|
||||
engine_params_for_orchestrator,
|
||||
engine_params_for_memoryer,
|
||||
os_aci,
|
||||
platform=platform,
|
||||
client_password=args.client_password,
|
||||
max_trajectory_length=args.max_trajectory_length,
|
||||
enable_reflection=args.enable_reflection,
|
||||
)
|
||||
|
||||
active_environments.append(env)
|
||||
active_environments.append(search_env)
|
||||
logger.info(f"Process {current_process().name} started.")
|
||||
while True:
|
||||
try:
|
||||
item = task_queue.get(timeout=5)
|
||||
except Exception:
|
||||
break
|
||||
domain, example_id = item
|
||||
try:
|
||||
config_file = os.path.join(
|
||||
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
|
||||
)
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
example = json.load(f)
|
||||
|
||||
if args.enable_rewrite_instruction and "rewritten_instruction" in example:
|
||||
instruction = example["rewritten_instruction"]
|
||||
else:
|
||||
instruction = example["instruction"]
|
||||
|
||||
example_result_dir = os.path.join(
|
||||
args.result_dir,
|
||||
domain,
|
||||
example_id
|
||||
)
|
||||
os.makedirs(example_result_dir, exist_ok=True)
|
||||
logger.info(f"[{current_process().name}][Domain]: {domain}")
|
||||
logger.info(f"[{current_process().name}][Example ID]: {example_id}")
|
||||
logger.info(f"[{current_process().name}][Instruction]: {instruction}")
|
||||
try:
|
||||
lib_run_single.run_single_example_os_symphony(
|
||||
agent,
|
||||
env,
|
||||
example,
|
||||
args.max_steps,
|
||||
instruction,
|
||||
args,
|
||||
example_result_dir,
|
||||
shared_scores,
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
logger.error(
|
||||
f"Exception in {current_process().name} {domain}/{example_id}: {e}"
|
||||
)
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
with open(os.path.join(os.path.dirname(example_result_dir), "error.jsonl"), "a") as f:
|
||||
f.write(json.dumps({"Error": f"{domain}/{example_id} - {e}"}))
|
||||
f.write("\n")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Task-level error in {current_process().name}: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
except Exception as e:
|
||||
logger.error(f"Process-level error in {current_process().name}: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
finally:
|
||||
logger.info(f"{current_process().name} cleaning up environment...")
|
||||
try:
|
||||
if env:
|
||||
env.close()
|
||||
logger.info(f"{current_process().name} environment closed successfully")
|
||||
if search_env:
|
||||
search_env.close()
|
||||
logger.info(f"{current_process().name} searcher environment closed successfully")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"{current_process().name} error during environment cleanup: {e}"
|
||||
)
|
||||
|
||||
# exit function
|
||||
def signal_handler(signum, frame):
|
||||
global is_terminating, active_environments, processes
|
||||
if is_terminating:
|
||||
return
|
||||
is_terminating = True
|
||||
logger.info(f"Received signal {signum}. Gracefully shutting down...")
|
||||
for env in active_environments:
|
||||
try:
|
||||
logger.info(f"Closing environment...")
|
||||
env.close()
|
||||
logger.info(f"Environment closed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing environment: {e}")
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
try:
|
||||
logger.info(f"Sending termination signal to process {p.name}...")
|
||||
p.terminate()
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending termination signal to process: {e}")
|
||||
time.sleep(1)
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
try:
|
||||
logger.info(f"Forcefully terminating process {p.name}...")
|
||||
import signal as sig
|
||||
|
||||
os.kill(p.pid, sig.SIGKILL)
|
||||
except Exception as e:
|
||||
logger.error(f"Error forcefully terminating process: {e}")
|
||||
logger.info("Shutdown complete. Exiting.")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def config() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run end-to-end evaluation on the benchmark"
|
||||
)
|
||||
|
||||
# environment config
|
||||
parser.add_argument("--path_to_vm", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--provider_name",
|
||||
type=str,
|
||||
default="vmware",
|
||||
help="Virtualization provider (vmware, docker, aws, azure, gcp, virtualbox)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--headless", action="store_true", help="Run in headless machine"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--action_space", type=str, default="pyautogui", help="Action type"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--observation_type",
|
||||
choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
default="screenshot",
|
||||
help="Observation type",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_envs",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of environments to run in parallel",
|
||||
)
|
||||
parser.add_argument("--screen_width", type=int, default=1920, help="Main environment's width")
|
||||
parser.add_argument("--screen_height", type=int, default=1080, help="Main environment's height")
|
||||
parser.add_argument("--sleep_after_execution", type=float, default=1.0)
|
||||
parser.add_argument("--max_steps", type=int, default=15)
|
||||
|
||||
# Benchmark
|
||||
parser.add_argument("--benchmark", type=str, default="osworld", help="osworld / waa / macos")
|
||||
|
||||
parser.add_argument("--domain", type=str, default="all")
|
||||
parser.add_argument(
|
||||
"--test_all_meta_path", type=str, default="evaluation_examples/osworld/test_all.json"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test_config_base_dir", type=str, default="evaluation_examples"
|
||||
)
|
||||
parser.add_argument("--result_dir", type=str, default="./results")
|
||||
|
||||
parser.add_argument(
|
||||
"--region", type=str, default="us-east-1", help="AWS region for the VM for OSWorld."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--client_password", type=str, default="password", help="Client password for OSWorld. Aws is 'osworld-public-evaluation', other is 'password'"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--proxy", type=str, default="http://10.1.8.5:23128", help="Important! Proxy setting, format should be http://<ip>:<port>, if no-use, set it empty"
|
||||
)
|
||||
|
||||
# agent config
|
||||
parser.add_argument("--max_trajectory_length", type=int, default=8)
|
||||
parser.add_argument("--enable_reflection", action="store_true", default=False)
|
||||
parser.add_argument("--enable_rewrite_instruction", action="store_true", default=False)
|
||||
parser.add_argument(
|
||||
"--tool_config",
|
||||
type=str,
|
||||
help="The path of tool config yaml"
|
||||
)
|
||||
|
||||
# generator-agent config
|
||||
parser.add_argument("--orchestrator_provider", type=str, default="openai")
|
||||
parser.add_argument("--orchestrator_model", type=str, default="gpt-5")
|
||||
parser.add_argument(
|
||||
"--orchestrator_url",
|
||||
type=str,
|
||||
default="",
|
||||
help="The URL of the main orchestrator model API.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--orchestrator_api_key",
|
||||
type=str,
|
||||
default="",
|
||||
help="The API key of the main orchestrator model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--orchestrator_temperature",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Temperature to fix the orchestrator model at (e.g. o3 can only be run with 1.0)",
|
||||
)
|
||||
parser.add_argument("--orchestrator_keep_first_image", action="store_true", default=False, help="Whether keep the first image(first state) in the orchestrator agent")
|
||||
|
||||
# code-agent config
|
||||
parser.add_argument("--coder_provider", type=str, default="openai")
|
||||
parser.add_argument("--coder_model", type=str, default="gpt-4o")
|
||||
parser.add_argument(
|
||||
"--coder_url",
|
||||
type=str,
|
||||
default="",
|
||||
help="The URL of the coder model API.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--coder_api_key",
|
||||
type=str,
|
||||
default="",
|
||||
help="The API key of the coder model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--coder_temperature",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Temperature to fix the coder model at (e.g. o3 can only be run with 1.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--coder_budget",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Max inner loop steps of coder agent",
|
||||
)
|
||||
|
||||
# reflection-memory agent config
|
||||
parser.add_argument("--memoryer_provider", type=str, default="openai")
|
||||
parser.add_argument("--memoryer_model", type=str, default="gpt-4o")
|
||||
parser.add_argument(
|
||||
"--memoryer_url",
|
||||
type=str,
|
||||
default="",
|
||||
help="The URL of the memoryer model API.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--memoryer_api_key",
|
||||
type=str,
|
||||
default="",
|
||||
help="The API key of the memoryer model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--memoryer_temperature",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Temperature to fix the memoryer model at (e.g. o3 can only be run with 1.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--memoryer_max_images",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Max images of memoryer model"
|
||||
)
|
||||
|
||||
# search model config
|
||||
parser.add_argument("--searcher_provider", type=str, default="openai")
|
||||
parser.add_argument("--searcher_model", type=str, default="gpt-4o")
|
||||
parser.add_argument(
|
||||
"--searcher_url",
|
||||
type=str,
|
||||
default="",
|
||||
help="The URL of the searcher model API.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--searcher_api_key",
|
||||
type=str,
|
||||
default="",
|
||||
help="The API key of the searcher model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--searcher_temperature",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Temperature to fix searcher model at (e.g. o3 can only be run with 1.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--searcher_type",
|
||||
type=str,
|
||||
default="vlm",
|
||||
help="Type of search agent, vlm/llm(all in search action), default is vlm",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--searcher_engine",
|
||||
type=str,
|
||||
default="google",
|
||||
help="Type of search engine, google / duckduckgo",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--searcher_budget",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Max inner loop steps of search agent",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--searcher_screen_width",
|
||||
type=int,
|
||||
default=1920,
|
||||
help="Search enviroment's width",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--searcher_screen_height",
|
||||
type=int,
|
||||
default=1080,
|
||||
help="Search enviroment's height",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--searcher_path_to_vm",
|
||||
type=str,
|
||||
default="/nvme/yangbowen/vm_stroage/osworld/Ubuntu.qcow2",
|
||||
help="Searcher Env VM's path (OSWorld'VM Path)",
|
||||
)
|
||||
|
||||
# grounding model config, temperature is 0 with hardcode
|
||||
parser.add_argument(
|
||||
"--grounder_provider",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The provider for the grounder model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grounder_url", type=str, required=True, help="The URL of the grounder model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grounder_api_key",
|
||||
type=str,
|
||||
default="",
|
||||
help="The API key of the grounder model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grounder_model",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The model name for the grounder model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grounding_width",
|
||||
type=int,
|
||||
required=True,
|
||||
help="Width of screenshot image after processor rescaling",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grounding_height",
|
||||
type=int,
|
||||
required=True,
|
||||
help="Height of screenshot image after processor rescaling",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grounding_smart_resize",
|
||||
action="store_true", default=False,
|
||||
help="UI-TARS-1.5 and ScaleCUA needs smart resize, if this set, grounding_width and grounding_height is no use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grounder_zoom_in_time",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Zoom-in times for grounder agent, aiming to enhance grounding ability.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp_name",
|
||||
type=str,
|
||||
default="",
|
||||
help="Experiment Name",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||
global processes
|
||||
logger.info("Args: %s", args)
|
||||
all_tasks = distribute_tasks(test_all_meta)
|
||||
logger.info(f"Total tasks: {len(all_tasks)}")
|
||||
|
||||
engine_params_for_orchestrator = {
|
||||
"engine_type": args.orchestrator_provider,
|
||||
"model": args.orchestrator_model,
|
||||
"base_url": getattr(args, "orchestrator_url", ""),
|
||||
"api_key": getattr(args, "orchestrator_api_key", ""),
|
||||
"temperature": getattr(args, "orchestrator_temperature", None),
|
||||
"tool_config": args.tool_config,
|
||||
"keep_first_image": args.orchestrator_keep_first_image,
|
||||
"agent_name": "orchestrator"
|
||||
}
|
||||
|
||||
|
||||
engine_params_for_grounder = {
|
||||
"engine_type": args.grounder_provider,
|
||||
"model": args.grounder_model,
|
||||
"base_url": getattr(args, "grounder_url", ""),
|
||||
"api_key": getattr(args, "grounder_api_key", ""),
|
||||
"grounding_width": args.grounding_width,
|
||||
"grounding_height": args.grounding_height,
|
||||
"grounding_smart_resize": args.grounding_smart_resize,
|
||||
"grounder_zoom_in_time": args.grounder_zoom_in_time,
|
||||
"agent_name": "grounder"
|
||||
}
|
||||
|
||||
engine_params_for_coder = {
|
||||
"engine_type": args.coder_provider,
|
||||
"model": args.coder_model,
|
||||
"base_url": getattr(args, "coder_url", ""),
|
||||
"api_key": getattr(args, "coder_api_key", ""),
|
||||
"temperature": getattr(args, "coder_temperature", None),
|
||||
"budget": args.coder_budget,
|
||||
"agent_name": "coder"
|
||||
}
|
||||
|
||||
engine_params_for_memoryer = {
|
||||
"engine_type": args.memoryer_provider,
|
||||
"model": args.memoryer_model,
|
||||
"base_url": getattr(args, "memoryer_url", ""),
|
||||
"api_key": getattr(args, "memoryer_api_key", ""),
|
||||
"temperature": getattr(args, "memoryer_temperature", None),
|
||||
"max_images": args.memoryer_max_images,
|
||||
"agent_name": "memoryer"
|
||||
}
|
||||
|
||||
engine_params_for_searcher = {
|
||||
"engine_type": args.searcher_provider,
|
||||
"model": args.searcher_model,
|
||||
"base_url": getattr(args, "searcher_url", ""),
|
||||
"api_key": getattr(args, "searcher_api_key", ""),
|
||||
"temperature": getattr(args, "searcher_temperature", None),
|
||||
"budget": args.searcher_budget,
|
||||
"type": args.searcher_type,
|
||||
"engine": args.searcher_engine,
|
||||
"agent_name": "searcher"
|
||||
}
|
||||
|
||||
# --- Initialize Worker Path ---
|
||||
num_envs = args.num_envs
|
||||
# only for waa
|
||||
if args.benchmark == "waa":
|
||||
logger.info(f"[WindowsAgentArena] Initializing storage for {num_envs} workers from golden image: {args.path_to_vm}")
|
||||
for i in range(num_envs):
|
||||
s_path, b_path = prepare_worker_vm_paths(args.path_to_vm, i)
|
||||
initialize_worker_files(args.path_to_vm, b_path, s_path)
|
||||
|
||||
with Manager() as manager:
|
||||
shared_scores = manager.list()
|
||||
task_queue = manager.Queue()
|
||||
for item in all_tasks:
|
||||
task_queue.put(item)
|
||||
processes = []
|
||||
for worker_id in range(num_envs):
|
||||
p = Process(
|
||||
target=run_env_tasks,
|
||||
args=(
|
||||
task_queue,
|
||||
args,
|
||||
shared_scores,
|
||||
engine_params_for_orchestrator,
|
||||
engine_params_for_grounder,
|
||||
engine_params_for_coder,
|
||||
engine_params_for_memoryer,
|
||||
engine_params_for_searcher,
|
||||
worker_id
|
||||
),
|
||||
name=f"EnvProcess-{worker_id+1}",
|
||||
)
|
||||
p.daemon = True
|
||||
p.start()
|
||||
processes.append(p)
|
||||
logger.info(f"Started process {p.name} with PID {p.pid}")
|
||||
try:
|
||||
while True:
|
||||
alive_count = 0
|
||||
for idx, p in enumerate(processes):
|
||||
if not p.is_alive():
|
||||
logger.warning(f"Process {p.name} died, restarting...")
|
||||
new_p = Process(
|
||||
target=run_env_tasks,
|
||||
args=(
|
||||
task_queue,
|
||||
args,
|
||||
shared_scores,
|
||||
engine_params_for_orchestrator,
|
||||
engine_params_for_grounder,
|
||||
engine_params_for_coder,
|
||||
engine_params_for_memoryer,
|
||||
engine_params_for_searcher,
|
||||
idx
|
||||
),
|
||||
name=f"EnvProcess-Restart-{idx+1}",
|
||||
)
|
||||
new_p.daemon = True
|
||||
new_p.start()
|
||||
processes[idx] = new_p
|
||||
logger.info(
|
||||
f"Restarted process {new_p.name} with PID {new_p.pid}"
|
||||
)
|
||||
else:
|
||||
alive_count += 1
|
||||
if task_queue.empty():
|
||||
logger.info("All tasks finished.")
|
||||
break
|
||||
if alive_count == 0:
|
||||
logger.error("All processes died, exiting.")
|
||||
break
|
||||
time.sleep(5)
|
||||
for p in processes:
|
||||
p.join()
|
||||
except KeyboardInterrupt:
|
||||
logger.info(
|
||||
"Main process received KeyboardInterrupt. Initiating graceful shutdown..."
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Unexpected error while waiting for processes: {e}", exc_info=True
|
||||
)
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
try:
|
||||
logger.info(f"Terminating process {p.name} due to error...")
|
||||
p.terminate()
|
||||
except Exception as term_e:
|
||||
logger.error(f"Error terminating process {p.name}: {term_e}")
|
||||
raise
|
||||
scores = list(shared_scores)
|
||||
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
|
||||
|
||||
def get_unfinished(
|
||||
target_dir, total_file_json
|
||||
):
|
||||
|
||||
if not os.path.exists(target_dir):
|
||||
return total_file_json
|
||||
|
||||
finished = {}
|
||||
for domain in os.listdir(target_dir):
|
||||
finished[domain] = []
|
||||
domain_path = os.path.join(target_dir, domain)
|
||||
if os.path.isdir(domain_path):
|
||||
for example_id in os.listdir(domain_path):
|
||||
if example_id == "onboard":
|
||||
continue
|
||||
example_path = os.path.join(domain_path, example_id)
|
||||
if os.path.isdir(example_path):
|
||||
if "result.txt" not in os.listdir(example_path):
|
||||
# empty all files under example_id
|
||||
for file in os.listdir(example_path):
|
||||
os.remove(os.path.join(example_path, file))
|
||||
else:
|
||||
finished[domain].append(example_id)
|
||||
|
||||
if not finished:
|
||||
return total_file_json
|
||||
|
||||
for domain, examples in finished.items():
|
||||
if domain in total_file_json:
|
||||
total_file_json[domain] = [
|
||||
x for x in total_file_json[domain] if x not in examples
|
||||
]
|
||||
|
||||
return total_file_json
|
||||
|
||||
|
||||
def get_result(target_dir, total_file_json: dict):
|
||||
if not os.path.exists(target_dir):
|
||||
print("New experiment, no result yet.")
|
||||
return None
|
||||
|
||||
# list for all tasks
|
||||
all_result = []
|
||||
|
||||
for domain, example_id_list in total_file_json.items():
|
||||
for example_id in example_id_list:
|
||||
example_path = os.path.join(target_dir, domain, example_id)
|
||||
if os.path.isdir(example_path):
|
||||
if "result.txt" in os.listdir(example_path):
|
||||
# empty all files under example_id
|
||||
try:
|
||||
all_result.append(
|
||||
float(
|
||||
open(
|
||||
os.path.join(example_path, "result.txt"), "r"
|
||||
).read()
|
||||
)
|
||||
)
|
||||
except:
|
||||
all_result.append(0.0)
|
||||
else:
|
||||
all_result.append(0.0)
|
||||
else:
|
||||
all_result.append(0.0)
|
||||
|
||||
if not all_result:
|
||||
print("New experiment, no result yet.")
|
||||
return None
|
||||
else:
|
||||
print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%")
|
||||
return all_result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
####### The complete version of the list of examples #######
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
args = config()
|
||||
|
||||
if args.exp_name != "":
|
||||
args.result_dir = os.path.join(
|
||||
args.result_dir,
|
||||
args.exp_name
|
||||
)
|
||||
else:
|
||||
args.result_dir = os.path.join(
|
||||
args.result_dir,
|
||||
args.action_space,
|
||||
args.observation_type,
|
||||
args.model
|
||||
)
|
||||
|
||||
path_to_args = os.path.join(
|
||||
args.result_dir,
|
||||
"args.json"
|
||||
)
|
||||
os.makedirs(os.path.dirname(path_to_args), exist_ok=True)
|
||||
with open(path_to_args, "w", encoding="utf-8") as f:
|
||||
json.dump(vars(args), f, indent=4)
|
||||
|
||||
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
|
||||
test_all_meta = json.load(f)
|
||||
|
||||
if args.domain != "all":
|
||||
test_all_meta = {args.domain: test_all_meta[args.domain]}
|
||||
|
||||
|
||||
logger.info(f"====================\nExperiment on {args.benchmark} is started\n====================")
|
||||
test_file_list = get_unfinished(
|
||||
target_dir=args.result_dir,
|
||||
total_file_json=test_all_meta,
|
||||
)
|
||||
left_info = ""
|
||||
for domain in test_file_list:
|
||||
left_info += f"{domain}: {len(test_file_list[domain])}\n"
|
||||
logger.info(f"Left tasks:\n{left_info}")
|
||||
|
||||
get_result(
|
||||
target_dir=args.result_dir,
|
||||
total_file_json=test_all_meta
|
||||
)
|
||||
test(
|
||||
args,
|
||||
test_file_list
|
||||
)
|
||||
logger.info(f"====================\nExperiment on {args.benchmark} is ended\n====================")
|
||||
|
||||
logger.info(f"====================\nExperiment {args.exp_name} is totally ended!\n====================")
|
||||
|
|
@ -57,13 +57,13 @@ def config() -> argparse.Namespace:
|
|||
parser.add_argument("--model", type=str, default="qwen3-vl")
|
||||
parser.add_argument("--temperature", type=float, default=0)
|
||||
parser.add_argument("--top_p", type=float, default=0.9)
|
||||
parser.add_argument("--max_tokens", type=int, default=1500)
|
||||
parser.add_argument("--max_tokens", type=int, default=32768)
|
||||
parser.add_argument("--stop_token", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--coord",
|
||||
type=str,
|
||||
choices=["absolute", "relative"],
|
||||
default="absolute",
|
||||
default="relative",
|
||||
help="Coordinate system for agent outputs (absolute or relative)",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
|
@ -99,7 +99,7 @@ def config() -> argparse.Namespace:
|
|||
"--provider_name",
|
||||
type=str,
|
||||
default="docker",
|
||||
choices=["aws", "virtualbox", "vmware", "docker", "azure"],
|
||||
choices=["aws", "virtualbox", "vmware", "docker", "azure", "aliyun"],
|
||||
help="Provider name",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,540 @@
|
|||
from __future__ import annotations
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import signal
|
||||
import time
|
||||
from typing import List, Dict
|
||||
from multiprocessing import Process, Manager
|
||||
from multiprocessing import current_process
|
||||
import lib_run_single
|
||||
from desktop_env.desktop_env import DesktopEnv
|
||||
from mm_agents.seed_agent import SeedAgent
|
||||
import os
|
||||
|
||||
|
||||
# Global variables for signal handling
|
||||
active_environments = []
|
||||
processes = []
|
||||
is_terminating = False
|
||||
|
||||
# load the environment variables from .env file
|
||||
if os.path.exists(".env"):
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
# Logger Configs {{{ #
|
||||
def config() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run end-to-end evaluation on the benchmark"
|
||||
)
|
||||
|
||||
# environment config
|
||||
parser.add_argument("--path_to_vm", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--headless", action="store_true", help="Run in headless machine"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--action_space", type=str, default="pyautogui", help="Action type"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--observation_type",
|
||||
choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
default="screenshot",
|
||||
help="Observation type",
|
||||
)
|
||||
parser.add_argument("--sleep_after_execution", type=float, default=5.0)
|
||||
parser.add_argument("--max_steps", type=int, default=100)
|
||||
|
||||
# evaluation config
|
||||
parser.add_argument(
|
||||
"--test_config_base_dir", type=str, default="evaluation_examples"
|
||||
)
|
||||
|
||||
# lm config
|
||||
parser.add_argument("--model", type=str, default="doubao-1-5-thinking-vision-pro-250428")
|
||||
parser.add_argument("--model_type", type=str, default="doubao", choices=["doubao", "qwen25"])
|
||||
parser.add_argument("--temperature", type=float, default=1.0)
|
||||
parser.add_argument("--top_p", type=float, default=0.7)
|
||||
parser.add_argument("--max_tokens", type=int, default=4096)
|
||||
parser.add_argument("--use_thinking", action="store_true", default=False)
|
||||
|
||||
parser.add_argument("--max_trajectory_length", type=int, default=None, help="The max number of trajectory steps.")
|
||||
parser.add_argument("--history_n", type=int, default=5, help="The max number of images in the history.")
|
||||
parser.add_argument("--language", type=str, default="Chinese", help="Language for the agent.")
|
||||
|
||||
# example config
|
||||
parser.add_argument("--domain", type=str, default="all")
|
||||
parser.add_argument(
|
||||
"--test_all_meta_path", type=str, default="evaluation_examples/test_all.json"
|
||||
)
|
||||
|
||||
# logging related
|
||||
parser.add_argument("--result_dir", type=str, default="./results")
|
||||
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel")
|
||||
parser.add_argument("--log_level", type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
|
||||
default='INFO', help="Set the logging level")
|
||||
# aws config
|
||||
parser.add_argument(
|
||||
"--region", type=str, default="us-east-1", help="AWS region for the VM"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--provider_name", type=str, default="aws", choices=["aws", "virtualbox", "vmware", "docker", "azure"], help="Provider name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--client_password", type=str, default="", help="Client password"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--screen_width", type=int, default=1920, help="Screen width"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--screen_height", type=int, default=1080, help="Screen height"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
args = config() # Get command line arguments first
|
||||
|
||||
logger = logging.getLogger()
|
||||
log_level = getattr(logging, args.log_level.upper())
|
||||
logger.setLevel(log_level)
|
||||
|
||||
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
|
||||
file_handler = logging.FileHandler(
|
||||
os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||
)
|
||||
debug_handler = logging.FileHandler(
|
||||
os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||
)
|
||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||
|
||||
file_handler.setLevel(logging.INFO)
|
||||
debug_handler.setLevel(logging.DEBUG)
|
||||
stdout_handler.setLevel(log_level)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
|
||||
)
|
||||
file_handler.setFormatter(formatter)
|
||||
debug_handler.setFormatter(formatter)
|
||||
stdout_handler.setFormatter(formatter)
|
||||
|
||||
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
||||
|
||||
logger.addHandler(file_handler)
|
||||
logger.addHandler(debug_handler)
|
||||
logger.addHandler(stdout_handler)
|
||||
# }}} Logger Configs #
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
|
||||
def distribute_tasks(test_all_meta: dict) -> List[tuple]:
|
||||
all_tasks = []
|
||||
for domain, examples in test_all_meta.items():
|
||||
for example_id in examples:
|
||||
all_tasks.append((domain, example_id))
|
||||
return all_tasks
|
||||
|
||||
|
||||
def process_signal_handler(signum, frame, env_idx):
|
||||
"""Signal handler for child processes to gracefully shut down their environments."""
|
||||
logger.info(f"Process {env_idx + 1} received signal {signum}. Shutting down...")
|
||||
|
||||
# Get the active_environments from the caller's frame
|
||||
local_vars = frame.f_locals
|
||||
active_environments = local_vars.get('active_environments', [])
|
||||
|
||||
# Close environment in the current process context
|
||||
for env in active_environments:
|
||||
if env is not None:
|
||||
try:
|
||||
logger.info(f"Process {env_idx + 1} closing environment...")
|
||||
env.close()
|
||||
logger.info(f"Process {env_idx + 1} environment closed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Process {env_idx + 1} error closing environment: {e}")
|
||||
|
||||
logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.")
|
||||
sys.exit(0)
|
||||
|
||||
def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: list):
|
||||
active_environments = []
|
||||
env = None
|
||||
try:
|
||||
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
|
||||
REGION = args.region
|
||||
screen_size = (args.screen_width, args.screen_height)
|
||||
ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)])
|
||||
env = DesktopEnv(
|
||||
path_to_vm=args.path_to_vm,
|
||||
action_space=args.action_space,
|
||||
provider_name=args.provider_name,
|
||||
region=REGION,
|
||||
snapshot_name=ami_id,
|
||||
screen_size=screen_size,
|
||||
headless=args.headless,
|
||||
os_type="Ubuntu",
|
||||
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
enable_proxy=True,
|
||||
client_password=args.client_password
|
||||
)
|
||||
active_environments.append(env)
|
||||
agent = SeedAgent(
|
||||
model=args.model,
|
||||
model_type=args.model_type,
|
||||
max_tokens=args.max_tokens,
|
||||
top_p=args.top_p,
|
||||
temperature=args.temperature,
|
||||
max_trajectory_length=args.max_trajectory_length,
|
||||
history_n=args.history_n,
|
||||
use_thinking=args.use_thinking,
|
||||
)
|
||||
logger.info(f"Process {current_process().name} started.")
|
||||
while True:
|
||||
try:
|
||||
item = task_queue.get(timeout=5)
|
||||
except Exception:
|
||||
break
|
||||
domain, example_id = item
|
||||
try:
|
||||
config_file = os.path.join(
|
||||
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
|
||||
)
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
example = json.load(f)
|
||||
logger.info(f"[{current_process().name}][Domain]: {domain}")
|
||||
logger.info(f"[{current_process().name}][Example ID]: {example_id}")
|
||||
logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}")
|
||||
example_result_dir = os.path.join(
|
||||
args.result_dir,
|
||||
args.action_space,
|
||||
args.observation_type,
|
||||
args.model,
|
||||
domain,
|
||||
example_id,
|
||||
)
|
||||
os.makedirs(example_result_dir, exist_ok=True)
|
||||
try:
|
||||
lib_run_single.run_single_example(
|
||||
agent,
|
||||
env,
|
||||
example,
|
||||
args.max_steps,
|
||||
example["instruction"],
|
||||
args,
|
||||
example_result_dir,
|
||||
shared_scores,
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
try:
|
||||
env.controller.end_recording(
|
||||
os.path.join(example_result_dir, "recording.mp4")
|
||||
)
|
||||
except Exception as rec_e:
|
||||
logger.error(f"Failed to end recording: {rec_e}")
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||
f.write(
|
||||
json.dumps(
|
||||
{"Error": f"{domain}/{example_id} - {e}"}
|
||||
)
|
||||
)
|
||||
f.write("\n")
|
||||
except Exception as e:
|
||||
logger.error(f"Task-level error in {current_process().name}: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
except Exception as e:
|
||||
logger.error(f"Process-level error in {current_process().name}: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
finally:
|
||||
logger.info(f"{current_process().name} cleaning up environment...")
|
||||
try:
|
||||
if env:
|
||||
env.close()
|
||||
logger.info(f"{current_process().name} environment closed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"{current_process().name} error during environment cleanup: {e}")
|
||||
|
||||
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
"""Handle termination signals (SIGINT, SIGTERM) to gracefully shutdown environments."""
|
||||
global is_terminating, active_environments, processes
|
||||
|
||||
# Avoid duplicate handling
|
||||
if is_terminating:
|
||||
return
|
||||
|
||||
is_terminating = True
|
||||
logger.info(f"Received signal {signum}. Gracefully shutting down...")
|
||||
|
||||
# Close all registered environments in the main process
|
||||
for env in active_environments:
|
||||
try:
|
||||
logger.info(f"Closing environment...")
|
||||
env.close()
|
||||
logger.info(f"Environment closed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing environment: {e}")
|
||||
|
||||
# Send termination signal to all child processes first
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
try:
|
||||
logger.info(f"Sending termination signal to process {p.name}...")
|
||||
p.terminate()
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending termination signal to process: {e}")
|
||||
|
||||
# Allow a short time for processes to handle their own cleanup
|
||||
time.sleep(1)
|
||||
|
||||
# Forcefully terminate any processes that didn't exit
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
try:
|
||||
logger.info(f"Forcefully terminating process {p.name}...")
|
||||
import signal as sig
|
||||
os.kill(p.pid, sig.SIGKILL)
|
||||
except Exception as e:
|
||||
logger.error(f"Error forcefully terminating process: {e}")
|
||||
|
||||
logger.info("Shutdown complete. Exiting.")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||
global processes
|
||||
logger.info("Args: %s", args)
|
||||
all_tasks = distribute_tasks(test_all_meta)
|
||||
logger.info(f"Total tasks: {len(all_tasks)}")
|
||||
with Manager() as manager:
|
||||
shared_scores = manager.list()
|
||||
task_queue = manager.Queue()
|
||||
for item in all_tasks:
|
||||
task_queue.put(item)
|
||||
num_envs = args.num_envs
|
||||
processes = []
|
||||
for i in range(num_envs):
|
||||
p = Process(
|
||||
target=run_env_tasks,
|
||||
args=(task_queue, args, shared_scores),
|
||||
name=f"EnvProcess-{i+1}"
|
||||
)
|
||||
p.daemon = True
|
||||
p.start()
|
||||
processes.append(p)
|
||||
logger.info(f"Started process {p.name} with PID {p.pid}")
|
||||
try:
|
||||
while True:
|
||||
alive_count = 0
|
||||
for idx, p in enumerate(processes):
|
||||
if not p.is_alive():
|
||||
logger.warning(f"Process {p.name} died, restarting...")
|
||||
new_p = Process(
|
||||
target=run_env_tasks,
|
||||
args=(task_queue, args, shared_scores),
|
||||
name=f"EnvProcess-Restart-{idx+1}"
|
||||
)
|
||||
new_p.daemon = True
|
||||
new_p.start()
|
||||
processes[idx] = new_p
|
||||
logger.info(f"Restarted process {new_p.name} with PID {new_p.pid}")
|
||||
else:
|
||||
alive_count += 1
|
||||
if task_queue.empty():
|
||||
logger.info("All tasks finished.")
|
||||
break
|
||||
if alive_count == 0:
|
||||
logger.error("All processes died, exiting.")
|
||||
break
|
||||
time.sleep(5)
|
||||
for p in processes:
|
||||
p.join()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True)
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
try:
|
||||
logger.info(f"Terminating process {p.name} due to error...")
|
||||
p.terminate()
|
||||
except Exception as term_e:
|
||||
logger.error(f"Error terminating process {p.name}: {term_e}")
|
||||
raise
|
||||
scores = list(shared_scores)
|
||||
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
|
||||
|
||||
|
||||
def get_unfinished(
|
||||
action_space, use_model, observation_type, result_dir, total_file_json
|
||||
):
|
||||
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||
|
||||
if not os.path.exists(target_dir):
|
||||
return total_file_json
|
||||
|
||||
finished = {}
|
||||
for domain in os.listdir(target_dir):
|
||||
finished[domain] = []
|
||||
domain_path = os.path.join(target_dir, domain)
|
||||
if os.path.isdir(domain_path):
|
||||
for example_id in os.listdir(domain_path):
|
||||
if example_id == "onboard":
|
||||
continue
|
||||
example_path = os.path.join(domain_path, example_id)
|
||||
if os.path.isdir(example_path):
|
||||
if "result.txt" not in os.listdir(example_path):
|
||||
# empty all files under example_id
|
||||
for file in os.listdir(example_path):
|
||||
os.remove(os.path.join(example_path, file))
|
||||
else:
|
||||
finished[domain].append(example_id)
|
||||
|
||||
if not finished:
|
||||
return total_file_json
|
||||
|
||||
for domain, examples in finished.items():
|
||||
if domain in total_file_json:
|
||||
total_file_json[domain] = [
|
||||
x for x in total_file_json[domain] if x not in examples
|
||||
]
|
||||
|
||||
return total_file_json
|
||||
|
||||
|
||||
def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
|
||||
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||
if not os.path.exists(target_dir):
|
||||
print("New experiment, no result yet.")
|
||||
return None
|
||||
|
||||
all_result = []
|
||||
|
||||
for domain in os.listdir(target_dir):
|
||||
domain_path = os.path.join(target_dir, domain)
|
||||
if os.path.isdir(domain_path):
|
||||
for example_id in os.listdir(domain_path):
|
||||
example_path = os.path.join(domain_path, example_id)
|
||||
if os.path.isdir(example_path):
|
||||
if "result.txt" in os.listdir(example_path):
|
||||
# empty all files under example_id
|
||||
try:
|
||||
all_result.append(
|
||||
float(
|
||||
open(
|
||||
os.path.join(example_path, "result.txt"), "r"
|
||||
).read()
|
||||
)
|
||||
)
|
||||
except:
|
||||
all_result.append(0.0)
|
||||
|
||||
if not all_result:
|
||||
print("New experiment, no result yet.")
|
||||
return None
|
||||
else:
|
||||
print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%")
|
||||
return all_result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
####### The complete version of the list of examples #######
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
# Register signal handlers for graceful termination
|
||||
signal.signal(signal.SIGINT, signal_handler) # Handle Ctrl+C
|
||||
signal.signal(signal.SIGTERM, signal_handler) # Handle termination signal
|
||||
|
||||
try:
|
||||
args = config()
|
||||
|
||||
# save args to json in result_dir/action_space/observation_type/model/args.json
|
||||
path_to_args = os.path.join(
|
||||
args.result_dir,
|
||||
args.action_space,
|
||||
args.observation_type,
|
||||
args.model,
|
||||
"args.json",
|
||||
)
|
||||
os.makedirs(os.path.dirname(path_to_args), exist_ok=True)
|
||||
with open(path_to_args, "w", encoding="utf-8") as f:
|
||||
json.dump(vars(args), f, indent=4)
|
||||
|
||||
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
|
||||
test_all_meta = json.load(f)
|
||||
|
||||
if args.domain != "all":
|
||||
test_all_meta = {args.domain: test_all_meta[args.domain]}
|
||||
|
||||
test_file_list = get_unfinished(
|
||||
args.action_space,
|
||||
args.model,
|
||||
args.observation_type,
|
||||
args.result_dir,
|
||||
test_all_meta,
|
||||
)
|
||||
left_info = ""
|
||||
for domain in test_file_list:
|
||||
left_info += f"{domain}: {len(test_file_list[domain])}\n"
|
||||
logger.info(f"Left tasks:\n{left_info}")
|
||||
|
||||
get_result(
|
||||
args.action_space,
|
||||
args.model,
|
||||
args.observation_type,
|
||||
args.result_dir,
|
||||
test_all_meta,
|
||||
)
|
||||
test(args, test_file_list)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Main process received KeyboardInterrupt.")
|
||||
# Signal handler will take care of cleanup
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in main process: {e}", exc_info=True)
|
||||
# Also trigger cleanup for unhandled exceptions
|
||||
signal_handler(signal.SIGTERM, None)
|
||||
finally:
|
||||
# Final cleanup in case any environments or processes remain
|
||||
logger.info("Main process final cleanup...")
|
||||
for env in active_environments:
|
||||
if env is not None:
|
||||
try:
|
||||
logger.info(f"Closing environment in final cleanup...")
|
||||
env.close()
|
||||
logger.info(f"Environment closed successfully in final cleanup")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during final environment cleanup: {e}")
|
||||
|
||||
# First try gentle termination
|
||||
for p in processes:
|
||||
if p is not None and p.is_alive():
|
||||
try:
|
||||
logger.info(f"Terminating process {p.name}...")
|
||||
p.terminate()
|
||||
except Exception as e:
|
||||
logger.error(f"Error terminating process: {e}")
|
||||
|
||||
# Wait a moment for processes to terminate
|
||||
time.sleep(1)
|
||||
|
||||
# Then force kill if needed
|
||||
for p in processes:
|
||||
if p is not None and p.is_alive():
|
||||
try:
|
||||
logger.info(f"Force killing process {p.name}...")
|
||||
os.kill(p.pid, signal.SIGKILL)
|
||||
logger.info(f"Process {p.name} force killed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error force killing process: {e}")
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
EXP_NAME="xxx"
|
||||
export AWS_SECRET_ACCESS_KEY="xxx"
|
||||
export AWS_ACCESS_KEY_ID="xxx"
|
||||
export AWS_REGION="us-east-1"
|
||||
export AWS_SUBNET_ID="xxx"
|
||||
export AWS_SECURITY_GROUP_ID="xxx"
|
||||
# >> logs/${EXP_NAME}.log 2>&1
|
||||
python run_multienv_os_symphony.py \
|
||||
--provider_name "aws" \
|
||||
--region "us-east-1" \
|
||||
--client_password "osworld-public-evaluation" \
|
||||
--headless \
|
||||
--num_envs 7 \
|
||||
--max_steps 50 \
|
||||
--benchmark osworld \
|
||||
--domain "all" \
|
||||
--test_all_meta_path evaluation_examples/test_nogdrive.json \
|
||||
--result_dir "results" \
|
||||
--tool_config mm_agents/os_symphony/tool/all_tool_config.yaml \
|
||||
--orchestrator_provider "openai" \
|
||||
--orchestrator_model "gpt-5" \
|
||||
--orchestrator_url "xxx" \
|
||||
--orchestrator_api_key "xxx" \
|
||||
--orchestrator_temperature 0.1 \
|
||||
--orchestrator_keep_first_image \
|
||||
--max_trajectory_length 8 \
|
||||
--grounder_provider "vllm" \
|
||||
--grounder_model "UI-TARS-1.5-7B" \
|
||||
--grounder_api_key "none" \
|
||||
--grounder_url "xxx" \
|
||||
--grounding_smart_resize \
|
||||
--grounding_width 1920 \
|
||||
--grounding_height 1080 \
|
||||
--coder_provider "openai" \
|
||||
--coder_model "gpt-5" \
|
||||
--coder_url "xxx" \
|
||||
--coder_api_key "xxx" \
|
||||
--coder_temperature 0.1 \
|
||||
--coder_budget 20 \
|
||||
--memoryer_provider "openai" \
|
||||
--memoryer_model "gpt-5" \
|
||||
--memoryer_url "xxx" \
|
||||
--memoryer_api_key "xxx" \
|
||||
--memoryer_temperature 0.1 \
|
||||
--memoryer_max_images 8 \
|
||||
--searcher_provider "openai" \
|
||||
--searcher_model "gpt-5" \
|
||||
--searcher_url "xxx" \
|
||||
--searcher_api_key "xxx" \
|
||||
--searcher_temperature 0.1 \
|
||||
--searcher_type "vlm" \
|
||||
--searcher_engine "google" \
|
||||
--searcher_budget 20 \
|
||||
--searcher_screen_width 1920 \
|
||||
--searcher_screen_height 1080 \
|
||||
--sleep_after_execution 3 \
|
||||
--exp_name ${EXP_NAME} \
|
||||
--enable_reflection >> logs/${EXP_NAME}.log 2>&1
|
||||
6
setup.py
6
setup.py
|
|
@ -23,7 +23,7 @@ class InstallPlaywrightCommand(install):
|
|||
|
||||
setup(
|
||||
name="desktop_env",
|
||||
version="1.0.0",
|
||||
version="1.0.1",
|
||||
author="Tianbao Xie, Danyang Zhang, Jixuan Chen, Xiaochuan Li, Siheng Zhao, Ruisheng Cao, Toh Jing Hua, etc.",
|
||||
author_email="tianbaoxiexxx@gmail.com",
|
||||
description="The package provides a desktop environment for setting and evaluating desktop automation tasks.",
|
||||
|
|
@ -38,7 +38,7 @@ setup(
|
|||
],
|
||||
python_requires='>=3.10',
|
||||
install_requires=[
|
||||
"numpy~=1.24.4",
|
||||
"numpy>=1.26,<3",
|
||||
"Pillow~=11.0.0",
|
||||
"fabric",
|
||||
"gymnasium~=0.28.1",
|
||||
|
|
@ -53,7 +53,7 @@ setup(
|
|||
"pyautogui~=0.9.54",
|
||||
"psutil~=5.9.6",
|
||||
"tqdm~=4.65.0",
|
||||
"pandas~=2.2.3",
|
||||
"pandas>=2.2,<2.3",
|
||||
"flask~=3.0.0",
|
||||
"requests-toolbelt~=1.0.0",
|
||||
"ag2~=0.9.7",
|
||||
|
|
|
|||
Loading…
Reference in New Issue