Compare commits

...

27 Commits

Author SHA1 Message Date
蘑菇先生 5ef8bdfa35
EvoCUA Update (2025.01.05) (#412)
* evocua init

* setup max_token

* evocua update

---------

Co-authored-by: xuetaofeng <xuetaofeng@meituan.com>
Co-authored-by: Tianbao Xie <47296835+Timothyxxx@users.noreply.github.com>
2026-01-05 16:14:53 +08:00
Bowen Yang 439e178a2e
fix(os_symphony_evaluation) (#410)
* fix(os_symphony)

* Update desktop_env_os_symphony.py

* fix(os_symphony_desktop)

* fix(os_symphony_start)

* Add docstring to run_multienv_os_symphony.py

Added documentation header for the evaluation script.
2026-01-04 15:56:51 +08:00
Bowen Yang 951e1928c8
fix(desktop_os_symphony):support aws (#406)
* fix(os_symphony)

* Update desktop_env_os_symphony.py
2026-01-01 11:27:34 +08:00
Bowen Yang 02a35be067
fix(os_symphony) (#405) 2025-12-30 22:43:47 +08:00
Bowen Yang 662826f57e
fix(os_symphony):prompt (#402)
* add_os_symphony

* fix(os_symphony)

* fix(os_symphony):prompt

---------

Co-authored-by: Tianbao Xie <47296835+Timothyxxx@users.noreply.github.com>
2025-12-29 20:45:36 +08:00
xuetf 410ec63a89
Add EvoCUA Support (#401)
* evocua init

* setup max_token

---------

Co-authored-by: xuetaofeng <xuetaofeng@meituan.com>
Co-authored-by: Tianbao Xie <47296835+Timothyxxx@users.noreply.github.com>
2025-12-23 20:46:23 +08:00
Bowen Yang 031696e83c
fix os_symphony (#400)
* add_os_symphony

* fix(os_symphony)

---------

Co-authored-by: Tianbao Xie <47296835+Timothyxxx@users.noreply.github.com>
2025-12-23 20:45:30 +08:00
Bowen Yang f593f35b1c
add_os_symphony (#399) 2025-12-23 14:30:44 +08:00
Ubuntu ac31778ee3 Update: requirements.txt for seed agent 2025-12-15 11:47:56 +00:00
Ubuntu 60caa52fc4 Update: requirements.txt for seed agent 2025-12-15 11:47:40 +00:00
Ubuntu 41477a9c40 Update: seed agent 2025-12-15 11:45:57 +00:00
Ubuntu 78433ecfcf Add agent: seed agent 2025-12-12 05:35:20 +00:00
Meshal Nayim 9540454b0a
Fix demo agent (PromptAgent) reset(): add vm_ip and kwargs for compatibility with lib_run_single.py (#388) 2025-12-09 15:59:25 +08:00
MillanK cbc3b590ff
Task fix batch (#383)
* update 873cafdd-a581-47f6-8b33-b9696ddb7b05 task eval

* c1fa57f3-c3db-4596-8f09-020701085416 fix, add tolerance to url matching

* 8df7e444-8e06-4f93-8a1a-c5c974269d82 add more clear instruction to the filename for compress

* add address string normalization for 6f4073b8-d8ea-4ade-8a18-c5d1d5d5aa9a

---------

Co-authored-by: Jiaqi <dengjiaqi@moonshot.cn>
2025-11-19 17:24:25 +08:00
Qichen Fu 903ed36715
Add Claude Sonnet 4.5 support and improve action handling (#362)
🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude <noreply@anthropic.com>
2025-11-14 13:54:32 +08:00
Subash Shibu 3167339e45
Add hosted GBOX agent for OSWorld evaluation (#376) 2025-11-13 13:13:31 +08:00
Pengxiang-Li 00b6468eb7
feat/dart_gui (#371) 2025-11-07 21:50:01 +08:00
yiqilin 6d43dbc532
Update GIMP evaluation examples to replace local file paths with cloud file URLs for consistency and accessibility. (#372) 2025-11-07 21:49:49 +08:00
Timothyxxx 8365edc975 Add new section in README for OSWorld-MCP project 2025-10-30 06:06:48 +00:00
Daphne Barretto 21c2b7629b
Add consistent scores validation (#368)
* Add consistent scores validation

* revert osworld_run_maestro.py changes
2025-10-29 01:44:48 +08:00
Timothyxxx 3bf54c92a9 Merge branch 'main' of github.com:xlang-ai/OSWorld 2025-10-23 14:28:14 +08:00
Timothyxxx a484f2e484 Update setup.py for version bump and dependency adjustments
- Bump version from 1.0.0 to 1.0.1
- Update numpy dependency to allow versions >=1.26 and <3
- Adjust pandas dependency to allow versions >=2.2 and <2.3
- Add new __init__.py file in the docker provider directory
2025-10-23 14:27:52 +08:00
Atharva Gundawar 9f97535ef9
oswrold agent wrapper for trained v7 (#360) 2025-10-18 02:29:15 +08:00
ludunjie.ldj afd29115da support aliyun eval of qwen3vl 2025-10-16 16:20:54 +08:00
Dunjie Lu 55372c4432
Fix API base URLs for OpenAI and DashScope
Updated the base URLs for OpenAI and DashScope API calls.
2025-10-14 12:57:00 +08:00
Dunjie Lu d25464c203
Djlu/qwen3vl dash (#356)
* support dashscopoe sdk to call qwen3-vl-plus

* support dashscopoe sdk to call qwen3-vl-plus

---------

Co-authored-by: Timothyxxx <Timothyxxx@users.noreply.github.com>
2025-10-13 16:31:06 +08:00
Xinyuan Wang f9e9273b3b
OpenCUA-72B (#354)
* use aws pub ip

* os task fix: set the default dim screen time to be 300s

* OpenCUA-72B

* update password

* update

* update

* update opencua72b agent

* change provider ip

---------

Co-authored-by: Jiaqi <dengjiaqi@moonshot.cn>
2025-10-13 10:39:33 +08:00
75 changed files with 16076 additions and 961 deletions

9
.gitignore vendored
View File

@ -204,4 +204,11 @@ reference/
draft/
manual_examine.py
run_human_examine.sh
quick_start.py
quick_start.py
result_multi_apps_pengxiang_transformers12evaluation_examples/settings/proxy/dataimpulse.json
evaluation_examples/settings/proxy/dataimpulse.json
# Local test configurations (not for public repo)
evaluation_examples/spiderman.json
evaluation_examples/test_50_random_proportional.json
evaluation_examples/test_chrome.json

View File

@ -228,3 +228,7 @@ Special thanks to the following institutions that provided feedback and particip
Special thanks to the following students who participated in the specific fixes: [Mengqi Yuan](https://yuanmengqi.github.io/), [Danyang Zhang](https://zdy023.github.io/), [Xinzhuang Xiong](https://thisisxxz.com/), [Zhennan Shen](https://scholar.google.com/citations?user=JPwg5MwAAAAJ&hl=en), [Zilong Zhou](https://github.com/adlsdztony), Yanxu Chen, [Jiaqi Deng](https://millank0817.github.io/), [Tianbao Xie](https://tianbaoxie.com/), Junda Chen, [Jixuan Chen](https://chenjix.github.io/), [Haoyuan Wu](https://www.linkedin.com/in/haoyuan-wu-240878291/).
Special thanks to the following students who participated in running the re-evaluation: [Mengqi Yuan](https://yuanmengqi.github.io/), [Zilong Zhou](https://github.com/adlsdztony), [Xinyuan Wang](https://xinyuanwangcs.github.io/), [Bowen Wang](https://bowenbryanwang.github.io/).
## You might also be interested
- **OSWorld-MCP**: Benchmarking MCP Tool Invocation in Computer-Use Agents. [Website](https://osworld-mcp.github.io/)

View File

@ -238,12 +238,17 @@ class PythonController:
"returncode": -1
}
def execute_action(self, action: Dict[str, Any]):
def execute_action(self, action):
"""
Executes an action on the server computer.
"""
# Handle string actions
if action in ['WAIT', 'FAIL', 'DONE']:
return
# Handle dictionary actions
if type(action) == dict and action.get('action_type') in ['WAIT', 'FAIL', 'DONE']:
return
action_type = action["action_type"]
parameters = action["parameters"] if "parameters" in action else {param: action[param] for param in action if param != 'action_type'}

View File

@ -391,12 +391,12 @@ class DesktopEnv(gym.Env):
logger.info(f"Step {self._step_no} in trajectory {self._traj_no} with action: {action}")
# handle the special actions
if action in ['WAIT', 'FAIL', 'DONE'] or (type(action) == dict and action['action_type'] in ['WAIT', 'FAIL', 'DONE']):
if action == 'WAIT':
if action == 'WAIT' or (type(action) == dict and action.get('action_type') == 'WAIT'):
time.sleep(pause)
elif action == 'FAIL':
elif action == 'FAIL' or (type(action) == dict and action.get('action_type') == 'FAIL'):
done = True
info = {"fail": True}
elif action == 'DONE':
elif action == 'DONE' or (type(action) == dict and action.get('action_type') == 'DONE'):
done = True
info = {"done": True}
@ -404,7 +404,7 @@ class DesktopEnv(gym.Env):
# the set of all possible actions defined in the action representation
self.controller.execute_action(action)
elif self.action_space == "pyautogui" or self.action_space == "claude_computer_use":
if action in ['WAIT', 'FAIL', 'DONE']:
if action in ['WAIT', 'FAIL', 'DONE'] or (type(action) == dict and action.get('action_type') in ['WAIT', 'FAIL', 'DONE']):
self.controller.execute_action(action)
else:
# the set of all possible python commands insides `pyautogui`
@ -434,13 +434,16 @@ class DesktopEnv(gym.Env):
self.is_environment_used = True
if self.evaluator['func'] == "infeasible":
if len(self.action_history) > 0 and self.action_history[-1] == "FAIL":
return 1
else:
return 0
if len(self.action_history) > 0:
last_action = self.action_history[-1]
if last_action == "FAIL" or (type(last_action) == dict and last_action.get('action_type') == 'FAIL'):
return 1
return 0
else:
if len(self.action_history) > 0 and self.action_history[-1] == "FAIL":
return 0
if len(self.action_history) > 0:
last_action = self.action_history[-1]
if last_action == "FAIL" or (type(last_action) == dict and last_action.get('action_type') == 'FAIL'):
return 0
if type(self.metric) == list:
# Multiple metrics to evaluate whether the task is successfully completed

View File

@ -0,0 +1,499 @@
from __future__ import annotations
import logging
import os
import time
import re
from typing import Callable, Any, Optional, Tuple
from typing import List, Dict, Union
import gymnasium as gym
from desktop_env.controllers.python import PythonController
from desktop_env.controllers.setup import SetupController
from desktop_env.evaluators import metrics, getters
from desktop_env.providers import create_vm_manager_and_provider
logger = logging.getLogger("desktopenv.env")
Metric = Callable[[Any, Any], float]
Getter = Callable[[gym.Env, Dict[str, Any]], Any]
MAX_RETRIES = 5 # Maximum retries for environment setup
def _fix_pyautogui_less_than_bug(command: str) -> str:
"""
Fix PyAutoGUI '<' character bug by converting it to hotkey("shift", ',') calls.
This fixes the known PyAutoGUI issue where typing '<' produces '>' instead.
References:
- https://github.com/asweigart/pyautogui/issues/198
- https://github.com/xlang-ai/OSWorld/issues/257
Args:
command (str): The original pyautogui command
Returns:
str: The fixed command with '<' characters handled properly
"""
# Pattern to match press('<') or press('\u003c') calls
press_pattern = r'pyautogui\.press\(["\'](?:<|\\u003c)["\']\)'
# Handle press('<') calls
def replace_press_less_than(match):
return 'pyautogui.hotkey("shift", ",")'
# First handle press('<') calls
command = re.sub(press_pattern, replace_press_less_than, command)
# Pattern to match typewrite calls with quoted strings
typewrite_pattern = r'pyautogui\.typewrite\((["\'])(.*?)\1\)'
# Then handle typewrite calls
def process_typewrite_match(match):
quote_char = match.group(1)
content = match.group(2)
# Preprocess: Try to decode Unicode escapes like \u003c to actual '<'
# This handles cases where '<' is represented as escaped Unicode
try:
# Attempt to decode unicode escapes
decoded_content = content.encode('utf-8').decode('unicode_escape')
content = decoded_content
except UnicodeDecodeError:
# If decoding fails, proceed with original content to avoid breaking existing logic
pass # English comment: Graceful degradation - fall back to original content if decoding fails
# Check if content contains '<'
if '<' not in content:
return match.group(0)
# Split by '<' and rebuild
parts = content.split('<')
result_parts = []
for i, part in enumerate(parts):
if i == 0:
# First part
if part:
result_parts.append(f"pyautogui.typewrite({quote_char}{part}{quote_char})")
else:
# Add hotkey for '<' and then typewrite for the rest
result_parts.append('pyautogui.hotkey("shift", ",")')
if part:
result_parts.append(f"pyautogui.typewrite({quote_char}{part}{quote_char})")
return '; '.join(result_parts)
command = re.sub(typewrite_pattern, process_typewrite_match, command)
return command
class DesktopEnv(gym.Env):
"""
DesktopEnv with OpenAI Gym interface. It provides a desktop environment for setting and evaluating desktop automation tasks.
"""
def __init__(
self,
provider_name: str = "vmware",
region: str = None,
path_to_vm: str = None,
snapshot_name: str = "init_state",
action_space: str = "pyautogui",
cache_dir: str = "cache",
screen_size: Tuple[int] = (int(os.environ.get("SCREEN_WIDTH", 1920)), int(os.environ.get("SCREEN_HEIGHT", 1080))),
headless: bool = False,
require_a11y_tree: bool = True,
require_terminal: bool = False,
os_type: str = "Ubuntu",
enable_proxy: bool = False,
client_password: str = "",
):
"""
Args:
provider_name (str): virtualization provider name, default to "vmware"
region (str): the region for allocate machines, work for cloud services, default to "us-east-1"
path_to_vm (str): path to .vmx file
snapshot_name (str): snapshot name to revert to, default to "init_state"
action_space (str): "computer_13" | "pyautogui"
cache_dir (str): cache directory to cache task-related stuffs like
reference file for evaluation
screen_size (Tuple[int]): screen size of the VM
headless (bool): whether to run the VM in headless mode
require_a11y_tree (bool): whether to require accessibility tree
require_terminal (bool): whether to require terminal output
os_type (str): operating system type, default to "Ubuntu"
enable_proxy (bool): whether to enable proxy support, default to False
"""
# Initialize VM manager and vitualization provider
self.region = region
self.provider_name = provider_name
self.enable_proxy = enable_proxy # Store proxy enablement setting
if client_password == "":
if self.provider_name == "aws":
self.client_password = "osworld-public-evaluation"
else:
self.client_password = "password"
else:
self.client_password = client_password
self.screen_width = screen_size[0]
self.screen_height = screen_size[1]
# Default
self.server_port = 5000
self.chromium_port = 9222
self.vnc_port = 8006
self.vlc_port = 8080
# Initialize with default (no proxy) provider
self.current_use_proxy = False
self.manager, self.provider = None, None
self.os_type = os_type
self.path_to_vm = path_to_vm
# Track whether environment has been used (step/setup) to optimize snapshot revert
# docker, aws, gcp, azure are always unused as the emulator starts from a clean state
# vmware, virtualbox are always used as the emulator starts from a dirty state
if self.provider_name in {"docker", "aws", "gcp", "azure", "aliyun", "volcengine"}:
self.is_environment_used = False
elif self.provider_name in {"vmware", "virtualbox"}:
self.is_environment_used = True
else:
raise ValueError(f"Invalid provider name: {self.provider_name}")
self.snapshot_name = snapshot_name
self.cache_dir_base: str = cache_dir
self.headless = headless
self.require_a11y_tree = require_a11y_tree
self.require_terminal = require_terminal
# mode: human or machine
self.instruction = None
assert action_space in ["computer_13", "pyautogui", "claude_computer_use", "autoglm_computer_use"]
self.action_space = action_space # todo: refactor it to the ActType
# episodic stuffs, like counters, will be updated or reset
# when calling self.reset()
self._traj_no: int = -1
self._step_no: int = 0
self.action_history: List[Dict[str, any]] = []
def start(self):
# Initialize emulator and controller
if not self.manager and not self.provider:
logger.info("Initializing...")
self.manager, self.provider = create_vm_manager_and_provider(self.provider_name, self.region, use_proxy=False)
if self.path_to_vm:
self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(self.path_to_vm))) \
if self.provider_name in {"vmware", "virtualbox"} else self.path_to_vm
else:
self.path_to_vm = self.manager.get_vm_path(os_type=self.os_type, region=self.region, screen_size=(self.screen_width, self.screen_height))
self._start_emulator()
def _start_emulator(self):
try:
# Power on the virtual machine
self.provider.start_emulator(self.path_to_vm, self.headless, self.os_type)
# Get the ip from the virtual machine, and setup the controller
vm_ip_ports = self.provider.get_ip_address(self.path_to_vm).split(':')
self.vm_ip = vm_ip_ports[0]
# Get the ports from the virtual machine (for Docker provider only)
if len(vm_ip_ports) > 1:
self.server_port = int(vm_ip_ports[1])
self.chromium_port = int(vm_ip_ports[2])
self.vnc_port = int(vm_ip_ports[3])
self.vlc_port = int(vm_ip_ports[4])
self.controller = PythonController(vm_ip=self.vm_ip, server_port=self.server_port)
self.setup_controller = SetupController(vm_ip=self.vm_ip, server_port=self.server_port, chromium_port=self.chromium_port, vlc_port=self.vlc_port, cache_dir=self.cache_dir_base, client_password=self.client_password, screen_width=self.screen_width, screen_height=self.screen_height)
except Exception as e:
try:
self.provider.stop_emulator(self.path_to_vm)
except Exception as stop_err:
logger.warning(f"Cleanup after interrupt failed: {stop_err}")
raise
def _revert_to_snapshot(self):
# Revert to certain snapshot of the virtual machine, and refresh the path to vm and ip of vm
# due to the fact it could be changed when implemented by cloud services
path_to_vm = self.provider.revert_to_snapshot(self.path_to_vm, self.snapshot_name)
if path_to_vm and not path_to_vm == self.path_to_vm:
# path_to_vm has to be a new path
self.manager.delete_vm(self.path_to_vm, self.region)
self.manager.add_vm(path_to_vm, self.region)
self.manager.occupy_vm(path_to_vm, os.getpid(), self.region)
self.path_to_vm = path_to_vm
def _save_state(self, snapshot_name=None):
# Save the current virtual machine state to a certain snapshot name
self.provider.save_state(self.path_to_vm, snapshot_name)
def close(self):
# Close (release) the virtual machine
self.provider.stop_emulator(self.path_to_vm)
def reset(self, task_config: Optional[Dict[str, Any]] = None, seed=None, options=None) -> Dict[str, Any]:
# Reset to certain task in OSWorld
logger.info("Resetting environment...")
logger.info("Switching task...")
logger.info("Setting counters...")
self._traj_no += 1
self._step_no = 0
self.action_history.clear()
for attempt in range(MAX_RETRIES):
# Only revert to snapshot if environment has been used (step/setup)
# This optimization is especially important for cloud providers like AWS
# where unnecessary snapshot operations are costly and time-consuming
if task_config is not None:
# Only consider task proxy requirement if proxy is enabled at system level
task_use_proxy = task_config.get("proxy", False) and self.enable_proxy
if not self.enable_proxy and task_config.get("proxy", False):
logger.info("Task requires proxy but proxy is disabled at system level, ignoring proxy requirement.")
if task_use_proxy != self.current_use_proxy:
# keep because get_info_from_website depend on this
self.current_use_proxy = task_use_proxy
if self.is_environment_used:
logger.info("Environment has been used, reverting to snapshot {}...".format(self.snapshot_name))
self._revert_to_snapshot()
logger.info("Starting emulator...")
self._start_emulator()
logger.info("Emulator started.")
# Reset the usage flag after reverting
self.is_environment_used = False
else:
logger.info("Environment is clean, skipping snapshot revert (provider: {}).".format(self.provider_name))
if task_config is not None:
if task_config.get("proxy", False) and self.enable_proxy:
# If using proxy and proxy is enabled, set up the proxy configuration
self.setup_controller._proxy_setup(self.client_password)
self._set_task_info(task_config)
self.setup_controller.reset_cache_dir(self.cache_dir)
logger.info("Setting up environment...")
success = self.setup_controller.setup(self.config, task_config.get("proxy", False) and self.enable_proxy)
if success:
# Mark environment as used when setup is successfully executed
if self.config: # Only mark as used if there were actual setup operations
self.is_environment_used = True
break
else:
logger.error(
"Environment setup failed, retrying (%d/%d)...",
attempt + 1,
MAX_RETRIES,
)
time.sleep(5)
else:
break
logger.info("Environment setup complete.")
observation = self._get_obs()
return observation
def _get_obs(self):
# We provide screenshot, accessibility_tree (optional), terminal (optional), and instruction.
# can be customized and scaled
return {
"screenshot": self.controller.get_screenshot(),
"accessibility_tree": self.controller.get_accessibility_tree() if self.require_a11y_tree else None,
"terminal": self.controller.get_terminal_output() if self.require_terminal else None,
"instruction": self.instruction
}
@property
def vm_platform(self):
return self.controller.get_vm_platform()
@property
def vm_screen_size(self):
return self.controller.get_vm_screen_size()
def _set_task_info(self, task_config: Dict[str, Any]):
"""Set task info (proxy logic is handled in reset method)"""
self.task_id: str = task_config["id"]
self.cache_dir: str = os.path.join(self.cache_dir_base, self.task_id)
os.makedirs(self.cache_dir, exist_ok=True)
self.instruction = task_config["instruction"]
self.config = task_config["config"] if "config" in task_config else []
self._set_evaluator_info(task_config)
def _set_evaluator_info(self, task_config: Dict[str, Any]):
"""Set evaluator information from task config"""
if "evaluator" not in task_config:
return
# evaluator dict
# func -> metric function string, or list of metric function strings
# conj -> conjunction of multiple metrics if func is a list with length > 1, "and"/"or"
# result -> result getter config, or list of result getter configs
# expected (optional) -> expected getter config, or list of expected getter configs
# options (optional) -> metric options, or list of metric options
# if func is a str list, then result, expected (if exists), options (if exists) should also be lists of the same length
# even if one of the metrics does not need expected or options field, it should be included in the list with None
self.evaluator = task_config["evaluator"]
self.metric: Metric = [getattr(metrics, func) for func in self.evaluator["func"]] \
if isinstance(self.evaluator["func"], list) \
else getattr(metrics, self.evaluator["func"])
self.metric_conj: str = self.evaluator.get("conj", "and") # take conjunction of multiple metrics
if "result" in self.evaluator and len(self.evaluator["result"]) > 0:
self.result_getter: Getter = [getattr(getters, "get_{:}".format(res["type"])) for res in
self.evaluator["result"]] \
if isinstance(self.evaluator["result"], list) \
else getattr(getters, "get_{:}".format(self.evaluator["result"]["type"]))
else:
self.result_getter = [None] * len(self.metric) \
if isinstance(self.metric, list) \
else None
if "expected" in self.evaluator and len(self.evaluator["expected"]) > 0:
self.expected_getter: Getter = [getattr(getters, "get_{:}".format(exp["type"])) if exp else None for exp in
self.evaluator["expected"]] \
if isinstance(self.evaluator["expected"], list) \
else getattr(getters, "get_{:}".format(self.evaluator["expected"]["type"]))
else:
self.expected_getter = [None] * len(self.metric) \
if isinstance(self.metric, list) \
else None
self.metric_options: Union[List[Dict[str, Any]], Dict[str, Any]] = [opt if opt else {} for opt in
self.evaluator["options"]] \
if isinstance(self.evaluator.get("options", {}), list) \
else self.evaluator["options"] \
if "options" in self.evaluator \
else [{}] * len(self.metric) \
if isinstance(self.metric, list) \
else {}
assert (not isinstance(self.evaluator["func"], list)
or (len(self.metric) == len(self.result_getter) == len(self.expected_getter) == len(
self.metric_options)))
def step(self, action, pause=2):
self._step_no += 1
self.action_history.append(action)
# Mark environment as used when step is called
self.is_environment_used = True
reward = 0 # todo: Define reward calculation for each example
done = False # todo: Define episode termination condition for each example
info = {}
logger.info(f"Step {self._step_no} in trajectory {self._traj_no} with action: {action}")
# handle the special actions
if action in ['WAIT', 'FAIL', 'DONE'] or (type(action) == dict and action['action_type'] in ['WAIT', 'FAIL', 'DONE']):
if action == 'WAIT' or (type(action) == dict and action.get('action_type') == 'WAIT'):
time.sleep(pause)
elif action == 'FAIL' or (type(action) == dict and action.get('action_type') == 'FAIL'):
done = True
info = {"fail": True}
elif action == 'DONE' or (type(action) == dict and action.get('action_type') == 'DONE'):
done = True
info = {"done": True}
if self.action_space == "computer_13":
# the set of all possible actions defined in the action representation
self.controller.execute_action(action)
elif self.action_space == "pyautogui" or self.action_space == "claude_computer_use":
if action in ['WAIT', 'FAIL', 'DONE'] or (type(action) == dict and action.get('action_type') in ['WAIT', 'FAIL', 'DONE']):
self.controller.execute_action(action)
else:
# the set of all possible python commands insides `pyautogui`
if type(action) == str:
# Fix PyAutoGUI '<' character bug before execution
fixed_command = _fix_pyautogui_less_than_bug(action)
self.controller.execute_python_command(fixed_command)
elif type(action) == dict:
# Fix PyAutoGUI '<' character bug before execution
fixed_command = _fix_pyautogui_less_than_bug(action['command'])
self.controller.execute_python_command(fixed_command)
time.sleep(pause)
observation = self._get_obs()
return observation, reward, done, info
def evaluate(self):
"""
Evaluate whether the task is successfully completed.
"""
postconfig = self.evaluator.get("postconfig", [])
self.setup_controller.setup(postconfig, self.enable_proxy)
# Mark environment as used if there were postconfig setup operations
if postconfig:
self.is_environment_used = True
if self.evaluator['func'] == "infeasible":
if len(self.action_history) > 0:
last_action = self.action_history[-1]
if last_action == "FAIL" or (type(last_action) == dict and last_action.get('action_type') == 'FAIL'):
return 1
return 0
else:
if len(self.action_history) > 0:
last_action = self.action_history[-1]
if last_action == "FAIL" or (type(last_action) == dict and last_action.get('action_type') == 'FAIL'):
return 0
if type(self.metric) == list:
# Multiple metrics to evaluate whether the task is successfully completed
results = []
assert len(self.metric) == len(self.result_getter), "The number of metrics and result getters must be the same"
if "expected" in self.evaluator:
assert len(self.metric) == len(self.expected_getter), "The number of metrics and expected getters must be the same"
for idx, metric in enumerate(self.metric):
try:
config = self.evaluator["result"][idx]
result_state = self.result_getter[idx](self, config)
except FileNotFoundError:
logger.error("File not found!")
if self.metric_conj == 'and':
return 0
if "expected" in self.evaluator and self.expected_getter and self.evaluator["expected"]:
expected_state = self.expected_getter[idx](self, self.evaluator["expected"][idx])
metric: int = metric(result_state, expected_state, **self.metric_options[idx])
else:
metric: int = metric(result_state, **self.metric_options[idx])
if self.metric_conj == 'and' and float(metric) == 0.0:
return 0
elif self.metric_conj == 'or' and float(metric) == 1.0:
return 1
else:
results.append(metric)
return sum(results) / len(results) if self.metric_conj == 'and' else max(results)
else:
# Single metric to evaluate whether the task is successfully completed
try:
result_state = self.result_getter(self, self.evaluator["result"])
except FileNotFoundError:
logger.error("File not found!")
return 0
if "expected" in self.evaluator and self.expected_getter and self.evaluator["expected"]:
expected_state = self.expected_getter(self, self.evaluator["expected"])
metric: float = self.metric(result_state, expected_state, **self.metric_options)
else:
metric: float = self.metric(result_state, **self.metric_options)
return metric
def render(self, mode='rgb_array'):
if mode == 'rgb_array':
return self.controller.get_screenshot()
else:
raise ValueError('Unsupported render mode: {}'.format(mode))

View File

@ -827,8 +827,8 @@ def get_active_tab_info(env, config: Dict[str, str]):
try:
logger.info(f"[ACTIVE_TAB_INFO] Navigating to URL: {active_tab_url}")
page.goto(active_tab_url, wait_until='networkidle', timeout=timeout_ms)
page.wait_for_load_state('networkidle', timeout=timeout_ms) # Wait for the 'load' event to complete
page.goto(active_tab_url, wait_until='load', timeout=timeout_ms)
page.wait_for_load_state('load', timeout=timeout_ms) # Wait for the 'load' event to complete
active_tab_info = {
'title': page.title(),

View File

@ -2,6 +2,8 @@ import functools
import itertools
import logging
import os.path
import re
import unicodedata
# import operator
from numbers import Number
@ -744,6 +746,18 @@ def compare_table(result: str, expected: str = None, **options) -> float:
# }}} function compare_table #
def _normalize_city_string(value: Any) -> str:
"""Lowercase, strip punctuation, and remove accents for tolerant matching."""
if value is None:
return ""
if not isinstance(value, str):
value = str(value)
normalized = unicodedata.normalize("NFKD", value)
normalized = "".join(ch for ch in normalized if not unicodedata.combining(ch))
normalized = re.sub(r"[^a-z0-9]+", " ", normalized.lower())
return normalized.strip()
def compare_conference_city_in_order(actual_city_list_path, expected_city):
expected_city_list = expected_city["expected"]
wb = openpyxl.load_workbook(actual_city_list_path)
@ -752,38 +766,35 @@ def compare_conference_city_in_order(actual_city_list_path, expected_city):
for row in sheet["C2:C22"]:
for cell in row:
actual_city_list.append(cell.value)
# expected_city is the city that we want to compare with the actual city list
# must in order index
# debug
try:
for i in range(len(actual_city_list)):
if isinstance(expected_city_list[i], str):
if expected_city_list[i] not in actual_city_list[i]:
logger.debug(
f"Expected city {expected_city_list[i]}; Actual city {actual_city_list[i]}"
)
print(
f"Expected city {expected_city_list[i]}; Actual city {actual_city_list[i]}"
)
return 0.0
elif isinstance(expected_city_list[i], List):
if not any(
possible_str in actual_city_list[i]
for possible_str in expected_city_list[i]
):
logger.debug(
f"Expected city {expected_city_list[i]}; Actual city {actual_city_list[i]}"
)
print(
f"Expected city {expected_city_list[i]}; Actual city {actual_city_list[i]}"
)
return 0.0
for i, actual_city in enumerate(actual_city_list):
actual_normalized = _normalize_city_string(actual_city)
expected_entry = expected_city_list[i]
if isinstance(expected_entry, str):
expected_candidates = [expected_entry]
elif isinstance(expected_entry, List):
expected_candidates = expected_entry
else:
raise TypeError("Expected city should be a string or a list of strings")
except:
matched = False
for candidate in expected_candidates:
normalized_candidate = _normalize_city_string(candidate)
if normalized_candidate and normalized_candidate in actual_normalized:
matched = True
break
if not matched:
logger.debug(
f"Expected city {expected_entry}; Actual city {actual_city}"
)
print(f"Expected city {expected_entry}; Actual city {actual_city}")
return 0.0
except Exception as exc:
logger.error(f"Error comparing conference cities: {exc}")
return 0.0
return 1.0

View File

@ -10,7 +10,7 @@ from desktop_env.providers.aws.config import ENABLE_TTL, DEFAULT_TTL_MINUTES, AW
from desktop_env.providers.aws.scheduler_utils import schedule_instance_termination
INSTANCE_TYPE = "t3.medium"
INSTANCE_TYPE = "t3.xlarge"
# Load environment variables from .env file
dotenv.load_dotenv()
@ -40,9 +40,9 @@ DEFAULT_REGION = "us-east-1"
# todo: public the AMI images
IMAGE_ID_MAP = {
"us-east-1": {
# (1920, 1080): "ami-0d23263edb96951d8"
(1920, 1080): "ami-0d23263edb96951d8",
# For CoACT-1, uncomment to use the following AMI
(1920, 1080): "ami-0b505e9d0d99ba88c"
# (1920, 1080): "ami-0b505e9d0d99ba88c"
},
"ap-east-1": {
(1920, 1080): "ami-06850864d18fad836"

View File

View File

@ -52,7 +52,7 @@
"type": "rule",
"rules": {
"expected": [
"united.com/en/us/checked-bag-fee-calculator"
"united\\.com/en/us/checked-bag-fee-calculator(/.*)?"
]
}
}

View File

@ -88,7 +88,7 @@
],
"func": "check_image_mirror",
"expected": {
"type": "vm_file",
"type": "cloud_file",
"path": "https://huggingface.co/datasets/xlangai/ubuntu_osworld_file_cache/resolve/main/gimp/72f83cdc-bf76-4531-9a1b-eb893a13f8aa/berry.jpeg",
"dest": "berry.png"
},

View File

@ -32,8 +32,8 @@
"evaluator": {
"func": "check_file_exists_and_structure_sim",
"expected": {
"type": "vm_file",
"path": "/home/user/Desktop/The_Lost_River_Of_Dreams.jpg",
"type": "cloud_file",
"path": "https://huggingface.co/datasets/xlangai/ubuntu_osworld_file_cache/resolve/main/gimp/77b8ab4d-994f-43ac-8930-8ca087d7c4b4/The_Lost_River_Of_Dreams.jpg",
"dest": "The_Lost_River_Of_Dreams.jpg"
},
"result": {

View File

@ -60,7 +60,7 @@
"rules": {
"expected": [
"Zoom Chrome Extension",
"Speechify Text to Speech Voice Reader",
"Speechify — Voice AI Assistant",
"React Developer Tools",
"Momentum",
"Google Translate"

View File

@ -40,8 +40,8 @@
},
"result": {
"type": "vm_file",
"path": "/home/user/Recruitment_and_retention_of_health_professionals_across_Europe.zip",
"dest": "Recruitment_and_retention_of_health_professionals_across_Europe.zip"
"path": "/home/user/essay_submission.zip",
"dest": "essay_submission.zip"
}
},
"proxy": false,

135
lib_results_logger.py Normal file
View File

@ -0,0 +1,135 @@
#!/usr/bin/env python3
"""
Thread-safe results logging for OSWorld evaluations.
Appends task completion results to results.json in real-time.
"""
import json
import os
import time
import fcntl
from pathlib import Path
from typing import Dict, Any, Optional
def extract_domain_from_path(result_path: str) -> str:
"""
Extract domain/application from result directory path.
Expected structure: results/{action_space}/{observation_type}/{model}/{domain}/{task_id}/
"""
path_parts = Path(result_path).parts
if len(path_parts) >= 2:
return path_parts[-2] # Second to last part should be domain
return "unknown"
def append_task_result(
task_id: str,
domain: str,
score: float,
result_dir: str,
args: Any,
error_message: Optional[str] = None
) -> None:
"""
Thread-safely append a task result to results.json.
Args:
task_id: UUID of the task
domain: Application domain (chrome, vlc, etc.)
score: Task score (0.0 or 1.0)
result_dir: Full path to the task result directory
args: Command line arguments object
error_message: Error message if task failed
"""
# Create result entry
result_entry = {
"application": domain,
"task_id": task_id,
"status": "error" if error_message else "success",
"score": score,
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
}
if error_message:
result_entry["err_message"] = error_message
# Determine summary directory and results file path
# Extract base result directory from args
base_result_dir = Path(args.result_dir)
summary_dir = base_result_dir / "summary"
results_file = summary_dir / "results.json"
# Ensure summary directory exists
summary_dir.mkdir(parents=True, exist_ok=True)
# Thread-safe JSON append with file locking
try:
with open(results_file, 'a+') as f:
# Lock the file for exclusive access
fcntl.flock(f.fileno(), fcntl.LOCK_EX)
try:
# Move to beginning to read existing content
f.seek(0)
content = f.read().strip()
# Parse existing JSON array or create new one
if content:
try:
existing_results = json.loads(content)
if not isinstance(existing_results, list):
existing_results = []
except json.JSONDecodeError:
existing_results = []
else:
existing_results = []
# Add new result
existing_results.append(result_entry)
# Write back the complete JSON array
f.seek(0)
f.truncate()
json.dump(existing_results, f, indent=2)
f.write('\n') # Add newline for readability
finally:
# Always unlock the file
fcntl.flock(f.fileno(), fcntl.LOCK_UN)
print(f"📝 Logged result: {domain}/{task_id} -> {result_entry['status']} (score: {score})")
except Exception as e:
# Don't let logging errors break the main evaluation
print(f"⚠️ Failed to log result for {task_id}: {e}")
def log_task_completion(example: Dict, result: float, result_dir: str, args: Any) -> None:
"""
Convenience wrapper for logging successful task completion.
Args:
example: Task configuration dictionary
result: Task score
result_dir: Path to task result directory
args: Command line arguments
"""
task_id = example.get('id', 'unknown')
domain = extract_domain_from_path(result_dir)
append_task_result(task_id, domain, result, result_dir, args)
def log_task_error(example: Dict, error_msg: str, result_dir: str, args: Any) -> None:
"""
Convenience wrapper for logging task errors.
Args:
example: Task configuration dictionary
error_msg: Error message
result_dir: Path to task result directory
args: Command line arguments
"""
task_id = example.get('id', 'unknown')
domain = extract_domain_from_path(result_dir)
append_task_result(task_id, domain, 0.0, result_dir, args, error_msg)

View File

@ -4,18 +4,22 @@ import logging
import os
import time
from wrapt_timeout_decorator import *
from lib_results_logger import log_task_completion
logger = logging.getLogger("desktopenv.experiment")
def run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir, scores):
runtime_logger = setup_logger(example, example_result_dir)
try:
agent.reset(runtime_logger)
except Exception as e:
agent.reset()
# Reset environment first to get fresh VM IP
env.reset(task_config=example)
# Reset agent with fresh VM IP (for snapshot reverts)
try:
agent.reset(runtime_logger, vm_ip=env.vm_ip)
except Exception as e:
agent.reset(vm_ip=env.vm_ip)
time.sleep(60) # Wait for the environment to be ready
obs = env._get_obs() # Get the initial observation
@ -29,7 +33,7 @@ def run_single_example(agent, env, example, max_steps, instruction, args, exampl
)
for action in actions:
# Capture the timestamp before executing the action
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S%f")
logger.info("Step %d: %s", step_idx + 1, action)
obs, reward, done, info = env.step(action, args.sleep_after_execution)
@ -55,11 +59,16 @@ def run_single_example(agent, env, example, max_steps, instruction, args, exampl
logger.info("The episode is done.")
break
step_idx += 1
time.sleep(20) # Wait for the environment to settle
result = env.evaluate()
logger.info("Result: %.2f", result)
scores.append(result)
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
f.write(f"{result}\n")
# Log task completion to results.json
log_task_completion(example, result, example_result_dir, args)
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
@ -96,6 +105,67 @@ def run_single_example_human(env, example, max_steps, instruction, args, example
def run_single_example_agi(agent, env, example, max_steps, instruction, args, example_result_dir, scores):
runtime_logger = setup_logger(example, example_result_dir)
agent.reset(runtime_logger)
env.reset(task_config=example)
time.sleep(60) # Wait for the environment to be ready
obs = env._get_obs() # Get the initial observation
done = False
step_idx = 0
env.controller.start_recording()
while not done and step_idx < max_steps:
response, actions = agent.predict(
instruction,
obs
)
done = not response.get('state_correct', False)
for action in actions:
# Capture the timestamp before executing the action
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
logger.info("Step %d: %s", step_idx + 1, action)
obs, reward, done, info, step_info = agent.step(action)
if not done:
if not response.get('state_correct', False):
done = True
logger.info("Reward: %.2f", reward)
logger.info("Done: %s", done)
# Save screenshot and trajectory information
with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"),
"wb") as _f:
_f.write(obs['screenshot'])
# Remove pending checks if they exist which will cause issues with json serialization
if action.get('pending_checks', None):
del action['pending_checks']
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write(json.dumps({
"step_num": step_idx + 1,
"action_timestamp": action_timestamp,
"action": action,
"reward": reward,
"done": done,
"info": info,
"screenshot_file": f"step_{step_idx + 1}_{action_timestamp}.png"
}))
f.write("\n")
if done:
logger.info("The episode is done.")
break
step_idx += 1
result = env.evaluate()
logger.info("Result: %.2f", result)
scores.append(result)
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
f.write(f"{result}\n")
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
def run_single_example_openaicua(agent, env, example, max_steps, instruction, args, example_result_dir, scores):
runtime_logger = setup_logger(example, example_result_dir)
agent.reset(runtime_logger)
@ -186,23 +256,25 @@ def run_single_example_opencua(agent, env, example, max_steps, instruction, args
"wb") as _f:
_f.write(obs['screenshot'])
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
with open(os.path.join(example_result_dir, "traj.jsonl"), "a", encoding="utf-8") as f:
f.write(json.dumps({
"step_num": step_idx + 1,
"action_timestamp": action_timestamp,
"action": action,
"natural_language_action": info_dict.get("action"),
"action_timestamp": action_timestamp,
"response": response,
"reward": reward,
"done": done,
"info": info,
"screenshot_file": f"step_{step_idx + 1}_{action_timestamp}.png"
}))
}, ensure_ascii=False))
f.write("\n")
if done:
logger.info("The episode is done.")
break
step_idx += 1
time.sleep(20) # Wait for the environment to settle
result = env.evaluate()
logger.info("Result: %.2f", result)
scores.append(result)
@ -389,3 +461,185 @@ def run_single_example_uipath(agent, env, example, max_steps, instruction, args,
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
f.write(f"{result}\n")
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
from mm_agents.os_symphony.utils.common_utils import draw_coordinates
from mm_agents.os_symphony.utils.process_context import set_current_result_dir
logger = logging.getLogger("desktopenv.experiment")
def run_single_example_os_symphony(agent, env, example, max_steps, instruction, args, example_result_dir, scores):
set_current_result_dir(example_result_dir)
agent.reset(result_dir=example_result_dir)
env.reset(task_config=example)
time.sleep(30) # Wait for the environment to be ready
obs = env._get_obs() # Get the initial observation
done = False
step_idx = 0
# env.controller.start_recording()
start_time = time.time()
while not done and step_idx < max_steps:
response, actions = agent.predict(
instruction,
obs,
step_idx == max_steps - 1
)
for action in actions:
# Save screenshot and trajectory information
if "reflection" in response and response["reflection"].get("is_milestone"):
img_name = f"step_{step_idx + 1}_milestone.png"
else:
img_name = f"step_{step_idx + 1}.png"
with open(os.path.join(example_result_dir, img_name),
"wb") as _f:
_f.write(obs['screenshot'])
if "coordinates" in response and response["coordinates"]:
draw_coordinates(
image_bytes=obs['screenshot'],
coordinates=response["coordinates"],
save_path=os.path.join(example_result_dir, img_name[:-4] + "_draw.png")
)
logger.info("Step %d: %s", step_idx + 1, action)
obs, reward, done, info = env.step(action, args.sleep_after_execution)
logger.info("Done: %s", done)
with open(os.path.join(example_result_dir, "traj.jsonl"), "a", encoding="utf-8") as f:
f.write(json.dumps({
"instruction": instruction,
"step_num": step_idx + 1,
"action": action,
"response": response,
"done": done,
"info": info,
"screenshot_file": img_name
}))
f.write("\n")
with open(os.path.join(example_result_dir, f"traj_{step_idx+1}.json"), "w", encoding="utf-8") as f:
json.dump({
"step_num": step_idx + 1,
"action": action,
"response": response,
"done": done,
"info": info,
"screenshot_file": img_name
}, f, indent=4, ensure_ascii=False)
if done:
logger.info("The episode is done.")
time.sleep(60)
break
step_idx += 1
end_time = time.time()
result = float(env.evaluate())
logger.info("Result: %.2f", result)
scores.append(result)
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
f.write(f"{result}\n")
with open(os.path.join(example_result_dir, "time.txt"), "w", encoding="utf-8") as f:
f.write(f"{end_time-start_time:.2f}\n")
def run_single_example_evocua(agent, env, example, max_steps, instruction, args, example_result_dir, scores):
"""
Unified run function for EvoCUAAgent (supporting both S1 and S2 modes).
"""
runtime_logger = setup_logger(example, example_result_dir)
# Reset Environment
env.reset(task_config=example)
# Reset Agent
# Handle agent reset signature differences if any
try:
agent.reset(runtime_logger, vm_ip=env.vm_ip)
except Exception:
try:
agent.reset(runtime_logger)
except Exception:
agent.reset()
time.sleep(60) # Wait for the environment to be ready
obs = env._get_obs() # Get the initial observation
done = False
step_idx = 0
env.controller.start_recording()
while not done and step_idx < max_steps:
# EvoCUAAgent.predict unified signature: returns (response, actions)
# It handles both modes internally.
predict_res = agent.predict(instruction, obs)
# Check return signature logic
if len(predict_res) == 3:
# Compatibility with S1 original signature if agent was updated to match
response, actions, info_dict = predict_res
else:
response, actions = predict_res
info_dict = {}
logger.info(f"Step {step_idx + 1} Actions: {actions}")
# Break if no actions (fail-safe)
if not actions or (len(actions) == 1 and (actions[0] == "" or "error" in actions[0].lower())):
# Allow "FAIL" or "DONE" to process through execution loop if agent outputs them as actions
if not (actions and actions[0] in ["FAIL", "DONE"]):
logger.warning("No valid actions returned. Breaking loop.")
break
for action in actions:
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S%f")
logger.info("Executing action: %s", action)
# Execute
obs, reward, done, info = env.step(action, args.sleep_after_execution)
logger.info("Reward: %.2f", reward)
logger.info("Done: %s", done)
# Save screenshot
screenshot_file = f"step_{step_idx + 1}_{action_timestamp}.png"
with open(os.path.join(example_result_dir, screenshot_file), "wb") as _f:
_f.write(obs['screenshot'])
# Log Trajectory
log_entry = {
"step_num": step_idx + 1,
"action_timestamp": action_timestamp,
"action": action,
"response": response,
"reward": reward,
"done": done,
"info": info,
"screenshot_file": screenshot_file
}
# Add natural language info if available (S1 style)
if info_dict:
log_entry["natural_language_action"] = info_dict.get("action")
with open(os.path.join(example_result_dir, "traj.jsonl"), "a", encoding="utf-8") as f:
f.write(json.dumps(log_entry, ensure_ascii=False))
f.write("\n")
if done:
logger.info("The episode is done.")
break
step_idx += 1
time.sleep(20) # Wait for environment to settle
result = env.evaluate()
logger.info("Result: %.2f", result)
scores.append(result)
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
f.write(f"{result}\n")
log_task_completion(example, result, example_result_dir, args)
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))

84
lib_run_single_os_symphony.py Executable file
View File

@ -0,0 +1,84 @@
import json
import logging
import os
import time
from wrapt_timeout_decorator import *
from mm_agents.os_symphony.utils.common_utils import draw_coordinates
from mm_agents.os_symphony.utils.process_context import set_current_result_dir
logger = logging.getLogger("desktopenv.experiment")
def run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir, scores):
set_current_result_dir(example_result_dir)
agent.reset(result_dir=example_result_dir)
env.reset(task_config=example)
time.sleep(30) # Wait for the environment to be ready
obs = env._get_obs() # Get the initial observation
done = False
step_idx = 0
# env.controller.start_recording()
start_time = time.time()
while not done and step_idx < max_steps:
response, actions = agent.predict(
instruction,
obs,
step_idx == max_steps - 1
)
for action in actions:
# Save screenshot and trajectory information
if "reflection" in response and response["reflection"].get("is_milestone"):
img_name = f"step_{step_idx + 1}_milestone.png"
else:
img_name = f"step_{step_idx + 1}.png"
with open(os.path.join(example_result_dir, img_name),
"wb") as _f:
_f.write(obs['screenshot'])
if "coordinates" in response and response["coordinates"]:
draw_coordinates(
image_bytes=obs['screenshot'],
coordinates=response["coordinates"],
save_path=os.path.join(example_result_dir, img_name[:-4] + "_draw.png")
)
logger.info("Step %d: %s", step_idx + 1, action)
obs, reward, done, info = env.step(action, args.sleep_after_execution)
logger.info("Done: %s", done)
with open(os.path.join(example_result_dir, "traj.jsonl"), "a", encoding="utf-8") as f:
f.write(json.dumps({
"instruction": instruction,
"step_num": step_idx + 1,
"action": action,
"response": response,
"done": done,
"info": info,
"screenshot_file": img_name
}))
f.write("\n")
with open(os.path.join(example_result_dir, f"traj_{step_idx+1}.json"), "w", encoding="utf-8") as f:
json.dump({
"step_num": step_idx + 1,
"action": action,
"response": response,
"done": done,
"info": info,
"screenshot_file": img_name
}, f, indent=4, ensure_ascii=False)
if done:
logger.info("The episode is done.")
time.sleep(60)
break
step_idx += 1
end_time = time.time()
result = float(env.evaluate())
logger.info("Result: %.2f", result)
scores.append(result)
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
f.write(f"{result}\n")
with open(os.path.join(example_result_dir, "time.txt"), "w", encoding="utf-8") as f:
f.write(f"{end_time-start_time:.2f}\n")

View File

@ -1134,10 +1134,12 @@ class PromptAgent:
return actions
def reset(self, _logger=None):
def reset(self, _logger=None, vm_ip=None, **kwargs):
global logger
logger = _logger if _logger is not None else logging.getLogger("desktopenv.agent")
self.vm_ip = vm_ip
self.thoughts = []
self.actions = []
self.observations = []
self.observations = []

219
mm_agents/agi_agent.py Normal file
View File

@ -0,0 +1,219 @@
import base64
import logging
import time
from typing import Dict, List, Tuple, Any, Optional
import httpx
logger = logging.getLogger("desktopenv.agent")
class Timer:
"""Context manager for timing code blocks."""
def __enter__(self):
self.start = time.time()
return self
def __exit__(self, *args):
self.duration = time.time() - self.start
class AGIAgent:
"""Agent that communicates with your private AGI server for decision-making."""
def __init__(
self,
env,
server_url: str = "https://your-private-agi-endpoint", # Contact the authors for access to a private deployment endpoint.
platform: str = "ubuntu",
action_space: str = "pyautogui",
observation_type: str = "screenshot",
max_trajectory_length: int = 100,
client_password: str = "",
provider_name: str = "aws",
screen_width: int = 1920,
screen_height: int = 1080,
timeout: int = 1800,
):
"""Initialize the AGI client.
Args:
env: The desktop environment
server_url: URL of your private AGI server
"""
self.env = env
self.server_url = server_url.rstrip("/")
self.platform = platform
self.action_space = action_space
self.observation_type = observation_type
self.max_trajectory_length = max_trajectory_length
self.client_password = client_password
self.provider_name = provider_name
self.screen_width = screen_width
self.screen_height = screen_height
# Session management
self.session_id: Optional[str] = None
self.instruction: Optional[str] = None
# HTTP client
self.client = httpx.Client(timeout=timeout)
# Tracking
self.thoughts = []
self.actions = []
self.observations = []
logger.info(f"Initialized AGIAgent with server URL: {self.server_url}")
def reset(self, runtime_logger=None):
"""Reset the agent and create a new session on the server.
Args:
runtime_logger: Optional logger for runtime information
"""
global logger
logger = runtime_logger if runtime_logger is not None else logging.getLogger("desktopenv.agent")
# Clear local state
self.thoughts = []
self.actions = []
self.observations = []
self.session_id = None
logger.info("AGIAgent reset complete")
def _create_session(self, instruction: str) -> str:
"""Create a new session on the server.
Args:
instruction: The task instruction
Returns:
The session ID
Equivalent curl request:
curl -X POST {server_url}/sessions \
-H "Content-Type: application/json" \
-d '{"task_description": "{instruction}"}'
"""
try:
# print(f"Creating session with instruction: {instruction}")
# print(f"Server URL: {self.server_url}")
response = self.client.post(
f"{self.server_url}/sessions",
json={"task_description": instruction}
)
response.raise_for_status()
session_id = response.json()["session_id"]
logger.info(f"Created session: {session_id}")
return session_id
except Exception as e:
logger.error(f"Failed to create session: {e}")
raise
def predict(self, instruction: str, obs: Dict) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
"""Predict the next action based on the current observation.
Args:
instruction: The task instruction
obs: Observation dictionary containing 'screenshot' key with image bytes
Returns:
Tuple of (predict_info dict, list of action dicts)
"""
# Create session on first prediction
if self.session_id is None:
self.instruction = instruction
self.session_id = self._create_session(instruction)
# input("Session created, press Enter to continue")
# Encode screenshot to base64
screenshot_bytes = obs["screenshot"]
screenshot_b64 = base64.b64encode(screenshot_bytes).decode("utf-8")
# Call the server
with Timer() as model_timer:
try:
response = self.client.post(
f"{self.server_url}/sessions/{self.session_id}/step",
json={
"screenshot_base64_png": screenshot_b64,
"error": None # Could be populated from previous step errors
}
)
response.raise_for_status()
result = response.json()
parsed_action = result["parsed_response"]
logger.info(f"Server returned action: {parsed_action[:100]}...")
except Exception as e:
logger.error(f"Error calling server: {e}")
raise
# Format response as expected by lib_run_single
actions = [{
"action_space": "pyautogui",
"action": parsed_action,
"pending_checks": [],
"call_id": ""
}]
# Check if task is complete or failed
state_correct = parsed_action not in ["FAIL", "DONE"]
predict_info = {
"model_usage": {
"model_time": model_timer.duration,
"prompt_tokens": 0, # Server doesn't expose these
"completion_tokens": 0,
},
"messages": [], # Server manages conversation history
"response": parsed_action,
"state_correct": state_correct,
}
return predict_info, actions
def step(self, action: Dict[str, Any]) -> Tuple[Dict, float, bool, Dict, Dict]:
"""Execute an action in the environment.
Args:
action: Action dictionary with 'action' key containing PyAutoGUI command
Returns:
Tuple of (observation, reward, done, info, step_info)
"""
try:
if not action:
logger.warning("Empty action received, terminating episode")
# Get observation without executing action
obs = self.env._get_obs()
return obs, 0.0, True, {}, {"step_time": 0.0, "action": action}
action_str = action.get("action", "")
logger.info(f"Executing action: {action_str[:100]}...")
with Timer() as step_timer:
# Execute the action directly (it's already a PyAutoGUI command string)
obs, reward, terminated, info = self.env.step(action_str)
logger.debug(f"Action completed in {step_timer.duration:.2f}s")
if terminated:
logger.info("Environment signaled termination")
return obs, reward, terminated, info, {
"step_time": step_timer.duration,
"action": action
}
except Exception as e:
logger.exception(f"Environment step failed: {str(e)}")
raise
def close(self):
"""Close the HTTP client."""
self.client.close()

View File

@ -17,7 +17,7 @@ from anthropic.types.beta import (
BetaMessageParam,
BetaTextBlockParam,
)
from .utils import COMPUTER_USE_BETA_FLAG, PROMPT_CACHING_BETA_FLAG,SYSTEM_PROMPT, SYSTEM_PROMPT_WINDOWS, APIProvider, PROVIDER_TO_DEFAULT_MODEL_NAME
from .utils import COMPUTER_USE_BETA_FLAG, PROMPT_CACHING_BETA_FLAG,SYSTEM_PROMPT, SYSTEM_PROMPT_WINDOWS, APIProvider, PROVIDER_TO_DEFAULT_MODEL_NAME, get_model_name
from .utils import _response_to_params, _inject_prompt_caching, _maybe_filter_to_n_most_recent_images
import logging
@ -30,14 +30,18 @@ API_RETRY_INTERVAL = 5
class AnthropicAgent:
def __init__(self,
platform: str = "Ubuntu",
model: str = "claude-3-5-sonnet-20241022",
provider: APIProvider = APIProvider.BEDROCK,
model: str = "claude-sonnet-4-5-20250929",
provider: APIProvider = APIProvider.ANTHROPIC,
max_tokens: int = 4096,
api_key: str = os.environ.get("ANTHROPIC_API_KEY", None),
system_prompt_suffix: str = "",
only_n_most_recent_images: Optional[int] = 10,
action_space: str = "claude_computer_use",
screen_size: tuple[int, int] = (1920, 1080),
no_thinking: bool = False,
use_isp: bool = False,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
*args, **kwargs
):
self.platform = platform
@ -52,10 +56,24 @@ class AnthropicAgent:
self.only_n_most_recent_images = only_n_most_recent_images
self.messages: list[BetaMessageParam] = []
self.screen_size = screen_size
self.no_thinking = no_thinking
self.use_isp = use_isp
self.temperature = temperature
self.top_p = top_p
self.resize_factor = (
screen_size[0] / 1280, # Assuming 1280 is the base width
screen_size[1] / 720 # Assuming 720 is the base height
)
def _get_sampling_params(self):
"""Get sampling parameters (temperature and/or top_p) - let API validate exclusivity"""
params = {}
if self.temperature is not None:
params['temperature'] = self.temperature
if self.top_p is not None:
params['top_p'] = self.top_p
return params
def add_tool_result(self, tool_call_id: str, result: str, screenshot: bytes = None):
"""Add tool result to message history"""
@ -84,6 +102,21 @@ class AnthropicAgent:
"content": tool_result_content
})
def _extract_raw_response_string(self, response) -> str:
"""Extract and concatenate raw response content into a single string."""
raw_response_str = ""
if response.content:
for block in response.content:
if hasattr(block, 'text') and block.text:
raw_response_str += f"[TEXT] {block.text}\n"
elif hasattr(block, 'thinking') and block.thinking:
raw_response_str += f"[THINKING] {block.thinking}\n"
elif hasattr(block, 'name') and hasattr(block, 'input'):
raw_response_str += f"[TOOL_USE] {block.name}: {block.input}\n"
else:
raw_response_str += f"[OTHER] {str(block)}\n"
return raw_response_str.strip()
def parse_actions_from_tool_call(self, tool_call: Dict) -> str:
result = ""
function_args = (
@ -194,13 +227,23 @@ class AnthropicAgent:
result += (f"pyautogui.keyUp('{key}')\n")
expected_outcome = f"Key {key} pressed."
elif action == "type":
result += (
f"pyautogui.typewrite(\"\"\"{text}\"\"\", interval=0.01)\n"
)
for char in text:
if char == '\n':
result += "pyautogui.press('enter')\n"
elif char == "'":
result += 'pyautogui.press("\'")\n'
elif char == '\\':
result += "pyautogui.press('\\\\')\n"
elif char == '"':
result += "pyautogui.press('\"')\n"
else:
result += f"pyautogui.press('{char}')\n"
expected_outcome = f"Text {text} written."
# Handle scroll actions
elif action == "scroll":
if text is not None:
result += (f"pyautogui.keyDown('{text.lower()}')\n")
if coordinate is None:
if scroll_direction in ("up", "down"):
result += (
@ -221,6 +264,8 @@ class AnthropicAgent:
result += (
f"pyautogui.hscroll({scroll_amount if scroll_direction == 'right' else -scroll_amount}, {x}, {y})\n"
)
if text is not None:
result += (f"pyautogui.keyUp('{text.lower()}')\n")
expected_outcome = "Scroll action finished"
# Handle click actions
@ -285,7 +330,7 @@ class AnthropicAgent:
expected_outcome = "Call user"
elif action == "screenshot":
result += "pyautogui.sleep(0.1)\n"
expected_outcome = "Screenshot taken"
expected_outcome = "Screenshot taken"
else:
raise ValueError(f"Invalid action: {action}")
@ -303,6 +348,9 @@ class AnthropicAgent:
screenshot_bytes = obs["screenshot"]
screenshot_image = Image.open(io.BytesIO(screenshot_bytes))
# Store original unresized screenshot for zoom processing
obs["screenshot_original"] = screenshot_bytes
# Calculate new size based on resize factor
new_width, new_height = 1280, 720
@ -334,23 +382,45 @@ class AnthropicAgent:
]
})
if self.messages and "tool_use" in [content_block["type"] for content_block in self.messages[-1]["content"]]:
self.add_tool_result(
self.messages[-1]["content"][-1]["id"],
f"Success",
screenshot=obs.get("screenshot") if obs else None
)
# Add tool_result for ALL tool_use blocks in the last message
if self.messages:
last_message_content = self.messages[-1]["content"]
tool_use_blocks = [block for block in last_message_content if block.get("type") == "tool_use"]
for i, tool_block in enumerate(tool_use_blocks):
tool_input = tool_block.get("input", {})
action = tool_input.get("action")
is_last_tool = i == len(tool_use_blocks) - 1
include_screenshot = None
if obs:
if action == "screenshot":
# Screenshot action always gets regular screenshot
include_screenshot = obs.get("screenshot")
elif is_last_tool:
# Auto-screenshot: last tool gets regular screenshot (unless it's zoom, handled above)
include_screenshot = obs.get("screenshot")
self.add_tool_result(
tool_block["id"],
f"Success",
screenshot=include_screenshot
)
enable_prompt_caching = False
betas = ["computer-use-2025-01-24"]
if self.model_name == "claude-3-7-sonnet-20250219" or self.model_name == "claude-4-opus-20250514" or self.model_name == "claude-4-sonnet-20250514":
betas = ["computer-use-2025-01-24"]
elif self.model_name == "claude-3-5-sonnet-20241022":
betas = [COMPUTER_USE_BETA_FLAG]
betas = [COMPUTER_USE_BETA_FLAG]
# Add interleaved thinking beta if ISP is requested
if self.use_isp:
betas.append("interleaved-thinking-2025-05-14")
logger.info(f"Added interleaved thinking beta. Betas: {betas}")
image_truncation_threshold = 10
if self.provider == APIProvider.ANTHROPIC:
client = Anthropic(api_key=self.api_key, max_retries=4)
client = Anthropic(api_key=self.api_key, max_retries=4).with_options(
default_headers={"anthropic-beta": COMPUTER_USE_BETA_FLAG}
)
enable_prompt_caching = True
elif self.provider == APIProvider.VERTEX:
client = AnthropicVertex()
@ -368,7 +438,7 @@ class AnthropicAgent:
if enable_prompt_caching:
betas.append(PROMPT_CACHING_BETA_FLAG)
_inject_prompt_caching(self.messages)
image_truncation_threshold = 50
image_truncation_threshold = 20
system["cache_control"] = {"type": "ephemeral"}
if self.only_n_most_recent_images:
@ -378,49 +448,65 @@ class AnthropicAgent:
min_removal_threshold=image_truncation_threshold,
)
try:
if self.model_name == "claude-3-5-sonnet-20241022":
tools = [
{'name': 'computer', 'type': 'computer_20241022', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1},
# {'type': 'bash_20241022', 'name': 'bash'},
# {'name': 'str_replace_editor', 'type': 'text_editor_20241022'}
] if self.platform == 'Ubuntu' else [
{'name': 'computer', 'type': 'computer_20241022', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1},
]
elif self.model_name in ["claude-3-7-sonnet-20250219", "claude-4-opus-20250514", "claude-4-sonnet-20250514"]:
tools = [
{'name': 'computer', 'type': 'computer_20250124', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1},
# {'type': 'bash_20250124', 'name': 'bash'},
# {'name': 'str_replace_editor', 'type': 'text_editor_20250124'}
] if self.platform == 'Ubuntu' else [
{'name': 'computer', 'type': 'computer_20250124', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1},
]
# Configure tool settings - use modern computer tool for all models
tool_config = {
'name': 'computer',
'type': 'computer_20250124',
'display_width_px': 1280,
'display_height_px': 720,
'display_number': 1
}
tools = [
tool_config,
] if self.platform == 'Ubuntu' else [
tool_config,
]
# Configure thinking mode based on user preferences
if self.no_thinking:
# Disable thinking mode - omit the thinking parameter
extra_body = {}
actual_max_tokens = self.max_tokens # Use default when no thinking
logger.info("Thinking mode: DISABLED")
else:
# Enable thinking mode (regular or interleaved)
# Use consistent 2048 budget for both regular and ISP thinking
budget_tokens = 2048
# For regular thinking: max_tokens > budget_tokens (API requirement)
# For ISP: budget_tokens can exceed max_tokens (represents total across all thinking blocks)
if self.max_tokens <= budget_tokens:
required_max_tokens = budget_tokens + 500 # Give some headroom
logger.warning(f"Regular thinking requires max_tokens > budget_tokens. Increasing max_tokens from {self.max_tokens} to {required_max_tokens}")
actual_max_tokens = required_max_tokens
else:
actual_max_tokens = self.max_tokens
extra_body = {
"thinking": {"type": "enabled", "budget_tokens": 1024}
"thinking": {"type": "enabled", "budget_tokens": budget_tokens}
}
if self.use_isp:
logger.info("Thinking mode: INTERLEAVED SCRATCHPAD (ISP)")
else:
logger.info("Thinking mode: REGULAR SCRATCHPAD")
try:
response = None
for attempt in range(API_RETRY_TIMES):
try:
if self.model_name in ["claude-3-7-sonnet-20250219", "claude-4-opus-20250514", "claude-4-sonnet-20250514"]:
response = client.beta.messages.create(
max_tokens=self.max_tokens,
messages=self.messages,
model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name],
system=[system],
tools=tools,
betas=betas,
extra_body=extra_body
)
elif self.model_name == "claude-3-5-sonnet-20241022":
response = client.beta.messages.create(
max_tokens=self.max_tokens,
messages=self.messages,
model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name],
system=[system],
tools=tools,
betas=betas,
)
response = client.beta.messages.create(
max_tokens=actual_max_tokens,
messages=self.messages,
model=get_model_name(self.provider, self.model_name),
system=[system],
tools=tools,
betas=betas,
extra_body=extra_body,
**self._get_sampling_params()
)
logger.info(f"Response: {response}")
break
except (APIError, APIStatusError, APIResponseValidationError) as e:
@ -450,26 +536,20 @@ class AnthropicAgent:
try:
logger.warning("Retrying with backup API key...")
backup_client = Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY_BACKUP"), max_retries=4)
if self.model_name in ["claude-3-7-sonnet-20250219", "claude-4-opus-20250514", "claude-4-sonnet-20250514"]:
response = backup_client.beta.messages.create(
max_tokens=self.max_tokens,
messages=self.messages,
model=PROVIDER_TO_DEFAULT_MODEL_NAME[APIProvider.ANTHROPIC, self.model_name],
system=[system],
tools=tools,
betas=betas,
extra_body=extra_body
)
elif self.model_name == "claude-3-5-sonnet-20241022":
response = backup_client.beta.messages.create(
max_tokens=self.max_tokens,
messages=self.messages,
model=PROVIDER_TO_DEFAULT_MODEL_NAME[APIProvider.ANTHROPIC, self.model_name],
system=[system],
tools=tools,
betas=betas,
)
backup_client = Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY_BACKUP"), max_retries=4).with_options(
default_headers={"anthropic-beta": COMPUTER_USE_BETA_FLAG}
)
response = backup_client.beta.messages.create(
max_tokens=actual_max_tokens,
messages=self.messages,
model=get_model_name(self.provider, self.model_name),
system=[system],
tools=tools,
betas=betas,
extra_body=extra_body,
**self._get_sampling_params()
)
logger.info("Successfully used backup API key")
except Exception as backup_e:
backup_error_msg = str(backup_e)
@ -497,9 +577,16 @@ class AnthropicAgent:
logger.exception(f"Error in Anthropic API: {str(e)}")
return None, None
if response is None:
logger.error("Response is None after API call - this should not happen")
return None, None
response_params = _response_to_params(response)
logger.info(f"Received response params: {response_params}")
# Convert raw response to concatenated string for trajectory logging
raw_response_str = self._extract_raw_response_string(response)
# Store response in message history
self.messages.append({
"role": "assistant",
@ -518,7 +605,8 @@ class AnthropicAgent:
"input": cast(dict[str, Any], content_block["input"]),
"id": content_block["id"],
"action_type": content_block.get("type"),
"command": self.parse_actions_from_tool_call(content_block)
"command": self.parse_actions_from_tool_call(content_block),
"raw_response": raw_response_str # Add raw response to each action
})
elif content_block["type"] == "text":
reasonings.append(content_block["text"])
@ -526,10 +614,23 @@ class AnthropicAgent:
reasonings = reasonings[0]
else:
reasonings = ""
# Check if the model indicated the task is infeasible
if raw_response_str and "[INFEASIBLE]" in raw_response_str:
logger.info("Detected [INFEASIBLE] pattern in response, triggering FAIL action")
# Override actions with FAIL
actions = [{
"action_type": "FAIL",
"raw_response": raw_response_str
}]
logger.info(f"Received actions: {actions}")
logger.info(f"Received reasonings: {reasonings}")
if len(actions) == 0:
actions = ["DONE"]
actions = [{
"action_type": "DONE",
"raw_response": raw_response_str
}]
return reasonings, actions
except Exception as e:
logger.warning(f"parse_actions_from_tool_call parsing failed (attempt {parse_retry+1}/3), will retry API request: {e}")
@ -539,25 +640,17 @@ class AnthropicAgent:
response = None
for attempt in range(API_RETRY_TIMES):
try:
if self.model_name in ["claude-3-7-sonnet-20250219", "claude-4-opus-20250514", "claude-4-sonnet-20250514"]:
response = client.beta.messages.create(
max_tokens=self.max_tokens,
messages=self.messages,
model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name],
system=[system],
tools=tools,
betas=betas,
extra_body=extra_body
)
elif self.model_name == "claude-3-5-sonnet-20241022":
response = client.beta.messages.create(
max_tokens=self.max_tokens,
messages=self.messages,
model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name],
system=[system],
tools=tools,
betas=betas,
)
response = client.beta.messages.create(
max_tokens=actual_max_tokens,
messages=self.messages,
model=get_model_name(self.provider, self.model_name),
system=[system],
tools=tools,
betas=betas,
extra_body=extra_body,
**self._get_sampling_params()
)
logger.info(f"Response: {response}")
break # Success, exit retry loop
except (APIError, APIStatusError, APIResponseValidationError) as e2:
@ -569,13 +662,20 @@ class AnthropicAgent:
raise
response_params = _response_to_params(response)
logger.info(f"Received response params: {response_params}")
# Update raw response string for retry case (will be used in next loop iteration)
raw_response_str = self._extract_raw_response_string(response)
self.messages.append({
"role": "assistant",
"content": response_params
})
if parse_retry == max_parse_retry - 1:
logger.error(f"parse_actions_from_tool_call parsing failed 3 times consecutively, terminating: {e}")
actions = ["FAIL"]
actions = [{
"action_type": "FAIL",
"raw_response": f"Failed to parse actions from tool call after {max_parse_retry} attempts: {e}"
}]
return reasonings, actions
def reset(self, _logger = None, *args, **kwargs):
"""

View File

@ -27,7 +27,7 @@ from datetime import datetime
from .tools import ToolResult
COMPUTER_USE_BETA_FLAG = "computer-use-2024-10-22"
COMPUTER_USE_BETA_FLAG = "computer-use-2025-01-24"
PROMPT_CACHING_BETA_FLAG = "prompt-caching-2024-07-31"
@ -47,12 +47,25 @@ PROVIDER_TO_DEFAULT_MODEL_NAME: dict[(APIProvider, str), str] = {
(APIProvider.ANTHROPIC, "claude-4-opus-20250514"): "claude-4-opus-20250514",
(APIProvider.BEDROCK, "claude-4-opus-20250514"): "us.anthropic.claude-opus-4-20250514-v1:0",
(APIProvider.VERTEX, "claude-4-opus-20250514"): "claude-4-opus-v1@20250514",
# Add mapping for the alternative model name format
(APIProvider.ANTHROPIC, "claude-opus-4-20250514"): "claude-opus-4-20250514",
(APIProvider.ANTHROPIC, "claude-opus-4-1-20250805"): "claude-opus-4-1-20250805",
(APIProvider.ANTHROPIC, "claude-4-sonnet-20250514"): "claude-4-sonnet-20250514",
(APIProvider.ANTHROPIC, "claude-sonnet-4-20250514"): "claude-sonnet-4-20250514",
(APIProvider.BEDROCK, "claude-4-sonnet-20250514"): "us.anthropic.claude-sonnet-4-20250514-v1:0",
(APIProvider.VERTEX, "claude-4-sonnet-20250514"): "claude-sonnet-4-v1@20250514",
}
def get_model_name(provider: APIProvider, model_name: str) -> str:
"""
Get the actual model name to use for API calls.
Simply returns the model name as-is for direct API usage.
"""
return model_name
# This system prompt is optimized for the Docker environment in this repository and
# specific tool combinations enabled.
# We encourage modifying this system prompt to ensure the model has context for the
@ -67,8 +80,15 @@ SYSTEM_PROMPT = f"""<SYSTEM_CAPABILITY>
* When viewing a page it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available.
* DO NOT ask users for clarification during task execution. DO NOT stop to request more information from users. Always take action using available tools.
* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request.
* TASK FEASIBILITY: You can declare a task infeasible at any point during execution - whether at the beginning after taking a screenshot, or later after attempting some actions and discovering barriers. Carefully evaluate whether the task is feasible given the current system state, available applications, and task requirements. If you determine that a task cannot be completed due to:
- Missing required applications or dependencies that cannot be installed
- Insufficient permissions or system limitations
- Contradictory or impossible requirements
- Any other fundamental barriers that make completion impossible
Then you MUST output exactly "[INFEASIBLE]" (including the square brackets) anywhere in your response to trigger the fail action. The system will automatically detect this pattern and terminate the task appropriately.
* The current date is {datetime.today().strftime('%A, %B %d, %Y')}.
* Home directory of this Ubuntu system is '/home/user'.
* If you need a password for sudo, the password of the computer is 'osworld-public-evaluation'.
</SYSTEM_CAPABILITY>
<IMPORTANT>
@ -82,6 +102,7 @@ SYSTEM_PROMPT_WINDOWS = f"""<SYSTEM_CAPABILITY>
* The current date is {datetime.today().strftime('%A, %B %d, %Y')}.
* Home directory of this Windows system is 'C:\\Users\\user'.
* When you want to open some applications on Windows, please use Double Click on it instead of clicking once.
* If you need a password for sudo, The password of the computer is 'osworld-public-evaluation'.
</SYSTEM_CAPABILITY>"""
@ -154,21 +175,30 @@ def _inject_prompt_caching(
one cache breakpoint is left for tools/system prompt, to be shared across sessions
"""
breakpoints_remaining = 3
breakpoints_remaining = 2 # Use full budget for recent messages
messages_processed = 0
for message in reversed(messages):
if message["role"] == "user" and isinstance(
content := message["content"], list
):
if breakpoints_remaining:
breakpoints_remaining -= 1
messages_processed += 1
# Check if this message would fit within the remaining budget
if breakpoints_remaining >= len(content):
# We have enough budget, spend it and add cache_control
breakpoints_remaining -= len(content)
# Use type ignore to bypass TypedDict check until SDK types are updated
content[-1]["cache_control"] = BetaCacheControlEphemeralParam( # type: ignore
{"type": "ephemeral"}
)
else:
content[-1].pop("cache_control", None)
# we'll only every have one extra turn per loop
break
# Check if this is the first message (contains image + text with task description)
is_first_message = messages_processed == len([msg for msg in messages if msg["role"] == "user"])
if not is_first_message:
# Not enough budget, remove any existing cache_control from this message
content[-1].pop("cache_control", None)
# Continue to clean up older messages that might have cache_control from previous turns
def _maybe_filter_to_n_most_recent_images(
@ -220,6 +250,105 @@ def _maybe_filter_to_n_most_recent_images(
tool_result["content"] = new_content
def validate_model_support(model_name: str, api_key: str = None, temperature: float = None, top_p: float = None, no_thinking: bool = False, use_isp: bool = False) -> bool:
"""
Validate model support with the same API call pattern as the main agent.
Args:
model_name: The model name to validate
api_key: Optional API key, defaults to ANTHROPIC_API_KEY env var
temperature: Optional temperature parameter for testing
top_p: Optional top_p parameter for testing
no_thinking: Disable thinking mode (matches AnthropicAgent)
use_isp: Use interleaved scratchpad mode (matches AnthropicAgent)
Returns:
True if model is supported and API call succeeds, False otherwise
"""
print(f"🔍 Validating model support: {model_name}")
try:
from anthropic import Anthropic
import os
import time
# Same client setup as main agent but with manual retry (max_retries=1 for faster feedback)
client = Anthropic(
api_key=api_key or os.environ.get("ANTHROPIC_API_KEY"),
max_retries=4
).with_options(default_headers={"anthropic-beta": COMPUTER_USE_BETA_FLAG})
# Same message format as main agent - always use structured format with cache_control
messages = [{"role": "user", "content": [{"type": "text", "text": "Respond with 'OK'", "cache_control": {"type": "ephemeral"}}]}]
# Same betas configuration as main agent
betas = [COMPUTER_USE_BETA_FLAG]
if use_isp:
betas.append("interleaved-thinking-2025-05-14")
system = [{"type": "text", "text": "You are Claude. Respond with 'OK'."}]
# Same tools configuration as main agent - use modern computer tool for all models
tools = [{"name": "computer", "type": "computer_20250124",
"display_width_px": 1280, "display_height_px": 720, "display_number": 1}]
# Same thinking configuration as main agent
max_tokens = 50 # Base validation max_tokens
if no_thinking:
extra_body = {}
actual_max_tokens = max_tokens
else:
budget_tokens = 2048
# Same logic as main agent: if max_tokens <= budget_tokens, increase it
if max_tokens <= budget_tokens:
actual_max_tokens = budget_tokens + 500
else:
actual_max_tokens = max_tokens
extra_body = {
"thinking": {"type": "enabled", "budget_tokens": budget_tokens}
}
# Sampling parameters (same logic as main agent)
sampling_params = {}
if temperature is not None:
sampling_params['temperature'] = temperature
if top_p is not None:
sampling_params['top_p'] = top_p
# Retry logic with 5 attempts, 5 second delays
for attempt in range(5):
try:
# Same API call pattern as main agent
client.beta.messages.create(
max_tokens=actual_max_tokens,
messages=messages,
model=get_model_name(APIProvider.ANTHROPIC, model_name),
system=system,
tools=tools,
betas=betas,
extra_body=extra_body,
**sampling_params
)
print(f"✅ Model {model_name} validated successfully")
return True
except Exception as e:
if attempt < 4: # Don't print error on final attempt
print(f"🔄 Validation attempt {attempt + 1}/5 failed: {e}")
print(f"⏳ Retrying in 5 seconds...")
time.sleep(5)
else:
print(f"❌ All validation attempts failed. Final error: {e}")
return False
except ValueError:
return False
except Exception as e:
print(f"❌ API validation setup failed: {e}")
return False
def _response_to_params(
response: BetaMessage,
) -> list[BetaContentBlockParam]:

View File

@ -0,0 +1,161 @@
COMPUTER_USE_PROMPT = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
## Output Format
```
Thought: ...
Action: ...
```
## Action Space
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
hotkey(key='')
type(content='') #If you want to submit your input, use "\\n" at the end of `content`.
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
wait() #Sleep for 5s and take a screenshot to check for any changes.
finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format.
## Note
- Use {language} in `Thought` part.
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
- My computer's password is 'password', feel free to use it when you need sudo rights.
## User Instruction
{instruction}
"""
COMPUTER_USE_PROMPT_WITH_CALL_USER = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
## Output Format
```
Thought: ...
Action: ...
```
## Action Space
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
hotkey(key='')
type(content='') #If you want to submit your input, use "\\n" at the end of `content`.
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
wait() #Sleep for 5s and take a screenshot to check for any changes.
finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format.
call_user() # Submit the task and call the user when the task is unsolvable, or when you need the user's help.
## Note
- Use {language} in `Thought` part.
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
- My computer's password is 'password', feel free to use it when you need sudo rights.
## User Instruction
{instruction}
"""
UITARS_ACTION_SPACE = """
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
hotkey(key='')
type(content='') #If you want to submit your input, use "\\n" at the end of `content`.
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
wait() #Sleep for 5s and take a screenshot to check for any changes.
finished()
"""
UITARS_CALL_USR_ACTION_SPACE = """
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
hotkey(key='')
type(content='') #If you want to submit your input, use "\\n" at the end of `content`.
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
wait() #Sleep for 5s and take a screenshot to check for any changes.
finished()
call_user() # Submit the task and call the user when the task is unsolvable, or when you need the user's help.
"""
UITARS_NORMAL_ACTION_SPACE = """
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
hotkey(key='')
type(content='') #If you want to submit your input, use "\\n" at the end of `content`.
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
wait() #Sleep for 5s and take a screenshot to check for any changes.
finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format.
"""
UITARS_USR_PROMPT_NOTHOUGHT = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
## Output Format
```
Action: ...
```
## Action Space
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
hotkey(key='')
type(content='') #If you want to submit your input, use "\\n" at the end of `content`.
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
wait() #Sleep for 5s and take a screenshot to check for any changes.
finished()
call_user() # Submit the task and call the user when the task is unsolvable, or when you need the user's help.
## User Instruction
{instruction}
"""
UITARS_USR_PROMPT_THOUGHT = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
## Output Format
```
Thought: ...
Action: ...
```
## Action Space
{action_space}
## Note
- Use {language} in `Thought` part.
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
## User Instruction
{instruction}
"""
FAILURE_INDICATORS = [
# Direct inability expressions
"无法", "不能", "不可以", "做不到", "实现不了", "完成不了","没法",
# Regret/apology expressions
"遗憾", "抱歉", "很抱歉", "非常抱歉", "对不起",
# Not supported/available
"不直接支持", "不支持", "不提供", "不具备", "没有权限", "权限不足", "不在这里面","不符合",#"不存在",
# Cannot access/handle
"无权访问", "访问不了", "处理不了", "操作不了", "执行不了", "没找到", "空空如也",
# Not possible/feasible
"不可能", "无法实现", "实现不了", "办不到", "做不了","找不到","存在技术限制","没有找到","没有内置",
# System limitations
"超出范围", "不在我的能力范围", "能力有限", "功能限制","没有成功","没成功","硬件的问题",
# Refusal indicators
"拒绝", "不允许", "禁止", "不合适", "不恰当",
# Trying Restart
"从头开始", "藏在", "浪费时间","一个更合理的思路","正确的方向","没有意义",#, "重新","重启",
]

View File

@ -0,0 +1,202 @@
import asyncio
from typing import List, Optional, Union, Dict, Any
import json
import os
import hashlib
from pathlib import Path
from omegaconf import DictConfig
from dataclasses import dataclass, asdict
import copy
import logging
import random
from prompts import COMPUTER_USE_PROMPT, COMPUTER_USE_PROMPT_WITH_CALL_USER
from log_config import setup_logging
# 设置统一的日志系统
setup_logging()
logger = logging.getLogger(__name__)
class TaskLoader:
def __init__(self, task_cfg: DictConfig, storage_root):
self.task_file = Path(task_cfg.task_file)
#self.task_root = Path(task_cfg.task_root)
self.osworld_root = Path(task_cfg.osworld_root)
self._latest_sha: Optional[str] = None
self.storage_root = storage_root
self.resume = task_cfg.resume
def poll_for_tasks(self) -> List[Dict]:
"""find new tasks json file
return list of TaskInfo dict if there is new json
else return []
"""
self._maybe_refresh_dataset()
tasks_list = [task.to_dict() for task in self._tasks]
random.shuffle(tasks_list)
return tasks_list
def _maybe_refresh_dataset_bak(self):
# check new json
latest_json = self._find_latest_json()
if latest_json is None:
return False # no json file
sha = self._calc_sha1(latest_json)
if sha == self._latest_sha:
return False # no change
with open(latest_json) as f:
data = json.load(f)
raw_tasks = [
{"task_type": task_type, "task_id": task_id}
for task_type, task_ids in data.items()
for task_id in task_ids
]
self._tasks = [build_task(raw, self.osworld_root) for raw in raw_tasks]
self._latest_sha = sha
logger.info(f"当前任务文件: {str(latest_json)}")
logger.info(f"任务总数: {len(raw_tasks)}")
return True
def _maybe_refresh_dataset(self):
latest_json = self.task_file
print("Current tasks file: ", str(latest_json))
with open(latest_json) as f:
data = json.load(f)
raw_tasks = [
{"task_type": task_type, "task_id": task_id}
for task_type, task_ids in data.items()
for task_id in task_ids
]
if self.resume:
# 过滤已完成或类型不匹配的任务
filtered_tasks = []
storage_root = Path(self.storage_root)
for raw in raw_tasks:
task_id = str(raw["task_id"])
task_type_expected = raw["task_type"]
# 找到所有以 task_id 开头的子目录(允许有多个版本)
candidate_dirs = [
d for d in storage_root.iterdir()
if d.is_dir() and d.name.startswith(task_id)
]
# 默认认为任务未完成
task_finished = False
for d in candidate_dirs:
cfg_path = d / "task_config.json"
if not cfg_path.exists():
print("找不到config文件")
continue
try:
with cfg_path.open("r", encoding="utf-8") as cf:
cfg = json.load(cf)
except Exception:
print("配置损坏,忽略此目录")
continue
# 3.1 task_type 不同 => 不是同一个任务,直接跳过这目录
if cfg.get("raw", {}).get("task_type") != task_type_expected:
continue
# 3.2 task_type 相同,检查 reward.txt
if (d / "reward.txt").exists():
task_finished = True
break # 已找到完成记录,无需再看其他目录
if not task_finished:
filtered_tasks.append(raw)
self._tasks = [build_task(raw, self.osworld_root) for raw in filtered_tasks]
print(f"Total number of tasks: {len(raw_tasks)}, Remained:{len(filtered_tasks)}")
else:
self._tasks = [build_task(raw, self.osworld_root) for raw in raw_tasks]
print(f"Total number of tasks: {len(raw_tasks)}")
return True
def _find_latest_json(self) -> Optional[Path]:
files = list(self.task_root.glob("*.json"))
return max(files, key=lambda p: p.stat().st_mtime) if files else None
@staticmethod
def _calc_sha1(fp: Path, chunk_size=2<<20) -> str:
h = hashlib.sha1()
with fp.open("rb") as f:
for chunk in iter(lambda: f.read(chunk_size), b""):
h.update(chunk)
return h.hexdigest()
@dataclass
class TaskInfo:
messages: List
instruction: str
task_config: Dict
def to_dict(self):
return asdict(self)
def build_task(raw: Dict, osworld_root: Path, use_call_user: bool = False) -> TaskInfo:
task_type = raw["task_type"]
task_id = raw["task_id"]
task_path = os.path.join(osworld_root, task_type, task_id + ".json")
with open(task_path) as f:
task_data = json.load(f)
task_data["raw"] = {
"task_type": task_type,
"task_id": task_id
}
instruction = task_data["instruction"]
if "human-ground-truth" in task_data and "single-action" in task_data["human-ground-truth"]:
plan = task_data["human-ground-truth"]["single-action"]
plan_text = "\n".join(plan)
instruction = instruction.strip() + "\nHere is an instruction to help you complete the task: \n" + plan_text
system_prompt = COMPUTER_USE_PROMPT if not use_call_user else COMPUTER_USE_PROMPT_WITH_CALL_USER
messages = [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": [
{
"type": "text",
"text": system_prompt.format(
instruction=instruction,
language="English"
)}
]
}
]
return TaskInfo(
messages = messages,
instruction = instruction,
task_config = task_data
)

511
mm_agents/dart_gui/utils.py Normal file
View File

@ -0,0 +1,511 @@
import ast
import base64
import logging
import math
import re
import xml.etree.ElementTree as ET
from io import BytesIO
from typing import Dict, List
import numpy as np
import openai
from openai import OpenAI
from PIL import Image
from requests.exceptions import SSLError
from mm_agents.dart_gui.prompts import FAILURE_INDICATORS
# 设置日志系统
logger = logging.getLogger(__name__)
FINISH_WORD = "finished"
WAIT_WORD = "wait"
ENV_FAIL_WORD = "error_env"
CALL_USER = "call_user"
IMAGE_FACTOR = 28
MIN_PIXELS = 100 * 28 * 28
MAX_PIXELS = 16384 * 28 * 28
MAX_RATIO = 200
pure_text_settings = ["a11y_tree"]
attributes_ns_ubuntu = "https://accessibility.windows.example.org/ns/attributes"
attributes_ns_windows = "https://accessibility.windows.example.org/ns/attributes"
state_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/state"
state_ns_windows = "https://accessibility.windows.example.org/ns/state"
component_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/component"
component_ns_windows = "https://accessibility.windows.example.org/ns/component"
value_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/value"
value_ns_windows = "https://accessibility.windows.example.org/ns/value"
class_ns_windows = "https://accessibility.windows.example.org/ns/class"
# More namespaces defined in OSWorld, please check desktop_env/server/main.py
# 定义一个函数来解析每个 action
def parse_action(action_str):
try:
# 解析字符串为 AST 节点
node = ast.parse(action_str, mode='eval')
# 确保节点是一个表达式
if not isinstance(node, ast.Expression):
raise ValueError("Not an expression")
# 获取表达式的主体
call = node.body
# 确保主体是一个函数调用
if not isinstance(call, ast.Call):
raise ValueError("Not a function call")
# 获取函数名
if isinstance(call.func, ast.Name):
func_name = call.func.id
elif isinstance(call.func, ast.Attribute):
func_name = call.func.attr
else:
func_name = None
# 获取关键字参数
kwargs = {}
for kw in call.keywords:
key = kw.arg
# 处理不同类型的值,这里假设都是常量
if isinstance(kw.value, ast.Constant):
value = kw.value.value
elif isinstance(kw.value, ast.Str): # 兼容旧版本 Python
value = kw.value.s
else:
value = None
kwargs[key] = value
return {
'function': func_name,
'args': kwargs
}
except Exception as e:
logger.error(f"Failed to parse action '{action_str}': {e}")
return None
def escape_single_quotes(text):
# 匹配未转义的单引号(不匹配 \\'
pattern = r"(?<!\\)'"
return re.sub(pattern, r"\\'", text)
def round_by_factor(number: int, factor: int) -> int:
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
return round(number / factor) * factor
def ceil_by_factor(number: int, factor: int) -> int:
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
return math.ceil(number / factor) * factor
def floor_by_factor(number: int, factor: int) -> int:
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return math.floor(number / factor) * factor
def linear_resize(
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
) -> tuple[int, int]:
if width * height > max_pixels:
"""
如果图片超过/低于像素限制则计算一个缩放因子resize_factor使图片的像素数缩小到等于或小于max_pixels这个缩放因子是通过开平方根计算的确保纵横比保持不变,这样原始的相对坐标可以不经转换直接复用
"""
resize_factor = math.sqrt(max_pixels / (width * height))
width, height = int(width * resize_factor), int(height * resize_factor)
if width * height < min_pixels:
resize_factor = math.sqrt(min_pixels / (width * height))
width, height = math.ceil(width * resize_factor), math.ceil(height * resize_factor)
return height, width
def smart_resize(
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
) -> tuple[int, int]:
"""
Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
if max(height, width) / min(height, width) > MAX_RATIO:
raise ValueError(
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
)
h_bar = max(factor, round_by_factor(height, factor))
w_bar = max(factor, round_by_factor(width, factor))
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = floor_by_factor(height / beta, factor)
w_bar = floor_by_factor(width / beta, factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = ceil_by_factor(height * beta, factor)
w_bar = ceil_by_factor(width * beta, factor)
return h_bar, w_bar
def parse_action_to_structure_output(text, factor, origin_resized_height, origin_resized_width, model_type, max_pixels=16384*28*28, min_pixels=100*28*28):
text = text.strip()
if model_type == "qwen25vl":
smart_resize_height, smart_resize_width = smart_resize(origin_resized_height, origin_resized_width, factor=IMAGE_FACTOR, min_pixels=min_pixels, max_pixels=max_pixels)
# 正则表达式匹配 Action 字符串
if text.startswith("Thought:"):
thought_pattern = r"Thought: (.+?)(?=\s*Action:|$)"
thought_hint = "Thought: "
elif text.startswith("Reflection:"):
thought_pattern = r"Reflection: (.+?)Action_Summary: (.+?)(?=\s*Action:|$)"
thought_hint = "Reflection: "
elif text.startswith("Action_Summary:"):
thought_pattern = r"Action_Summary: (.+?)(?=\s*Action:|$)"
thought_hint = "Action_Summary: "
else:
thought_pattern = r"Thought: (.+?)(?=\s*Action:|$)"
thought_hint = "Thought: "
reflection, thought = None, None
thought_match = re.search(thought_pattern, text, re.DOTALL)
if thought_match:
if len(thought_match.groups()) == 1:
thought = thought_match.group(1).strip()
elif len(thought_match.groups()) == 2:
thought = thought_match.group(2).strip()
reflection = thought_match.group(1).strip()
assert "Action:" in text
action_str = text.split("Action:")[-1]
tmp_all_action = action_str.split("\n\n")
all_action = []
for action_str in tmp_all_action:
if "type(content" in action_str:
# 正则表达式匹配 content 中的字符串并转义单引号
def escape_quotes(match):
content = match.group(1) # 获取 content 的值
return content
# 使用正则表达式进行替换
pattern = r"type\(content='(.*?)'\)" # 匹配 type(content='...')
content = re.sub(pattern, escape_quotes, action_str)
# 处理字符串
action_str = escape_single_quotes(content)
action_str = "type(content='" + action_str + "')"
if "finished(content" in action_str:
# 正则表达式匹配 content 中的字符串并转义单引号
def escape_quotes(match):
content = match.group(1) # 获取 content 的值
return content
# 使用正则表达式进行替换
pattern = r"finished\(content='(.*?)'\)" # 匹配 type(content='...')
content = re.sub(pattern, escape_quotes, action_str)
# 处理字符串
action_str = escape_single_quotes(content)
action_str = "finished(content='" + action_str + "')"
all_action.append(action_str)
parsed_actions = [parse_action(action.replace("\n","\\n").lstrip()) for action in all_action]
actions = []
for action_instance, raw_str in zip(parsed_actions, all_action):
if action_instance == None:
logger.error(f"Action can't parse: {raw_str}")
# raise ValueError(f"Action can't parse: {raw_str}")
continue
action_type = action_instance["function"]
params = action_instance["args"]
# import pdb; pdb.set_trace()
action_inputs = {}
for param_name, param in params.items():
if param == "": continue
param = param.lstrip() # 去掉引号和多余的空格
# 处理start_box或者end_box参数格式 '<bbox>x1 y1 x2 y2</bbox>'
action_inputs[param_name.strip()] = param
if "start_box" in param_name or "end_box" in param_name:
ori_box = param
# Remove parentheses and split the string by commas
numbers = ori_box.replace("(", "").replace(")", "").split(",")
# Convert to float and scale by 1000
# Qwen2.5vl output absolute coordinates, qwen2vl output relative coordinates
if model_type == "qwen25vl":
float_numbers = []
for num_idx, num in enumerate(numbers):
num = float(num)
if (num_idx + 1) % 2 == 0:
float_numbers.append(float(num/smart_resize_height))
else:
float_numbers.append(float(num/smart_resize_width))
else:
float_numbers = [float(num) / factor for num in numbers]
if len(float_numbers) == 2:
float_numbers = [float_numbers[0], float_numbers[1], float_numbers[0], float_numbers[1]]
action_inputs[param_name.strip()] = str(float_numbers)
# import pdb; pdb.set_trace()
actions.append(
{
"reflection": reflection,
"thought": thought,
"action_type": action_type,
"action_inputs": action_inputs,
"text": text
})
return actions
def parsing_response_to_pyautogui_code(responses, image_height: int, image_width:int, input_swap:bool=True) -> str:
'''
将M模型的输出解析为OSWorld中的action生成pyautogui代码字符串
参数:
response: 包含模型输出的字典结构类似于
{
"action_type": "hotkey",
"action_inputs": {
"hotkey": "v ctrl",
"start_box": None,
"end_box": None
}
}
返回:
生成的pyautogui代码字符串
'''
pyautogui_code = "import pyautogui\nimport time\n"
if isinstance(responses, dict):
responses = [responses]
for response_id, response in enumerate(responses):
if "observation" in response:
observation = response["observation"]
else:
observation = ""
if "thought" in response:
thought = response["thought"]
else:
thought = ""
if response_id == 0:
pyautogui_code += f"'''\nObservation:\n{observation}\n\nThought:\n{thought}\n'''\n"
else:
pyautogui_code += "\ntime.sleep(1)\n"
action_dict = response
response_text = action_dict.get("text", "")
action_type = action_dict.get("action_type")
action_inputs = action_dict.get("action_inputs", {})
if action_type == "hotkey":
# Parsing hotkey action
if "key" in action_inputs:
hotkey = action_inputs.get("key", "")
else:
hotkey = action_inputs.get("hotkey", "")
if hotkey == "arrowleft":
hotkey = "left"
elif hotkey == "arrowright":
hotkey = "right"
elif hotkey == "arrowup":
hotkey = "up"
elif hotkey == "arrowdown":
hotkey = "down"
if hotkey:
# Handle other hotkeys
keys = hotkey.split() # Split the keys by space
convert_keys = []
for key in keys:
if key == "space":
key = ' '
convert_keys.append(key)
pyautogui_code += f"\npyautogui.hotkey({', '.join([repr(k) for k in convert_keys])})"
elif action_type == "press":
# Parsing press action
if "key" in action_inputs:
key_to_press = action_inputs.get("key", "")
else:
key_to_press = action_inputs.get("press", "")
if hotkey == "arrowleft":
hotkey = "left"
elif hotkey == "arrowright":
hotkey = "right"
elif hotkey == "arrowup":
hotkey = "up"
elif hotkey == "arrowdown":
hotkey = "down"
elif hotkey == "space":
hotkey = " "
if key_to_press:
# Simulate pressing a single key
pyautogui_code += f"\npyautogui.press({repr(key_to_press)})"
elif action_type == "keyup":
key_to_up = action_inputs.get("key", "")
pyautogui_code += f"\npyautogui.keyUp({repr(key_to_up)})"
elif action_type == "keydown":
key_to_down = action_inputs.get("key", "")
pyautogui_code += f"\npyautogui.keyDown({repr(key_to_down)})"
elif action_type == "type":
# Parsing typing action using clipboard
content = action_inputs.get("content", "")
content = escape_single_quotes(content)
stripped_content = content
if content.endswith("\n") or content.endswith("\\n"):
stripped_content = stripped_content.rstrip("\\n").rstrip("\n")
if content:
if input_swap:
pyautogui_code += "\nimport pyperclip"
pyautogui_code += f"\npyperclip.copy('{stripped_content}')"
pyautogui_code += "\npyautogui.hotkey('ctrl', 'v')"
pyautogui_code += "\ntime.sleep(0.5)\n"
if content.endswith("\n") or content.endswith("\\n"):
pyautogui_code += "\npyautogui.press('enter')"
else:
pyautogui_code += f"\npyautogui.write('{stripped_content}', interval=0.1)"
pyautogui_code += "\ntime.sleep(0.5)\n"
if content.endswith("\n") or content.endswith("\\n"):
pyautogui_code += "\npyautogui.press('enter')"
elif action_type in ["drag", "select"]:
# Parsing drag or select action based on start and end_boxes
start_box = action_inputs.get("start_box")
end_box = action_inputs.get("end_box")
if start_box and end_box:
x1, y1, x2, y2 = eval(start_box) # Assuming box is in [x1, y1, x2, y2]
sx = round(float((x1 + x2) / 2) * image_width, 3)
sy = round(float((y1 + y2) / 2) * image_height, 3)
x1, y1, x2, y2 = eval(end_box) # Assuming box is in [x1, y1, x2, y2]
ex = round(float((x1 + x2) / 2) * image_width, 3)
ey = round(float((y1 + y2) / 2) * image_height, 3)
pyautogui_code += (
f"\npyautogui.moveTo({sx}, {sy})\n"
f"\npyautogui.dragTo({ex}, {ey}, duration=1.0)\n"
)
elif action_type == "scroll":
# Parsing scroll action
start_box = action_inputs.get("start_box")
if start_box:
x1, y1, x2, y2 = eval(start_box) # Assuming box is in [x1, y1, x2, y2]
x = round(float((x1 + x2) / 2) * image_width, 3)
y = round(float((y1 + y2) / 2) * image_height, 3)
# # 先点对应区域,再滚动
# pyautogui_code += f"\npyautogui.click({x}, {y}, button='left')"
else:
x = None
y = None
direction = action_inputs.get("direction", "")
if x == None:
if "up" in direction.lower():
pyautogui_code += "\npyautogui.scroll(5)"
elif "down" in direction.lower():
pyautogui_code += "\npyautogui.scroll(-5)"
else:
if "up" in direction.lower():
pyautogui_code += f"\npyautogui.scroll(5, x={x}, y={y})"
elif "down" in direction.lower():
pyautogui_code += f"\npyautogui.scroll(-5, x={x}, y={y})"
elif action_type in ["click", "left_single", "left_double", "right_single", "hover"]:
# Parsing mouse click actions
start_box = action_inputs.get("start_box")
start_box = str(start_box)
if start_box:
start_box = eval(start_box)
if start_box is None:
logger.warning(f"[Warning] start_box is None and wired condition:\n{action_inputs}")
if len(start_box) == 4:
x1, y1, x2, y2 = start_box # Assuming box is in [x1, y1, x2, y2]
elif len(start_box) == 2:
x1, y1 = start_box
x2 = x1
y2 = y1
x = round(float((x1 + x2) / 2) * image_width, 3)
y = round(float((y1 + y2) / 2) * image_height, 3)
if action_type == "left_single" or action_type == "click":
pyautogui_code += f"\npyautogui.click({x}, {y}, button='left')"
elif action_type == "left_double":
pyautogui_code += f"\npyautogui.doubleClick({x}, {y}, button='left')"
elif action_type == "right_single":
pyautogui_code += f"\npyautogui.click({x}, {y}, button='right')"
elif action_type == "hover":
pyautogui_code += f"\npyautogui.moveTo({x}, {y})"
elif action_type in ["finished"]:
pyautogui_code = "DONE"
print(f"FINISHED:response_text: {response_text}")
print(f"FINISHED:response: {str(response)}")
for failure_indicator in FAILURE_INDICATORS:
if failure_indicator in response_text:
pyautogui_code = "FAIL"
break
elif action_type in ["wait"]:
pyautogui_code = "WAIT"
elif action_type in ["call_user"]:
pyautogui_code = "FAIL"
else:
pyautogui_code += f"\n# Unrecognized action type: {action_type}"
return pyautogui_code
def add_box_token(input_string):
# Step 1: Split the string into individual actions
if "Action: " in input_string and "start_box=" in input_string:
suffix = input_string.split("Action: ")[0] + "Action: "
actions = input_string.split("Action: ")[1:]
processed_actions = []
for action in actions:
action = action.strip()
# Step 2: Extract coordinates (start_box or end_box) using regex
coordinates = re.findall(r"(start_box|end_box)='\((\d+),\s*(\d+)\)'", action)
updated_action = action # Start with the original action
for coord_type, x, y in coordinates:
# Convert x and y to integers
updated_action = updated_action.replace(f"{coord_type}='({x},{y})'", f"{coord_type}='<|box_start|>({x},{y})<|box_end|>'")
processed_actions.append(updated_action)
# Step 5: Reconstruct the final string
final_string = suffix + "\n\n".join(processed_actions)
else:
final_string = input_string
# print(f"Input string: {input_string}")
# print(f"Final string: {final_string}")
return [{"type": "text", "text": final_string}]
def pil_to_base64(image):
"""Convert PIL Image or bytes to base64 string"""
if isinstance(image, bytes):
# If it's already bytes, just encode to base64
return base64.b64encode(image).decode("utf-8")
else:
# If it's a PIL Image, convert it
buffer = BytesIO()
image.save(buffer, format="PNG")
return base64.b64encode(buffer.getvalue()).decode("utf-8")

686
mm_agents/dart_gui_agent.py Normal file
View File

@ -0,0 +1,686 @@
"""
Dart Agent - Custom agent for GUI automation using Dart models
Based on UITARSAgent structure but using Dart-specific utilities and prompts
"""
import ast
import base64
import logging
import math
import os
import re
import time
from io import BytesIO
from typing import Dict, List, Any
from PIL import Image
from openai import OpenAI
import backoff
import openai
import requests
from requests.exceptions import SSLError
from google.api_core.exceptions import (
BadRequest,
InternalServerError,
InvalidArgument,
ResourceExhausted,
)
# Import Dart-specific utilities and prompts
from mm_agents.dart_gui.utils import (
pil_to_base64,
parse_action_to_structure_output,
parsing_response_to_pyautogui_code,
parse_action,
escape_single_quotes,
round_by_factor,
ceil_by_factor,
floor_by_factor,
linear_resize,
smart_resize,
add_box_token,
IMAGE_FACTOR,
MIN_PIXELS,
MAX_PIXELS,
MAX_RATIO,
FINISH_WORD,
WAIT_WORD,
ENV_FAIL_WORD,
CALL_USER
)
from mm_agents.dart_gui.prompts import (
COMPUTER_USE_PROMPT,
COMPUTER_USE_PROMPT_WITH_CALL_USER,
UITARS_ACTION_SPACE,
UITARS_CALL_USR_ACTION_SPACE,
UITARS_USR_PROMPT_THOUGHT,
UITARS_USR_PROMPT_NOTHOUGHT
)
logger = logging.getLogger("desktopenv.agent")
class DartAgent:
def __init__(
self,
model: str,
runtime_conf: Dict,
platform="ubuntu",
max_tokens=1000,
top_p=0.9,
top_k=1.0,
temperature=0.0,
action_space="pyautogui",
observation_type="screenshot",
max_trajectory_length=50,
model_type="qwen25vl",
**kwargs
):
self.model = model
self.platform = platform
self.action_space = action_space
self.observation_type = observation_type
self.max_trajectory_length = max_trajectory_length
self.model_type = model_type
self.runtime_conf = runtime_conf
# Extract runtime configuration parameters
self.max_tokens = self.runtime_conf.get("max_tokens", max_tokens)
self.top_p = self.runtime_conf.get("top_p", top_p)
self.top_k = self.runtime_conf.get("top_k", top_k)
self.temperature = self.runtime_conf.get("temperature", temperature)
self.infer_mode = self.runtime_conf.get("infer_mode", "dart_mode")
self.prompt_style = self.runtime_conf.get("prompt_style", "dart_style")
self.input_swap = self.runtime_conf.get("input_swap", False)
self.language = self.runtime_conf.get("language", "English")
self.max_pixels = self.runtime_conf.get("max_pixels", MAX_PIXELS)
self.min_pixels = self.runtime_conf.get("min_pixels", MIN_PIXELS)
self.history_n = self.runtime_conf.get("history_n", 5)
# Dart specific configurations
self.max_images = self.runtime_conf.get("max_images", 5)
self.max_texts = self.runtime_conf.get("max_texts", 35)
# Initialize OpenAI client - use Dart API if provided
dart_api_key = self.runtime_conf.get("dart_api_key", "")
dart_base_url = self.runtime_conf.get("dart_base_url", "")
if dart_base_url:
# 检查是否为直接的生成端点(包含 /generate
if '/generate' in dart_base_url:
# 直接使用提供的 URL不添加 /v1
logger.info(f"使用直接生成端点: {dart_base_url}")
self.dart_direct_url = dart_base_url
self.vlm = None # 不使用 OpenAI 客户端
else:
# 传统的 OpenAI 兼容端点,确保以 /v1 结尾
if not dart_base_url.endswith('/v1'):
dart_base_url = dart_base_url.rstrip('/') + '/v1'
self.vlm = OpenAI(
base_url=dart_base_url,
api_key=dart_api_key,
)
self.dart_direct_url = None
else:
# Fallback to environment variables
base_url = os.environ.get('DART_API_URL', os.environ.get('DOUBAO_API_URL'))
if base_url:
if '/generate' in base_url:
# 直接生成端点
self.dart_direct_url = base_url
self.vlm = None
else:
if not base_url.endswith('/v1'):
base_url = base_url.rstrip('/') + '/v1'
self.vlm = OpenAI(
base_url=base_url,
api_key=os.environ.get('DART_API_KEY', os.environ.get('DOUBAO_API_KEY')),
)
self.dart_direct_url = None
else:
self.vlm = None
self.dart_direct_url = None
# Initialize trajectory storage - similar to trajectory_runner.py
self.thoughts = []
self.actions = []
self.observations = []
self.history_images = []
self.history_responses = []
# Message handling similar to trajectory_runner.py
self.base_messages = [] # for model client (with base64 images)
self.base_messages_for_save = [] # for storage (with file paths)
self.prompt_dialogue = [] # for model client
self.save_dialogue = [] # for storage
self.save_dialogue_full = [] # for full storage (保存所有图片路径)
self.image_refs = [] # record image position
# All image paths storage - to keep track of all images even when trimmed
self.all_image_paths = []
# Current screenshot file path for proper saving
self.current_screenshot_path = None
# Configure prompt and action space based on mode
if self.infer_mode == "dart_mode":
self.prompt_action_space = UITARS_ACTION_SPACE
self.prompt_template = COMPUTER_USE_PROMPT
else:
# For qwen2vl_user mode
self.prompt_action_space = UITARS_CALL_USR_ACTION_SPACE
if self.prompt_style == "qwen2vl_user":
self.prompt_template = UITARS_USR_PROMPT_THOUGHT
elif self.prompt_style == "qwen2vl_no_thought":
self.prompt_template = UITARS_USR_PROMPT_NOTHOUGHT
else:
self.prompt_template = UITARS_USR_PROMPT_THOUGHT
self.action_parse_res_factor = 1000
logger.info(f"Initialized DartAgent with model: {self.model}, mode: {self.infer_mode}")
def reset(self, runtime_logger=None):
"""Reset the agent state"""
self.thoughts = []
self.actions = []
self.observations = []
self.history_images = []
self.history_responses = []
# Reset message handling
self.base_messages = []
self.base_messages_for_save = []
self.prompt_dialogue = []
self.save_dialogue = []
self.save_dialogue_full = []
self.image_refs = []
self.all_image_paths = []
self.current_screenshot_path = None
logger.info("DartAgent reset")
def set_base_messages(self, instruction: str):
"""Initialize base messages similar to task_loader.py"""
system_prompt = COMPUTER_USE_PROMPT
self.base_messages = [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": [
{
"type": "text",
"text": system_prompt.format(
instruction=instruction,
language=self.language
)
}
]
}
]
# Copy for save version
from copy import deepcopy
self.base_messages_for_save = deepcopy(self.base_messages)
def set_current_screenshot_path(self, screenshot_path: str):
"""Set the current screenshot file path for proper saving"""
self.current_screenshot_path = screenshot_path
def predict(
self, instruction: str, obs: Dict, last_action_after_obs: Dict = None
) -> tuple:
"""
Predict the next action(s) based on the current observation.
Returns: (response_text, actions_list)
"""
# Initialize base messages if not set
if not self.base_messages:
self.set_base_messages(instruction)
# Store current observation
self._add_observation(obs)
# For first step, set the first frame
if len(self.observations) == 1:
self._set_first_frame(obs["screenshot"], self.current_screenshot_path)
else:
# For subsequent steps, add the new image to dialogue
# This represents the result of the previous action
self._add_image(obs["screenshot"], self.current_screenshot_path)
# Build prompt messages (base_messages + prompt_dialogue)
messages = self._build_messages()
# Call model to get response
prediction = self._call_model(messages)
if prediction is None:
return "client error", ["DONE"]
# Store response and parse actions
self._add_text(prediction)
# Parse response to actions
try:
image_size = self._get_current_image_size()
actions = self._parse_and_convert_actions(prediction, image_size)
# Check for terminal actions
terminal_action = self._check_terminal_actions(actions)
if terminal_action:
self.actions.append(actions)
return prediction, [terminal_action]
except Exception as e:
logger.error(f"Parsing action error: {prediction}, error: {e}")
return f"Parsing action error: {prediction}, error: {e}", ["DONE"]
self.actions.append(actions)
# Check max steps
if len(self.history_responses) >= self.max_trajectory_length:
actions = ["FAIL"]
return prediction, actions
@backoff.on_exception(
backoff.constant,
(
# General exceptions
SSLError,
# OpenAI exceptions
openai.RateLimitError,
openai.BadRequestError,
openai.InternalServerError,
# Google exceptions
InvalidArgument,
ResourceExhausted,
InternalServerError,
BadRequest,
),
interval=30,
max_tries=10,
)
def predict_with_backoff(self, instruction: str, obs: Dict, last_action_after_obs: Dict = None):
"""Predict with backoff for rate limiting and temporary errors"""
return self.predict(instruction, obs, last_action_after_obs)
def get_trajectory(self) -> List[Dict]:
"""Get the current trajectory for saving"""
trajectory = []
for i in range(len(self.observations)):
trajectory.append({
"observation": self.observations[i],
"thought": self.thoughts[i] if i < len(self.thoughts) else "",
"action": self.actions[i] if i < len(self.actions) else []
})
return trajectory
def get_full_messages(self) -> List[Dict]:
"""Get the complete conversation messages for saving (including base messages and dialogue)"""
# Combine base_messages_for_save with save_dialogue_full to get complete conversation
full_messages = []
# Add base messages (system prompt and initial user message)
full_messages.extend(self.base_messages_for_save)
# Add dialogue messages (user images + assistant responses) with all images
full_messages.extend(self.save_dialogue_full)
return full_messages
def get_all_image_paths(self) -> List[str]:
"""Get all image paths that have been used throughout the conversation"""
return self.all_image_paths.copy()
# ========== Private Methods ==========
def _validate_trajectory(self):
"""Validate trajectory consistency"""
assert len(self.observations) == len(self.actions) and len(self.actions) == len(
self.thoughts
), "The number of observations and actions should be the same."
def _add_observation(self, obs: Dict):
"""Process observation and add to history"""
# Store observation
if self.observation_type in ["screenshot", "screenshot_a11y_tree"]:
base64_image = obs["screenshot"]
try:
# Handle accessibility tree if needed
linearized_accessibility_tree = None
if self.observation_type == "screenshot_a11y_tree" and "accessibility_tree" in obs:
# For now, we'll skip accessibility tree processing in Dart mode
linearized_accessibility_tree = None
except:
linearized_accessibility_tree = None
if self.observation_type == "screenshot_a11y_tree":
self.observations.append({
"screenshot": base64_image,
"accessibility_tree": linearized_accessibility_tree,
})
else:
self.observations.append({
"screenshot": base64_image,
"accessibility_tree": None
})
else:
raise ValueError("Invalid observation_type type: " + self.observation_type)
def _build_messages(self) -> List[Dict]:
"""Build messages for model API call - similar to trajectory_runner._build_messages"""
return self.base_messages + self.prompt_dialogue
def _call_model(self, messages: List[Dict]) -> str:
"""Call model with retry logic"""
try_times = 3
while try_times > 0:
try:
# 如果使用直接生成端点
if hasattr(self, 'dart_direct_url') and self.dart_direct_url:
prediction = self._call_direct_generate_endpoint(messages)
else:
# 使用标准 OpenAI 客户端
response = self.vlm.chat.completions.create(
model=self.model,
messages=messages,
frequency_penalty=1,
max_tokens=self.max_tokens,
temperature=self.temperature,
top_p=self.top_p
)
prediction = response.choices[0].message.content
logger.info(f"Model response: {prediction}")
return prediction
except Exception as e:
logger.error(f"Error when fetching response from client: {e}")
try_times -= 1
if try_times <= 0:
logger.error("Reach max retry times to fetch response from client")
return None
return None
def _call_direct_generate_endpoint(self, messages: List[Dict]) -> str:
"""直接调用生成端点"""
try:
# 构建请求数据
payload = {
"messages": messages,
"model": self.model,
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
"frequency_penalty": 1
}
# 添加 API key 到 headers
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.runtime_conf.get('dart_api_key', '')}"
}
# 重试机制最多重试3次每次推理60秒
max_retries = 3
response = None
for attempt in range(max_retries):
try:
logger.info(f"尝试第 {attempt + 1} 次请求...")
response = requests.post(
self.dart_direct_url,
json=payload,
headers=headers,
timeout=60
)
response.raise_for_status()
break # 成功则跳出重试循环
except Exception as e:
logger.warning(f"{attempt + 1} 次请求失败: {e}")
if attempt == max_retries - 1: # 最后一次重试失败
logger.error(f"所有 {max_retries} 次重试都失败了")
raise e
else:
logger.info(f"等待后重试...")
import time
time.sleep(2) # 等待2秒后重试
# 解析响应
result = response.json()
# 尝试多种可能的响应格式
if 'choices' in result and len(result['choices']) > 0:
# OpenAI 兼容格式
return result['choices'][0]['message']['content']
elif 'response' in result:
# 简单的 response 字段
return result['response']
elif 'text' in result:
# text 字段
return result['text']
elif 'content' in result:
# content 字段
return result['content']
else:
# 如果找不到标准字段,返回整个响应的字符串
logger.warning(f"未知的响应格式: {result}")
return str(result)
except Exception as e:
logger.error(f"直接端点调用失败: {e}")
raise e
def _add_text(self, assistant_txt: str):
"""Add text response to history - similar to trajectory_runner.py"""
self.history_responses.append(assistant_txt)
self.thoughts.append(assistant_txt)
# Add to dialogue similar to trajectory_runner._add_text
msg = {
"role": "assistant",
"content": add_box_token(assistant_txt)
}
self.prompt_dialogue.append(msg)
self.save_dialogue.append(msg)
self.save_dialogue_full.append(msg)
self._trim()
def _set_first_frame(self, obs_img: bytes, frame_path: str = None):
"""Set first frame in base_messages - similar to trajectory_runner._set_first_frame"""
self.base_messages[1]["content"].append(
{
"type": "image_url",
"image_url": {"url": "data:image/png;base64," + pil_to_base64(obs_img)}
}
)
# Use actual frame path if provided, otherwise use current_screenshot_path or placeholder
if frame_path:
first_frame_path = frame_path
elif self.current_screenshot_path:
first_frame_path = self.current_screenshot_path
else:
first_frame_path = "first_frame.png"
# Store in all_image_paths
self.all_image_paths.append(first_frame_path)
self.base_messages_for_save[1]["content"].append(
{
"type": "image_url",
"image_url": first_frame_path
}
)
self.image_refs.append(
{"source": "base", "msg_idx": 1,
"content_idx": len(self.base_messages[1]["content"]) - 1}
)
def _add_image(self, img_bytes: bytes, frame_path: str = None):
"""Add image to dialogue - similar to trajectory_runner._add_image"""
self.prompt_dialogue.append({
"role": "user",
"content": [{
"type": "image_url",
"image_url": {"url": "data:image/png;base64," + pil_to_base64(img_bytes)}
}]
})
# Use actual frame path if provided, otherwise use current_screenshot_path
if frame_path:
image_url = frame_path
elif self.current_screenshot_path:
image_url = self.current_screenshot_path
else:
# Fallback to a placeholder - this should rarely happen in practice
image_url = f"frame_{len(self.save_dialogue)}.png"
# Store in all_image_paths for complete record
self.all_image_paths.append(image_url)
# Add to save_dialogue (trimmed version)
self.save_dialogue.append({
"role": "user",
"content": [{
"type": "image_url",
"image_url": image_url
}]
})
# Add to save_dialogue_full (complete version - never trimmed)
self.save_dialogue_full.append({
"role": "user",
"content": [{
"type": "image_url",
"image_url": image_url
}]
})
self.image_refs.append(
{"source": "dialogue", "msg_idx": len(self.prompt_dialogue) - 1,
"content_idx": None}
)
self._trim()
def _trim(self):
"""Ensure image num ≤ max_images and assistant text num ≤ max_texts - similar to trajectory_runner._trim"""
img_cnt = len(self.image_refs)
txt_cnt = sum(m["role"] == "assistant" for m in self.prompt_dialogue)
while img_cnt > self.max_images or txt_cnt > self.max_texts:
# 图片超限:最早一张
if img_cnt > self.max_images:
ref = self.image_refs.pop(0)
if ref["source"] == "base":
self.base_messages[ref["msg_idx"]]["content"].pop(ref["content_idx"])
else: # dialogue 图
self._remove_dialogue_msg(ref["msg_idx"])
img_cnt -= 1
continue
# 文本超限:最早 assistant 文本
if txt_cnt > self.max_texts:
for i, m in enumerate(self.prompt_dialogue):
if m["role"] == "assistant":
self._remove_dialogue_msg(i)
txt_cnt -= 1
break
def _remove_dialogue_msg(self, idx: int):
"""Remove dialogue message and update refs - similar to trajectory_runner._remove_dialogue_msg"""
self.prompt_dialogue.pop(idx)
self.save_dialogue.pop(idx)
# Note: save_dialogue_full is never trimmed, so we don't remove from it
# 更新 image_refs
self.image_refs = [
r if not (r["source"] == "dialogue" and r["msg_idx"] == idx)
else None # 同一条被删掉的图引用直接丢弃
for r in self.image_refs
]
self.image_refs = [
(
{**r, "msg_idx": r["msg_idx"] - 1}
if r and r["source"] == "dialogue" and r["msg_idx"] > idx # idx后的图片索引均-1
else r
)
for r in self.image_refs
if r # 剔除 None
]
def _get_current_image_size(self) -> tuple:
"""Get current image size for coordinate conversion"""
if len(self.observations) > 0:
try:
current_image_bytes = self.observations[-1]["screenshot"]
if isinstance(current_image_bytes, bytes):
current_image = Image.open(BytesIO(current_image_bytes))
return (current_image.height, current_image.width)
except Exception as e:
logger.warning(f"Error getting image size: {e}")
# Fallback to default screen size
return (1080, 1920)
def _parse_and_convert_actions(self, prediction: str, image_size: tuple) -> List[str]:
"""Parse response and convert to pyautogui actions - similar to trajectory_runner._parse"""
image_height, image_width = image_size
# Parse the response to structured actions
parsed_responses = parse_action_to_structure_output(
prediction,
factor=self.action_parse_res_factor,
origin_resized_height=image_height,
origin_resized_width=image_width,
model_type=self.model_type,
max_pixels=self.max_pixels,
min_pixels=self.min_pixels
)
# Convert parsed responses to pyautogui actions
actions = []
for parsed_response in parsed_responses:
try:
pyautogui_code = parsing_response_to_pyautogui_code(
parsed_response,
image_height=image_height,
image_width=image_width,
input_swap=self.input_swap
)
actions.append(pyautogui_code)
except Exception as e:
logger.error(f"Error generating pyautogui code: {e}")
actions.append("FAIL")
return actions
def _check_terminal_actions(self, actions: List[str]) -> str:
"""Check if any action is terminal and return appropriate code"""
for action in actions:
if isinstance(action, dict) and "action_type" in action:
action_type = action["action_type"]
if action_type == FINISH_WORD:
return "DONE"
elif action_type == WAIT_WORD:
return "WAIT"
elif action_type == ENV_FAIL_WORD:
return "FAIL"
elif action_type == CALL_USER:
return "FAIL"
return None

View File

@ -0,0 +1,653 @@
import os
import re
import json
import logging
import backoff
import openai
from typing import Dict, List, Tuple, Optional
from io import BytesIO
from PIL import Image
from mm_agents.evocua.utils import (
process_image,
encode_image,
rewrite_pyautogui_text_inputs,
project_coordinate_to_absolute_scale,
log_messages
)
from mm_agents.evocua.prompts import (
S1_SYSTEM_PROMPT,
S1_INSTRUTION_TEMPLATE,
S1_STEP_TEMPLATE,
S1_ACTION_HISTORY_TEMPLATE,
S2_ACTION_DESCRIPTION,
S2_DESCRIPTION_PROMPT_TEMPLATE,
S2_SYSTEM_PROMPT,
build_s2_tools_def
)
logger = logging.getLogger("desktopenv.evocua")
class EvoCUAAgent:
"""
EvoCUA - A Native GUI agent model for desktop automation.
"""
def __init__(
self,
model: str = "EvoCUA-S2",
max_tokens: int = 32768,
top_p: float = 0.9,
temperature: float = 0.0,
action_space: str = "pyautogui",
observation_type: str = "screenshot",
max_steps: int = 50,
prompt_style: str = "S2", # "S1" or "S2"
max_history_turns: int = 4,
screen_size: Tuple[int, int] = (1920, 1080),
coordinate_type: str = "relative",
password: str = "osworld-public-evaluation",
resize_factor: int = 32,
**kwargs
):
self.model = model
self.max_tokens = max_tokens
self.top_p = top_p
self.temperature = temperature
self.action_space = action_space
self.observation_type = observation_type
self.max_steps = max_steps
self.prompt_style = prompt_style
assert self.prompt_style in ["S1", "S2"], f"Invalid prompt_style: {self.prompt_style}"
self.max_history_turns = max_history_turns
self.screen_size = screen_size
self.coordinate_type = coordinate_type
self.password = password
self.resize_factor = resize_factor
# Action space assertion
assert self.action_space == "pyautogui", f"Invalid action space: {self.action_space}"
assert self.observation_type == "screenshot", f"Invalid observation type: {self.observation_type}"
# State
self.thoughts = []
self.actions = []
self.observations = []
self.responses = []
self.screenshots = [] # Stores encoded string
self.cots = [] # For S1 style history
def reset(self, _logger=None, vm_ip=None):
global logger
if _logger:
logger = _logger
self.thoughts = []
self.actions = []
self.observations = []
self.responses = []
self.screenshots = []
self.cots = []
def predict(self, instruction: str, obs: Dict) -> List:
"""
Main prediction loop.
"""
logger.info(f"========================== {self.model} ===================================")
logger.info(f"Instruction: \n{instruction}")
screenshot_bytes = obs["screenshot"]
try:
original_img = Image.open(BytesIO(screenshot_bytes))
original_width, original_height = original_img.size
except Exception as e:
logger.warning(f"Failed to read screenshot size, falling back to screen_size: {e}")
original_width, original_height = self.screen_size
if self.prompt_style == "S1":
raw_b64 = encode_image(screenshot_bytes)
self.screenshots.append(raw_b64)
return self._predict_s1(instruction, obs, raw_b64)
else:
processed_b64, p_width, p_height = process_image(screenshot_bytes, factor=self.resize_factor)
self.screenshots.append(processed_b64)
return self._predict_s2(
instruction,
obs,
processed_b64,
p_width,
p_height,
original_width,
original_height,
)
def _predict_s2(self, instruction, obs, processed_b64, p_width, p_height, original_width, original_height):
current_step = len(self.actions)
current_history_n = self.max_history_turns
response = None
if self.coordinate_type == "absolute":
resolution_info = f"* The screen's resolution is {p_width}x{p_height}."
else:
resolution_info = "* The screen's resolution is 1000x1000."
description_prompt = S2_DESCRIPTION_PROMPT_TEMPLATE.format(resolution_info=resolution_info)
tools_def = build_s2_tools_def(description_prompt)
system_prompt = S2_SYSTEM_PROMPT.format(tools_xml=json.dumps(tools_def))
# Retry loop for context length
while True:
messages = self._build_s2_messages(
instruction,
processed_b64,
current_step,
current_history_n,
system_prompt
)
try:
response = self.call_llm({
"model": self.model,
"messages": messages,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"temperature": self.temperature,
})
break
except Exception as e:
# Handle Context Too Large
if self._should_giveup_on_context_error(e) and current_history_n > 0:
current_history_n -= 1
logger.warning(f"Context too large, retrying with history_n={current_history_n}")
else:
logger.error(f"Error in predict: {e}")
break
self.responses.append(response)
low_level_instruction, pyautogui_code = self._parse_response_s2(
response, p_width, p_height, original_width, original_height
)
# new added
current_step = len(self.actions) + 1
first_action = pyautogui_code[0] if pyautogui_code else ""
if current_step >= self.max_steps and str(first_action).upper() not in ("DONE", "FAIL"):
logger.warning(f"Reached maximum steps {self.max_steps}. Forcing termination with FAIL.")
low_level_instruction = "Fail the task because reaching the maximum step limit."
pyautogui_code = ["FAIL"]
logger.info(f"Low level instruction: {low_level_instruction}")
logger.info(f"Pyautogui code: {pyautogui_code}")
self.actions.append(low_level_instruction)
return response, pyautogui_code
def _build_s2_messages(self, instruction, current_img, step, history_n, system_prompt):
messages = [{"role": "system", "content": [{"type": "text", "text": system_prompt}]}]
previous_actions = []
history_start_idx = max(0, step - history_n)
for i in range(history_start_idx):
if i < len(self.actions):
previous_actions.append(f"Step {i+1}: {self.actions[i]}")
previous_actions_str = "\n".join(previous_actions) if previous_actions else "None"
# Add History
history_len = min(history_n, len(self.responses))
if history_len > 0:
hist_responses = self.responses[-history_len:]
hist_imgs = self.screenshots[-history_len-1:-1]
for i in range(history_len):
if i < len(hist_imgs):
screenshot_b64 = hist_imgs[i]
if i == 0:
# First history item: Inject Instruction + Previous Actions Context
img_url = f"data:image/png;base64,{screenshot_b64}"
instruction_prompt = f"""
Please generate the next move according to the UI screenshot, instruction and previous actions.
Instruction: {instruction}
Previous actions:
{previous_actions_str}"""
messages.append({
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": img_url}},
{"type": "text", "text": instruction_prompt}
]
})
else:
img_url = f"data:image/png;base64,{screenshot_b64}"
messages.append({
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": img_url}},
]
})
messages.append({
"role": "assistant",
"content": [{"type": "text", "text": hist_responses[i]}]
})
# Current Turn
# We re-use previous_actions_str logic for the case where history_len == 0
if history_len == 0:
# First turn logic: Include Instruction + Previous Actions
instruction_prompt = f"""
Please generate the next move according to the UI screenshot, instruction and previous actions.
Instruction: {instruction}
Previous actions:
{previous_actions_str}"""
messages.append({
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{current_img}"}},
{"type": "text", "text": instruction_prompt}
]
})
else:
# Subsequent turns logic (context already in first history message): Image Only
messages.append({
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{current_img}"}}
]
})
return messages
def _parse_response_s2(
self,
response: str,
processed_width: int = None,
processed_height: int = None,
original_width: Optional[int] = None,
original_height: Optional[int] = None,
) -> Tuple[str, List[str]]:
"""
Parse LLM response and convert it to low level action and pyautogui code.
"""
# Prefer the real screenshot resolution (passed from predict), fallback to configured screen_size.
if not (original_width and original_height):
original_width, original_height = self.screen_size
low_level_instruction = ""
pyautogui_code: List[str] = []
if response is None or not response.strip():
return low_level_instruction, pyautogui_code
def adjust_coordinates(x: float, y: float) -> Tuple[int, int]:
if not (original_width and original_height):
return int(x), int(y)
if self.coordinate_type == "absolute":
# scale from processed pixels to original
if processed_width and processed_height:
x_scale = original_width / processed_width
y_scale = original_height / processed_height
return int(x * x_scale), int(y * y_scale)
return int(x), int(y)
# relative: scale from 0..999 grid
x_scale = original_width / 999
y_scale = original_height / 999
return int(x * x_scale), int(y * y_scale)
def process_tool_call(json_str: str) -> None:
try:
tool_call = json.loads(json_str)
if tool_call.get("name") == "computer_use":
args = tool_call["arguments"]
action = args["action"]
def _clean_keys(raw_keys):
keys = raw_keys if isinstance(raw_keys, list) else [raw_keys]
cleaned_keys = []
for key in keys:
if isinstance(key, str):
if key.startswith("keys=["):
key = key[6:]
if key.endswith("]"):
key = key[:-1]
if key.startswith("['") or key.startswith('["'):
key = key[2:] if len(key) > 2 else key
if key.endswith("']") or key.endswith('"]'):
key = key[:-2] if len(key) > 2 else key
key = key.strip()
cleaned_keys.append(key)
else:
cleaned_keys.append(key)
return cleaned_keys
if action == "left_click" or action == "click":
if "coordinate" in args:
x, y = args["coordinate"]
adj_x, adj_y = adjust_coordinates(x, y)
pyautogui_code.append(f"pyautogui.click({adj_x}, {adj_y})")
else:
pyautogui_code.append("pyautogui.click()")
elif action == "right_click":
if "coordinate" in args:
x, y = args["coordinate"]
adj_x, adj_y = adjust_coordinates(x, y)
pyautogui_code.append(
f"pyautogui.rightClick({adj_x}, {adj_y})"
)
else:
pyautogui_code.append("pyautogui.rightClick()")
elif action == "middle_click":
if "coordinate" in args:
x, y = args["coordinate"]
adj_x, adj_y = adjust_coordinates(x, y)
pyautogui_code.append(
f"pyautogui.middleClick({adj_x}, {adj_y})"
)
else:
pyautogui_code.append("pyautogui.middleClick()")
elif action == "double_click":
if "coordinate" in args:
x, y = args["coordinate"]
adj_x, adj_y = adjust_coordinates(x, y)
pyautogui_code.append(
f"pyautogui.doubleClick({adj_x}, {adj_y})"
)
else:
pyautogui_code.append("pyautogui.doubleClick()")
elif action == "triple_click":
if "coordinate" in args:
x, y = args["coordinate"]
adj_x, adj_y = adjust_coordinates(x, y)
pyautogui_code.append(
f"pyautogui.tripleClick({adj_x}, {adj_y})"
)
else:
pyautogui_code.append("pyautogui.tripleClick()")
elif action == "type":
text = args.get("text", "")
try:
text = text.encode('latin-1', 'backslashreplace').decode('unicode_escape')
except Exception as e:
logger.error(f"Failed to unescape text: {e}")
logger.info(f"Pyautogui code[before rewrite]: {text}")
result = ""
for char in text:
if char == '\n':
result += "pyautogui.press('enter')\n"
elif char == "'":
result += 'pyautogui.press("\'")\n'
elif char == '\\':
result += "pyautogui.press('\\\\')\n"
elif char == '"':
result += "pyautogui.press('\"')\n"
else:
result += f"pyautogui.press('{char}')\n"
pyautogui_code.append(result)
logger.info(f"Pyautogui code[after rewrite]: {pyautogui_code}")
elif action == "key":
keys = _clean_keys(args.get("keys", []))
keys_str = ", ".join([f"'{key}'" for key in keys])
if len(keys) > 1:
pyautogui_code.append(f"pyautogui.hotkey({keys_str})")
else:
pyautogui_code.append(f"pyautogui.press({keys_str})")
elif action == "key_down":
keys = _clean_keys(args.get("keys", []))
for k in keys:
pyautogui_code.append(f"pyautogui.keyDown('{k}')")
elif action == "key_up":
keys = _clean_keys(args.get("keys", []))
for k in reversed(keys):
pyautogui_code.append(f"pyautogui.keyUp('{k}')")
elif action == "scroll":
pixels = args.get("pixels", 0)
pyautogui_code.append(f"pyautogui.scroll({pixels})")
elif action == "wait":
pyautogui_code.append("WAIT")
elif action == "terminate":
# Termination should respect status:
# - success -> DONE
# - failure -> FAIL
# Backward compatible: missing status defaults to success.
status = args.get("status", "success")
if str(status).lower() == "failure":
pyautogui_code.append("FAIL")
else:
pyautogui_code.append("DONE")
elif action == "mouse_move":
if "coordinate" in args:
x, y = args["coordinate"]
adj_x, adj_y = adjust_coordinates(x, y)
pyautogui_code.append(
f"pyautogui.moveTo({adj_x}, {adj_y})"
)
else:
pyautogui_code.append("pyautogui.moveTo(0, 0)")
elif action == "left_click_drag":
if "coordinate" in args:
x, y = args["coordinate"]
adj_x, adj_y = adjust_coordinates(x, y)
duration = args.get("duration", 0.5)
pyautogui_code.append(
f"pyautogui.dragTo({adj_x}, {adj_y}, duration={duration})"
)
else:
pyautogui_code.append("pyautogui.dragTo(0, 0)")
except (json.JSONDecodeError, KeyError) as e:
logger.error(f"Failed to parse tool call: {e}")
lines = response.split("\n")
inside_tool_call = False
current_tool_call: List[str] = []
for line in lines:
line = line.strip()
if not line:
continue
if line.lower().startswith(("action:")):
if not low_level_instruction:
low_level_instruction = line.split("Action:")[-1].strip()
continue
if line.startswith("<tool_call>"):
inside_tool_call = True
continue
elif line.startswith("</tool_call>"):
if current_tool_call:
process_tool_call("\n".join(current_tool_call))
current_tool_call = []
inside_tool_call = False
continue
if inside_tool_call:
current_tool_call.append(line)
continue
if line.startswith("{") and line.endswith("}"):
try:
json_obj = json.loads(line)
if "name" in json_obj and "arguments" in json_obj:
process_tool_call(line)
except json.JSONDecodeError:
pass
if current_tool_call:
process_tool_call("\n".join(current_tool_call))
if not low_level_instruction and len(pyautogui_code) > 0:
first_action = pyautogui_code[0]
if "." in first_action:
action_type = first_action.split(".", 1)[1].split("(", 1)[0]
else:
action_type = first_action.lower()
low_level_instruction = f"Performing {action_type} action"
return low_level_instruction, pyautogui_code
def _predict_s1(self, instruction, obs, processed_b64):
messages = [{"role": "system", "content": S1_SYSTEM_PROMPT.format(password=self.password)}]
# Reconstruct History Logic for S1 mode
history_step_texts = []
for i in range(len(self.actions)):
cot = self.cots[i] if i < len(self.cots) else {}
# Step Content string
step_content = S1_STEP_TEMPLATE.format(step_num=i+1) + S1_ACTION_HISTORY_TEMPLATE.format(action=cot.get('action', ''))
if i > len(self.actions) - self.max_history_turns:
# Recent history: Add User(Image) and Assistant(Text)
if i < len(self.screenshots) - 1: # Screenshot exists for this step
img = self.screenshots[i]
messages.append({
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}}
]
})
messages.append({"role": "assistant", "content": step_content})
else:
# Old history: Collect text
history_step_texts.append(step_content)
# If this is the last step before the recent window, flush collected texts
if i == len(self.actions) - self.max_history_turns:
messages.append({
"role": "assistant",
"content": "\n".join(history_step_texts)
})
# Current
messages.append({
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{processed_b64}"}},
{"type": "text", "text": S1_INSTRUTION_TEMPLATE.format(instruction=instruction)}
]
})
response = self.call_llm({
"model": self.model,
"messages": messages,
"max_tokens": self.max_tokens
})
low_level, codes, cot_data = self._parse_response_s1(response)
self.observations.append(obs)
self.cots.append(cot_data)
self.actions.append(low_level)
self.responses.append(response)
return response, codes
def _parse_response_s1(self, response):
sections = {}
# Simple Regex Parsing
for key, pattern in [
('observation', r'#{1,2}\s*Observation\s*:?[\n\r]+(.*?)(?=^#{1,2}\s|$)'),
('thought', r'#{1,2}\s*Thought\s*:?[\n\r]+(.*?)(?=^#{1,2}\s|$)'),
('action', r'#{1,2}\s*Action\s*:?[\n\r]+(.*?)(?=^#{1,2}\s|$)')
]:
m = re.search(pattern, response, re.DOTALL | re.MULTILINE)
if m: sections[key] = m.group(1).strip()
code_blocks = re.findall(r'```(?:code|python)?\s*(.*?)\s*```', response, re.DOTALL | re.IGNORECASE)
code = code_blocks[-1].strip() if code_blocks else "FAIL"
sections['code'] = code
# Post-process code
if "computer.terminate" in code:
final_code = ["DONE"] if "success" in code.lower() else ["FAIL"]
elif "computer.wait" in code:
final_code = ["WAIT"]
else:
# Project coordinates
code = project_coordinate_to_absolute_scale(
code,
self.screen_size[0],
self.screen_size[1],
self.coordinate_type,
self.resize_factor
)
logger.info(f"[rewrite before]: {code}")
final_code = [rewrite_pyautogui_text_inputs(code)]
logger.info(f"[rewrite after]: {final_code}")
return sections.get('action', 'Acting'), final_code, sections
@staticmethod
def _should_giveup_on_context_error(e):
"""对于 context length 相关的错误,立即放弃重试,交给外层处理"""
error_str = str(e)
return "Too Large" in error_str or "context_length_exceeded" in error_str or "413" in error_str
@backoff.on_exception(backoff.constant, Exception, interval=30, max_tries=10, giveup=_should_giveup_on_context_error.__func__)
def call_llm(self, payload):
"""Unified OpenAI-compatible API call"""
# Get env vars
base_url = os.environ.get("OPENAI_BASE_URL", "url-xxx")
api_key = os.environ.get("OPENAI_API_KEY", "sk-xxx")
client = openai.OpenAI(base_url=base_url, api_key=api_key)
messages = payload["messages"]
log_messages(messages, "LLM Request")
params = {
"model": payload["model"],
"messages": messages,
"max_tokens": payload["max_tokens"],
"temperature": self.temperature,
"top_p": self.top_p
}
try:
resp = client.chat.completions.create(**params)
content = resp.choices[0].message.content
logger.info(f"LLM Response:\n{content}")
return content
except Exception as e:
logger.error(f"LLM Call failed: {e}")
raise e

148
mm_agents/evocua/prompts.py Normal file
View File

@ -0,0 +1,148 @@
S1_SYSTEM_PROMPT = """You are a GUI agent. You are given a task, a screenshot of the screen and your previous interactions with the computer. You need to perform a series of actions to complete the task. The password of the computer is "{password}", use it when you need sudo rights. You need to **wait** explicitly for installation, waiting website loading or running commands to finish. Don't terminate the task unless you are sure the task is finished. If you find that you can't finish the task, or the task is not finished exactly as the instruction indicates (you have made progress but not finished the task completely), or the task is impossible to complete, you must report **failure**.
For each step, provide your response in this format:
# Step: {{step number}}
## Thought:
{{thought}}
## Action:
{{action}}
## Code:
{{code}}
For the Thought section, you should include the following parts:
- Reflection on the task when there is previous action:
- Consider the correnctness of previous action and its outcomes
- If the previous action was correct, describe the change in the state of the computer and reason
- If the previous action was incorrect, reflect on what went wrong and why
- Step by Step Progress Assessment:
- Add necessary information according to the history screenshots, former actions and current screenshot.
- Analyze what parts of the task have already been completed and how they contribute to the overall goal.
- Make a plan on how to complete the task based on the history and currect screenshot.
- Next Action Prediction:
- Propose the most possible next action and state the reason
- For Text Input Actions:
- Note current cursor position
- Consolidate repetitive actions (specify count for multiple keypresses)
- Describe expected final text outcome
- Use first-person perspective in reasoning
For the action section, you should provide clear, concise, and actionable instructions in one sentence.
- If the action involves interacting with a specific target:
- Describe target explicitly (if multiple elements share that name, you should distinguish the target) without using coordinates
- Specify element names when possible (use original language if non-English)
- Describe features (shape, color, position) if name unavailable
- If the action involves keyboard actions like 'press', 'write', 'hotkey':
- Consolidate repetitive keypresses with count
- Specify expected text outcome for typing actions
For the code section, you should output the corresponding code for the action. The code should be either PyAutoGUI code or one of the following functions warped in the code block:
- {{"name": "computer.wait", "description": "Make the computer wait for 20 seconds for installation, running code, etc.", "parameters": {{"type": "object", "properties": {{}}, "required": []}}}}
- {{"name": "computer.terminate", "description": "Terminate the current task and report its completion status", "parameters": {{"type": "object", "properties": {{"status": {{"type": "string", "enum": ["success", "failure"], "description": "The status of the task"}}, {{"answer": {{"type": "string", "description": "The answer of the task"}}}}, "required": ["status"]}}}}
Examples for the code section:
```python
pyautogui.click(x=123, y=456)
```
```code
computer.terminate(status="success")
```
```code
computer.terminate(status="success", answer='''text''')
```"""
# S1 prompt templates for generating trajectories
S1_STEP_TEMPLATE = "# Step {step_num}:\n"
S1_INSTRUTION_TEMPLATE = "# Task Instruction:\n{instruction}\n\nPlease generate the next move according to the screenshot, task instruction and previous steps (if provided).\n"
S1_ACTION_HISTORY_TEMPLATE = "## Action:\n{action}\n"
# S2 Prompts
S2_ACTION_DESCRIPTION = """
* `key`: Performs key down presses on the arguments passed in order, then performs key releases in reverse order.
* `key_down`: Press and HOLD the specified key(s) down in order (no release). Use this for stateful holds like holding Shift while clicking.
* `key_up`: Release the specified key(s) in reverse order.
* `type`: Type a string of text on the keyboard.
* `mouse_move`: Move the cursor to a specified (x, y) pixel coordinate on the screen.
* `left_click`: Click the left mouse button at a specified (x, y) pixel coordinate on the screen.
* `left_click_drag`: Click and drag the cursor to a specified (x, y) pixel coordinate on the screen.
* `right_click`: Click the right mouse button at a specified (x, y) pixel coordinate on the screen.
* `middle_click`: Click the middle mouse button at a specified (x, y) pixel coordinate on the screen.
* `double_click`: Double-click the left mouse button at a specified (x, y) pixel coordinate on the screen.
* `triple_click`: Triple-click the left mouse button at a specified (x, y) pixel coordinate on the screen.
* `scroll`: Performs a scroll of the mouse scroll wheel.
* `hscroll`: Performs a horizontal scroll (mapped to regular scroll).
* `wait`: Wait specified seconds for the change to happen.
* `terminate`: Terminate the current task and report its completion status.
* `answer`: Answer a question.
"""
S2_DESCRIPTION_PROMPT_TEMPLATE = """Use a mouse and keyboard to interact with a computer, and take screenshots.
* This is an interface to a desktop GUI. You must click on desktop icons to start applications.
* Some applications may take time to start or process actions, so you may need to wait and take successive screenshots to see the results of your actions. E.g. if you click on Firefox and a window doesn't open, try wait and taking another screenshot.
{resolution_info}
* Whenever you intend to move the cursor to click on an element like an icon, you should consult a screenshot to determine the coordinates of the element before moving the cursor.
* If you tried clicking on a program or link but it failed to load even after waiting, try adjusting your cursor position so that the tip of the cursor visually falls on the element that you want to click.
* Make sure to click any buttons, links, icons, etc with the cursor tip in the center of the element. Don't click boxes on their edges unless asked."""
S2_SYSTEM_PROMPT = """# Tools
You may call one or more functions to assist with the user query.
You are provided with function signatures within <tools></tools> XML tags:
<tools>
{tools_xml}
</tools>
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
<tool_call>
{{"name": <function-name>, "arguments": <args-json-object>}}
</tool_call>
# Response format
Response format for every step:
1) Action: a short imperative describing what to do in the UI.
2) A single <tool_call>...</tool_call> block containing only the JSON: {{"name": <function-name>, "arguments": <args-json-object>}}.
Rules:
- Output exactly in the order: Action, <tool_call>.
- Be brief: one sentence for Action.
- Do not output anything else outside those parts.
- If finishing, use action=terminate in the tool call."""
def build_s2_tools_def(description_prompt):
return {
"type": "function",
"function": {
"name_for_human": "computer_use",
"name": "computer_use",
"description": description_prompt,
"parameters": {
"properties": {
"action": {
"description": S2_ACTION_DESCRIPTION,
"enum": ["key", "type", "mouse_move", "left_click", "left_click_drag",
"right_click", "middle_click", "double_click", "triple_click", "scroll",
"wait", "terminate", "key_down", "key_up"],
"type": "string"
},
"keys": {"description": "Required only by `action=key`.", "type": "array"},
"text": {"description": "Required only by `action=type`.", "type": "string"},
"coordinate": {"description": "The x,y coordinates for mouse actions.", "type": "array"},
"pixels": {"description": "The amount of scrolling.", "type": "number"},
"time": {"description": "The seconds to wait.", "type": "number"},
"status": {
"description": "The status of the task.",
"type": "string",
"enum": ["success", "failure"]
}
},
"required": ["action"],
"type": "object"
},
"args_format": "Format the arguments as a JSON object."
}
}

302
mm_agents/evocua/utils.py Normal file
View File

@ -0,0 +1,302 @@
import base64
import re
import ast
import logging
from io import BytesIO
import json
from PIL import Image
from mm_agents.utils.qwen_vl_utils import smart_resize
logger = logging.getLogger("desktopenv.evocua.utils")
def encode_image(image_content):
"""Encode image bytes to base64 string."""
return base64.b64encode(image_content).decode("utf-8")
def process_image(image_bytes, factor=32):
"""
Process an image for VL models.
factor: 32 for S2 mode, 28 for S1 mode default
"""
image = Image.open(BytesIO(image_bytes))
width, height = image.size
resized_height, resized_width = smart_resize(
height=height,
width=width,
factor=factor,
max_pixels=16 * 16 * 4 * 12800, # Large buffer
)
image = image.resize((resized_width, resized_height))
buffer = BytesIO()
image.save(buffer, format="PNG")
processed_bytes = buffer.getvalue()
return base64.b64encode(processed_bytes).decode("utf-8"), resized_width, resized_height
def _fallback_rewrite_pyautogui_text_inputs(code: str) -> str:
"""
Regex-based fallback to handle malformed pyautogui.write/typewrite calls.
"""
logger.info(f"SyntaxError detected in code, using regex fallback. Original code: {code}")
def _replacer(match):
call_content = match.group(0)
m = re.search(r'pyautogui\.(?:write|typewrite)\s*\(', call_content)
if not m:
return call_content
args_part = call_content[m.end():].strip()
args_part = re.sub(r'^(?:message|text)\s*=\s*', '', args_part)
text_content = ""
if args_part.startswith(("'''", '"""')):
quote_type = args_part[:3]
content = args_part[3:]
end_idx = content.rfind(quote_type)
if end_idx != -1:
text_content = content[:end_idx]
else:
text_content = content[:-1] if content.endswith(')') else content
elif args_part.startswith(("'", '"')):
quote_type = args_part[0]
content = args_part[1:]
if content.endswith(quote_type + ")"):
text_content = content[:-2]
elif content.endswith(")"):
if len(content) > 1 and content[-2] == quote_type:
text_content = content[:-2]
else:
text_content = content[:-1]
elif content.endswith(quote_type):
text_content = content[:-1]
else:
text_content = content
else:
text_content = args_part[:-1] if args_part.endswith(')') else args_part
new_cmds = []
for char in text_content:
p = "enter" if char == "\n" else char
p_esc = p.replace("'", "\\'")
new_cmds.append(f"pyautogui.press('{p_esc}')")
return "; ".join(new_cmds)
pattern = r"pyautogui\.(?:write|typewrite)\s*\(.*?(?=\s*;|\s*$|\n)"
new_code = re.sub(pattern, _replacer, code)
if new_code == code and ("pyautogui.write" in code or "pyautogui.typewrite" in code):
new_code = re.sub(r"pyautogui\.(?:write|typewrite)\s*\(.*", _replacer, code)
return new_code
def rewrite_pyautogui_text_inputs(code: str) -> str:
"""
Expand pyautogui.write/typewrite string literals into per-character presses.
"""
try:
tree = ast.parse(code)
class _TextCallRewriter(ast.NodeTransformer):
def _extract_text(self, call: ast.Call):
if not (
isinstance(call.func, ast.Attribute)
and isinstance(call.func.value, ast.Name)
and call.func.value.id == "pyautogui"
and call.func.attr in ("write", "typewrite")
):
return None
message_node = call.args[0] if call.args else None
if message_node is None:
for kw in call.keywords:
if kw.arg in ("message", "text"):
message_node = kw.value
break
if isinstance(message_node, ast.Constant) and isinstance(message_node.value, str):
return message_node.value
return None
def visit_Expr(self, node):
self.generic_visit(node)
if isinstance(node.value, ast.Call):
text = self._extract_text(node.value)
if text is not None:
new_nodes = []
for char in text:
press_value = "enter" if char == "\n" else char
press_call = ast.Expr(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id="pyautogui", ctx=ast.Load()),
attr="press",
ctx=ast.Load(),
),
args=[ast.Constant(value=press_value)],
keywords=[],
)
)
new_nodes.append(press_call)
return new_nodes if new_nodes else node
return node
tree = _TextCallRewriter().visit(tree)
tree = ast.fix_missing_locations(tree)
new_code = ast.unparse(tree)
return new_code
except (SyntaxError, Exception):
return _fallback_rewrite_pyautogui_text_inputs(code)
def project_coordinate_to_absolute_scale(pyautogui_code_relative_coordinates, screen_width, screen_height, coordinate_type="relative", resize_factor=28):
"""
Convert the relative coordinates in the pyautogui code to absolute coordinates based on the logical screen size.
"""
def _coordinate_projection(x, y, screen_width, screen_height, coordinate_type):
if coordinate_type == "qwen25":
height, width = smart_resize(
height=screen_height,
width=screen_width,
factor=resize_factor,
min_pixels=3136,
max_pixels=12845056
)
if 0 <= x <= 1 and 0 <= y <= 1:
# If already normalized, treat like "relative"
return int(round(x * width)), int(round(y * height))
return int(x / width * screen_width), int(y / height * screen_height)
else:
raise ValueError(f"Invalid coordinate type: {coordinate_type}. Expected 'qwen25'")
pattern = r'(pyautogui\.\w+\([^\)]*\))'
matches = re.findall(pattern, pyautogui_code_relative_coordinates)
new_code = pyautogui_code_relative_coordinates
for full_call in matches:
func_name_pattern = r'(pyautogui\.\w+)\((.*)\)'
func_match = re.match(func_name_pattern, full_call, re.DOTALL)
if not func_match:
continue
func_name = func_match.group(1)
args_str = func_match.group(2)
try:
parsed = ast.parse(f"func({args_str})").body[0].value
parsed_args = parsed.args
parsed_keywords = parsed.keywords
except SyntaxError:
continue
function_parameters = {
'click': ['x', 'y', 'clicks', 'interval', 'button', 'duration', 'pause'],
'rightClick': ['x', 'y', 'duration', 'tween', 'pause'],
'middleClick': ['x', 'y', 'duration', 'tween', 'pause'],
'doubleClick': ['x', 'y', 'interval', 'button', 'duration', 'pause'],
'tripleClick': ['x', 'y', 'interval', 'button', 'duration', 'pause'],
'moveTo': ['x', 'y', 'duration', 'tween', 'pause'],
'dragTo': ['x', 'y', 'duration', 'button', 'mouseDownUp', 'pause'],
}
func_base_name = func_name.split('.')[-1]
param_names = function_parameters.get(func_base_name, [])
args = {}
for idx, arg in enumerate(parsed_args):
if idx < len(param_names):
param_name = param_names[idx]
try:
arg_value = ast.literal_eval(arg)
args[param_name] = arg_value
except:
pass
try:
for kw in parsed_keywords:
param_name = kw.arg
arg_value = ast.literal_eval(kw.value)
args[param_name] = arg_value
except Exception as e:
logger.error(f"Error parsing keyword arguments: {e}")
continue
updated = False
if 'x' in args and 'y' in args:
try:
x_rel = float(args['x'])
y_rel = float(args['y'])
# Only project if they look like relative coords (e.g. <= 1.0 or depending on type)
# Projection applies unconditionally if type is relative
x_abs, y_abs = _coordinate_projection(x_rel, y_rel, screen_width, screen_height, coordinate_type)
# Apply coordinate transformation
args['x'] = x_abs
args['y'] = y_abs
updated = True
except (ValueError, TypeError):
pass
if updated:
reconstructed_args = []
for idx, param_name in enumerate(param_names):
if param_name in args:
arg_value = args[param_name]
if isinstance(arg_value, str):
arg_repr = f"'{arg_value}'"
else:
arg_repr = str(arg_value)
reconstructed_args.append(arg_repr)
else:
break
used_params = set(param_names[:len(reconstructed_args)])
for kw in parsed_keywords:
if kw.arg not in used_params:
arg_value = args[kw.arg]
if isinstance(arg_value, str):
arg_repr = f"{kw.arg}='{arg_value}'"
else:
arg_repr = f"{kw.arg}={arg_value}"
reconstructed_args.append(arg_repr)
new_args_str = ', '.join(reconstructed_args)
new_full_call = f"{func_name}({new_args_str})"
new_code = new_code.replace(full_call, new_full_call)
return new_code
def log_messages(messages, prefix="LLM Messages"):
"""Log messages with truncated base64 images"""
try:
log_msgs = []
for msg in messages:
msg_copy = msg.copy()
content = msg.get("content")
if isinstance(content, list):
new_content = []
for item in content:
if isinstance(item, dict) and item.get("type") == "image_url":
item_copy = item.copy()
url = item_copy.get("image_url", {}).get("url", "")
if len(url) > 100:
item_copy["image_url"] = {"url": url[:30] + "...[base64_truncated]..." + url[-10:]}
new_content.append(item_copy)
else:
new_content.append(item)
msg_copy["content"] = new_content
log_msgs.append(msg_copy)
logger.info(f"{prefix}:\n{json.dumps(log_msgs, indent=2, ensure_ascii=False)}")
except Exception as e:
logger.warning(f"Failed to log messages: {e}")

View File

@ -0,0 +1,190 @@
"""
Hosted GBOX Agent Client
Thin HTTP wrapper that calls the hosted GBOX service
"""
import os
import logging
import requests
from typing import Dict, List, Tuple
logger = logging.getLogger("hosted-gbox-agent")
class HostedGboxAgent:
"""
Client wrapper for hosted GBOX service.
Follows the same interface as other OSWorld agents but delegates execution to remote service.
"""
def __init__(
self,
server_url: str,
api_key: str,
vm_ip: str,
platform: str = "ubuntu",
model: str = "claude-sonnet-4-5",
max_steps: int = 15,
**kwargs
):
"""
Initialize hosted agent client
Args:
server_url: URL of hosted GBOX service (e.g., "http://44.201.221.203:8000")
api_key: API key for authentication
vm_ip: IP address of the VM to control
platform: OS platform (ubuntu/windows)
model: Claude model to use
max_steps: Maximum steps per task
"""
self.server_url = server_url.rstrip('/')
self.api_key = api_key
self.vm_ip = vm_ip
self.platform = platform
self.model = model
self.max_steps = max_steps
self.runtime_logger = None
# HTTP client with timeout
self.client = requests.Session()
self.client.headers.update({"X-API-Key": api_key})
logger.info(f"Initialized hosted agent client for VM {vm_ip}")
logger.info(f"Server: {server_url}, Model: {model}")
def reset(self, runtime_logger=None, vm_ip: str = None):
"""
Reset agent state (called by OSWorld before each task)
Args:
runtime_logger: Logger instance for OSWorld runtime logs
vm_ip: Updated VM IP (in case of snapshot revert)
"""
self.runtime_logger = runtime_logger
if vm_ip:
self.vm_ip = vm_ip
if self.runtime_logger:
self.runtime_logger.info(f"[HOSTED] Updated VM IP to {vm_ip}")
if self.runtime_logger:
self.runtime_logger.info(f"[HOSTED] Agent reset for VM {self.vm_ip}")
def predict(self, instruction: str, obs: Dict) -> Tuple[str, List[str]]:
"""
Execute task prediction (one call = full task execution)
Args:
instruction: Task instruction
obs: Observation dict (not used - agent fetches its own screenshots)
Returns:
(reasoning_text, actions_list)
- reasoning_text: Claude's reasoning/explanation
- actions_list: ["DONE"] or ["FAIL"] or PyAutoGUI code
"""
try:
# Prepare request (no screenshot needed - agent fetches its own)
payload = {
"vm_ip": self.vm_ip,
"instruction": instruction,
"platform": self.platform,
"model": self.model,
"max_steps": self.max_steps
}
# Log request
if self.runtime_logger:
self.runtime_logger.info(f"[HOSTED] Sending task to service...")
self.runtime_logger.info(f"[HOSTED] Instruction: {instruction[:100]}...")
# Call hosted service (this may take several minutes)
response = self.client.post(
f"{self.server_url}/execute",
json=payload,
timeout=3600 # 60 minutes timeout for full task execution
)
# Check for errors
if response.status_code == 401:
raise RuntimeError("Authentication failed - invalid API key")
elif response.status_code != 200:
raise RuntimeError(f"Service returned {response.status_code}: {response.text}")
# Parse response
result = response.json()
reasoning = result.get("reasoning", "")
actions = result.get("actions", ["FAIL"])
logs = result.get("logs", "")
session_id = result.get("session_id", "unknown")
# Forward server logs to OSWorld's runtime logger
if logs and self.runtime_logger:
for line in logs.split('\n'):
if line.strip():
self.runtime_logger.info(f"[SERVER] {line}")
# Log results
if self.runtime_logger:
self.runtime_logger.info(f"[HOSTED] Session ID: {session_id}")
self.runtime_logger.info(f"[HOSTED] Actions: {actions}")
self.runtime_logger.info(f"[HOSTED] Reasoning: {reasoning[:200]}...")
return reasoning, actions
except requests.Timeout:
error_msg = "Service timeout (task took longer than 60 minutes)"
logger.error(error_msg)
if self.runtime_logger:
self.runtime_logger.error(f"[HOSTED] {error_msg}")
return f"ERROR: {error_msg}", ["FAIL"]
except requests.ConnectionError as e:
error_msg = f"Cannot connect to service at {self.server_url}: {str(e)}"
logger.error(error_msg)
if self.runtime_logger:
self.runtime_logger.error(f"[HOSTED] {error_msg}")
return f"ERROR: {error_msg}", ["FAIL"]
except Exception as e:
error_msg = f"Hosted agent error: {str(e)}"
logger.error(error_msg, exc_info=True)
if self.runtime_logger:
self.runtime_logger.error(f"[HOSTED] {error_msg}")
return f"ERROR: {error_msg}", ["FAIL"]
def close(self):
"""Close HTTP session"""
self.client.close()
def __del__(self):
"""Cleanup on deletion"""
try:
self.close()
except:
pass
# Factory function for compatibility with OSWorld runner
def create_agent(vm_ip: str, **kwargs) -> HostedGboxAgent:
"""
Factory function to create hosted agent
Expects environment variables:
- GBOX_SERVICE_URL: URL of hosted service
- GBOX_SERVICE_API_KEY: API key for authentication
"""
server_url = os.getenv("GBOX_SERVICE_URL")
api_key = os.getenv("GBOX_SERVICE_API_KEY")
if not server_url:
raise ValueError("GBOX_SERVICE_URL environment variable not set")
if not api_key:
raise ValueError("GBOX_SERVICE_API_KEY environment variable not set")
return HostedGboxAgent(
server_url=server_url,
api_key=api_key,
vm_ip=vm_ip,
**kwargs
)

View File

@ -0,0 +1,3 @@
from mm_agents.opencua.opencua_agent import OpenCUAAgent
__all__ = ["OpenCUAAgent"]

View File

@ -0,0 +1,470 @@
"""
OpenCUA Agent Implementation
This module implements an OpenCUA agent for desktop automation tasks, building upon
existing frameworks and integrating multiple coordinate mapping systems.
Framework and Implementation Sources:
- Main framework structure follows: https://github.com/xlang-ai/OSWorld/blob/main/mm_agents/agent.py
- Agent implementation adapted from: https://github.com/xlang-ai/OSWorld/blob/main/mm_agents/aguvis_agent.py
- Qwen2.5-VL coordinate mapping from: https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py
"""
import re
import os
import ast
import time
import math
import httpx
import base64
import backoff
import traceback
from loguru import logger
from typing import Dict, List, Tuple, Optional
from mm_agents.opencua.utils import (
encode_image,
smart_resize,
)
from mm_agents.opencua.prompts import (
INSTRUTION_TEMPLATE,
STEP_TEMPLATE,
ACTION_HISTORY_TEMPLATE,
THOUGHT_HISTORY_TEMPLATE,
OBSERVATION_HISTORY_TEMPLATE,
# OpenCUA-7B, 32B system prompts
SYSTEM_PROMPT_V1_L1,
SYSTEM_PROMPT_V1_L2,
SYSTEM_PROMPT_V1_L3,
# OpenCUA-72B system prompts
build_sys_prompt,
)
def parse_response_to_cot_and_action(input_string, screen_size, coordinate_type) -> Tuple[str, List[str], dict]:
"""Parse response including Observation, Thought, Action and code block"""
sections = {}
try:
obs_match = re.search(r'^##\s*Observation\s*:?[\n\r]+(.*?)(?=^##\s*Thought:|^##\s*Action:|^##|\Z)', input_string, re.DOTALL | re.MULTILINE)
if obs_match:
sections['observation'] = obs_match.group(1).strip()
thought_match = re.search(r'^##\s*Thought\s*:?[\n\r]+(.*?)(?=^##\s*Action:|^##|\Z)', input_string, re.DOTALL | re.MULTILINE)
if thought_match:
sections['thought'] = thought_match.group(1).strip()
action_match = re.search(r'^##\s*Action\s*:?[\n\r]+(.*?)(?=^##|\Z)', input_string, re.DOTALL | re.MULTILINE)
if action_match:
action = action_match.group(1).strip()
sections['action'] = action.strip()
code_blocks = re.findall(r'```(?:code|python)?\s*(.*?)\s*```', input_string, re.DOTALL | re.IGNORECASE)
if not code_blocks:
logger.error("No code blocks found in the input string")
return f"<Error>: no code blocks found in the input string: {input_string}", ["FAIL"], sections
code_block = code_blocks[-1].strip()
sections['original_code'] = code_block
if "computer.wait" in code_block.lower():
sections["code"] = "WAIT"
return sections['action'], ["WAIT"], sections
elif "computer.terminate" in code_block.lower():
lower_block = code_block.lower()
if ("failure" in lower_block) or ("fail" in lower_block):
sections['code'] = "FAIL"
return code_block, ["FAIL"], sections
elif "success" in lower_block:
sections['code'] = "DONE"
return code_block, ["DONE"], sections
else:
logger.error("Terminate action found but no specific status provided in code block")
return f"<Error>: terminate action found but no specific status provided in code block: {input_string}", ["FAIL"], sections
# corrected_code = correct_pyautogui_arguments(code_block)
corrected_code = code_block
sections['code'] = corrected_code
sections['code'] = project_coordinate_to_absolute_scale(corrected_code, screen_width=screen_size[0], screen_height=screen_size[1], coordinate_type=coordinate_type)
if ('code' not in sections or sections['code'] is None or sections['code'] == "") or ('action' not in sections or sections['action'] is None or sections['action'] == ""):
logger.error("Missing required action or code section")
return f"<Error>: no code parsed: {input_string}", ["FAIL"], sections
return sections['action'], [sections['code']], sections
except Exception as e:
error_message = f"<Error>: parsing response: {str(e)}\nTraceback:\n{traceback.format_exc()}\nInput string: {input_string}"
logger.error(error_message)
return error_message, ['FAIL'], sections
def project_coordinate_to_absolute_scale(pyautogui_code_relative_coordinates, screen_width, screen_height, coordinate_type="relative"):
"""
Convert the relative coordinates in the pyautogui code to absolute coordinates based on the logical screen size.
"""
def _coordinate_projection(x, y, screen_width, screen_height, coordinate_type):
if coordinate_type == "relative":
return int(round(x * screen_width)), int(round(y * screen_height))
elif coordinate_type == "qwen25":
height, width = smart_resize(
height=screen_height,
width=screen_width,
factor=28,
min_pixels=3136,
max_pixels=12845056
)
if 0 <= x <= 1 and 0 <= y <= 1:
# If already normalized, treat like "relative"
return int(round(x * width)), int(round(y * height))
return int(x / width * screen_width), int(y / height * screen_height)
else:
raise ValueError(f"Invalid coordinate type: {coordinate_type}. Expected one of ['relative', 'relative1000', 'absolute', 'qwen25'].")
pattern = r'(pyautogui\.\w+\([^\)]*\))'
matches = re.findall(pattern, pyautogui_code_relative_coordinates)
new_code = pyautogui_code_relative_coordinates
for full_call in matches:
func_name_pattern = r'(pyautogui\.\w+)\((.*)\)'
func_match = re.match(func_name_pattern, full_call, re.DOTALL)
if not func_match:
continue
func_name = func_match.group(1)
args_str = func_match.group(2)
try:
parsed = ast.parse(f"func({args_str})").body[0].value
parsed_args = parsed.args
parsed_keywords = parsed.keywords
except SyntaxError:
return pyautogui_code_relative_coordinates
function_parameters = {
'click': ['x', 'y', 'clicks', 'interval', 'button', 'duration', 'pause'],
'rightClick': ['x', 'y', 'duration', 'tween', 'pause'],
'middleClick': ['x', 'y', 'duration', 'tween', 'pause'],
'doubleClick': ['x', 'y', 'interval', 'button', 'duration', 'pause'],
'tripleClick': ['x', 'y', 'interval', 'button', 'duration', 'pause'],
'moveTo': ['x', 'y', 'duration', 'tween', 'pause'],
'dragTo': ['x', 'y', 'duration', 'button', 'mouseDownUp', 'pause'],
}
func_base_name = func_name.split('.')[-1]
param_names = function_parameters.get(func_base_name, [])
args = {}
for idx, arg in enumerate(parsed_args):
if idx < len(param_names):
param_name = param_names[idx]
arg_value = ast.literal_eval(arg)
args[param_name] = arg_value
try:
for kw in parsed_keywords:
param_name = kw.arg
arg_value = ast.literal_eval(kw.value)
args[param_name] = arg_value
except Exception as e:
logger.error(f"Error parsing keyword arguments: {e}")
return pyautogui_code_relative_coordinates
updated = False
if 'x' in args and 'y' in args:
try:
x_rel = float(args['x'])
y_rel = float(args['y'])
x_abs, y_abs = _coordinate_projection(x_rel, y_rel, screen_width, screen_height, coordinate_type)
logger.warning(f"Projecting coordinates: ({x_rel}, {y_rel}) to ({x_abs}, {y_abs}) using {coordinate_type} projection.")
args['x'] = x_abs
args['y'] = y_abs
updated = True
except ValueError:
pass
if updated:
reconstructed_args = []
for idx, param_name in enumerate(param_names):
if param_name in args:
arg_value = args[param_name]
if isinstance(arg_value, str):
arg_repr = f"'{arg_value}'"
else:
arg_repr = str(arg_value)
reconstructed_args.append(arg_repr)
else:
break
used_params = set(param_names[:len(reconstructed_args)])
for kw in parsed_keywords:
if kw.arg not in used_params:
arg_value = args[kw.arg]
if isinstance(arg_value, str):
arg_repr = f"{kw.arg}='{arg_value}'"
else:
arg_repr = f"{kw.arg}={arg_value}"
reconstructed_args.append(arg_repr)
new_args_str = ', '.join(reconstructed_args)
new_full_call = f"{func_name}({new_args_str})"
new_code = new_code.replace(full_call, new_full_call)
return new_code
def transform_agnet_action_to_code_block(action):
if any(keyword in action for keyword in ["computer.terminate", "computer.wait", "browser.select_option", "browser.clear"]):
return f"```code\n{action}\n```"
else:
return f"```python\n{action}\n```"
class OpenCUAAgent:
"""
OpenCUA Agent for desktop automation tasks.
This class implements a OpenCUA Model based agent that can observe
desktop environments through screenshots and execute mouse/keyboard actions
via PyAutoGUI to complete automation tasks.
Attributes:
model (str): Name of the language model being used
history_type (str): Type of history recording mechanism
actions (list): History of executed actions
observations (list): History of environment observations
cots (list): Chain of thought reasoning records
"""
def __init__(
self,
model: str, # OpenCUA model name
history_type: str, # History step type: action_history, thought_history, observation_history
max_steps: int, # The max number of steps to finish the task
max_image_history_length: int = 3, # The max number of images in the history
platform: str = "ubuntu", # The platform of the computer
max_tokens: int = 1500, # The max number of tokens in the response
top_p: float = 0.9, # The top p value in the response
temperature: float = 0, # The temperature value in the response
action_space: str = "pyautogui", # The action space: pyautogui
observation_type: str = "screenshot", # The observation type: screenshot
cot_level: str = "l2", # The CoT level: l1, l2, l3
screen_size: Tuple[int, int] = (1920, 1080), # The screen size
coordinate_type: str = "relative", # The coordinate type: relative, absolute, qwen25
use_old_sys_prompt: bool = False, # Whether to use the old system prompt
password="osworld-public-evaluation", # The password for the ubuntu platform
**kwargs
):
assert coordinate_type in ["relative", "absolute", "qwen25"]
assert action_space in ["pyautogui"], "Invalid action space"
assert observation_type in ["screenshot"], "Invalid observation type"
assert history_type in ["action_history", "thought_history", "observation_history"]
assert model is not None, "Model cannot be None"
self.model = model
self.platform = platform
self.max_tokens = max_tokens
self.top_p = top_p
self.temperature = temperature
self.action_space = action_space
self.observation_type = observation_type
self.history_type = history_type
self.coordinate_type = coordinate_type
self.cot_level = cot_level
self.screen_size = screen_size
self.max_image_history_length = max_image_history_length
self.max_steps = max_steps
self.password = password
if history_type == "action_history":
self.HISTORY_TEMPLATE = ACTION_HISTORY_TEMPLATE
elif history_type == "thought_history":
self.HISTORY_TEMPLATE = THOUGHT_HISTORY_TEMPLATE
elif history_type == "observation_history":
self.HISTORY_TEMPLATE = OBSERVATION_HISTORY_TEMPLATE
else:
raise ValueError(f"Invalid history type: {history_type}")
if use_old_sys_prompt:
if cot_level == "l1":
self.system_prompt = SYSTEM_PROMPT_V1_L1
elif cot_level == "l2":
self.system_prompt = SYSTEM_PROMPT_V1_L2
elif cot_level == "l3":
self.system_prompt = SYSTEM_PROMPT_V1_L3
else:
raise ValueError("Invalid cot_level. Choose from 'l1', 'l2', or 'l3'.")
else:
self.system_prompt = build_sys_prompt(
level=self.cot_level,
password=self.password,
use_random=False
)
self.actions = []
self.observations = []
self.cots = []
def reset(self, _logger=None):
global logger
logger = _logger if _logger is not None else logging.getLogger("desktopenv.agent")
self.observations = []
self.cots = []
self.actions = []
def _scale_scroll_for_windows(self, code: str, factor: int = 50) -> str:
""" pyautogui.scroll has a different scale on Ubuntu and Windows, multiple 'factor' when scrolling on Windows system"""
if self.platform.lower() != "windows":
return code
pattern_pos = re.compile(r'(pyautogui\.scroll\()\s*([-+]?\d+)\s*\)')
code = pattern_pos.sub(lambda m: f"{m.group(1)}{int(m.group(2))*factor})", code)
return code
def predict(self, instruction: str, obs: Dict, **kwargs) -> Tuple[str, List[str], Dict]:
"""
Predict the next action(s) based on the current observation.
"""
if "step_idx" in kwargs:
logger.info(f"========= {self.model} Step {kwargs['step_idx']} =======")
else:
logger.info(f"========================== {self.model} ===================================")
logger.info(f"Instruction: \n{instruction}")
messages = []
messages.append({
"role": "system",
"content": self.system_prompt
})
instruction_prompt = INSTRUTION_TEMPLATE.format(instruction=instruction)
history_step_texts = []
for i in range(len(self.actions)):
if i > len(self.actions) - self.max_image_history_length:
messages.append({
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{encode_image(self.observations[i]['screenshot'])}"}
}
]
})
history_content = STEP_TEMPLATE.format(step_num=i+1) + self.HISTORY_TEMPLATE.format(
observation=self.cots[i].get('observation'),
thought=self.cots[i].get('thought'),
action=self.cots[i].get('action')
)
messages.append({
"role": "assistant",
"content": history_content
})
else:
history_content = STEP_TEMPLATE.format(step_num=i+1) + self.HISTORY_TEMPLATE.format(
observation=self.cots[i].get('observation'),
thought=self.cots[i].get('thought'),
action=self.cots[i].get('action')
)
history_step_texts.append(history_content)
if i == len(self.actions) - self.max_image_history_length:
messages.append({
"role":"assistant",
"content": "\n".join(history_step_texts)
})
messages.append({
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{encode_image(obs['screenshot'])}"}
},
{
"type": "text",
"text": instruction_prompt
}
]
})
max_retry = 5
retry_count = 0
low_level_instruction = None
pyautogui_actions = None
other_cot = {}
while retry_count < max_retry:
try:
response = self.call_llm({
"model": self.model,
"messages": messages,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"temperature": self.temperature if retry_count==0 else max(0.2, self.temperature)
}, self.model)
logger.info(f"Model Output: \n{response}")
if not response:
logger.error("No response found in the response.")
raise ValueError(f"No response found in the response:\n{response}.")
low_level_instruction, pyautogui_actions, other_cot = parse_response_to_cot_and_action(response, self.screen_size, self.coordinate_type)
if "<Error>" in low_level_instruction or not pyautogui_actions:
logger.error(f"Error parsing response: {low_level_instruction}")
raise ValueError(f"Error parsing response: {low_level_instruction}")
break
except Exception as e:
logger.error(f"Error during message preparation: {e}")
retry_count += 1
if retry_count == max_retry:
logger.error("Maximum retries reached. Exiting.")
return str(e), ['FAIL'], other_cot
pyautogui_actions = [
self._scale_scroll_for_windows(code) for code in pyautogui_actions
]
logger.info(f"Action: \n{low_level_instruction}")
logger.info(f"Code: \n{pyautogui_actions}")
self.observations.append(obs)
self.actions.append(low_level_instruction)
self.cots.append(other_cot)
current_step = len(self.actions)
if current_step >= self.max_steps and 'computer.terminate' not in pyautogui_actions[0].lower():
logger.warning(f"Reached maximum steps {self.max_steps}. Forcing termination.")
low_level_instruction = 'Fail the task because reaching the maximum step limit.'
pyautogui_actions = ['FAIL']
other_cot['code'] = 'FAIL'
return response, pyautogui_actions, other_cot
def call_llm(self, payload, model):
"""Call the LLM API"""
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ['OPENCUA_API_KEY']}"
}
for _ in range(20):
response = httpx.post(
f"https://{self.model}.app.msh.team/v1/chat/completions",
headers=headers,
json=payload,
timeout=500,
verify=False
)
if response.status_code != 200:
logger.error("Failed to call LLM: " + response.text)
logger.error("Retrying...")
time.sleep(5)
else:
response = response.json()
finish_reason = response["choices"][0].get("finish_reason")
if finish_reason is not None and finish_reason == "stop": # for most of the time, length will not exceed max_tokens
return response['choices'][0]['message']['content']
else:
logger.error("LLM did not finish properly, retrying...")
time.sleep(5)

View File

@ -0,0 +1,349 @@
import random
# System prompt for OpenCUA-7B, OpenCUA-32B
# System prompts used in the training data
SYSTEM_PROMPT_V1_L1 = "You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task.\n\nFor each step, provide your response in this format:\n\nAction:\n Provide clear, concise, and actionable instructions:\n - If the action involves interacting with a specific target:\n - Describe target explicitly without using coordinates\n - Specify element names when possible (use original language if non-English)\n - Describe features (shape, color, position) if name unavailable\n - For window control buttons, identify correctly (minimize \"\", maximize \"\", close \"X\")\n - if the action involves keyboard actions like 'press', 'write', 'hotkey':\n - Consolidate repetitive keypresses with count\n - Specify expected text outcome for typing actions\n\nFinally, output the action as PyAutoGUI code or the following functions:\n- {\"name\": \"computer.triple_click\", \"description\": \"Triple click on the screen\", \"parameters\": {\"type\": \"object\", \"properties\": {\"x\": {\"type\": \"number\", \"description\": \"The x coordinate of the triple click\"}, \"y\": {\"type\": \"number\", \"description\": \"The y coordinate of the triple click\"}}, \"required\": [\"x\", \"y\"]}}\n- {\"name\": \"computer.terminate\", \"description\": \"Terminate the current task and report its completion status\", \"parameters\": {\"type\": \"object\", \"properties\": {\"status\": {\"type\": \"string\", \"enum\": [\"success\", \"fail\"], \"description\": \"The status of the task\"}}, \"required\": [\"status\"]}}".strip()
SYSTEM_PROMPT_V1_L2 = "You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task.\n\nFor each step, provide your response in this format:\n\nThought:\n - Step by Step Progress Assessment:\n - Analyze completed task parts and their contribution to the overall goal\n - Reflect on potential errors, unexpected results, or obstacles\n - If previous action was incorrect, predict a logical recovery step\n - Next Action Analysis:\n - List possible next actions based on current state\n - Evaluate options considering current state and previous actions\n - Propose most logical next action\n - Anticipate consequences of the proposed action\n - For Text Input Actions:\n - Note current cursor position\n - Consolidate repetitive actions (specify count for multiple keypresses)\n - Describe expected final text outcome\n - Use first-person perspective in reasoning\n\nAction:\n Provide clear, concise, and actionable instructions:\n - If the action involves interacting with a specific target:\n - Describe target explicitly without using coordinates\n - Specify element names when possible (use original language if non-English)\n - Describe features (shape, color, position) if name unavailable\n - For window control buttons, identify correctly (minimize \"\", maximize \"\", close \"X\")\n - if the action involves keyboard actions like 'press', 'write', 'hotkey':\n - Consolidate repetitive keypresses with count\n - Specify expected text outcome for typing actions\n\nFinally, output the action as PyAutoGUI code or the following functions:\n- {\"name\": \"computer.triple_click\", \"description\": \"Triple click on the screen\", \"parameters\": {\"type\": \"object\", \"properties\": {\"x\": {\"type\": \"number\", \"description\": \"The x coordinate of the triple click\"}, \"y\": {\"type\": \"number\", \"description\": \"The y coordinate of the triple click\"}}, \"required\": [\"x\", \"y\"]}}\n- {\"name\": \"computer.terminate\", \"description\": \"Terminate the current task and report its completion status\", \"parameters\": {\"type\": \"object\", \"properties\": {\"status\": {\"type\": \"string\", \"enum\": [\"success\", \"fail\"], \"description\": \"The status of the task\"}}, \"required\": [\"status\"]}}".strip()
SYSTEM_PROMPT_V1_L3 = "You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task.\n\nFor each step, provide your response in this format:\n\nObservation:\n - Describe the current computer state based on the full screenshot in detail. \n - Application Context:\n - The active application\n - The active window or page\n - Overall layout and visible interface\n - Key Elements:\n - Menu items and toolbars \n - Buttons and controls\n - Text fields and content\n - Dialog boxes or popups\n - Error messages or notifications\n - Loading states\n - Other key elements\n - Describe any content, elements, options, information or clues that are possibly relevant to achieving the task goal, including their name, content, or shape (if possible).\n\nThought:\n - Step by Step Progress Assessment:\n - Analyze completed task parts and their contribution to the overall goal\n - Reflect on potential errors, unexpected results, or obstacles\n - If previous action was incorrect, predict a logical recovery step\n - Next Action Analysis:\n - List possible next actions based on current state\n - Evaluate options considering current state and previous actions\n - Propose most logical next action\n - Anticipate consequences of the proposed action\n - For Text Input Actions:\n - Note current cursor position\n - Consolidate repetitive actions (specify count for multiple keypresses)\n - Describe expected final text outcome\n - Use first-person perspective in reasoning\n\nAction:\n Provide clear, concise, and actionable instructions:\n - If the action involves interacting with a specific target:\n - Describe target explicitly without using coordinates\n - Specify element names when possible (use original language if non-English)\n - Describe features (shape, color, position) if name unavailable\n - For window control buttons, identify correctly (minimize \"\", maximize \"\", close \"X\")\n - if the action involves keyboard actions like 'press', 'write', 'hotkey':\n - Consolidate repetitive keypresses with count\n - Specify expected text outcome for typing actions\n\nFinally, output the action as PyAutoGUI code or the following functions:\n- {\"name\": \"computer.triple_click\", \"description\": \"Triple click on the screen\", \"parameters\": {\"type\": \"object\", \"properties\": {\"x\": {\"type\": \"number\", \"description\": \"The x coordinate of the triple click\"}, \"y\": {\"type\": \"number\", \"description\": \"The y coordinate of the triple click\"}}, \"required\": [\"x\", \"y\"]}}\n- {\"name\": \"computer.terminate\", \"description\": \"Terminate the current task and report its completion status\", \"parameters\": {\"type\": \"object\", \"properties\": {\"status\": {\"type\": \"string\", \"enum\": [\"success\", \"fail\"], \"description\": \"The status of the task\"}}, \"required\": [\"status\"]}}\n".strip()
# Testing prompt on OSWorld-Verified
SYSTEM_PROMPT_V1_L2 = """You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task. The password of the computer is "osworld-public-evaluation". If the task is not possible to do, output the action computer.terminate(status='failure').
For each step, provide your response in this format:
Thought:\n - Step by Step Progress Assessment:\n - Analyze completed task parts and their contribution to the overall goal\n - Reflect on potential errors, unexpected results, or obstacles\n - If previous action was incorrect, predict a logical recovery step\n - Next Action Analysis:\n - List possible next actions based on current state\n - Evaluate options considering current state and previous actions\n - Propose most logical next action\n - Anticipate consequences of the proposed action\n - For Text Input Actions:\n - Note current cursor position\n - Consolidate repetitive actions (specify count for multiple keypresses)\n - Describe expected final text outcome\n - Use first-person perspective in reasoning
Action:\n Provide clear, concise, and actionable instructions:\n - If the action involves interacting with a specific target:\n - Describe target explicitly without using coordinates\n - Specify element names when possible (use original language if non-English)\n - Describe features (shape, color, position) if name unavailable\n - For window control buttons, identify correctly (minimize "", maximize "", close "X")\n - if the action involves keyboard actions like \'press\', \'write\', \'hotkey\':\n - Consolidate repetitive keypresses with count\n - Specify expected text outcome for typing actions
Finally, output the action as PyAutoGUI code or the following functions:
- {"name": "computer.triple_click", "description": "Triple click on the screen", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The x coordinate of the triple click"}, "y": {"type": "number", "description": "The y coordinate of the triple click"}}, "required": ["x", "y"]}}
- {"name": "computer.terminate", "description": "Terminate the current task and report its completion status", "parameters": {"type": "object", "properties": {"status": {"type": "string", "enum": ["success", "failure"], "description": "The status of the task"}}, "required": ["status"]}}
""".strip()
# SYSTEM_PROMPT for OpenCUA-72B
general_computer_instructions = [
"""
You are a GUI agent. You are given a task, a screenshot of the screen and your previous interactions with the computer. You need to perform a series of actions to complete the task. The password of the computer is "{password}", use it when you need sudo rights. You need to **wait** explicitly for installation, waiting website loading or running commands to finish. Don\'t terminate the task unless you are sure the task is finished. If you find that you can\'t finish the task, or the task is not finished exactly as the instruction indicates (you have made progress but not finished the task completely), or the task is impossible to complete, you must report **failure**.
""".strip(),
"""
You are acting as a GUI agent. A task description, a screenshot, and your past interactions will be supplied. Execute the necessary steps to fulfil the task. Whenever sudo operations are required, use the computers password "{password}". Insert an explicit **wait** after launching any installation, waiting website loading or long-running command to let it finish. Do not output terminate action unless you are certain the task is complete. If you realise the task can be finished or impossible to do, you should report **failure**.
""".strip(),
"""
Your mission as a GUI agent is to complete the provided task using the current screen image and the history of interactions. For commands requiring elevated privileges, supply "{password}" as the sudo password. Explicitly invoke **wait** after launching any installation or command that may take time to finish. Do not terminate the session unless success is certain. If the task cannot be fully executed, or turns out impossible, you must declare **failure**.
""".strip(),
]
l3_format_instruction = """For each step, provide your response in this format:
# Step: {step number}
## Observation:
{observation}
## Thought:
{thought}
## Action:
{action}
## Code:
{code}"""
l2_format_instruction = """For each step, provide your response in this format:
# Step: {step number}
## Thought:
{thought}
## Action:
{action}
## Code:
{code}"""
l1_format_instruction = """For each step, provide your response in this format:
# Step: {step number}
## Action:
{action}
## Code:
{code}"""
observation_instructions = [
"""For the Observation section, you should include the following parts if helpful:
- Describe the current computer state based on the full screenshot in detail.
- Application Context:
- The active application
- The active window or page
- Overall layout and visible interface
- Key Elements:
- Menu items and toolbars
- Buttons and controls
- Text fields and content
- Dialog boxes or popups
- Error messages or notifications
- Loading states
- Other key elements
- Describe any content, elements, options, information or clues that are possibly relevant to achieving the task goal, including their name, content, or shape (if possible).
""".strip(),
"""In the Observation section, outline everything visible on screen that could influence your next move:
Current system state as seen in the screenshot.
Application context:
- Which application is running in the foreground
- Specific window, tab, or page being displayed
- High-level layout of panels, sidebars, and work areas
Salient interface elements:
- Menus, ribbons, and toolbars
- Actionable buttons, icons, toggles, and controls
- Input areas such as text boxes or code editors
- Pop-up dialogs, modals, alerts, or system notifications
- Progress bars, spinners, or other loading indicators
Any text, labels, shapes, or on-screen cues that might help accomplish the task (cite names or visual traits when available).
""".strip(),
# ── Variant 3 ──────────────────────────────────────────────────────────
"""Write the Observation section as a thorough snapshot of the UI:
- Start with a full-screen description: what the user sees at a glance.
- Give application details: title, active workspace, and structural layout.
- Enumerate critical elements:
* Navigation menus and context bars
* Primary and secondary buttons or icons
* Editable fields, lists, tables, or rich-text areas
* Dialogs, pop-ups, warnings, or confirmations
* Indicators of loading or processing activity
- Note any evidence, hints, or data (textual or visual) that could guide the task toward completion, referencing names, colors, shapes, or positions when explicit identifiers are missing.
""".strip(),
]
thought_instructions = [
"""For the Thought section, you should include the following parts:
- Reflection on the task when there is previous action:
- Consider the correnctness of previous action and its outcomes
- If the previous action was correct, describe the change in the state of the computer and reason
- If the previous action was incorrect, reflect on what went wrong and why
- Step by Step Progress Assessment:
- Add necessary information according to the history screenshots, former actions and current screenshot.
- Analyze what parts of the task have already been completed and how they contribute to the overall goal.
- Make a plan on how to complete the task based on the history and currect screenshot.
- Next Action Prediction:
- Propose the most possible next action and state the reason
- For Text Input Actions:
- Note current cursor position
- Consolidate repetitive actions (specify count for multiple keypresses)
- Describe expected final text outcome
- Use first-person perspective in reasoning
""".strip(),
"""
In the **Thought** block, cover these topics:
1. **Last-Step Reflection** (when a prior action exists)
Was my previous action correct? What evidence shows this?
If it succeeded, what state change occurred and why?
If it failed, where did I go wrong?
2. **Incremental Progress Audit**
Which sub-tasks are completed and how do they advance the mission?
Make a plan to finish the task based on past actions and the current UI state.
3. **Foresight for the Coming Action**
Predict the most logical next step.
State the reason why it is the best choice given the current context.
4. **Guidance for Text Entry**
Note the cursor location
Compress multiple identical keystrokes (e.g., press Backspace ×3)
Clarify the exact text expected after input
Use first-person inner dialogue throughout.
""".strip(),
"""
Compose your **Thought** section as an internal monologue that includes:
- **Retrospective** (if a prior step exists):
* Evaluate the accuracy and effect of the last action.
* If it was successful, reason about the resulting interface change.
* If it was faulty, diagnose the misstep and its cause.
- **Ongoing Progress Evaluation**:
* Outline which parts of the task are done and their impact on the overall objective.
* Suggest a plan to complete the task based on past history and the current screen.
- **Decision Framework for the Next Move**:
* Brainstorm possible next action given the present state.
* Explain why this action is the most logical choice.
- **Special Rules for Keyboard Input**:
* Specify current cursor focus or field.
* Merge repeated keypresses into counts for brevity.
* Describe the intended final text after typing.
Maintain a first-person voice for clarity of reasoning.
""".strip(),
]
action_instructions = [
"""For the action section, you should provide clear, concise, and actionable instructions in one sentence.
- If the action involves interacting with a specific target:
- Describe target explicitly (if multiple elements share that name, you should distinguish the target) without using coordinates
- Specify element names when possible (use original language if non-English)
- Describe features (shape, color, position) if name unavailable
- If the action involves keyboard actions like 'press', 'write', 'hotkey':
- Consolidate repetitive keypresses with count
- Specify expected text outcome for typing actions
""".strip(),
"""
Write the **Action** in one short, direct sentence.
When clicking or otherwise interacting with a UI element:
- Name the element explicitly and, if multiple elements share that name, add a distinguishing detail.
- Do **not** give coordinates.
- Use the element's label (keep original language when it isn't English).
- If unnamed, describe recognisable traits (shape, colour, on-screen position).
When using the keyboard (press, type, hotkey):
- Collapse repeated key presses into counts.
- For typing, specify the text that should appear.
""".strip(),
"""
Provide the **Action** as a single, crisp imperative sentence.
- Mouse/GUI interactions:
* Identify the target by name, and if duplicate names exist, clarify which one you mean.
* Do not supply XY coordinates.
* Preserve non-English labels verbatim.
* If unnamed, describe the element's look or location (colour, shape, relative position).
- Keyboard operations (press, write, hotkey):
* Combine repeated keystrokes with a multiplier.
* State the exact text that will be entered.
""".strip(),
]
code_instrucion = """For the code section, you should output the corresponding code for the action. The code should be either PyAutoGUI code or one of the following functions warped in the code block:
- {"name": "computer.wait", "description": "Make the computer wait for 20 seconds for installation, running code, etc.", "parameters": {"type": "object", "properties": {}, "required": []}}
- {"name": "computer.terminate", "description": "Terminate the current task and report its completion status", "parameters": {"type": "object", "properties": {"status": {"type": "string", "enum": ["success", "failure"], "description": "The status of the task"}, {"answer": {"type": "string", "description": "The answer of the task"}}, "required": ["status"]}}
Examples for the code section:
```python
pyautogui.click(x=123, y=456)
```
```code
computer.terminate(status="success")
```
```code
computer.terminate(status="success", answer='''text''')
```"""
SYSTEM_PROMPT_V2_L1 = """
{general_computer_instruction}
{format_instruction}
{action_instruction}
{code_instruction}
""".strip()
SYSTEM_PROMPT_V2_L2 = """
{general_computer_instruction}
{format_instruction}
{thought_instruction}
{action_instruction}
{code_instruction}
""".strip()
SYSTEM_PROMPT_V2_L3 = """
{general_computer_instruction}
{format_instruction}
{observation_instruction}
{thought_instruction}
{action_instruction}
{code_instruction}
""".strip()
def build_sys_prompt(level, password="password", use_random=False):
if not use_random:
if level == "l1":
return SYSTEM_PROMPT_V2_L1.format(
general_computer_instruction=general_computer_instructions[0].format(
password=password
),
format_instruction=l1_format_instruction,
action_instruction=action_instructions[0],
code_instruction=code_instrucion,
)
elif level == "l2":
return SYSTEM_PROMPT_V2_L2.format(
general_computer_instruction=general_computer_instructions[0].format(
password=password
),
format_instruction=l2_format_instruction,
thought_instruction=thought_instructions[0],
action_instruction=action_instructions[0],
code_instruction=code_instrucion,
)
elif level == "l3":
return SYSTEM_PROMPT_V2_L3.format(
general_computer_instruction=general_computer_instructions[0].format(
password=password
),
format_instruction=l3_format_instruction,
observation_instruction=observation_instructions[0],
thought_instruction=thought_instructions[0],
action_instruction=action_instructions[0],
code_instruction=code_instrucion,
)
else:
raise ValueError("Invalid level. Choose from 'l1', 'l2', or 'l3'.")
else:
if level == "l1":
return SYSTEM_PROMPT_V2_L1.format(
general_computer_instruction=random.choice(
general_computer_instructions
),
format_instruction=l1_format_instruction,
action_instruction=random.choice(action_instructions),
code_instruction=code_instrucion,
)
elif level == "l2":
return SYSTEM_PROMPT_V2_L2.format(
general_computer_instruction=random.choice(
general_computer_instructions
),
format_instruction=l2_format_instruction,
thought_instruction=random.choice(thought_instructions),
action_instruction=random.choice(action_instructions),
code_instruction=code_instrucion,
)
elif level == "l3":
return SYSTEM_PROMPT_V2_L3.format(
general_computer_instruction=random.choice(
general_computer_instructions
),
format_instruction=l3_format_instruction,
observation_instruction=random.choice(observation_instructions),
thought_instruction=random.choice(thought_instructions),
action_instruction=random.choice(action_instructions),
code_instruction=code_instrucion,
)
else:
raise ValueError("Invalid level. Choose from 'l1', 'l2', or 'l3'.")
# Modeling prompt templates for generating trajectories
STEP_TEMPLATE = "# Step {step_num}:\n"
INSTRUTION_TEMPLATE = "# Task Instruction:\n{instruction}\n\nPlease generate the next move according to the screenshot, task instruction and previous steps (if provided).\n"
ACTION_HISTORY_TEMPLATE = "## Action:\n{action}\n"
THOUGHT_HISTORY_TEMPLATE = "## Thought:\n{thought}\n\n## Action:\n{action}\n"
OBSERVATION_HISTORY_TEMPLATE = "## Observation:\n{observation}\n\n## Thought:\n{thought}\n\n## Action:\n{action}\n"
ACTION_HISTORY_TEMPLATE_WITH_CODE = "## Action:\n{action}\n\n## Code:\n{code}\n"
THOUGHT_HISTORY_TEMPLATE_WITH_CODE = "## Thought:\n{thought}\n\n## Action:\n{action}\n\n## Code:\n{code}\n"
OBSERVATION_HISTORY_TEMPLATE_WITH_CODE = "## Observation:\n{observation}\n\n## Thought:\n{thought}\n\n## Action:\n{action}\n\n## Code:\n{code}\n"

483
mm_agents/opencua/utils.py Normal file
View File

@ -0,0 +1,483 @@
import re
import base64
from loguru import logger
from typing import List, Optional
from PIL import Image
from io import BytesIO
import tempfile
import os
import math
def encode_image(image_content):
return base64.b64encode(image_content).decode("utf-8")
def smart_resize(
height: int,
width: int,
factor: int = 28,
min_pixels: int = 56 * 56,
max_pixels: int = 14 * 14 * 4 * 1280,
max_aspect_ratio_allowed: Optional[float] = None,
size_can_be_smaller_than_factor: bool = False,
):
"""Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
if not size_can_be_smaller_than_factor and (height < factor or width < factor):
raise ValueError(
f"height:{height} or width:{width} must be larger than factor:{factor} "
f"(when size_can_be_smaller_than_factor is False)"
)
elif (
max_aspect_ratio_allowed is not None
and max(height, width) / min(height, width) > max_aspect_ratio_allowed
):
raise ValueError(
f"absolute aspect ratio must be smaller than {max_aspect_ratio_allowed}, "
f"got {max(height, width) / min(height, width)}"
f"(when max_aspect_ratio_allowed is not None)"
)
h_bar = max(1, round(height / factor)) * factor
w_bar = max(1, round(width / factor)) * factor
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = max(1, math.floor(height / beta / factor)) * factor
w_bar = max(1, math.floor(width / beta / factor)) * factor
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = math.ceil(height * beta / factor) * factor
w_bar = math.ceil(width * beta / factor) * factor
return h_bar, w_bar
def call_openai_naive(model, payload, address_hint=None):
"""
Naive OpenAI API call using requests.
"""
# Extract fields from payload
model = payload.get("model")
payload["model"] = model.model_id if hasattr(model, "model_id") else "None"
# address_hint not used here
base_url = model.base_url
# logger.warning(f"Base URL: {base_url}, Payload model: {payload['model']}")
url = f"{base_url}/chat/completions"
headers = {
"Content-Type": "application/json",
}
data = {
**payload,
"n": 1,
}
max_retry = 5
chat_completions = None
success = False
while success is False and max_retry > 0:
try:
json_data = json.dumps(data)
response = requests.post(
url, headers=headers, data=json_data, timeout=120, verify=False
)
if response.status_code == 200:
chat_completions = response.json()
try:
finish_reason = chat_completions["choices"][0].get("finish_reason")
if (
finish_reason is not None and finish_reason == "stop"
): # for most of the time, length will not exceed max_tokens
success = True
else:
time.sleep(5)
max_retry -= 1
except Exception as e:
logger.error(f"Error in processing chat completion: {e}")
time.sleep(5)
max_retry -= 1
else:
logger.error(f"Failed to call OpenAI API: {response.text}")
time.sleep(5)
max_retry -= 1
except requests.exceptions.ReadTimeout:
# timeout is normal, don't print trace
max_retry -= 1
logger.warning(f"Timeout in OpenAI API call, left retries: {max_retry}")
time.sleep(5)
except Exception as e:
max_retry -= 1
logger.exception(f"Failed to call OpenAI API: {e}")
time.sleep(5)
if chat_completions is None:
raise RuntimeError("Failed to call OpenAI API, max_retry used up")
try:
infos = {}
if "choices" in chat_completions:
infos["finish_reason"] = chat_completions["choices"][0].get("finish_reason")
infos["n"] = len(chat_completions["choices"])
if "tool_calls" in chat_completions["choices"][0]["message"]:
infos["tool_calls"] = chat_completions["choices"][0]["message"][
"tool_calls"
]
infos["choices"] = chat_completions["choices"] # for the case of n > 1
if "usage" in chat_completions:
infos["usage"] = chat_completions["usage"]
return chat_completions["choices"][0]["message"]["content"], infos
except Exception as e:
logger.error(f"Error in processing chat completion {e}")
return "", {"n": 1, "usage": 0, "finish_reason": f"error {e}"}
def preprocess_for_naive_openai(self, payload):
if isinstance(payload["model"], str):
payload["model"] = getattr(self, "openai_client", None)
return payload
def encoded_img_to_pil_img(data_str):
base64_str = data_str.replace("data:image/png;base64,", "")
image_data = base64.b64decode(base64_str)
return Image.open(BytesIO(image_data))
def save_to_tmp_img_file(data_str):
base64_str = data_str.replace("data:image/png;base64,", "")
image_data = base64.b64decode(base64_str)
image = Image.open(BytesIO(image_data))
tmp_img_path = os.path.join(tempfile.mkdtemp(), "tmp_img.png")
image.save(tmp_img_path)
return tmp_img_path
def bbox_to_center_1000(bbox: str) -> tuple[int, int]:
regex_list = [
r"<\|box_start\|>\((\d+),(\d+)\),\((\d+),(\d+)\)<\|box_end\|>", # '<|box_start|>(576,12),(592,42)<|box_end|>'
r"<\|box_start\|>\[\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)\]\]<\|box_end\|>", # '<|box_start|>[[576, 12, 592, 42]]<|box_end|>'
r"<\|box_start\|>\[\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)\]<\|box_end\|>", # '<|box_start|>[[576, 12, 592, 42]<|box_end|>', this is actually wrong format, but we parse it anyway
r"<\|box_start\|>\((\d+),\s*(\d+),\s*(\d+),\s*(\d+)\)<\|box_end\|>", # '<|box_start|>(576, 12, 592, 42)<|box_end|>', this is actually wrong format, but we parse it anyway
r"\((\d+),(\d+)\),\((\d+),(\d+)\)", # Versions without the 'bbox' special tokens
r"\[\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)\]\]",
r"\[\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)\]",
r"\((\d+),\s*(\d+),\s*(\d+),\s*(\d+)\)",
]
for regex in regex_list:
match = re.search(regex, bbox)
if match:
break
if not match:
raise ValueError(
f"Bounding box coordinates not found in the input string: {bbox}"
)
x_top_left, y_top_left, x_bottom_right, y_bottom_right = map(int, match.groups())
x_center = (x_top_left + x_bottom_right) // 2
y_center = (y_top_left + y_bottom_right) // 2
return x_center, y_center
def bbox_to_center_1(bbox: str) -> tuple[int, int]:
regex_list = [
r"\[\s*(-?\d+\.\d+)\s*,\s*(-?\d+\.\d+)\s*,\s*(-?\d+\.\d+)\s*,\s*(-?\d+\.\d+)\s*\]",
]
for regex in regex_list:
match = re.search(regex, bbox)
if match:
break
if not match:
raise ValueError(
f"Bounding box coordinates not found in the input string: {bbox}"
)
coordinates = tuple(map(float, match.groups()))
coordinates = [int(coord * 1000) for coord in coordinates]
x_center = (coordinates[0] + coordinates[2]) // 2
y_center = (coordinates[1] + coordinates[3]) // 2
return x_center, y_center
def _coordinate_projection(x, y, screen_width, screen_height, coordinate_type):
if coordinate_type == "relative":
return int(round(x * screen_width)), int(round(y * screen_height))
elif coordinate_type == "absolute":
return x, y
elif coordinate_type == "qwen25":
height, width = smart_resize(
height=screen_height,
width=screen_width,
factor=28,
min_pixels=3136,
max_pixels=12845056,
)
return int(x / width * screen_width), int(y / height * screen_height)
elif coordinate_type == "relative1000":
if screen_width == 0 or screen_height == 0:
raise ValueError(
"Screen width and height must be greater than zero for relative1000 coordinates."
)
x_abs = int(round(x * screen_width / 1000))
y_abs = int(round(y * screen_height / 1000))
return x_abs, y_abs
else:
raise ValueError(f"Unsupported coordinate type: {coordinate_type}")
def rescale_coord(
coord: tuple[int, int],
original_width: int,
original_height: int,
scaled_width=1000,
scaled_height=1000,
) -> tuple[int, int]:
# According to https://huggingface.co/spaces/maxiw/OS-ATLAS/blob/398c3256a4fec409a074e0e4b5ac1d1d5bf7c240/app.py#L36
# It seems that OS-ATLAS model are rescaled to output 1000x1000 images
# So we need to rescale the coordinates back to the original image size
x_scale = original_width / scaled_width
y_scale = original_height / scaled_height
return int(coord[0] * x_scale), int(coord[1] * y_scale)
def _pyautogui_code_to_absolute_coordinates(
pyautogui_code_relative_coordinates,
logical_screen_size,
coordinate_type="relative",
model_input_size=None,
):
"""
Convert the relative coordinates in the pyautogui code to absolute coordinates based on the logical screen size.
"""
import re
import ast
if coordinate_type not in ["relative", "relative1000", "absolute", "qwen25"]:
raise ValueError(
f"Invalid coordinate type: {coordinate_type}. Expected one of ['relative', 'relative1000', 'absolute', 'qwen25']."
)
screen_width, screen_height = logical_screen_size
if model_input_size is not None:
model_width, model_height = model_input_size
width_scale, height_scale = (
screen_width / model_width,
screen_height / model_height,
)
else:
width_scale, height_scale = 1, 1
pattern = r"(pyautogui\.\w+\([^\)]*\))"
matches = re.findall(pattern, pyautogui_code_relative_coordinates)
new_code = pyautogui_code_relative_coordinates
for full_call in matches:
func_name_pattern = r"(pyautogui\.\w+)\((.*)\)"
func_match = re.match(func_name_pattern, full_call, re.DOTALL)
if not func_match:
continue
func_name = func_match.group(1)
args_str = func_match.group(2)
try:
parsed = ast.parse(f"func({args_str})").body[0].value
parsed_args = parsed.args
parsed_keywords = parsed.keywords
except SyntaxError:
return pyautogui_code_relative_coordinates
function_parameters = {
"click": ["x", "y", "clicks", "interval", "button", "duration", "pause"],
"moveTo": ["x", "y", "duration", "tween", "pause"],
"moveRel": ["xOffset", "yOffset", "duration", "tween", "pause"],
"dragTo": ["x", "y", "duration", "button", "mouseDownUp", "pause"],
"dragRel": [
"xOffset",
"yOffset",
"duration",
"button",
"mouseDownUp",
"pause",
],
"doubleClick": ["x", "y", "interval", "button", "duration", "pause"],
}
func_base_name = func_name.split(".")[-1]
param_names = function_parameters.get(func_base_name, [])
args = {}
for idx, arg in enumerate(parsed_args):
if idx < len(param_names):
param_name = param_names[idx]
arg_value = ast.literal_eval(arg)
args[param_name] = arg_value
try:
for kw in parsed_keywords:
param_name = kw.arg
arg_value = ast.literal_eval(kw.value)
args[param_name] = arg_value
except Exception as e:
logger.error(f"Error parsing keyword arguments: {e}")
return pyautogui_code_relative_coordinates
updated = False
if "x" in args and "y" in args:
try:
x_rel = float(args["x"])
y_rel = float(args["y"])
x_abs, y_abs = _coordinate_projection(
x_rel, y_rel, screen_width, screen_height, coordinate_type
)
# logger.warning(f"Projecting coordinates: ({x_rel}, {y_rel}) to ({x_abs}, {y_abs}) using {coordinate_type} projection.")
args["x"] = x_abs * width_scale
args["y"] = y_abs * height_scale
updated = True
except ValueError:
pass
if "xOffset" in args and "yOffset" in args:
try:
x_rel = float(args["xOffset"])
y_rel = float(args["yOffset"])
x_abs, y_abs = _coordinate_projection(
x_rel, y_rel, screen_width, screen_height, coordinate_type
)
args["xOffset"] = x_abs * width_scale
args["yOffset"] = y_abs * height_scale
updated = True
except ValueError:
pass
if updated:
reconstructed_args = []
for idx, param_name in enumerate(param_names):
if param_name in args:
arg_value = args[param_name]
if isinstance(arg_value, str):
arg_repr = f"'{arg_value}'"
else:
arg_repr = str(arg_value)
reconstructed_args.append(arg_repr)
else:
break
used_params = set(param_names[: len(reconstructed_args)])
for kw in parsed_keywords:
if kw.arg not in used_params:
arg_value = args[kw.arg]
if isinstance(arg_value, str):
arg_repr = f"{kw.arg}='{arg_value}'"
else:
arg_repr = f"{kw.arg}={arg_value}"
reconstructed_args.append(arg_repr)
new_args_str = ", ".join(reconstructed_args)
new_full_call = f"{func_name}({new_args_str})"
new_code = new_code.replace(full_call, new_full_call)
return new_code
def split_args(args_str: str) -> List[str]:
args = []
current_arg = ""
within_string = False
string_char = ""
prev_char = ""
for char in args_str:
if char in ['"', "'"]:
if not within_string:
within_string = True
string_char = char
elif within_string and prev_char != "\\" and char == string_char:
within_string = False
if char == "," and not within_string:
args.append(current_arg)
current_arg = ""
else:
current_arg += char
prev_char = char
if current_arg:
args.append(current_arg)
return args
def correct_pyautogui_arguments(code: str) -> str:
function_corrections = {
"write": {
"incorrect_args": ["text", "content"],
"correct_args": [],
"keyword_arg": "message",
},
"press": {
"incorrect_args": ["key", "button"],
"correct_args": [],
"keyword_arg": None,
},
"hotkey": {
"incorrect_args": ["key1", "key2", "keys"],
"correct_args": [],
"keyword_arg": None,
},
}
lines = code.strip().split("\n")
corrected_lines = []
for line in lines:
line = line.strip()
match = re.match(r"(pyautogui\.(\w+))\((.*)\)", line)
if match:
full_func_call = match.group(1)
func_name = match.group(2)
args_str = match.group(3)
if func_name in function_corrections:
func_info = function_corrections[func_name]
args = split_args(args_str)
corrected_args = []
for arg in args:
arg = arg.strip()
kwarg_match = re.match(r"(\w+)\s*=\s*(.*)", arg)
if kwarg_match:
arg_name = kwarg_match.group(1)
arg_value = kwarg_match.group(2)
if arg_name in func_info["incorrect_args"]:
if func_info["keyword_arg"]:
corrected_args.append(
f"{func_info['keyword_arg']}={arg_value}"
)
else:
corrected_args.append(arg_value)
else:
corrected_args.append(f"{arg_name}={arg_value}")
else:
corrected_args.append(arg)
corrected_args_str = ", ".join(corrected_args)
corrected_line = f"{full_func_call}({corrected_args_str})"
corrected_lines.append(corrected_line)
else:
corrected_lines.append(line)
else:
corrected_lines.append(line)
corrected_code = "\n".join(corrected_lines)
return corrected_code
def image_message_from_obs(obs, for_training=False):
if not for_training:
return {
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{encode_image(obs['screenshot'])}",
"detail": "high",
},
}
else:
return {"type": "image_url", "image_url": {"url": obs["screenshot_path"]}}

View File

@ -1,736 +0,0 @@
"""
OpenCUA Agent Implementation
This module implements an OpenCUA agent for desktop automation tasks, building upon
existing frameworks and integrating multiple coordinate mapping systems.
Framework and Implementation Sources:
- Main framework structure follows: https://github.com/xlang-ai/OSWorld/blob/main/mm_agents/agent.py
- Agent implementation adapted from: https://github.com/xlang-ai/OSWorld/blob/main/mm_agents/aguvis_agent.py
- Qwen2.5-VL coordinate mapping from: https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py
"""
import re
import os
import ast
import time
import math
import httpx
import base64
import backoff
from loguru import logger
from typing import Dict, List, Tuple, Optional
# System prompts used in the training data
AGNET_SYS_PROMPT_L1 = "You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task.\n\nFor each step, provide your response in this format:\n\nAction:\n Provide clear, concise, and actionable instructions:\n - If the action involves interacting with a specific target:\n - Describe target explicitly without using coordinates\n - Specify element names when possible (use original language if non-English)\n - Describe features (shape, color, position) if name unavailable\n - For window control buttons, identify correctly (minimize \"\", maximize \"\", close \"X\")\n - if the action involves keyboard actions like 'press', 'write', 'hotkey':\n - Consolidate repetitive keypresses with count\n - Specify expected text outcome for typing actions\n\nFinally, output the action as PyAutoGUI code or the following functions:\n- {\"name\": \"computer.triple_click\", \"description\": \"Triple click on the screen\", \"parameters\": {\"type\": \"object\", \"properties\": {\"x\": {\"type\": \"number\", \"description\": \"The x coordinate of the triple click\"}, \"y\": {\"type\": \"number\", \"description\": \"The y coordinate of the triple click\"}}, \"required\": [\"x\", \"y\"]}}\n- {\"name\": \"computer.terminate\", \"description\": \"Terminate the current task and report its completion status\", \"parameters\": {\"type\": \"object\", \"properties\": {\"status\": {\"type\": \"string\", \"enum\": [\"success\", \"fail\"], \"description\": \"The status of the task\"}}, \"required\": [\"status\"]}}".strip()
# AGNET_SYS_PROMPT_L2 = "You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task.\n\nFor each step, provide your response in this format:\n\nThought:\n - Step by Step Progress Assessment:\n - Analyze completed task parts and their contribution to the overall goal\n - Reflect on potential errors, unexpected results, or obstacles\n - If previous action was incorrect, predict a logical recovery step\n - Next Action Analysis:\n - List possible next actions based on current state\n - Evaluate options considering current state and previous actions\n - Propose most logical next action\n - Anticipate consequences of the proposed action\n - For Text Input Actions:\n - Note current cursor position\n - Consolidate repetitive actions (specify count for multiple keypresses)\n - Describe expected final text outcome\n - Use first-person perspective in reasoning\n\nAction:\n Provide clear, concise, and actionable instructions:\n - If the action involves interacting with a specific target:\n - Describe target explicitly without using coordinates\n - Specify element names when possible (use original language if non-English)\n - Describe features (shape, color, position) if name unavailable\n - For window control buttons, identify correctly (minimize \"—\", maximize \"□\", close \"X\")\n - if the action involves keyboard actions like 'press', 'write', 'hotkey':\n - Consolidate repetitive keypresses with count\n - Specify expected text outcome for typing actions\n\nFinally, output the action as PyAutoGUI code or the following functions:\n- {\"name\": \"computer.triple_click\", \"description\": \"Triple click on the screen\", \"parameters\": {\"type\": \"object\", \"properties\": {\"x\": {\"type\": \"number\", \"description\": \"The x coordinate of the triple click\"}, \"y\": {\"type\": \"number\", \"description\": \"The y coordinate of the triple click\"}}, \"required\": [\"x\", \"y\"]}}\n- {\"name\": \"computer.terminate\", \"description\": \"Terminate the current task and report its completion status\", \"parameters\": {\"type\": \"object\", \"properties\": {\"status\": {\"type\": \"string\", \"enum\": [\"success\", \"fail\"], \"description\": \"The status of the task\"}}, \"required\": [\"status\"]}}".strip()
AGNET_SYS_PROMPT_L3 = "You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task.\n\nFor each step, provide your response in this format:\n\nObservation:\n - Describe the current computer state based on the full screenshot in detail. \n - Application Context:\n - The active application\n - The active window or page\n - Overall layout and visible interface\n - Key Elements:\n - Menu items and toolbars \n - Buttons and controls\n - Text fields and content\n - Dialog boxes or popups\n - Error messages or notifications\n - Loading states\n - Other key elements\n - Describe any content, elements, options, information or clues that are possibly relevant to achieving the task goal, including their name, content, or shape (if possible).\n\nThought:\n - Step by Step Progress Assessment:\n - Analyze completed task parts and their contribution to the overall goal\n - Reflect on potential errors, unexpected results, or obstacles\n - If previous action was incorrect, predict a logical recovery step\n - Next Action Analysis:\n - List possible next actions based on current state\n - Evaluate options considering current state and previous actions\n - Propose most logical next action\n - Anticipate consequences of the proposed action\n - For Text Input Actions:\n - Note current cursor position\n - Consolidate repetitive actions (specify count for multiple keypresses)\n - Describe expected final text outcome\n - Use first-person perspective in reasoning\n\nAction:\n Provide clear, concise, and actionable instructions:\n - If the action involves interacting with a specific target:\n - Describe target explicitly without using coordinates\n - Specify element names when possible (use original language if non-English)\n - Describe features (shape, color, position) if name unavailable\n - For window control buttons, identify correctly (minimize \"\", maximize \"\", close \"X\")\n - if the action involves keyboard actions like 'press', 'write', 'hotkey':\n - Consolidate repetitive keypresses with count\n - Specify expected text outcome for typing actions\n\nFinally, output the action as PyAutoGUI code or the following functions:\n- {\"name\": \"computer.triple_click\", \"description\": \"Triple click on the screen\", \"parameters\": {\"type\": \"object\", \"properties\": {\"x\": {\"type\": \"number\", \"description\": \"The x coordinate of the triple click\"}, \"y\": {\"type\": \"number\", \"description\": \"The y coordinate of the triple click\"}}, \"required\": [\"x\", \"y\"]}}\n- {\"name\": \"computer.terminate\", \"description\": \"Terminate the current task and report its completion status\", \"parameters\": {\"type\": \"object\", \"properties\": {\"status\": {\"type\": \"string\", \"enum\": [\"success\", \"fail\"], \"description\": \"The status of the task\"}}, \"required\": [\"status\"]}}\n".strip()
# Testing prompt on OSWorld-Verified
AGNET_SYS_PROMPT_L2 = """You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task. The password of the computer is "osworld-public-evaluation". If the task is not possible to do, output the action computer.terminate(status='failure').
For each step, provide your response in this format:
Thought:\n - Step by Step Progress Assessment:\n - Analyze completed task parts and their contribution to the overall goal\n - Reflect on potential errors, unexpected results, or obstacles\n - If previous action was incorrect, predict a logical recovery step\n - Next Action Analysis:\n - List possible next actions based on current state\n - Evaluate options considering current state and previous actions\n - Propose most logical next action\n - Anticipate consequences of the proposed action\n - For Text Input Actions:\n - Note current cursor position\n - Consolidate repetitive actions (specify count for multiple keypresses)\n - Describe expected final text outcome\n - Use first-person perspective in reasoning
Action:\n Provide clear, concise, and actionable instructions:\n - If the action involves interacting with a specific target:\n - Describe target explicitly without using coordinates\n - Specify element names when possible (use original language if non-English)\n - Describe features (shape, color, position) if name unavailable\n - For window control buttons, identify correctly (minimize "", maximize "", close "X")\n - if the action involves keyboard actions like \'press\', \'write\', \'hotkey\':\n - Consolidate repetitive keypresses with count\n - Specify expected text outcome for typing actions
Finally, output the action as PyAutoGUI code or the following functions:
- {"name": "computer.triple_click", "description": "Triple click on the screen", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The x coordinate of the triple click"}, "y": {"type": "number", "description": "The y coordinate of the triple click"}}, "required": ["x", "y"]}}
- {"name": "computer.terminate", "description": "Terminate the current task and report its completion status", "parameters": {"type": "object", "properties": {"status": {"type": "string", "enum": ["success", "failure"], "description": "The status of the task"}}, "required": ["status"]}}
""".strip()
STEP_TEMPLATE = "# Step {step_num}:\n"
INSTRUTION_TEMPLATE = "# Task Instruction:\n{instruction}\n\nPlease generate the next move according to the screenshot, task instruction and previous steps (if provided).\n"
ACTION_HISTORY_TEMPLATE = "## Action:\n{action}\n"
THOUGHT_HISTORY_TEMPLATE = "## Thought:\n{thought}\n\n## Action:\n{action}\n"
OBSERVATION_HISTORY_TEMPLATE = "## Observation:\n{observation}\n\n## Thought:\n{thought}\n\n## Action:\n{action}\n"
DETAIL_HISTORY_TEMPLATE = "## Thought:\n{thought}\n\n## Action:\n{action}\n\n## Code:\n{code}\n"
def encode_image(image_content):
"""Encode the image to base64"""
return base64.b64encode(image_content).decode('utf-8')
def parse_response_to_cot_and_action(input_string, screen_size, coordinate_type) -> Tuple[str, List[str], dict]:
"""Parse response including Observation, Thought, Action and code block"""
try:
sections = {}
obs_match = re.search(r'^##\s*Observation\s*:?[\n\r]+(.*?)(?=^##\s*Thought:|^##\s*Action:|^##|\Z)', input_string, re.DOTALL | re.MULTILINE)
if obs_match:
sections['observation'] = obs_match.group(1).strip()
thought_match = re.search(r'^##\s*Thought\s*:?[\n\r]+(.*?)(?=^##\s*Action:|^##|\Z)', input_string, re.DOTALL | re.MULTILINE)
if thought_match:
sections['thought'] = thought_match.group(1).strip()
action_match = re.search(r'^##\s*Action\s*:?[\n\r]+(.*?)(?=^##|\Z)', input_string, re.DOTALL | re.MULTILINE)
if action_match:
action = action_match.group(1).strip()
sections['action'] = action.strip()
if "computer.terminate" in input_string.lower():
# Look for code blocks that might contain terminate command
code_blocks = re.findall(r'```(?:code|python)?\s*(.*?)\s*```', input_string, re.DOTALL | re.IGNORECASE)
if code_blocks:
last_code = code_blocks[-1].strip().lower()
if "fail" in last_code:
sections['code'] = "FAIL"
return "FAIL", ["FAIL"], sections
elif "success" in last_code:
sections['code'] = "DONE"
return "DONE", ["DONE"], sections
# Default to DONE if terminate is mentioned but no specific status
sections['code'] = "DONE"
return "DONE", ["DONE"], sections
code_blocks = re.findall(r'```(?:python)\s*(.*?)\s*```', input_string, re.DOTALL)
if code_blocks:
code = code_blocks[-1].strip()
sections['original_code'] = transform_agnet_action_to_code_block(code)
corrected_code = correct_pyautogui_arguments(code)
sections['code'] = corrected_code
sections['code'] = project_coordinate_to_absolute_scale(corrected_code, screen_width=screen_size[0], screen_height=screen_size[1], coordinate_type=coordinate_type)
else:
# No code blocks found
sections['code'] = "WAIT"
return "WAIT", ["WAIT"], sections
if 'code' not in sections:
logger.error("Missing required action or code section")
return None, None, {}
if 'action' not in sections:
sections['action'] = ""
return sections['action'], [sections['code']], sections
except Exception as e:
logger.exception(f"Error parsing response: {str(e)}\nInput string: {input_string}")
return None, None, {}
def correct_pyautogui_arguments(code: str) -> str:
"""Correct the pyautogui arguments"""
function_corrections = {
'write': {
'incorrect_args': ['text', 'content'],
'correct_args': [],
'keyword_arg': 'message'
},
'press': {
'incorrect_args': ['key', 'button'],
'correct_args': [],
'keyword_arg': None
},
'hotkey': {
'incorrect_args': ['key1', 'key2', 'keys'],
'correct_args': [],
'keyword_arg': None
},
}
lines = code.strip().split('\n')
corrected_lines = []
for line in lines:
line = line.strip()
match = re.match(r'(pyautogui\.(\w+))\((.*)\)', line)
if match:
full_func_call = match.group(1)
func_name = match.group(2)
args_str = match.group(3)
if func_name in function_corrections:
func_info = function_corrections[func_name]
args = split_args(args_str)
corrected_args = []
for arg in args:
arg = arg.strip()
kwarg_match = re.match(r'(\w+)\s*=\s*(.*)', arg)
if kwarg_match:
arg_name = kwarg_match.group(1)
arg_value = kwarg_match.group(2)
if arg_name in func_info['incorrect_args']:
if func_info['keyword_arg']:
corrected_args.append(f"{func_info['keyword_arg']}={arg_value}")
else:
corrected_args.append(arg_value)
else:
corrected_args.append(f'{arg_name}={arg_value}')
else:
corrected_args.append(arg)
corrected_args_str = ', '.join(corrected_args)
corrected_line = f'{full_func_call}({corrected_args_str})'
corrected_lines.append(corrected_line)
else:
corrected_lines.append(line)
else:
corrected_lines.append(line)
corrected_code = '\n'.join(corrected_lines)
return corrected_code
def split_args(args_str: str) -> List[str]:
"""Split the arguments string into a list of arguments"""
args = []
current_arg = ''
within_string = False
string_char = ''
prev_char = ''
for char in args_str:
if char in ['"', "'"]:
if not within_string:
within_string = True
string_char = char
elif within_string and prev_char != '\\' and char == string_char:
within_string = False
if char == ',' and not within_string:
args.append(current_arg)
current_arg = ''
else:
current_arg += char
prev_char = char
if current_arg:
args.append(current_arg)
return args
def smart_resize(
height: int,
width: int,
factor: int,
min_pixels: int,
max_pixels: int,
max_aspect_ratio_allowed: Optional[float] = None,
size_can_be_smaller_than_factor: bool = False,
):
"""
The function is modified from https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py
Qwen2.5-VL based model need this function to resize screenshots.
Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
if not size_can_be_smaller_than_factor and (height < factor or width < factor):
raise ValueError(
f"height:{height} or width:{width} must be larger than factor:{factor} "
f"(when size_can_be_smaller_than_factor is False)"
)
elif max_aspect_ratio_allowed is not None and max(height, width) / min(height, width) > max_aspect_ratio_allowed:
raise ValueError(
f"absolute aspect ratio must be smaller than {max_aspect_ratio_allowed}, "
f"got {max(height, width) / min(height, width)}"
f"(when max_aspect_ratio_allowed is not None)"
)
h_bar = max(1, round(height / factor)) * factor
w_bar = max(1, round(width / factor)) * factor
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = max(1, math.floor(height / beta / factor)) * factor
w_bar = max(1, math.floor(width / beta / factor)) * factor
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = math.ceil(height * beta / factor) * factor
w_bar = math.ceil(width * beta / factor) * factor
return h_bar, w_bar
def _coordinate_projection(x, y, screen_width, screen_height, coordinate_type):
"""Project the coordinates to the absolute scale"""
if coordinate_type == "relative":
return int(round(x * screen_width)), int(round(y * screen_height))
elif coordinate_type == "absolute":
return x, y
elif coordinate_type == "qwen25":
if 0 <= x <= 1 and 0 <= y <= 1:
# If already normalized, treat like "relative"
return int(round(x * screen_width)), int(round(y * screen_height))
height, width = smart_resize(
height=screen_height,
width=screen_width,
factor=28,
min_pixels=3136,
max_pixels=12845056 # We use this max_pixels setting in our training data
)
return int(x / width * screen_width), int(y / height * screen_height)
else:
raise ValueError(f"Unsupported coordinate type: {coordinate_type}")
def project_coordinate_to_absolute_scale(pyautogui_code_relative_coordinates, screen_width, screen_height, coordinate_type="relative"):
"""Convert the relative coordinates in the pyautogui code to absolute coordinates based on the logical screen size."""
if coordinate_type not in ["relative", "relative1000", "absolute", "qwen25"]:
raise ValueError(f"Invalid coordinate type: {coordinate_type}. Expected one of ['relative', 'relative1000', 'absolute', 'qwen25'].")
pattern = r'(pyautogui\.\w+\([^\)]*\))'
matches = re.findall(pattern, pyautogui_code_relative_coordinates)
new_code = pyautogui_code_relative_coordinates
for full_call in matches:
func_name_pattern = r'(pyautogui\.\w+)\((.*)\)'
func_match = re.match(func_name_pattern, full_call, re.DOTALL)
if not func_match:
continue
func_name = func_match.group(1)
args_str = func_match.group(2)
try:
parsed = ast.parse(f"func({args_str})").body[0].value
parsed_args = parsed.args
parsed_keywords = parsed.keywords
except SyntaxError:
return pyautogui_code_relative_coordinates
function_parameters = {
'click': ['x', 'y', 'clicks', 'interval', 'button', 'duration', 'pause'],
'moveTo': ['x', 'y', 'duration', 'tween', 'pause'],
'moveRel': ['xOffset', 'yOffset', 'duration', 'tween', 'pause'],
'dragTo': ['x', 'y', 'duration', 'button', 'mouseDownUp', 'pause'],
'dragRel': ['xOffset', 'yOffset', 'duration', 'button', 'mouseDownUp', 'pause'],
'doubleClick': ['x', 'y', 'interval', 'button', 'duration', 'pause'],
}
func_base_name = func_name.split('.')[-1]
param_names = function_parameters.get(func_base_name, [])
args = {}
for idx, arg in enumerate(parsed_args):
if idx < len(param_names):
param_name = param_names[idx]
arg_value = ast.literal_eval(arg)
args[param_name] = arg_value
try:
for kw in parsed_keywords:
param_name = kw.arg
arg_value = ast.literal_eval(kw.value)
args[param_name] = arg_value
except Exception as e:
logger.error(f"Error parsing keyword arguments: {e}")
return pyautogui_code_relative_coordinates
updated = False
if 'x' in args and 'y' in args:
try:
x_rel = float(args['x'])
y_rel = float(args['y'])
x_abs, y_abs = _coordinate_projection(x_rel, y_rel, screen_width, screen_height, coordinate_type)
logger.warning(f"Projecting coordinates: ({x_rel}, {y_rel}) to ({x_abs}, {y_abs}) using {coordinate_type} projection.")
args['x'] = x_abs
args['y'] = y_abs
updated = True
except ValueError:
pass
if 'xOffset' in args and 'yOffset' in args:
try:
x_rel = float(args['xOffset'])
y_rel = float(args['yOffset'])
x_abs, y_abs = _coordinate_projection(x_rel, y_rel, screen_width, screen_height, coordinate_type)
args['xOffset'] = x_abs
args['yOffset'] = y_abs
updated = True
except ValueError:
pass
if updated:
reconstructed_args = []
for idx, param_name in enumerate(param_names):
if param_name in args:
arg_value = args[param_name]
if isinstance(arg_value, str):
arg_repr = f"'{arg_value}'"
else:
arg_repr = str(arg_value)
reconstructed_args.append(arg_repr)
else:
break
used_params = set(param_names[:len(reconstructed_args)])
for kw in parsed_keywords:
if kw.arg not in used_params:
arg_value = args[kw.arg]
if isinstance(arg_value, str):
arg_repr = f"{kw.arg}='{arg_value}'"
else:
arg_repr = f"{kw.arg}={arg_value}"
reconstructed_args.append(arg_repr)
new_args_str = ', '.join(reconstructed_args)
new_full_call = f"{func_name}({new_args_str})"
new_code = new_code.replace(full_call, new_full_call)
return new_code
def extract_positions_and_instructions(code, action) -> list[dict]:
"""
Extracts all `(x, y)` coordinates (both positional and keyword arguments)
and their associated preceding comments as instructions from Python code.
If there are no comments, use the corresponding action instead.
Args:
code (str): The Python code as a string.
action (str): The low-level action as a string.
Returns:
list[dict]: A list of dictionaries with extracted positions and instructions.
- function (str): The pyautogui function name.
- x (int or float): The x-coordinate.
- y (int or float): The y-coordinate.
- instruction (str): The preceding comment as an instruction.
"""
lines = code.splitlines()
extracted = []
preceding_comment = action # To store the preceding comment
for line in lines:
preceding_comment = action
# Check if the line is a comment and store it
if line.strip().startswith("#"):
preceding_comment = line.strip().lstrip("#").strip() # Clean the comment
# Match pyautogui functions with positional arguments
match_positional = re.match(r"(pyautogui\.\w+)\((\d+(\.\d+)?),\s*(\d+(\.\d+)?).*?\)", line)
if match_positional:
extracted.append({
"function": match_positional.group(1), # pyautogui function name
"x": float(match_positional.group(2)) if '.' in match_positional.group(2)\
else int(match_positional.group(2)), # x-coordinate
"y": float(match_positional.group(4)) if '.' in match_positional.group(4)\
else int(match_positional.group(3)), # y-coordinate
"instruction": preceding_comment, # Use the preceding comment
})
preceding_comment = None # Reset after associating it with a line
continue
# Match pyautogui functions with keyword arguments
match_keyword = re.match(r"(pyautogui\.\w+)\(.*?x=(\d+(\.\d+)?),\s*y=(\d+(\.\d+)?).*?\)", line)
if match_keyword:
extracted.append({
"function": match_keyword.group(1), # pyautogui function name
"x": float(match_keyword.group(2)) if '.' in match_keyword.group(2)\
else int(match_keyword.group(2)), # x-coordinate
"y": float(match_keyword.group(4)) if '.' in match_keyword.group(4)\
else int(match_keyword.group(3)), # y-coordinate
"instruction": preceding_comment, # Use the preceding comment
})
preceding_comment = None # Reset after associating it with a line
logger.info(f"Grounding extracted:\n{extracted}")
return extracted
def update_code_with_new_coordinates(code, updated_positions):
"""
Replaces old `(x, y)` coordinates (both positional and keyword arguments)
with updated ones in the code, handling multiple occurrences correctly.
Args:
code (str): The original Python code as a string.
updated_positions (list): A list of dictionaries with updated positions.
Returns:
str: The updated Python code.
"""
lines = code.splitlines()
updated_code_lines = []
position_index = 0 # Tracks which position update to use
for line in lines:
if position_index < len(updated_positions):
# Get the next update position
update = updated_positions[position_index]
function_pattern_positional = rf"{update['function']}\(\d+(\.\d+)?, \d+(\.\d+)?"
function_pattern_keyword = rf"{update['function']}\(.*?x=\d+(\.\d+)?, y=\d+(\.\d+)?"
if re.search(function_pattern_positional, line):
# Replace positional arguments
line = re.sub(
function_pattern_positional,
f"{update['function']}({update['x']}, {update['y']}",
line,
count=1
)
position_index += 1 # Move to the next update
elif re.search(function_pattern_keyword, line):
# Replace keyword arguments
line = re.sub(
function_pattern_keyword,
f"{update['function']}(x={update['x']}, y={update['y']}",
line,
count=1
)
position_index += 1 # Move to the next update
updated_code_lines.append(line)
return "\n".join(updated_code_lines)
def transform_agnet_action_to_code_block(action):
"""Transform the agent action to a code block: not used in agent, for logging only"""
if "computer.terminate" in action or "browser.select_option" in action or "browser.clear" in action:
return f"```code\n{action}\n```"
else:
return f"```python\n{action}\n```"
class OpenCUAAgent:
"""
OpenCUA Agent for desktop automation tasks.
This class implements a OpenCUA Model based agent that can observe
desktop environments through screenshots and execute mouse/keyboard actions
via PyAutoGUI to complete automation tasks.
Attributes:
model (str): Name of the language model being used
history_type (str): Type of history recording mechanism
actions (list): History of executed actions
observations (list): History of environment observations
cots (list): Chain of thought reasoning records
"""
def __init__(
self,
model: str, # OpenCUA model name
history_type: str, # History step type: action_history, thought_history, observation_history
max_image_history_length: int = 3, # The max number of images in the history
platform: str = "ubuntu", # The platform of the computer
max_tokens: int = 1500, # The max number of tokens in the response
top_p: float = 0.9, # The top p value in the response
temperature: float = 0, # The temperature value in the response
action_space: str = "pyautogui", # The action space: pyautogui
observation_type: str = "screenshot", # The observation type: screenshot
cot_level: str = "l2", # The CoT level: l1, l2, l3
screen_size: Tuple[int, int] = (1920, 1080), # The screen size
coordinate_type: str = "relative", # The coordinate type: relative, absolute, qwen25
**kwargs
):
assert coordinate_type in ["relative", "absolute", "qwen25"]
assert action_space in ["pyautogui"], "Invalid action space"
assert observation_type in ["screenshot"], "Invalid observation type"
assert history_type in ["action_history", "thought_history", "observation_history"]
assert model is not None, "Model cannot be None"
self.model = model
self.platform = platform
self.max_tokens = max_tokens
self.top_p = top_p
self.temperature = temperature
self.action_space = action_space
self.observation_type = observation_type
self.history_type = history_type
self.coordinate_type = coordinate_type
self.cot_level = cot_level
self.screen_size = screen_size
self.max_image_history_length = max_image_history_length
if history_type == "action_history":
self.HISTORY_TEMPLATE = ACTION_HISTORY_TEMPLATE
elif history_type == "thought_history":
self.HISTORY_TEMPLATE = THOUGHT_HISTORY_TEMPLATE
elif history_type == "observation_history":
self.HISTORY_TEMPLATE = OBSERVATION_HISTORY_TEMPLATE
else:
raise ValueError(f"Invalid history type: {history_type}")
if cot_level == "l3":
self.SYSTEM_PROMPT = AGNET_SYS_PROMPT_L3
elif cot_level == "l2":
self.SYSTEM_PROMPT = AGNET_SYS_PROMPT_L2
elif cot_level == "l1":
self.SYSTEM_PROMPT = AGNET_SYS_PROMPT_L1
else:
raise ValueError(f"Invalid COT level: {cot_level}")
self.actions = []
self.observations = []
self.cots = []
def reset(self, _logger=None):
global logger
logger = _logger if _logger is not None else logging.getLogger("desktopenv.agent")
self.observations = []
self.cots = []
self.actions = []
def _scale_scroll_for_windows(self, code: str, factor: int = 50) -> str:
""" pyautogui.scroll has a different scale on Ubuntu and Windows, multiple 'factor' when scrolling on Windows system"""
if self.platform.lower() != "windows":
return code
pattern_pos = re.compile(r'(pyautogui\.scroll\()\s*([-+]?\d+)\s*\)')
code = pattern_pos.sub(lambda m: f"{m.group(1)}{int(m.group(2))*factor})", code)
return code
def predict(self, instruction: str, obs: Dict, **kwargs) -> Tuple[str, List[str], Dict]:
"""
Predict the next action(s) based on the current observation.
"""
if "step_idx" in kwargs:
logger.info(f"========= {self.model} Step {kwargs['step_idx']} =======")
else:
logger.info(f"========================== {self.model} ===================================")
logger.info(f"Instruction: \n{instruction}")
messages = []
messages.append({
"role": "system",
"content": self.SYSTEM_PROMPT
})
history_step_texts = []
for i in range(len(self.actions)):
if i > len(self.actions) - self.max_image_history_length:
messages.append({
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{encode_image(self.observations[i]['screenshot'])}"}
}
]
})
history_content = STEP_TEMPLATE.format(step_num=i+1) + self.HISTORY_TEMPLATE.format(
observation=self.cots[i].get('observation'),
thought=self.cots[i].get('thought'),
action=self.cots[i].get('action')
)
messages.append({
"role": "assistant",
"content": history_content
})
else:
history_content = STEP_TEMPLATE.format(step_num=i+1) + self.HISTORY_TEMPLATE.format(
observation=self.cots[i].get('observation'),
thought=self.cots[i].get('thought'),
action=self.cots[i].get('action')
)
history_step_texts.append(history_content)
if i == len(self.actions) - self.max_image_history_length:
messages.append({
"role":"assistant",
"content": "\n".join(history_step_texts)
})
messages.append({
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{encode_image(obs['screenshot'])}"}
},
{
"type": "text",
"text": INSTRUTION_TEMPLATE.format(instruction=instruction)
}
]
})
response = self.call_llm({
"model": self.model,
"messages": messages,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"temperature": self.temperature
}, self.model)
logger.info(f"Model Output: \n{response}")
if not response:
logger.error("No response found in the response.")
return "ERROR", ["DONE"], {}
low_level_instruction, pyautogui_actions, other_cot = parse_response_to_cot_and_action(response, self.screen_size, self.coordinate_type)
if not pyautogui_actions or len(pyautogui_actions) == 0:
logger.error("No pyautogui actions found in the response.")
return response, ["FAIL"], {}
pyautogui_actions = [
self._scale_scroll_for_windows(code) for code in pyautogui_actions
]
self.observations.append(obs)
logger.info(f"Parsed Low-level Action: \n{low_level_instruction}")
logger.info(f"Parsed pyautogui Action: \n{pyautogui_actions}")
self.actions.append(low_level_instruction)
if 'action' not in other_cot or not other_cot['action'] or 'thought' not in other_cot or not other_cot['thought']:
logger.error("Error! no action/thought in cot")
logger.error(f"response: {response}")
logger.error(f"cot: {other_cot}")
self.cots.append(other_cot)
# Print message structure if needed
# messages_to_print = []
# current_image = 1
# for msg in messages:
# msg_copy = copy.deepcopy(msg)
# if isinstance(msg_copy['content'], list):
# for content in msg_copy['content']:
# if content['type'] == 'image_url':
# content['image_url']['url'] = f'Image {current_image}'
# current_image += 1
# messages_to_print.append(msg_copy)
# messages_to_print.append({
# "new_step_cot": other_cot,
# "response": response
# })
# logger.info(json.dumps(messages_to_print, indent=2))
logger.info(f"New step cot: {other_cot}")
return response, pyautogui_actions, {}
@backoff.on_exception(
backoff.constant,
# here you should add more model exceptions as you want,
# but you are forbidden to add "Exception", that is, a common type of exception
# because we want to catch this kind of Exception in the outside to ensure
# each example won't exceed the time limit
(
Exception
),
interval=30,
max_tries=10
)
def call_llm(self, payload, model):
"""Call the LLM API"""
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ['OPENCUA_API_KEY']}"
}
for _ in range(30):
response = httpx.post(
os.environ['OPENCUA_URL'],
headers=headers,
json=payload,
timeout=500,
verify=False
)
if response.status_code != 200:
logger.error("Failed to call LLM: " + response.text)
logger.error("Retrying...")
time.sleep(5)
else:
response = response.json()
finish_reason = response["choices"][0].get("finish_reason")
if finish_reason is not None and finish_reason == "stop": # for most of the time, length will not exceed max_tokens
return response['choices'][0]['message']['content']
else:
logger.error("LLM did not finish properly, retrying...")
time.sleep(5)

View File

View File

View File

@ -0,0 +1,350 @@
import logging
from typing import Dict, List, Tuple, Optional
from mm_agents.os_symphony.memory.procedural_memory import PROCEDURAL_MEMORY
from mm_agents.os_symphony.utils.common_utils import call_llm_safe, parse_code_from_string
from mm_agents.os_symphony.core.mllm import LMMAgent
logger = logging.getLogger("desktopenv.coder_agent")
def extract_code_block(action: str) -> Tuple[Optional[str], Optional[str]]:
"""Extract code and determine type from action string."""
if "```python" in action:
code_type = "python"
code = action.split("```python")[1].split("```")[0].strip()
elif "```bash" in action:
code_type = "bash"
code = action.split("```bash")[1].split("```")[0].strip()
elif "```" in action:
code_type = None
code = action.split("```")[1].split("```")[0].strip()
else:
code_type = None
code = None
logger.debug(
f"Extracted code block: type={code_type}, length={len(code) if code else 0}"
)
return code_type, code
def execute_code(code_type: str, code: str, env_controller) -> Dict:
"""Execute code based on its type."""
# Log the full code being executed (untruncated)
logger.info(f"CODING_AGENT_CODE_EXECUTION - Type: {code_type}\nCode:\n{code}")
try:
if code_type == "bash":
result = env_controller.run_bash_script(code, timeout=30)
elif code_type == "python":
result = env_controller.run_python_script(code)
else:
result = {"status": "error", "error": f"Unknown code type: {code_type}"}
return result
except Exception as e:
logger.error(f"Error executing {code_type} code: {e}")
return {"status": "error", "error": str(e)}
def format_result(result: Dict, step_count: int) -> str:
"""Format execution result into context string."""
if not result:
logger.warning(f"Step {step_count + 1}: No result returned from execution")
return f"""
Step {step_count + 1} Error:
Error: No result returned from execution
"""
status = result.get("status", "unknown")
return_code = result.get("returncode", result.get("return_code", -1))
# Handle different response structures for bash vs python
if "returncode" in result:
# Bash script response
output = result.get("output", "") # Contains both stdout and stderr merged
error = result.get("error", "") # Always empty for bash
else:
# Python script response
output = result.get("output", "") # stdout only
error = result.get("error", "") # stderr only
logger.debug(f"Step {step_count + 1}: Status={status}, Return Code={return_code}")
# Format with better structure for multi-line outputs
result_text = f"Step {step_count + 1} Result:\n"
result_text += f"Status: {status}\n"
result_text += f"Return Code: {return_code}\n"
if output:
result_text += f"Output:\n{output}\n"
if error:
result_text += f"Error:\n{error}\n"
return result_text
class CoderAgent:
"""A dedicated agent for executing code with a budget of steps."""
def __init__(self, engine_params: Dict, client_password: str, platform: str = "linux"):
"""Initialize the CodeAgent."""
if not engine_params:
raise ValueError("engine_params cannot be None or empty")
self.engine_params = engine_params
self.budget = engine_params.get("budget", 20)
self.temperature = engine_params.get("temperature", 0.1)
self.agent = None
self.platform = platform
self.client_password = client_password
logger.info(f"CodeAgent initialized with budget={self.budget} and platform={self.platform}")
self.reset()
def reset(self):
"""Reset the code agent state."""
logger.debug("Resetting CodeAgent state")
self.agent = LMMAgent(
engine_params=self.engine_params,
system_prompt=PROCEDURAL_MEMORY.construct_coder_procedural_memory(platform=self.platform, client_password=self.client_password)
)
def execute(self, task_instruction: str, screenshot: str, env_controller) -> Dict:
"""Execute code for the given task with a budget of steps."""
if env_controller is None:
raise ValueError("env_controller is required for code execution")
print(f"\n🚀 STARTING CODE EXECUTION")
print("=" * 60)
print(f"Task: {task_instruction}")
print(f"Budget: {self.budget} steps")
print("=" * 60)
logger.info(f"Starting code execution for task: {task_instruction}")
logger.info(f"Budget: {self.budget} steps")
self.reset()
# Add initial task instruction and screenshot context as user message
context = (
f"Task: {task_instruction}\n\nCurrent screenshot is provided for context."
)
self.agent.add_message(context, image_content=screenshot, role="user")
step_count = 0
execution_history = []
execution_result_history = []
while step_count < self.budget:
logger.info(f"Step {step_count + 1}/{self.budget}")
# Get assistant response (thoughts and code)
response = call_llm_safe(self.agent, temperature=self.temperature)
# Print to terminal for immediate visibility
# print(f"\n🤖 CODING AGENT RESPONSE - Step {step_count + 1}/{self.budget}")
# print("=" * 60)
# print(response)
# print("=" * 60)
# Log the latest message from the coding agent (untruncated)
logger.info(
f"CODING_AGENT_LATEST_MESSAGE - Step {step_count + 1}:\n{response}"
)
# Check if response is None or empty
if not response or response.strip() == "":
error_msg = f"Step {step_count + 1}: LLM returned empty response"
logger.error(error_msg)
raise RuntimeError(error_msg)
# Parse the response to extract action
action = parse_code_from_string(response)
thoughts = response
execution_history.append(
{"step": step_count + 1, "action": action, "thoughts": thoughts}
)
# Check for completion signals
action_upper = action.upper().strip()
if action_upper == "DONE":
print(f"\n✅ TASK COMPLETED - Step {step_count + 1}")
print("=" * 60)
print("Agent signaled task completion")
print("=" * 60)
logger.info(f"Step {step_count + 1}: Task completed successfully")
completion_reason = "DONE"
break
elif action_upper == "FAIL":
print(f"\n❌ TASK FAILED - Step {step_count + 1}")
print("=" * 60)
print("Agent signaled task failure")
print("=" * 60)
logger.info(f"Step {step_count + 1}: Task failed by agent request")
completion_reason = "FAIL"
break
elif action_upper == 'INFEASIBLE':
print(f"\n❌ TASK INFEASIBLE - Step {step_count + 1}")
print("=" * 60)
print("Agent signaled task infeasible")
print("=" * 60)
logger.info(f"Step {step_count + 1}: Task infeasible by agent request")
completion_reason = "INFEASIBLE"
break
# Extract and execute code
code_type, code = extract_code_block(response.split("(Answer)")[-1])
if code:
result = execute_code(code_type, code, env_controller)
execution_result_history.append(
{"step": step_count + 1, "result": result}
)
# Prepare formatted output and error for logging
output = result.get("output", "")
error = result.get("error", "")
message = result.get("message", "")
status = result.get("status", "")
# Print execution result to terminal for immediate visibility
print(f"\n⚡ CODE EXECUTION RESULT - Step {step_count + 1}")
print("-" * 50)
print(f"Status: {status}")
if output:
print(f"Output:\n{output}")
if error:
print(f"Error:\n{error}")
if message and not output and not error:
print(f"Message:\n{message}")
print("-" * 50)
log_lines = [
f"CODING_AGENT_EXECUTION_RESULT - Step {step_count + 1}:",
f"Status: {status}" if status else None,
]
if output:
log_lines.append(
"Output:\n" + ("-" * 40) + f"\n{output}\n" + ("-" * 40)
)
if error:
log_lines.append(
"Error:\n" + ("!" * 40) + f"\n{error}\n" + ("!" * 40)
)
if message and not output and not error:
log_lines.append(
"Message:\n" + ("-" * 40) + f"\n{message}\n" + ("-" * 40)
)
# Remove None entries and join
formatted_log = "\n".join([line for line in log_lines if line])
logger.info(formatted_log)
else:
print(f"\n⚠️ NO CODE BLOCK FOUND - Step {step_count + 1}")
print("-" * 50)
print("Action did not contain executable code")
print("-" * 50)
logger.warning(f"Step {step_count + 1}: No code block found in action")
result = {"status": "skipped", "message": "No code block found"}
logger.info(
f"CODING_AGENT_EXECUTION_RESULT - Step {step_count + 1}:\n"
f"Status: skipped\n"
f"Message:\n{'-' * 40}\n{result['message']}\n{'-' * 40}"
)
# Add assistant's thoughts and code to message history
self.agent.add_message(response, role="assistant")
# Process result and add formatted environment results as user message
result_context = format_result(result, step_count)
self.agent.add_message(result_context, role="user")
step_count += 1
# Handle budget exhaustion
if "completion_reason" not in locals():
print(f"\n⏰ BUDGET EXHAUSTED - {step_count} steps completed")
print("=" * 60)
print(f"Maximum budget of {self.budget} steps reached")
print("=" * 60)
logger.info(f"Budget exhausted after {step_count} steps")
completion_reason = f"BUDGET_EXHAUSTED_AFTER_{step_count}_STEPS"
# Generate final summary
logger.info("Generating execution summary")
summary = self._generate_summary(execution_history, task_instruction)
result = {
"task_instruction": task_instruction,
"completion_reason": completion_reason,
"summary": summary,
"execution_history": execution_history,
"execution_result_history": execution_result_history,
"steps_executed": step_count,
"budget": self.budget
}
logger.info(f"Code execution completed: steps={step_count}")
return result
def _generate_summary(
self, execution_history: List[Dict], task_instruction: str
) -> str:
"""Generate summary of code execution session."""
if not execution_history:
logger.info("No execution history to summarize")
return "No actions were executed."
logger.info(f"Generated summary for {len(execution_history)} steps")
# Build detailed execution context for summary agent
execution_context = f"Task: {task_instruction}\n\nExecution Steps:\n"
for step in execution_history:
step_num = step["step"]
thoughts = step.get("thoughts", "")
action = step.get("action", "")
execution_context += f"\nStep {step_num}:\n"
if thoughts:
execution_context += f"Thoughts: {thoughts}\n"
execution_context += f"Code: {action}\n"
# Create summary prompt with same context as coding agent
summary_prompt = f"""
{execution_context}
Please provide a concise summary of the code execution session. Focus on:
1. The code logic implemented at each step
2. The outputs and results produced by each code execution
3. The progression of the solution approach
Do not make judgments about success or failure. Simply describe what was attempted and what resulted.
Keep the summary under 150 words and use clear, factual language.
"""
# Generate summary using LLM with dedicated summary system prompt
try:
summary_agent = LMMAgent(
engine_params=self.engine_params,
system_prompt=PROCEDURAL_MEMORY.CODE_SUMMARY_AGENT_PROMPT,
)
summary_agent.add_message(summary_prompt, role="user")
summary = call_llm_safe(summary_agent, temperature=self.temperature)
if not summary or summary.strip() == "":
summary = "Summary generation failed - no response from LLM"
logger.warning("Summary generation failed - empty response from LLM")
except Exception as e:
summary = f"Summary generation failed: {str(e)}"
logger.error(f"Error generating summary: {e}")
return summary

View File

@ -0,0 +1,109 @@
import re
from typing import Any, Dict, List
import pytesseract
from PIL import Image
import io
from mm_agents.os_symphony.core.mllm import LMMAgent
from mm_agents.os_symphony.utils.common_utils import call_llm_safe, smart_resize
from mm_agents.os_symphony.memory.procedural_memory import PROCEDURAL_MEMORY
import logging
logger = logging.getLogger("desktopenv.agent")
class GrounderAgent:
"""
Class designed for interacting with GUI, serving for Grounding Agent and VLMSearcher
"""
def __init__(self, engine_params: Dict, screen_width: int, screen_height: int):
self.engine_params_for_grounder = engine_params # grounder_params
system_prompt, self.user_message = PROCEDURAL_MEMORY.construct_grounder_procedural_memory(model_name=engine_params["model"])
self.grounding_model = LMMAgent(engine_params, system_prompt=system_prompt)
# Width and height for Grounding Agent!
self.width = engine_params['grounding_width']
self.height = engine_params['grounding_height']
print(f"[Grounder]: initialized width is {self.width}, height is {self.height}")
# Width and height for actual screen!
self.screen_width = screen_width
self.screen_height = screen_height
# Given the state and worker's referring expression, use the grounding model to generate (x,y)
def generate_coords(self, ref_expr: str, obs: Dict, detail=False, expansion_pixels=400, **kwargs) -> List:
cur_screenshot = obs["screenshot"]
# store global offset
global_offset_x = 0
global_offset_y = 0
# final coordinates for output
final_global_x = 0
final_global_y = 0
cur_width, cur_height = self.screen_width, self.screen_height
print(f"[Grounder] start to ground!")
self.grounding_model.reset()
# Configure the context
prompt = self.user_message.replace("REF_EXPR", ref_expr)
# cosistent with the system prompt presented in the paper of GTA-1
if 'gta' in self.engine_params_for_grounder['model']:
self.grounding_model.add_system_prompt("You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task.")
self.grounding_model.add_message(
text_content=prompt, image_content=cur_screenshot, put_text_last=True, role="user"
)
# Generate and parse coordinates
response = call_llm_safe(self.grounding_model, temperature=0.05, **kwargs)
print(f"[Grounder] prompt: {prompt}\nmodel: {self.engine_params_for_grounder['model']}, \nresponse: {response}")
# 1. highest priority: (x1="...", y1="...", x="...", y="...")
numericals = re.findall(r'(?:x1|y1|x|y)=["\']?(\d+)["\']?', response)
# 2. second highest priority: just like <points>653 42</points> or [653, 42]
if len(numericals) < 2:
clean_response = re.sub(r'[xXyY]\d', '', response)
numericals = re.findall(r'\d+', clean_response)
assert len(numericals) >= 2
print(f"[Grounder] the parsed coordinates: {numericals}")
local_x, local_y = self._resize_coordinates([int(numericals[0]), int(numericals[1])], width=cur_width, height=cur_height)
# current global coordinates = local ordinates + global offset
final_global_x = local_x + global_offset_x
final_global_y = local_y + global_offset_y
if detail:
return [cur_screenshot, global_offset_x, global_offset_y]
else:
return [final_global_x, final_global_y]
def dynamic_set_width_height(self, width: int, height: int):
self.width = width
self.height = height
# Resize from grounding model dim into OSWorld dim (1920 * 1080)
def _resize_coordinates(self, coordinates: List[int], width:int, height:int) -> List[int]:
"""
width, height: for current observation
grounding_width, grounding_height: width and height for Grounding model 1000x1000 or 1280x800)
"""
grounding_width = self.engine_params_for_grounder["grounding_width"]
grounding_height = self.engine_params_for_grounder["grounding_height"]
grounding_smart_resize = self.engine_params_for_grounder["grounding_smart_resize"]
if not grounding_smart_resize:
return [
round(coordinates[0] * width / grounding_width),
round(coordinates[1] * height / grounding_height),
]
else:
smart_height, smart_width = smart_resize(height, width)
return [
round(coordinates[0] * width / smart_width),
round(coordinates[1] * height / smart_height)
]

View File

@ -0,0 +1,428 @@
from ast import parse
import logging
import json
from typing import List, Dict, Any, Optional, Tuple
from mm_agents.os_symphony.utils.common_utils import (
call_llm_formatted,
enhance_observation,
parse_code_from_string
)
from functools import partial
from mm_agents.os_symphony.utils.formatters import JSON_ANSWER_FORMATTER
from mm_agents.os_symphony.core.mllm import LMMAgent
from mm_agents.os_symphony.memory.procedural_memory import PROCEDURAL_MEMORY
import imagehash
import io
import os
from PIL import Image
import numpy as np
from skimage.metrics import structural_similarity as ssim
logger = logging.getLogger("desktopenv.agent")
class StepBehavior:
"""
Narrative Step Behavior.
Description of each step, cosists of generative agent (main agent)'s output, screenshot (if this step is milestone), and textual description.
The textual description shows that how the agent thought and did, and how the state changes.
"""
def __init__(self, is_milestone: bool, gen_output: str, summary: str, obs: Dict, action_dict: Dict):
self.is_milestone = is_milestone
self.gen_output = gen_output
self.obs = obs
self.summary = summary
self.action_dict = action_dict
# Variants for opyimizing the time complexity of loop detection
# --- 1. pHash ---
self.phash = None
# --- 2. SSIM ---
self.ssim_list = []
def _update_phash_ssim(self, history: List):
# Calculate the ssim_list of current obs
# Update pHash
cur_img = Image.open(io.BytesIO(self.obs["screenshot"]))
cur_img_gray = cur_img.convert('L')
cur_img_np = np.array(cur_img_gray)
self.phash = imagehash.phash(cur_img)
# Update ssim_list
for hs in history:
compare_img = Image.open(io.BytesIO(hs.obs["screenshot"]))
compare_img_gray = compare_img.convert('L')
compare_img_np = np.array(compare_img_gray)
self.ssim_list.append(ssim(cur_img_np, compare_img_np, data_range=cur_img_np.max() - compare_img_np.min()))
class ReflectionMemoryAgent:
"""
Reflection Memory Agent (RMA).
Responsible for maintaining long-term memory, extracting narratives from trajectories,
providing reflections to the Main Agent, and validating task completion status.
"""
def __init__(self, engine_params: Dict):
"""
Initialize the RMA.
Args:
- engine_params:
{
"engine_type": args.provider,
"model": args.model,
"base_url": args.model_url,
"api_key": args.model_api_key,
"temperature": getattr(args, "model_temperature", None),
}
- max_img_len: max image number to use in reflection process
"""
self.engine_params = engine_params
self.max_images = engine_params.get('max_images', 8)
self.memoryer_level = engine_params.get('memoryer_level', 3)
self.reset()
logger.info(f"ReflectionMemoryAgent initialized with:\n {self.engine_params}")
def reset(self):
"""Reset the code agent state."""
logger.debug("Resetting RMA state")
self.instruction = None
self.trajectory: List[StepBehavior] = []
self.knowledge_base: List[str] = []
self.last_code_step_idx = -1
'''
Control the count of images, we only use the maximum number of max_img_len images.
The update logic: the 0-th screenshot is always retained. If the total number of screenshots is less than max_img_len, all are kept; otherwise, starting from index 1, milestone screenshots are managed via FIFO.
'''
self.active_img_idx = []
self.reflection_agent = LMMAgent(
engine_params=self.engine_params,
system_prompt=PROCEDURAL_MEMORY.REFLECTION_SYSTEM_PROMPT,
)
self.behavior_agent = LMMAgent(
engine_params=self.engine_params,
system_prompt=PROCEDURAL_MEMORY.SUMMARIZE_STEP_SYSTEM_PROMPT
)
def add_instruction(self, instruction):
"""
[Interface] Main -> RMA
Main agent set the instruction to RMA.
"""
self.instruction = instruction
def _update_trajectory(self, step_behavior):
self.trajectory.append(step_behavior)
if len(self.active_img_idx) >= self.max_images:
if step_behavior.is_milestone:
self.active_img_idx.append(len(self.trajectory) - 1) # over max_img_lenonly milestone image
del self.active_img_idx[1] # FIFO starts from index 1
else:
self.active_img_idx.append(len(self.trajectory) - 1) # less than max_img_len, feed all images
assert len(self.active_img_idx) <= self.max_images, "[RMA] Errors in updating StepBehavior!!"
def _summarize_step_behavior(
self,
generator_output: str,
cur_obs: Dict,
enhanced_obs: bytes | None,
is_milestone: bool,
mode: str = "gui",
code_exec_summary: str = "",
action_dict: Dict = {}
) -> Tuple[StepBehavior, bool]:
"""
[Interface] Main -> RMA
The Main Agent (MA) calls this method to "feed" the information of the just-completed step to the RMA.
RMA will internally process and store this step.
"""
if mode == "search":
is_success = "successful"
# summary is fixed
step_behavior = StepBehavior(
False,
generator_output,
"Search Agent was called last step, and a tutorial has been generated.",
cur_obs,
action_dict
)
elif mode == "code":
self.last_code_step_idx = len(self.trajectory)
is_success = "successful"
# the summary returned by the code agent
step_behavior = StepBehavior(
False,
generator_output,
f"Code Agent was called last step, and the summary of its trajectory is: \n---\n{code_exec_summary}\n---",
cur_obs,
action_dict
)
else: # common gui operation, use LLM to summarize
prev_obs = self.trajectory[-1].obs
text_content = f"""Computer Use Agent's Output: \n{generator_output}"""
self.behavior_agent.reset() # don't need history messages
updated_sys_prompt = (
self.behavior_agent.system_prompt + "\n" + text_content
)
self.behavior_agent.add_system_prompt(updated_sys_prompt)
self.behavior_agent.add_message(
text_content="This is the observation before executing action (attached below).",
image_content=prev_obs['screenshot'],
role="user",
put_text_last=False
)
self.behavior_agent.add_message(
text_content="This is the zoom-in view, which may help you to identify the operational region (attached below).",
image_content=enhanced_obs,
role="user",
put_text_last=False
)
self.behavior_agent.add_message(
text_content="This is the observation after executing action (attached below).",
image_content=cur_obs['screenshot'],
role="user",
put_text_last=False
)
required_fields = ["summary", "evaluation"]
format_checkers = [
partial(JSON_ANSWER_FORMATTER, required_fields)
]
full_response = call_llm_formatted(
self.behavior_agent,
format_checkers,
temperature=self.engine_params.get("temperture", 0.1),
)
response = parse_code_from_string(full_response)
try:
data = json.loads(response)
behavior_summary = data['summary']
is_success = data["evaluation"]
except Exception as e:
print("[RMA] Errors in generating step summary: ", e)
logger.info("Response is not a JSON object or miss required keys!")
behavior_summary = response
is_success = "successful"
step_behavior = StepBehavior(is_milestone, generator_output, behavior_summary, cur_obs, action_dict)
return step_behavior, is_success == "successful"
def get_reflection(
self,
cur_obs: Dict,
generator_output: str,
coordinates: List,
mode: str="gui",
code_exec_summary: str = "",
action_dict: Dict = {}
) -> Dict:
"""
[Interface] RMA -> Main
The Main Agent (MA) calls this method to get RMA's reflection before deciding the next action.
Args:
- cur_obs (Dict): The Main Agent's current observation (o_k).
- generator_output (str): The thoughts, screen analysis and action of Main Agent.
- coordinates (List): coordinates in the last operation step of Main Agent.
- mode(str): [gui, code, search]. Indicate which agent that main agent called last step.
- code_exec_summary: execution summary for code agent.
- action_dict: extracted action from generator output.
Returns:
- reflection_info(Dict): all the info related to reflection
"""
if self.memoryer_level == 0:
return {
"reflection": None,
"reflection_thoughts": None,
"existing_knowledge": None,
"is_milestone": False,
"new_knowledge": None,
"step_summary": None,
"hint": {
"gui_operation_error": False,
"lack_of_tutorial": False,
"code_error": False,
"loop_detection": None,
}
}
reflection = None
reflection_thought = None
if len(self.trajectory) == 0:
step_behavior = StepBehavior(
True,
"The initial screen is provided. No action has been taken yet.",
"The initial screen is provided. No action has been taken yet.",
cur_obs,
action_dict
)
step_behavior._update_phash_ssim(self.trajectory)
self._update_trajectory(step_behavior)
reflection_info = {
"reflection": reflection,
"reflection_thoughts": reflection_thought,
"existing_knowledge": "\n".join(self.knowledge_base),
"is_milestone": True,
"new_knowledge": "",
"step_summary": "",
"loop_detection": None
}
else:
### Step Summary
prev_obs = self.trajectory[-1].obs
enhanced_obs = None
if coordinates:
enhanced_obs, _, _, _, _ = enhance_observation(
prev_obs["screenshot"],
coordinates,
draw=True
)
# generate step behavior
step_behavior, last_gui_check = self._summarize_step_behavior(
generator_output,
cur_obs,
enhanced_obs,
False,
mode,
code_exec_summary,
action_dict
)
step_behavior._update_phash_ssim(self.trajectory)
### make additional hints
additional_hints = []
if not last_gui_check:
additional_hints.append(f"\t- Warning: The last GUI operation is unsuccessful. Careful review is required to avoid GUI Operation Error.")
code_error_hint = False
if self.last_code_step_idx != -1 and len(self.trajectory) - self.last_code_step_idx < 0:
code_error_hint = True
additional_hints.append(f"\t- Warning: The Computer Use Agent might in the verification stage of Code Agent. Careful review is required to avoid Code Error.")
# loop detection
from mm_agents.os_symphony.utils.loop_detection import detect_loop
is_loop, loop_details = detect_loop(full_trajectory=self.trajectory, N=3)
if is_loop and loop_details:
match_sequence_indices = loop_details["match_sequence_indices"]
loop_hint_message = f"\t- Warning: A potential LOOP has been detected between Step {match_sequence_indices[0]} and Step {match_sequence_indices[-1]}. Careful review is required to avoid Repetitive Behavior Error."
additional_hints.append(loop_hint_message)
self.reflection_agent.reset()
updated_sys_prompt = (
PROCEDURAL_MEMORY.REFLECTION_SYSTEM_PROMPT + "\n\n" +
f"---\n- **user instruction**: {self.instruction}\n" +
"- **existing knowledge**: \n" + "\n".join(self.knowledge_base) +
"\n- **additional_hints**: " + "\n".join(additional_hints) + "\n---"
)
# update system prompt
self.reflection_agent.add_system_prompt(updated_sys_prompt)
for i, step in enumerate(self.trajectory):
text_content = f"""### (Step {i}) history:\nsummary: '''\n{step.summary}\n'''"""
if i in self.active_img_idx:
if i == 0:
text_content += f"\ninitial screenshot:"
else:
text_content += f"\nscreenshot (after executing action): (attached below)"
self.reflection_agent.add_message(
text_content=text_content,
image_content=step.obs['screenshot'] if i in self.active_img_idx else None,
role="user",
put_text_last=False
)
text_content = f"""### (Last Step) CUA's output (has been finished):\n---\n{generator_output}\n---\nStep Summary:\n---\n{step_behavior.summary}\n---\nlatest_screenshot: (attached below)"""
self.reflection_agent.add_message(
text_content=text_content,
image_content=cur_obs['screenshot'],
role="user",
put_text_last=False
)
required_fields = ["is_milestone", "reflection", "knowledge"]
format_checkers = [
partial(JSON_ANSWER_FORMATTER, required_fields)
]
full_response = call_llm_formatted(
self.reflection_agent,
format_checkers
)
reflection_thought = full_response
response = parse_code_from_string(full_response)
try:
data = json.loads(response)
reflection = data['reflection']
is_milestone = data["is_milestone"]
knowledge = data['knowledge']
except Exception as e:
print("[RMA] Errors in dealing with reflection: ", e)
logger.info("Response is not a JSON object or miss required keys!")
reflection = response
is_milestone = False
knowledge = ""
if len(knowledge) > 0:
self.knowledge_base.append(knowledge)
if isinstance(is_milestone, str):
is_milestone = True if "true" in is_milestone.lower() else False
# update trajectory and is_milestone
self._update_trajectory(step_behavior)
if mode == "gui": # only gui opration can be considered as milestone
self.trajectory[-1].is_milestone = is_milestone
reflection_info = {
"reflection": reflection,
"reflection_thoughts": reflection_thought,
"existing_knowledge": "\n".join(self.knowledge_base),
"is_milestone": is_milestone,
"new_knowledge": knowledge,
"step_summary": step_behavior.summary,
"hint": {
"gui_operation_error": not last_gui_check,
"lack_of_tutorial": is_loop,
"code_error": code_error_hint,
"loop_detection": loop_details,
}
}
return reflection_info

View File

@ -0,0 +1,178 @@
import re
from io import BytesIO
from typing import Tuple, List, Dict
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import pytesseract
from pytesseract import Output
import easyocr
class OCRProcessor:
"""
OCR Processor supports Tesseract and EasyOCR
"""
def __init__(self, use_gpu: bool = False, languages: List[str] = ['en']):
"""
Initialize processor
Args:
use_gpu (bool): whether EasyOCR need to use gpu
languages (List[str]): language list that EasyOCR, e.g. ['en', 'ch_sim']
"""
self.use_gpu = use_gpu
self.languages = languages
self.reader = None # lazy-load EasyOCR Reader
def _get_easyocr_reader(self):
if self.reader is None:
print(f"Loading EasyOCR model (GPU={self.use_gpu})...")
self.reader = easyocr.Reader(self.languages, gpu=self.use_gpu)
return self.reader
def get_ocr_elements(self, bytes_image_data: bytes, mode: str = 'tesseract') -> Tuple[str, List[Dict]]:
"""
Executes OCR recognization.
Args:
bytes_image_data (str): image in Base64
mode (str): 'tesseract' (faster) or 'easyocr' (more precise)
Returns:
Tuple[str, List]: (textual table string, list of element details)
"""
try:
image = Image.open(BytesIO(bytes_image_data))
except Exception as e:
print(f"Error decoding or opening image: {e}")
return "", []
if mode == 'tesseract':
return self._process_tesseract(image)
elif mode == 'easyocr':
return self._process_easyocr(image)
else:
raise ValueError(f"Unknown mode: {mode}. Use 'tesseract' or 'easyocr'.")
def _process_tesseract(self, image: Image.Image) -> Tuple[str, List[Dict]]:
"""Tesseract processing"""
data = pytesseract.image_to_data(image, output_type=Output.DICT)
ocr_elements = []
ocr_table = "Text Table (Tesseract):\nWord id\tText\n"
ocr_id = 0
num_boxes = len(data['text'])
for i in range(num_boxes):
# filter text with low confidence
if int(data['conf'][i]) > 0 and data['text'][i].strip():
clean_text = re.sub(r"^[^a-zA-Z0-9\s.,!?;:\-\+]+|[^a-zA-Z0-9\s.,!?;:\-\+]+$", "", data['text'][i])
if not clean_text: continue
ocr_table += f"{ocr_id}\t{clean_text}\n"
ocr_elements.append({
"id": ocr_id,
"text": clean_text,
"mode": "tesseract",
"left": data["left"][i],
"top": data["top"][i],
"width": data["width"][i],
"height": data["height"][i],
"conf": data["conf"][i]
})
ocr_id += 1
return ocr_table, ocr_elements
def _process_easyocr(self, image: Image.Image) -> Tuple[str, List[Dict]]:
"""EasyOCR processing"""
reader = self._get_easyocr_reader()
image_np = np.array(image)
# detail=1 means returning (bbox, text, conf)
results = reader.readtext(image_np, detail=1, paragraph=False, width_ths=0.1)
ocr_elements = []
ocr_table = "Text Table (EasyOCR):\nWord id\tText\n"
ocr_id = 0
for (bbox, text, conf) in results:
clean_text = re.sub(r"^[^a-zA-Z0-9\s.,!?;:\-\+]+|[^a-zA-Z0-9\s.,!?;:\-\+]+$", "", text)
if not clean_text.strip(): continue
# EasyOCR returns [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
# we convert them into left, top, width, height
(tl, tr, br, bl) = bbox
tl = [int(v) for v in tl]
br = [int(v) for v in br]
left = min(tl[0], bl[0])
top = min(tl[1], tr[1])
right = max(tr[0], br[0])
bottom = max(bl[1], br[1])
width = right - left
height = bottom - top
# ---------------
ocr_table += f"{ocr_id}\t{clean_text}\n"
ocr_elements.append({
"id": ocr_id,
"text": clean_text,
"mode": "easyocr",
"left": left,
"top": top,
"width": width,
"height": height,
"conf": float(conf)
})
ocr_id += 1
return ocr_table, ocr_elements
@staticmethod
def visualize_ocr_results(image_path: str, ocr_elements: List[Dict], output_path: str):
"""
Draw bounding boxes and IDs on the original image.
"""
try:
image = Image.open(image_path).convert("RGB")
draw = ImageDraw.Draw(image)
try:
font = ImageFont.truetype("arial.ttf", 16)
except IOError:
font = ImageFont.load_default()
for element in ocr_elements:
left, top = element["left"], element["top"]
width, height = element["width"], element["height"]
color = "green" if element.get("mode") == "easyocr" else "red"
draw.rectangle([(left, top), (left + width, top + height)], outline=color, width=2)
text_str = str(element["id"])
if hasattr(draw, "textbbox"):
bbox = draw.textbbox((0, 0), text_str, font=font)
text_w, text_h = bbox[2]-bbox[0], bbox[3]-bbox[1]
else:
text_w, text_h = draw.textsize(text_str, font=font)
label_bg = [left, top - text_h - 4, left + text_w + 4, top]
draw.rectangle(label_bg, fill=color)
draw.text((left + 2, top - text_h - 4), text_str, fill="white", font=font)
image.save(output_path)
print(f"Visualization saved to: {output_path}")
except FileNotFoundError:
print(f"Error: Image {image_path} not found.")
except Exception as e:
print(f"Visualization error: {e}")

View File

@ -0,0 +1,575 @@
import re
from typing import Any, Dict, List, Optional, Tuple
from mm_agents.os_symphony.memory.procedural_memory import PROCEDURAL_MEMORY
from mm_agents.os_symphony.core.mllm import LMMAgent
from mm_agents.os_symphony.utils.common_utils import call_llm_safe
from mm_agents.os_symphony.agents.coder_agent import CoderAgent
from mm_agents.os_symphony.agents.grounder_agent import GrounderAgent
from mm_agents.os_symphony.agents.searcher_agent import SearcherAgent
import logging
from mm_agents.os_symphony.agents.ocr import OCRProcessor
logger = logging.getLogger("desktopenv.agent")
# Agent action decorator
def agent_action(func):
func.is_agent_action = True
return func
# GrounderAgent primitives are parameterized by description, and coordinate generation uses a pretrained grounding model
class OSACI:
def __init__(
self,
env,
search_env,
platform: str,
client_password: str,
engine_params_for_ocr: Dict,
engine_params_for_grounder: Dict,
engine_params_for_coder: Dict,
engine_params_for_searcher: Dict,
screen_width: int = 1920,
screen_height: int = 1080
):
self.env = env
self.platform = platform
self.client_password = client_password
self.result_dir = ""
self.grounder_agent = GrounderAgent(engine_params=engine_params_for_grounder, screen_width=screen_width, screen_height=screen_height)
# Configure text grounding agent
self.ocr_processor = OCRProcessor()
self.text_span_agent = LMMAgent(
engine_params=engine_params_for_ocr,
system_prompt=PROCEDURAL_MEMORY.PHRASE_TO_WORD_COORDS_PROMPT,
)
# Configure code agent
self.coder_agent = CoderAgent(
engine_params=engine_params_for_coder,
platform=self.platform,
client_password=client_password
)
# Configure search agent
self.searcher_agent = SearcherAgent.create(
engine_params=engine_params_for_searcher,
search_env=search_env,
grounder_agent=self.grounder_agent,
platform=self.platform,
client_password=self.client_password
)
# Store task instruction for code agent
self.current_task_instruction = None
self.last_code_agent_result = None
self.last_search_agent_result = None
self.notes: List[str] = []
# Tutorial should be a global info, not a local context, so how to add it to the global info
self.tutorials = []
def assign_screenshot(self, obs):
self.obs = obs
# Given the state and worker's text phrase, generate the coords of the first/last word in the phrase
def generate_text_coords(
self, phrase: str, obs: Dict, alignment: str = ""
) -> List[int]:
screenshot, global_offset_x, global_offset_y= obs["screenshot"], 0, 0
ocr_table, ocr_elements = self.ocr_processor.get_ocr_elements(screenshot, "easyocr")
alignment_prompt = ""
if alignment == "start":
alignment_prompt = "**Important**: Output the word id of the FIRST word in the provided phrase.\n"
elif alignment == "end":
alignment_prompt = "**Important**: Output the word id of the LAST word in the provided phrase.\n"
# Load LLM prompt
self.text_span_agent.reset()
self.text_span_agent.add_message(
alignment_prompt + "Phrase: " + phrase + "\n" + ocr_table, role="user"
)
self.text_span_agent.add_message(
"Screenshot:\n", image_content=screenshot, role="user"
)
# Obtain the target element
response = call_llm_safe(self.text_span_agent)
print("TEXT SPAN AGENT RESPONSE:", response)
numericals = re.findall(r"\d+", response)
if len(numericals) > 0:
text_id = int(numericals[-1])
else:
text_id = 0
elem = ocr_elements[text_id]
# Compute the element coordinates
# Note: 0.1 * elem["height"] is used to adjust coordinates to select the last character more precisely.
if alignment == "start":
coords = [elem["left"], elem["top"] + (elem["height"] // 2)]
elif alignment == "end":
coords = [elem["left"] + elem["width"] + 0.15 * elem["height"], elem["top"] + (elem["height"] // 2)]
print(f'[OCR] output coordinates: {[coords[0] + global_offset_x, coords[1] + global_offset_y]}')
return [int(coords[0] + global_offset_x), int(coords[1] + global_offset_y)]
def set_task_instruction(self, task_instruction: str):
"""Set the current task instruction for the code agent."""
self.current_task_instruction = task_instruction
@agent_action
def click(
self,
element_description: str,
num_clicks: int = 1,
button_type: str = "left",
hold_keys: List = []
):
"""Click on the element
Args:
element_description:str, a detailed descriptions of which element to click on. This description needs to be VERY unambiguous. If the page contains many similar elements, ensure the description uniquely identifies the target element.
num_clicks:int, number of times to click the element
button_type:str, which mouse button to press can be "left", "middle", or "right"
hold_keys:List, list of keys to hold while clicking
"""
x, y = self.grounder_agent.generate_coords(element_description, self.obs)
command = "import pyautogui; "
for k in hold_keys:
command += f"pyautogui.keyDown({repr(k)}); "
command += f"""import pyautogui; pyautogui.click({x}, {y}, clicks={num_clicks}, button={repr(button_type)}); """
for k in hold_keys:
command += f"pyautogui.keyUp({repr(k)}); "
# Return pyautoguicode to click on the element
action = {"function": "click", "args": {"x": x, "y": y, "button": button_type, "clicks": num_clicks}}
return (command, action)
@agent_action
def open(self, app_or_filename: str):
"""Open any application or file with name app_or_filename. Use this action to open applications or files on the desktop, do not open manually.
Args:
app_or_filename:str, the name of the application or filename to open
**Important**:
Provide only the name of the application or file. Do not include the full path (e.g., "/home/user/Desktop/my_report.docx"). The function works by searching for the name, not by accessing a file path directly.
"""
action = {"function": "open", "args": {"name": app_or_filename}}
if self.platform == "linux":
return (f"import pyautogui; pyautogui.hotkey('win'); time.sleep(1.0); pyautogui.write({repr(app_or_filename)}); time.sleep(1.0); pyautogui.hotkey('enter'); time.sleep(1.0)", action)
elif self.platform == "macos":
return (f"import pyautogui; import time; pyautogui.hotkey('command', 'space', interval=0.5); pyautogui.typewrite({repr(app_or_filename)}); pyautogui.press('enter'); time.sleep(1.0)", action)
elif self.platform == "windows":
return (f"import pyautogui; import time; pyautogui.hotkey('win'); time.sleep(0.5); pyautogui.write({repr(app_or_filename)}); time.sleep(1.0); pyautogui.press('enter'); time.sleep(0.5)", action)
else:
assert (
False
), f"Unsupported platform: {self.platform}. Supported platforms are: darwin, linux, windows."
def _paste(self, is_terminal):
if self.platform == 'macos':
return "pyautogui.hotkey('command', 'v');"
elif self.platform == 'linux':
if is_terminal:
return "pyautogui.hotkey('ctrl', 'shift', 'v');"
else:
return "pyautogui.hotkey('ctrl', 'v');"
elif self.platform == 'windows':
return "pyautogui.hotkey('ctrl', 'v');"
return ""
def _clear_all(self, is_terminal):
"""
Clean the content of current line
"""
# common apps in GUI
if not is_terminal:
if self.platform == 'macos':
# macOS GUI: Command + A -> Backspace
return "pyautogui.hotkey('command', 'a'); pyautogui.press('backspace');"
else:
# Windows/Linux GUI: Ctrl + A -> Backspace
return "pyautogui.hotkey('ctrl', 'a'); pyautogui.press('backspace');"
# terminal
else:
if self.platform == 'windows':
return "pyautogui.press('esc');"
else:
return "pyautogui.hotkey('ctrl', 'e'); pyautogui.hotkey('ctrl', 'u');"
def _type(
self,
text: str,
is_terminal: bool
):
"""
use copy and paste to input Chinese, otherwise type normally
"""
commands = ""
has_unicode = any(ord(char) > 127 for char in text)
if has_unicode and self.platform != "macos":
commands += (
"original_clipboard = pyperclip.paste();"
f"pyperclip.copy({repr(text)});"
"time.sleep(0.1);"
)
commands += self._paste(is_terminal=is_terminal)
commands += "pyperclip.copy(original_clipboard);"
else:
commands += f"pyautogui.write({repr(text)}, interval=0.1);"
return commands
@agent_action
def type(
self,
element_description: str,
text: str = "",
overwrite: bool = False,
enter: bool = False,
is_terminal = False
):
"""Type text/unicode into a specific element
Args:
element_description: str, a detailed description of which element to enter text in. If provided, the agent will click on this element before typing.
text:str, the text to type
overwrite:bool, Default is False, assign it to True if the text should overwrite the whole existing text. Using this argument clears all text in an element.
enter:bool, Assign it to True if the enter key should be pressed after typing all the text, otherwise assign it to False.
is_terminal:bool, (MANDATORY) You MUST set this to True whenever the target you will type into is a terminal.
"""
commands = (
"import os;"
"import pyautogui;"
"import pyperclip;"
"import subprocess;"
"import time;"
)
if self.platform == "linux":
commands += (
"p_http = os.environ.get('http_proxy') or os.environ.get('HTTP_PROXY');"
"p_https = os.environ.get('https_proxy') or os.environ.get('HTTPS_PROXY');"
"proxy_prefix = (f'http_proxy={p_http} ' if p_http else '') + (f'https_proxy={p_https} ' if p_https else '');"
f"subprocess.run(f'echo \"{self.client_password}\" | sudo -S {{proxy_prefix}}apt-get install -y xclip xsel', shell=True, check=True);"
)
x, y = None, None
if element_description is not None:
x, y = self.grounder_agent.generate_coords(element_description, self.obs)
commands += (
f"pyautogui.click({x}, {y}, clicks=2);"
f"time.sleep(1.0);"
f"pyautogui.click({x}, {y});"
)
if overwrite:
commands += self._clear_all(is_terminal=is_terminal)
commands += self._type(text=text, is_terminal=is_terminal)
if enter:
commands += "pyautogui.press('enter');"
if element_description is not None:
action = {"function": "type", "args": {"x": x, "y": y, "text": text}}
else:
action = {"function": "type", "args": {"text": text}}
return (commands, action)
@agent_action
def drag_and_drop(
self, starting_description: str, ending_description: str, hold_keys: List = []
):
"""Drag from the starting description to the ending description
Args:
starting_description:str, a very detailed description of where to start the drag action. This description should be at least a full sentence.
ending_description:str, a very detailed description of where to end the drag action. This description should be at least a full sentence.
hold_keys:List list of keys to hold while dragging
"""
x1, y1 = self.grounder_agent.generate_coords(starting_description, self.obs)
x2, y2 = self.grounder_agent.generate_coords(ending_description, self.obs)
command = "import pyautogui; "
command += f"pyautogui.moveTo({x1}, {y1}); "
# TODO: specified duration?
for k in hold_keys:
command += f"pyautogui.keyDown({repr(k)}); "
command += f"pyautogui.dragTo({x2}, {y2}, duration=3., button='left'); pyautogui.mouseUp(); "
for k in hold_keys:
command += f"pyautogui.keyUp({repr(k)}); "
# Return pyautoguicode to drag and drop the elements
action = {"function": "drag", "args": {"x1": x1, "y1": y1, "x2": x2, "y2": y2}}
return (command, action)
@agent_action
def highlight_text_span(
self,
starting_phrase: str,
ending_phrase: str,
button: str = "left",
text: Optional[str|None] = None
):
"""Highlight a text span between a provided starting phrase and ending phrase. Use this to highlight words, lines, and paragraphs.
Args:
starting_phrase: str, the sequence of words that marks the beginning of the text span. Provide a unique sequence of 5 to 10 words.
ending_phrase: str, the sequence of words that marks the end of the text span. Provide a unique sequence of 5 to 10 words.
button:str, the button to use to highlight the text span. Defaults to "left". Can be "left", "right", or "middle".
text: str | None, The text to overwrite the highlighted span with. Providing text here ensures the replacement happens immediately after selection, preventing focus loss.
"""
x1, y1 = self.generate_text_coords(
starting_phrase, self.obs, alignment="start"
)
x2, y2 = self.generate_text_coords(
ending_phrase, self.obs, alignment="end"
)
command = "import pyautogui; import time;"
command += f"pyautogui.moveTo({x1}, {y1}); "
# Click in advance to simulate selecting the text box.
command += (
f"pyautogui.click({x1}, {y1}, clicks=2);"
f"time.sleep(1.0); pyautogui.click({x1}, {y1}); time.sleep(1.0);"
)
command += f"pyautogui.dragTo({x2}, {y2}, duration=5., button='{button}'); time.sleep(0.5); pyautogui.mouseUp(); "
if text:
if self.platform == "linux":
command += "subprocess.run('echo \"password\" | sudo -S apt-get install -y xclip xsel', shell=True, check=True, env={\"http_proxy\": \"http://10.1.8.5:23128\", \"https_proxy\": \"http://10.1.8.5:23128\"});"
command += (
"original_clipboard = pyperclip.paste();"
f"pyperclip.copy({repr(text)});"
)
command += self._paste(is_terminal=False)
command += "pyperclip.copy(original_clipboard);"
# Return pyautoguicode to drag and drop the elements
action = {"function": "drag", "args": {"x1": x1, "y1": y1, "x2": x2, "y2": y2}}
return (command, action)
@agent_action
def locate_cursor(
self,
phrase: str,
start_or_end: str="start",
text: Optional[str|None] = None
):
"""Click at the beginning or end of a specific text phrase to precisely control cursor positioning. Please prefer using the "click" action in general situations, and use this action only in text-intensive software such as libreoffice_writer, impress, etc.
Args:
phrase: str, The text phrase where you want to position the cursor. Provide a unique sequence of 5 to 10 words. Do NOT use single words unless the total text is extremely short.
start_or_end: str, Whether to click at the "start" (beginning) or "end" (trailing edge) of the identified text phrase. Use "start" to position before the text, "end" to position after it.
text: str | None, The text to enter immediately after positioning the cursor. Use this parameter instead of a separate 'type' action to ensure precise input.
"""
x, y = self.generate_text_coords(
phrase, self.obs, alignment=start_or_end
)
command = (
"import pyautogui;"
"import time;"
"import subprocess;"
"import pyperclip;"
f"pyautogui.click({x}, {y}, button='left', clicks=2);"
"time.sleep(1.0);"
f"pyautogui.click({x}, {y}, button='left');"
)
if text:
if self.platform == "linux":
command += "subprocess.run('echo \"password\" | sudo -S apt-get install -y xclip xsel', shell=True, check=True, env={\"http_proxy\": \"http://10.1.8.5:23128\", \"https_proxy\": \"http://10.1.8.5:23128\"});"
command += self._type(text=text, is_terminal=False)
if text:
action = {"function": "type", "args": {"x": x, "y": y, "text": text}}
else:
action = {"function": "click", "args": {"x": x, "y": y, "clicks": 1, "button": "left"}}
return (command, action)
@agent_action
def call_code_agent(self, task: str):
"""Calls the code agent to execute a well-defined, self-contained goal that can be completed with code.
Args:
task: str, A specific, self-contained goal that the code agent can work on until completion.
**🚨 CRITICAL GUIDELINES:**
**Decompose the Main Objective into Logical Goals:**
- You **MUST** break down the overall mission into distinct, logical goals or stages.
- Your role is to define *what* needs to be done for a specific stage. The code agent's role is to figure out *how* to do it with code.
- Pass only one logical goal at a time. The `task` parameter is **REQUIRED**.
**Define a Self-Contained, Continuous Goal:**
- The `task` you provide should be a single, continuous goal. The code agent is capable of handling a multi-step process internally (e.g., opening a file, processing its data, and then saving it) to achieve this one goal.
- **Crucially, do not pass a task that combines multiple distinct objectives.** For example, instead of passing "Analyze the sales data, AND email the result," you should first pass the self-contained goal: "Analyze the sales data." After that goal is complete, you can proceed with the next logical goal (e.g., emailing the result) in a subsequent step.
- **If unsure, err on the side of caution.** If a task feels like it has two separate parts, break it down and pass only the first part.
- Your instruction must describe the desired end-state, NOT the recipe to get there. Do not specify any solution!
**Goal Purity is Essential:**
- **NEVER** rephrase, paraphrase, or modify the subtask instruction you have decided on. Pass the exact, original wording of the subtask to prevent instruction drift and hallucination.
Use this for tasks that can be fully accomplished through code execution, particularly for:
- Spreadsheet applications: data processing, filtering, sorting, calculations, formulas, data analysis
- Document editors: text processing, content editing, formatting, document manipulation
- Code editors: code editing, file processing, text manipulation, configuration
- Data analysis tools: statistical analysis, data transformation, reporting
- File management: bulk operations, file processing, content extraction
- System utilities: configuration, setup, automation
"""
logger.info("=" * 50)
logger.info("ACI: Calling Code Agent")
logger.info("=" * 50)
task_to_execute = task
logger.info(f"Executing SUBTASK: {task_to_execute}")
print("obs keys: ", self.obs.keys())
screenshot = self.obs.get("screenshot", "") if self.obs else ""
logger.info(f"Screenshot available: {'Yes' if screenshot else 'No'}")
logger.info("Executing code agent...")
result = self.coder_agent.execute(
task_to_execute, screenshot, self.env.controller
)
# Store the result for the worker to access
self.last_code_agent_result = result
logger.info("Code agent execution completed")
logger.info(f"Result - Completion reason: {result['completion_reason']}")
logger.info(f"Steps executed: {result['steps_executed']}")
logger.info(f"Summary: {result['summary']}")
logger.info("=" * 50)
logger.info("GROUNDING AGENT: Code Agent Call Finished")
logger.info("=" * 50)
action = {"function": "call_code_agent", "args": {"query": task, "result": True if result["completion_reason"] == "DONE" else False}}
# Return code to be executed in the environment
return ("import time; time.sleep(2.222)", action)
@agent_action
def scroll(self, element_description: str, clicks: int, shift: bool = False):
"""Scroll the element in the specified direction
Args:
element_description:str, a very detailed description of which element to enter scroll in. This description should be at least a full sentence.
clicks:int, the number of clicks to scroll can be positive (up) or negative (down).
shift:bool, whether to use shift+scroll for horizontal scrolling
"""
x, y = self.grounder_agent.generate_coords(element_description, self.obs)
action = {"function": "scroll", "args": {"x": x, "y": y, "clicks": clicks, "shift": shift}}
if shift:
return (f"import pyautogui; import time; pyautogui.moveTo({x}, {y}); time.sleep(0.5); pyautogui.hscroll({clicks})", action)
else:
return (f"import pyautogui; import time; pyautogui.moveTo({x}, {y}); time.sleep(0.5); pyautogui.vscroll({clicks})", action)
@agent_action
def hotkey(self, keys: List):
"""Press a hotkey combination (can press a single key as well)
Args:
keys:List the keys to press in combination in a list format (e.g. ['ctrl', 'c'], ['enter'])
"""
# add quotes around the keys
keys = [f"'{key}'" for key in keys]
keys_string = " ".join(keys)
action = {"function": "key", "args": {"keys": keys_string}}
return (f"import pyautogui; pyautogui.hotkey({', '.join(keys)});", action)
@agent_action
def hold_and_press(self, hold_keys: List, press_keys: List):
"""Hold a list of keys and press a list of keys
Args:
hold_keys:List, list of keys to hold
press_keys:List, list of keys to press in a sequence
"""
press_keys_str = "[" + ", ".join([f"'{key}'" for key in press_keys]) + "]"
command = "import pyautogui; "
for k in hold_keys:
command += f"pyautogui.keyDown({repr(k)}); "
command += f"pyautogui.press({press_keys_str}); "
for k in hold_keys:
command += f"pyautogui.keyUp({repr(k)}); "
hold_keys_string = " ".join(hold_keys)
press_keys_string = " ".join(press_keys)
action = {"function": "key", "args": {"keys": hold_keys_string + ";" + press_keys_string}}
return (command, action)
@agent_action
def wait(self, time: float):
"""Wait for a specified amount of time
Args:
time:float, the amount of time to wait in seconds
"""
return (f"""import time; time.sleep({time});""", {"function": "wait", "args": {}})
@agent_action
def done(
self,
):
"""
End the current task with a success. Use this when you believe the entire task has been fully completed. You must ensure all visual information aligns with the user's true intent.
"""
return ("""DONE""", {"function": "done", "args": {}})
@agent_action
def fail(self):
"""End the current task with a failure. Use this when you believe the entire task is impossible to complete."""
return ("""FAIL""", {"function": "fail", "args": {}})
@agent_action
def call_search_agent(
self,
query: str,
):
"""
Calls a specialized 'Searcher Agent' to find a detailed, step-by-step tutorial on the internet for a specific GUI action.
Args:
query:str, the search phrase or question for the tutorial. The formulation of this query is critical for success and must follow the guidelines below.
**Query Formulation Guidelines:**
Your query must be a well-defined question targeting a **single, specific action** within a **specific application**. To get the best results, adhere to these rules:
1. **Start with "How to":** Your query must begin with the phrase "How to" to frame it as a request for instructions.
2. **Include the Application Name:** Always specify the name of the software you are working in (e.g., "GIMP", "Google Chrome", "Libreoffice Writer").
3. **Focus on a Single Intent:** The query should represent one clear goal. Do not combine multiple steps or tasks into one query.
4. **Be Specific, Not Abstract:** Ask a concrete question. Avoid repeating the user's high-level or abstract instructions.
5. **Decompose Complex Tasks:** If the user's overall instruction involves multiple actions (e.g., "download a file and then email it"), and you are stuck on one part, search *only for that specific part*.
**Examples:**
* **User's Overall Instruction:** "Please help me download my latest bank statement and then send it to my accountant."
* **Correct Query (if stuck on downloading):** "How to download a bank statement from the Bank of America website?"
* **Correct Query (if stuck on attaching a file):** "How to attach a file to an email in Gmail?"
* **Incorrect Query:** "Download my bank statement and email it to my accountant" *(This query is too broad, contains multiple sub-tasks, and does not start with "How to".)*
"""
logger.info("=" * 50)
logger.info(f"ACI: Calling Search Agent(query={query})")
logger.info("=" * 50)
self.searcher_agent.result_dir = self.result_dir
result = self.searcher_agent.search(query=query, main_obs=self.obs)
self.last_search_agent_result = result
if result["completion_reason"] == "DONE":
self.tutorials.append(result["final_answer"])
action = {"function": "call_search_agent", "args": {"query": query, "result": True if result["completion_reason"] == "DONE" else False}}
return ("import time; time.sleep(2.222)", action)

View File

@ -0,0 +1,61 @@
import logging
import platform
from typing import Dict, List, Tuple
from mm_agents.os_symphony.agents.os_aci import OSACI
from mm_agents.os_symphony.agents.searcher_agent import VLMSearcherAgent
from mm_agents.os_symphony.agents.worker import Worker
logger = logging.getLogger("desktopenv.agent")
class OSSymphony:
def __init__(
self,
engine_params_for_orchestrator: Dict,
engine_params_for_memoryer: Dict,
os_aci: OSACI,
platform: str = platform.system().lower(),
client_password: str = "",
max_trajectory_length: int = 8,
enable_reflection: bool = True,
):
"""
Args:
worker_engine_params: Configuration parameters for the worker agent.
grounding_agent: Instance of ACI class for UI interaction
platform: Operating system platform (darwin, linux, windows)
max_trajectory_length: Maximum number of image turns to keep
enable_reflection: Creates a reflection agent to assist the worker agent
"""
self.engine_params_for_orchestrator = engine_params_for_orchestrator
self.engine_params_for_memoryer = engine_params_for_memoryer
self.os_aci: OSACI = os_aci
self.platform =platform
self.client_password = client_password
self.max_trajectory_length = max_trajectory_length
self.enable_reflection = enable_reflection
def reset(self, result_dir) -> None:
"""Reset agent state and initialize components"""
# Reset the search time per task
self.os_aci.result_dir = result_dir
self.executor = Worker(
engine_params_for_orchestrator=self.engine_params_for_orchestrator,
engine_params_for_memoryer=self.engine_params_for_memoryer,
os_aci=self.os_aci,
platform=self.platform,
client_password=self.client_password,
max_trajectory_length=self.max_trajectory_length,
enable_reflection=self.enable_reflection,
)
def predict(self, instruction: str, observation: Dict, is_last_step: bool) -> Tuple[Dict, List[str]]:
# Initialize the three info dictionaries
executor_info, actions = self.executor.generate_next_action(
instruction=instruction, obs=observation, is_last_step=is_last_step
)
# concatenate the three info dictionaries
info = {**{k: v for d in [executor_info or {}] for k, v in d.items()}}
return info, actions

View File

@ -0,0 +1,478 @@
import logging
import urllib.parse
from typing import Any, Dict, List, Optional
from mm_agents.os_symphony.memory.procedural_memory import PROCEDURAL_MEMORY
from mm_agents.os_symphony.utils.common_utils import (
draw_coordinates,
call_llm_formatted,
parse_code_from_string,
create_pyautogui_code
)
from mm_agents.os_symphony.core.mllm import LMMAgent
from mm_agents.os_symphony.agents.grounder_agent import GrounderAgent
import os
import time
import json
logger = logging.getLogger("desktopenv.searcher_agent")
# Agent action decorator
def searcher_agent_action(func):
func.is_searcher_agent_action = True
return func
# --- Abstract Base Class and Factory ---
class SearcherAgent:
def __init__(self, engine_params: Dict, platform: str):
self.engine_params = engine_params
self.result_dir = ""
self.tutorial_or_hint = ""
self.tutorial_notes = []
self.max_trajectory_length = 8
self.platform = platform
self.budget = engine_params.get("budget", 20)
@staticmethod
def create(engine_params: Dict, search_env, grounder_agent: GrounderAgent, platform: str, client_password: str="password"):
searcher_type = engine_params.get("type", "vlm")
if searcher_type == "vlm":
return VLMSearcherAgent(engine_params=engine_params, search_env=search_env, grounder_agent=grounder_agent, platform=platform, client_password=client_password)
else:
raise NotImplementedError
def _get_search_time(self) -> int:
"""for the name of result directory"""
if not self.result_dir: return 1
search_times: list[int] = []
try:
if not os.path.exists(self.result_dir): return 1
for item_name in os.listdir(self.result_dir):
full_path = os.path.join(self.result_dir, item_name)
if os.path.isdir(full_path) and item_name.startswith("search_"):
try:
time_val = int(item_name.split('_', 1)[1])
search_times.append(time_val)
except (ValueError, IndexError):
continue
except Exception:
return 1
if not search_times: return 1
return max(search_times) + 1
def search(self, query: str, obs) -> str:
"""
Args:
query: Format like "How to xxxx?", must be a detailed subtask
obs: Current screenshot
"""
raise NotImplementedError("Subclasses must implement the 'search' method")
class VLMSearcherAgent(SearcherAgent):
"""
Start a new, isolated vm, and open chrome in advance
"""
def __init__(self, engine_params: Dict, search_env, grounder_agent: GrounderAgent, platform: str, client_password: str):
SearcherAgent.__init__(self, engine_params=engine_params, platform=platform)
self.grounder_agent = grounder_agent
self.client_password = client_password
self.env = search_env
self.use_thinking = engine_params.get("model", "") in [
"claude-opus-4-20250514",
"claude-sonnet-4-20250514",
"claude-3-7-sonnet-20250219",
"claude-sonnet-4-5-20250929",
]
self.engine = engine_params.get("engine", "google")
# Reuse OSWorld's initialization script to set up Chrome, then directly perform a Google search using the query—currently, the query can be substituted by a placeholder field.
self.task_config = {
"id": "searcher",
"instruction": "searcher",
"config": [
{
"type": "launch",
"parameters": {
"command": [
"google-chrome",
"--remote-debugging-port=1337"
]
}
},
{
"type": "launch",
"parameters": {
"command": [
"socat",
"tcp-listen:9222,fork",
"tcp:localhost:1337"
]
}
},
{
"type": "chrome_open_tabs",
"parameters": {
"urls_to_open": [
"GOOGLE_SEARCH_URL"
]
}
},
{
"type": "activate_window",
"parameters": {
"window_name": "Google Chrome"
}
}
],
"proxy": True
}
self.obs = None
def reset(self, query):
# When the search function is invoked, a new agent is created; the environment is instantiated only upon the first call, but it must be reset on every invocation.
self.tutorial_notes = []
self.tutorial_or_hint = ""
self.system_prompt = PROCEDURAL_MEMORY.construct_vlm_searcher_procedural_memory(
agent_class=type(self)
).replace("CURRENT_OS", self.platform).replace("QUERY", query)
self.searcher_agent = LMMAgent(
engine_params=self.engine_params,
system_prompt=self.system_prompt
)
self.env.start()
# config URL and initialize search environment (google/duckduckgo)
search_url = f"https://www.google.com/search?q=" + urllib.parse.quote_plus(query) if self.engine == "google" else f"https://www.duckduckgo.com/?q=" + urllib.parse.quote_plus(query)
self.task_config["config"][2]["parameters"]["urls_to_open"][0] = search_url
self.env.reset(task_config=self.task_config)
print("[Searcher] sleeping...")
time.sleep(5)
def flush_messages(self):
"""Flush messages based on the model's context limits.
This method ensures that the agent's message history does not exceed the maximum trajectory length.
Side Effects:
- Modifies the messages of generator, reflection, and bon_judge agents to fit within the context limits.
"""
engine_type = self.engine_params.get("engine_type", "")
# Flush strategy for long-context models: keep all text, only keep latest images
if engine_type in ["anthropic", "openai", "gemini"]:
max_images = self.max_trajectory_length
for agent in [self.searcher_agent]:
if agent is None:
continue
# keep latest k images
# @Yang: keep the first main agent image
img_count = 0
for i in range(len(agent.messages) - 1, 1, -1):
for j in range(len(agent.messages[i]["content"]) - 1, -1, -1):
if "image" in agent.messages[i]["content"][j].get("type", ""):
img_count += 1
if img_count > max_images:
del agent.messages[i]["content"][j]
# Flush strategy for non-long-context models: drop full turns
else:
# generator msgs are alternating [user, assistant], so 2 per round
if len(self.searcher_agent.messages) > 2 * self.max_trajectory_length + 1:
self.searcher_agent.messages.pop(1)
self.searcher_agent.messages.pop(1)
def assign_screenshot(self, obs):
self.obs = obs
def search(self, query: str, main_obs):
# only create vm when search is called
self.reset(query=query) # reset
search_result_dir = os.path.join(self.result_dir, f"search_{self._get_search_time()}")
os.makedirs(search_result_dir, exist_ok=True)
obs = self.env._get_obs() # Get the initial observation
step_idx = 0
initial_state_text = (
"This screenshot shows the current visual context of the main GUI Agent you are assisting. "
"Use this image to understand the application, the current view, and the overall environment. "
"Your primary goal is to find a tutorial that is highly relevant and well-aligned with this specific context, "
"ensuring the instructions you find are applicable to what the main agent is currently seeing."
)
self.searcher_agent.add_message(
text_content=initial_state_text,
image_content=main_obs["screenshot"],
role="user"
)
execution_history = []
completion_reason = ""
final_answer = ""
while step_idx < self.budget:
# update system_prompt dynamically
tutorial_notes_str = ""
if len(self.tutorial_notes) > 0:
for i, note in enumerate(self.tutorial_notes, 1):
tutorial_notes_str += f"Tutorial Note {i}: {note}\n\n"
if step_idx == self.budget - 1:
# eager mode
self.system_prompt = PROCEDURAL_MEMORY.construct_searcher_eager_mode_procedural_memory(
agent_class=type(self)
).replace("CURRENT_OS", self.platform).replace("QUERY", query)
system_prompt = self.system_prompt.replace("TUTORIAL_PLACEHOLDER", tutorial_notes_str)
self.searcher_agent.add_system_prompt(system_prompt=system_prompt)
# start a new turn
self.assign_screenshot(obs=obs)
generator_message = ""
self.searcher_agent.add_message(
generator_message, image_content=obs["screenshot"], role="user"
)
format_checkers = []
# predict action
plan = call_llm_formatted(
self.searcher_agent,
format_checkers,
temperature=self.engine_params.get("temperture", 0.1),
use_thinking=self.use_thinking,
)
self.searcher_agent.add_message(plan, role="assistant")
execution_history.append(plan)
logger.info("SEARCHER PLAN:\n %s", plan)
plan_code = parse_code_from_string(plan)
try:
assert plan_code, "Plan code should not be empty"
# exec_code e.g. import pyautogui; pyautogui.click(1, 2);
exec_code, coords = create_pyautogui_code(self, plan_code, obs)
except Exception as e:
logger.error(
f"Could not evaluate the following plan code:\n{plan_code}\nError: {e}"
)
exec_code = self.wait(
1.333
) # Skip a turn if the code cannot be evaluated
self.flush_messages()
# execute action
action = exec_code
logger.info("Step %d: %s", step_idx + 1, action)
# Save screenshot and trajectory information
with open(os.path.join(search_result_dir, f"step_{step_idx + 1}.png"),
"wb") as _f:
_f.write(obs['screenshot'])
if coords is not None and isinstance(coords, list):
draw_coordinates(
image_bytes=obs['screenshot'],
coordinates=coords,
save_path=os.path.join(search_result_dir, f"step_{step_idx + 1}_draw.png")
)
with open(os.path.join(search_result_dir, "traj.jsonl"), "a", encoding="utf-8") as f:
f.write(json.dumps({
"query": query,
"step_num": step_idx + 1,
"action": action,
"response": {
"plan": plan,
"plan_code": plan_code,
"coordinates": coords
},
"screenshot_file": f"step_{step_idx + 1}.png"
}, ensure_ascii=False))
f.write("\n")
with open(os.path.join(search_result_dir, f"traj_{step_idx+1}.json"), "w", encoding="utf-8") as f:
json.dump({
"query": query,
"step_num": step_idx + 1,
"action": action,
"response": {
"plan": plan,
"plan_code": plan_code,
"coordinates": coords
},
"screenshot_file": f"step_{step_idx + 1}.png"
}, f, indent=4, ensure_ascii=False)
if exec_code in ["DONE", "FAIL"]:
# terminate loop
completion_reason = exec_code
final_answer = self.tutorial_or_hint
break
else:
obs, _, _, _ = self.env.step(action, 5)
step_idx += 1
if completion_reason == "":
completion_reason = "BUDGET_EXHAUSTED"
final_answer = "Sorry, can't get the useful tutorial about the GUI task you provided."
return {
"query": query,
"completion_reason": completion_reason,
"tutorial_notes": self.tutorial_notes,
"execution_history": execution_history,
"steps_executed": step_idx,
"budget": self.budget,
"final_answer": final_answer,
}
@searcher_agent_action
def click(
self,
element_description: str,
num_clicks: int = 1,
button_type: str = "left",
):
"""Click on the element
Args:
element_description:str, a detailed descriptions of which element to click on. This description should be at least a full sentence.
num_clicks:int, number of times to click the element
button_type:str, which mouse button to press can be "left", "middle", or "right"
"""
x, y = self.grounder_agent.generate_coords(element_description, self.obs)
command = "import pyautogui; "
command += f"""import pyautogui; pyautogui.click({x}, {y}, clicks={num_clicks}, button={repr(button_type)}); """
# Return pyautoguicode to click on the element
return (command, [x, y])
@searcher_agent_action
def type(
self,
element_description: Optional[str] = None,
text: str = "",
overwrite: bool = True,
enter: bool = False
):
"""Type text/unicode into a specific element
Args:
element_description:str, a detailed description of which element to enter text in. This description should be at least a full sentence.
text:str, the text to type
overwrite:bool, Default is True, assign it to False if the text should not overwrite the existing text. Using this argument clears all text in an element.
enter:bool, Assign it to True if the enter key should be pressed after typing the text, otherwise assign it to False.
"""
commands = (
"import os;"
"import pyautogui;"
"import pyperclip;"
"import subprocess;"
"import time;"
"p_http = os.environ.get('http_proxy') or os.environ.get('HTTP_PROXY');"
"p_https = os.environ.get('https_proxy') or os.environ.get('HTTPS_PROXY');"
"proxy_prefix = (f'http_proxy={p_http} ' if p_http else '') + (f'https_proxy={p_https} ' if p_https else '');"
f"subprocess.run(f'echo \"{self.client_password}\" | sudo -S {{proxy_prefix}}apt-get install -y xclip xsel', shell=True, check=True);"
)
click_coords = None
if element_description is not None:
x, y = self.grounder_agent.generate_coords(element_description, self.obs)
click_coords = [x, y]
commands += f"pyautogui.click({x}, {y});"
if overwrite:
commands += (
f"pyautogui.hotkey('ctrl', 'a');"
"pyautogui.press('backspace');"
)
# use paste to input
commands += (
"original_clipboard = pyperclip.paste();"
f"pyperclip.copy({repr(text)});"
"pyautogui.hotkey('ctrl', 'v');"
"pyperclip.copy(original_clipboard);"
)
if enter:
commands += "pyautogui.press('enter');"
if click_coords is not None:
return (commands, click_coords)
else:
return commands
@searcher_agent_action
def scroll(self, element_description: str, clicks: int, shift: bool = False):
"""Scroll the element in the specified direction
Args:
element_description:str, a very detailed description of which element to enter scroll in. This description should be at least a full sentence.
clicks:int, the number of clicks to scroll can be positive (up) or negative (down).
shift:bool, whether to use shift+scroll for horizontal scrolling
"""
x, y = self.grounder_agent.generate_coords(element_description, self.obs)
if shift:
return (f"import pyautogui; import time; pyautogui.moveTo({x}, {y}); time.sleep(0.5); pyautogui.hscroll({clicks})", [x, y])
else:
return (f"import pyautogui; import time; pyautogui.moveTo({x}, {y}); time.sleep(0.5); pyautogui.vscroll({clicks})", [x, y])
@searcher_agent_action
def hotkey(self, keys: List):
"""Press a hotkey combination (can press a single key as well)
Args:
keys: List the keys to press in combination in a list format (e.g. ['ctrl', 'c'], ['enter'])
"""
# add quotes around the keys
keys = [f"'{key}'" for key in keys]
return f"import pyautogui; pyautogui.hotkey({', '.join(keys)})"
@searcher_agent_action
def save_to_tutorial_notes(self, text: str):
"""Save high quality and useful information to a long-term knowledge bank for reuse during this search task.
Args:
text:str, the text to save to the tutorial notes
"""
self.tutorial_notes.append(text)
return """WAIT"""
@searcher_agent_action
def wait(self, time: float):
"""Wait for a specified amount of time
Args:
time:float the amount of time to wait in seconds
"""
return f"""import time; time.sleep({time})"""
@searcher_agent_action
def done(
self,
tutorial: str
):
"""End the current task with a success. Use this when you believe the entire task has been fully completed.
Args:
tutorial:str, A detailed, step-by-step tutorial compiled from the search results to be passed to the main agent.
"""
self.tutorial_or_hint = tutorial
return """DONE"""
@searcher_agent_action
def fail(
self,
hint: str
):
"""End the current task with a failure. Use this when you believe the entire task is impossible to complete.
Args:
hint:str, A hint or reason explaining why the search failed, or what kind of information was missing.
"""
self.tutorial_or_hint = hint
return """FAIL"""

View File

@ -0,0 +1,340 @@
from functools import partial
import logging
from typing import Dict, List, Tuple
from mm_agents.os_symphony.agents.memoryer_agent import ReflectionMemoryAgent
from mm_agents.os_symphony.agents.os_aci import OSACI
from mm_agents.os_symphony.core.module import BaseModule
from mm_agents.os_symphony.memory.procedural_memory import PROCEDURAL_MEMORY
from mm_agents.os_symphony.utils.common_utils import (
call_llm_formatted,
extract_coords_from_action_dict,
parse_action_from_string,
parse_code_from_string,
create_pyautogui_code,
)
from mm_agents.os_symphony.utils.formatters import (
SINGLE_ACTION_FORMATTER,
CODE_VALID_FORMATTER,
)
logger = logging.getLogger("desktopenv.agent")
class Worker(BaseModule):
def __init__(
self,
engine_params_for_orchestrator: Dict,
engine_params_for_memoryer: Dict,
os_aci: OSACI,
platform: str,
client_password: str,
max_trajectory_length: int = 8,
enable_reflection: bool = True,
):
"""
Worker receives the main task and generates actions, without the need of hierarchical planning
Args:
worker_engine_params: Dict
Parameters for the worker agent
os_aci: Agent
The grounding agent to use
platform: str
OS platform the agent runs on (darwin, linux, windows)
max_trajectory_length: int
The amount of images turns to keep
enable_reflection: bool
Whether to enable reflection
"""
super().__init__(platform=platform)
self.client_password = client_password
self.temperature = engine_params_for_orchestrator.get("temperature", 0.0)
self.tool_config = engine_params_for_orchestrator.get("tool_config", "")
self.use_thinking = engine_params_for_orchestrator.get("model", "") in [
"claude-opus-4-20250514",
"claude-sonnet-4-20250514",
"claude-3-7-sonnet-20250219",
"claude-sonnet-4-5-20250929",
]
self.engine_params_for_orchestrator = engine_params_for_orchestrator
self.engine_params_for_memoryer = engine_params_for_memoryer
self.os_aci: OSACI = os_aci
self.max_trajectory_length = max_trajectory_length if not self.engine_params_for_orchestrator.get("keep_first_image", False) else max_trajectory_length - 1
self.enable_reflection = enable_reflection
self.reset()
def reset(self):
# set_cell_values only occurs in linux; meanwhile there is no fail option in the other benchmarks
if self.platform in ["windows", "macos"]:
skipped_actions = ["set_cell_values", "fail"]
else:
skipped_actions = []
# Hide code agent action entirely if no env/controller is available
if not getattr(self.os_aci, "env", None) or not getattr(
getattr(self.os_aci, "env", None), "controller", None
):
skipped_actions.append("call_code_agent")
self.orchestrator_sys_prompt = PROCEDURAL_MEMORY.construct_simple_worker_procedural_memory(
agent_class=type(self.os_aci),
skipped_actions=skipped_actions,
tool_config=self.tool_config,
platform=self.platform
).replace("CURRENT_OS", self.platform).replace("CLIENT_PASSWORD", self.client_password)
# Worker contains orchestrator and reflection agent
self.orchestrator_agent = self._create_agent(
engine_params=self.engine_params_for_orchestrator,
system_prompt=self.orchestrator_sys_prompt
)
self.memoryer_agent = ReflectionMemoryAgent(self.engine_params_for_memoryer)
self.instruction = None
self.turn_count = 0
self.worker_history = []
self.coords_history = []
# For loop detection
self.action_dict_history = []
def flush_messages(self):
"""Flush messages based on the model's context limits.
This method ensures that the agent's message history does not exceed the maximum trajectory length.
Side Effects:
- Modifies the messages of generator, reflection, and bon_judge agents to fit within the context limits.
"""
engine_type = self.engine_params_for_orchestrator.get("engine_type", "")
# Flush strategy for long-context models: keep all text, only keep latest images
if engine_type in ["anthropic", "openai", "gemini", "vllm"]:
max_images = self.max_trajectory_length
# for agent in [self.generator_agent, self.reflection_agent]:
for agent in [self.orchestrator_agent]:
if agent is None:
continue
# keep latest k images
img_count = 0
stop_idx = 1 if self.engine_params_for_orchestrator.get("keep_first_image", False) else -1
for i in range(len(agent.messages) - 1, stop_idx, -1):
# for j in range(len(agent.messages[i]["content"])):
for j in range(len(agent.messages[i]["content"]) - 1, -1, -1):
if "image" in agent.messages[i]["content"][j].get("type", ""):
img_count += 1
if img_count > max_images:
del agent.messages[i]["content"][j]
# Flush strategy for non-long-context models: drop full turns
else:
# generator msgs are alternating [user, assistant], so 2 per round
if len(self.orchestrator_agent.messages) > 2 * self.max_trajectory_length + 1:
self.orchestrator_agent.messages.pop(1)
self.orchestrator_agent.messages.pop(1)
def generate_next_action(self, instruction: str, obs: Dict, is_last_step: bool) -> Tuple[Dict, List]:
"""
Predict the next action(s) based on the current observation.
"""
print("=" * 30, f"Turn {self.turn_count + 1}", "=" * 30)
print("=" * 10)
print(instruction)
print("=" * 10)
self.os_aci.assign_screenshot(obs)
self.os_aci.set_task_instruction(instruction)
generator_message = (
""
if self.turn_count > 0
else "The initial screen is provided. No action has been taken yet."
)
# Load the task into the system prompt
if is_last_step:
# Eager mode: must decide done / fail
prompt_with_instructions = PROCEDURAL_MEMORY.construct_eager_mode_procedural_memory(agent_class=type(self.os_aci)).replace(
"TASK_DESCRIPTION", instruction
).replace(
"CURRENT_OS", self.platform
)
print(f'Eager Mode Started, Instruction: {prompt_with_instructions}')
self.orchestrator_agent.add_system_prompt(prompt_with_instructions)
generator_message += "Note: 'EAGER MODE' is enabled. You must determine whether the task is done or fail in this step!!!"
else:
tutorials = ""
for idx, t in enumerate(self.os_aci.tutorials, start=1):
tutorials += f"### Tutorial {idx}:\n {t}\n"
prompt_with_instructions = self.orchestrator_sys_prompt.replace(
"TASK_DESCRIPTION", instruction
).replace(
"TUTORIAL_PLACEHOLDER", tutorials
)
self.orchestrator_agent.add_system_prompt(prompt_with_instructions)
# print(self.orchestrator_agent.system_prompt)
### Reflection Part
reflection_info = {}
if self.enable_reflection:
# set instruction to memory agent
self.memoryer_agent.add_instruction(instruction)
reflection = None
# Differentiate the operation mode of last step
last_code_summary = ""
mode = "gui"
if (
hasattr(self.os_aci, "last_code_agent_result")
and self.os_aci.last_code_agent_result is not None
):
# If code agent is called last step, we use its execution result as step behavior.
code_result = self.os_aci.last_code_agent_result
mode = "code"
last_code_summary += f"Subtask Instruction: {code_result['task_instruction']}\nSteps Completed: {code_result['steps_executed']}\nCompletion Reason: {code_result['completion_reason']}\nExec Summary: {code_result['summary']}\n"
if (
hasattr(self.os_aci, "last_search_agent_result")
and self.os_aci.last_search_agent_result is not None
):
mode = "search"
# retrieve reflection!!!
reflection_info = self.memoryer_agent.get_reflection(
cur_obs=obs,
# only use the string after "(next action)" in orchestrator's output
generator_output=parse_action_from_string(self.worker_history[-1]) if self.turn_count != 0 else "",
coordinates=self.coords_history[-1] if self.turn_count != 0 else [],
mode=mode,
code_exec_summary=last_code_summary,
action_dict=self.action_dict_history[-1] if self.turn_count != 0 else {}
)
reflection = reflection_info['reflection']
logger.info(f'[Reflection]: {reflection}')
if reflection:
generator_message += f"REFLECTION: You MUST use this reflection on the latest action:\n{reflection}\n"
else:
generator_message += "You should go on with your plan.\n"
else:
generator_message += "You should go on with your plan.\n"
# Add code agent result from previous step if available (from full task or subtask execution)
if (
hasattr(self.os_aci, "last_code_agent_result")
and self.os_aci.last_code_agent_result is not None
):
code_result = self.os_aci.last_code_agent_result
generator_message += f"\nCODE AGENT RESULT:\n"
generator_message += (
f"Task/Subtask Instruction: {code_result['task_instruction']}\n"
)
generator_message += f"Steps Completed: {code_result['steps_executed']}\n"
generator_message += f"Max Steps: {code_result['budget']}\n"
generator_message += (
f"Completion Reason: {code_result['completion_reason']}\n"
)
generator_message += f"Summary: {code_result['summary']}\n"
generator_message += "\n"
# Reset the code agent result after adding it to context
self.os_aci.last_code_agent_result = None
if (
hasattr(self.os_aci, "last_search_agent_result")
and self.os_aci.last_search_agent_result is not None
):
# Retrieve the result dictionary
search_result = self.os_aci.last_search_agent_result
# Add a clear, distinct header for this section in the prompt
generator_message += f"\nSEARCH AGENT RESULT:\n"
# Add contextual metadata from the search task
generator_message += f"Search Query: {search_result['query']}\n"
generator_message += f"Search Completion Reason: {search_result['completion_reason']}\n"
generator_message += "Search Result: "
# Add the most important part: the tutorial found by the agent.
# This is given a prominent sub-header so the LLM knows to pay close attention.
if search_result["completion_reason"] == "DONE":
generator_message += f'Search is completed, the tutorial it found has been already added to your system prompt.\n'
elif search_result["completion_reason"] == "FAIL":
generator_message += f"Search is fail, the failure reason or the hint is as follow: {search_result['final_answer']}\n"
# CRITICAL: Reset the search agent result after adding it to the context.
# This prevents it from being added to the prompt again in the next turn.
self.os_aci.last_search_agent_result = None
# Finalize the generator message
self.orchestrator_agent.add_message(
generator_message, image_content=obs["screenshot"], role="user", put_text_last=True
)
# Generate the plan and next action
format_checkers = [
SINGLE_ACTION_FORMATTER,
partial(CODE_VALID_FORMATTER, self.tool_config),
]
plan = call_llm_formatted(
self.orchestrator_agent,
format_checkers,
temperature=self.engine_params_for_orchestrator.get("temperture", 0.1),
use_thinking=self.use_thinking,
)
self.worker_history.append(plan)
self.orchestrator_agent.add_message(plan, role="assistant")
logger.info("PLAN:\n %s", plan)
# Extract the next action from the plan
# 此时的plan code e.g. agent.click('xxxxx', 1)
plan_code = parse_code_from_string(plan)
action_dict, coordinates = None, None
try:
assert plan_code, "Plan code should not be empty"
# exec_code e.g. import pyautogui; pyautogui.click(1, 2);
exec_code, action_dict = create_pyautogui_code(self.os_aci, plan_code, obs)
coordinates = extract_coords_from_action_dict(action_dict)
except Exception as e:
logger.error(
f"Could not evaluate the following plan code:\n{plan_code}\nError: {e}"
)
exec_code, action_dict = self.os_aci.wait(
1.333
) # Skip a turn if the code cannot be evaluated
self.action_dict_history.append(action_dict)
executor_info = {
"refined_instruction": self.instruction,
"plan": plan,
"plan_code": plan_code,
"exec_code": exec_code,
"coordinates": coordinates,
"reflection": reflection_info,
"code_agent_output": (
self.os_aci.last_code_agent_result
if hasattr(self.os_aci, "last_code_agent_result")
and self.os_aci.last_code_agent_result is not None
else None
),
"search_agent_output": (
self.os_aci.last_search_agent_result
if hasattr(self.os_aci, "last_search_agent_result")
and self.os_aci.last_search_agent_result is not None
else None
)
}
self.turn_count += 1
self.coords_history.append(coordinates)
self.flush_messages()
return executor_info, [exec_code]

View File

View File

@ -0,0 +1,480 @@
import copy
import json
import logging
import os
import base64
import backoff
from anthropic import Anthropic
from openai import (
AzureOpenAI,
APIConnectionError,
APIError,
AzureOpenAI,
OpenAI,
RateLimitError,
)
logger = logging.getLogger("desktopenv.agents.engine")
logger = logging.getLogger("desktopenv.agents.engine")
class LMMEngine:
pass
class LMMEngineOpenAI(LMMEngine):
def __init__(
self,
base_url=None,
api_key=None,
model=None,
rate_limit=-1,
temperature=None,
organization=None,
**kwargs,
):
assert model is not None, "model must be provided"
self.model = model
self.base_url = base_url
self.api_key = api_key
self.organization = organization
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
self.llm_client = None
self.temperature = temperature # Can force temperature to be the same (in the case of o3 requiring temperature to be 1)
@backoff.on_exception(
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
)
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
api_key = self.api_key or os.getenv("OPENAI_API_KEY")
if api_key is None:
raise ValueError(
"An API Key needs to be provided in either the api_key parameter or as an environment variable named OPENAI_API_KEY"
)
organization = self.organization or os.getenv("OPENAI_ORG_ID")
# H 集群认证 最后再删!!!!!!
if self.model.lower().startswith("ui") or self.model.lower().startswith("qwen") or self.model.lower().startswith("scale") or self.model.lower().startswith("holo"):
custom_headers = {
"Authorization": "Basic NWFkMzQxMDBlZTA1NWE0YmFlNjYzNzBhNWU2ODNiYWM6NjA3ZGU4MjQ5NjU3YTNiM2JkMDM2ZGM5NmQ0YzBiMmY="
}
else:
custom_headers = {}
if not self.llm_client:
if not self.base_url:
self.llm_client = OpenAI(
api_key=api_key,
organization=organization,
default_headers=custom_headers
)
else:
self.llm_client = OpenAI(
base_url=self.base_url,
api_key=api_key,
organization=organization,
default_headers=custom_headers
)
# print(**kwargs)
payload_size = len(json.dumps(messages)) / 1024 / 1024
logger.info(f"Payload size: {len(json.dumps(messages)) / 1024 / 1024:.2f} MB")
if payload_size > 30:
logger.info("Payload size exceeds 30MB!!!")
result = self.llm_client.chat.completions.create(
model=self.model,
messages=messages,
# max_completion_tokens=max_new_tokens if max_new_tokens else 4096,
temperature=(
temperature if self.temperature is None else self.temperature
),
**kwargs,
)
usage = result.usage
response = result.choices[0].message.content
return (response, usage)
class LMMEngineAnthropic(LMMEngine):
def __init__(
self,
base_url=None,
api_key=None,
model=None,
thinking=False,
temperature=None,
**kwargs,
):
assert model is not None, "model must be provided"
self.model = model
self.thinking = thinking
self.api_key = api_key
self.llm_client = None
self.temperature = temperature
@backoff.on_exception(
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
)
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
api_key = self.api_key or os.getenv("ANTHROPIC_API_KEY")
if api_key is None:
raise ValueError(
"An API Key needs to be provided in either the api_key parameter or as an environment variable named ANTHROPIC_API_KEY"
)
self.llm_client = Anthropic(api_key=api_key)
# Use the instance temperature if not specified in the call
temp = self.temperature if temperature is None else temperature
if self.thinking:
full_response = self.llm_client.messages.create(
system=messages[0]["content"][0]["text"],
model=self.model,
messages=messages[1:],
max_tokens=8192,
thinking={"type": "enabled", "budget_tokens": 4096},
**kwargs,
)
thoughts = full_response.content[0].thinking
return full_response.content[1].text
return (
self.llm_client.messages.create(
system=messages[0]["content"][0]["text"],
model=self.model,
messages=messages[1:],
max_tokens=max_new_tokens if max_new_tokens else 4096,
temperature=temp,
**kwargs,
)
.content[0]
.text
)
@backoff.on_exception(
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
)
# Compatible with Claude-3.7 Sonnet thinking mode
def generate_with_thinking(
self, messages, temperature=0.0, max_new_tokens=None, **kwargs
):
"""Generate the next message based on previous messages, and keeps the thinking tokens"""
api_key = self.api_key or os.getenv("ANTHROPIC_API_KEY")
if api_key is None:
raise ValueError(
"An API Key needs to be provided in either the api_key parameter or as an environment variable named ANTHROPIC_API_KEY"
)
self.llm_client = Anthropic(api_key=api_key)
full_response = self.llm_client.messages.create(
system=messages[0]["content"][0]["text"],
model=self.model,
messages=messages[1:],
max_tokens=8192,
thinking={"type": "enabled", "budget_tokens": 4096},
**kwargs,
)
thoughts = full_response.content[0].thinking
answer = full_response.content[1].text
full_response = (
f"<thoughts>\n{thoughts}\n</thoughts>\n\n<answer>\n{answer}\n</answer>\n"
)
return full_response
class LMMEngineGemini(LMMEngine):
def __init__(
self,
base_url=None,
api_key=None,
model=None,
rate_limit=-1,
temperature=None,
**kwargs,
):
assert model is not None, "model must be provided"
self.model = model
self.base_url = base_url
self.api_key = api_key
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
self.llm_client = None
self.temperature = temperature
@backoff.on_exception(
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
)
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
api_key = self.api_key or os.getenv("GEMINI_API_KEY")
if api_key is None:
raise ValueError(
"An API Key needs to be provided in either the api_key parameter or as an environment variable named GEMINI_API_KEY"
)
base_url = self.base_url or os.getenv("GEMINI_ENDPOINT_URL")
if base_url is None:
raise ValueError(
"An endpoint URL needs to be provided in either the endpoint_url parameter or as an environment variable named GEMINI_ENDPOINT_URL"
)
if not self.llm_client:
self.llm_client = OpenAI(base_url=base_url, api_key=api_key)
# Use the temperature passed to generate, otherwise use the instance's temperature, otherwise default to 0.0
temp = self.temperature if temperature is None else temperature
return (
self.llm_client.chat.completions.create(
model=self.model,
messages=messages,
max_tokens=max_new_tokens if max_new_tokens else 4096,
temperature=temp,
**kwargs,
)
.choices[0]
.message.content
)
class LMMEngineOpenRouter(LMMEngine):
def __init__(
self,
base_url=None,
api_key=None,
model=None,
rate_limit=-1,
temperature=None,
**kwargs,
):
assert model is not None, "model must be provided"
self.model = model
self.base_url = base_url
self.api_key = api_key
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
self.llm_client = None
self.temperature = temperature
@backoff.on_exception(
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
)
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
api_key = self.api_key or os.getenv("OPENROUTER_API_KEY")
if api_key is None:
raise ValueError(
"An API Key needs to be provided in either the api_key parameter or as an environment variable named OPENROUTER_API_KEY"
)
base_url = self.base_url or os.getenv("OPEN_ROUTER_ENDPOINT_URL")
if base_url is None:
raise ValueError(
"An endpoint URL needs to be provided in either the endpoint_url parameter or as an environment variable named OPEN_ROUTER_ENDPOINT_URL"
)
if not self.llm_client:
self.llm_client = OpenAI(base_url=base_url, api_key=api_key)
# Use self.temperature if set, otherwise use the temperature argument
temp = self.temperature if self.temperature is not None else temperature
return (
self.llm_client.chat.completions.create(
model=self.model,
messages=messages,
max_tokens=max_new_tokens if max_new_tokens else 4096,
temperature=temp,
**kwargs,
)
.choices[0]
.message.content
)
class LMMEngineAzureOpenAI(LMMEngine):
def __init__(
self,
base_url=None,
api_key=None,
azure_endpoint=None,
model=None,
api_version=None,
rate_limit=-1,
temperature=None,
**kwargs,
):
assert model is not None, "model must be provided"
self.model = model
self.api_version = api_version
self.api_key = api_key
self.azure_endpoint = azure_endpoint
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
self.llm_client = None
self.cost = 0.0
self.temperature = temperature
@backoff.on_exception(
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
)
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
api_key = self.api_key or os.getenv("AZURE_OPENAI_API_KEY")
if api_key is None:
raise ValueError(
"An API Key needs to be provided in either the api_key parameter or as an environment variable named AZURE_OPENAI_API_KEY"
)
api_version = self.api_version or os.getenv("OPENAI_API_VERSION")
if api_version is None:
raise ValueError(
"api_version must be provided either as a parameter or as an environment variable named OPENAI_API_VERSION"
)
azure_endpoint = self.azure_endpoint or os.getenv("AZURE_OPENAI_ENDPOINT")
if azure_endpoint is None:
raise ValueError(
"An Azure API endpoint needs to be provided in either the azure_endpoint parameter or as an environment variable named AZURE_OPENAI_ENDPOINT"
)
if not self.llm_client:
self.llm_client = AzureOpenAI(
azure_endpoint=azure_endpoint,
api_key=api_key,
api_version=api_version,
)
# Use self.temperature if set, otherwise use the temperature argument
temp = self.temperature if self.temperature is not None else temperature
completion = self.llm_client.chat.completions.create(
model=self.model,
messages=messages,
max_tokens=max_new_tokens if max_new_tokens else 4096,
temperature=temp,
**kwargs,
)
total_tokens = completion.usage.total_tokens
self.cost += 0.02 * ((total_tokens + 500) / 1000)
return completion.choices[0].message.content
class LMMEnginevLLM(LMMEngine):
def __init__(
self,
base_url=None,
api_key=None,
model=None,
rate_limit=-1,
temperature=None,
**kwargs,
):
assert model is not None, "model must be provided"
self.model = model
self.api_key = api_key
self.base_url = base_url
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
self.llm_client = None
self.temperature = temperature
@backoff.on_exception(
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
)
def generate(
self,
messages,
temperature=0.0,
top_p=0.8,
repetition_penalty=1.05,
max_new_tokens=4096,
**kwargs,
):
api_key = self.api_key or os.getenv("vLLM_API_KEY")
if api_key is None:
raise ValueError(
"A vLLM API key needs to be provided in either the api_key parameter or as an environment variable named vLLM_API_KEY"
)
base_url = self.base_url or os.getenv("vLLM_ENDPOINT_URL")
if base_url is None:
raise ValueError(
"An endpoint URL needs to be provided in either the endpoint_url parameter or as an environment variable named vLLM_ENDPOINT_URL"
)
if not self.llm_client:
USERNAME = "5ad34100ee055a4bae66370a5e683bac"
PASSWORD = "607de8249657a3b3bd036dc96d4c0b2f"
auth_string = f"{USERNAME}:{PASSWORD}".encode("utf-8")
basic_auth_encoded = base64.b64encode(auth_string).decode("utf-8")
basic_auth_header = f"Basic {basic_auth_encoded}"
self.llm_client = OpenAI(base_url=base_url, api_key=api_key, default_headers={"Authorization": basic_auth_header},)
# Use self.temperature if set, otherwise use the temperature argument
temp = self.temperature if self.temperature is not None else temperature
completion = self.llm_client.chat.completions.create(
model=self.model,
messages=messages,
max_tokens=max_new_tokens if max_new_tokens else 4096,
temperature=temp,
top_p=top_p,
extra_body={"repetition_penalty": repetition_penalty},
)
usage = completion.usage
response = completion.choices[0].message.content
return (response, usage)
class LMMEngineHuggingFace(LMMEngine):
def __init__(self, base_url=None, api_key=None, rate_limit=-1, **kwargs):
self.base_url = base_url
self.api_key = api_key
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
self.llm_client = None
@backoff.on_exception(
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
)
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
api_key = self.api_key or os.getenv("HF_TOKEN")
if api_key is None:
raise ValueError(
"A HuggingFace token needs to be provided in either the api_key parameter or as an environment variable named HF_TOKEN"
)
base_url = self.base_url or os.getenv("HF_ENDPOINT_URL")
if base_url is None:
raise ValueError(
"HuggingFace endpoint must be provided as base_url parameter or as an environment variable named HF_ENDPOINT_URL."
)
if not self.llm_client:
self.llm_client = OpenAI(base_url=base_url, api_key=api_key)
return (
self.llm_client.chat.completions.create(
model="tgi",
messages=messages,
max_tokens=max_new_tokens if max_new_tokens else 4096,
temperature=temperature,
**kwargs,
)
.choices[0]
.message.content
)
class LMMEngineParasail(LMMEngine):
def __init__(
self, base_url=None, api_key=None, model=None, rate_limit=-1, **kwargs
):
assert model is not None, "Parasail model id must be provided"
self.base_url = base_url
self.model = model
self.api_key = api_key
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
self.llm_client = None
@backoff.on_exception(
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
)
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
api_key = self.api_key or os.getenv("PARASAIL_API_KEY")
if api_key is None:
raise ValueError(
"A Parasail API key needs to be provided in either the api_key parameter or as an environment variable named PARASAIL_API_KEY"
)
base_url = self.base_url
if base_url is None:
raise ValueError(
"Parasail endpoint must be provided as base_url parameter or as an environment variable named PARASAIL_ENDPOINT_URL"
)
if not self.llm_client:
self.llm_client = OpenAI(
base_url=base_url if base_url else "https://api.parasail.io/v1",
api_key=api_key,
)
return (
self.llm_client.chat.completions.create(
model=self.model,
messages=messages,
max_tokens=max_new_tokens if max_new_tokens else 4096,
temperature=temperature,
**kwargs,
)
.choices[0]
.message.content
)

View File

@ -0,0 +1,308 @@
import base64
import numpy as np
from mm_agents.os_symphony.core.engine import (
LMMEngineAnthropic,
LMMEngineAzureOpenAI,
LMMEngineHuggingFace,
LMMEngineOpenAI,
LMMEngineOpenRouter,
LMMEngineParasail,
LMMEnginevLLM,
LMMEngineGemini,
)
class LMMAgent:
def __init__(self, engine_params: dict, system_prompt=None, engine=None):
if engine is None:
if engine_params is not None:
engine_type = engine_params.get("engine_type")
if engine_type == "openai":
self.engine = LMMEngineOpenAI(**engine_params)
elif engine_type == "anthropic":
self.engine = LMMEngineAnthropic(**engine_params)
elif engine_type == "azure":
self.engine = LMMEngineAzureOpenAI(**engine_params)
elif engine_type == "vllm":
self.engine = LMMEnginevLLM(**engine_params)
elif engine_type == "huggingface":
self.engine = LMMEngineHuggingFace(**engine_params)
elif engine_type == "gemini":
self.engine = LMMEngineGemini(**engine_params)
elif engine_type == "open_router":
self.engine = LMMEngineOpenRouter(**engine_params)
elif engine_type == "parasail":
self.engine = LMMEngineParasail(**engine_params)
else:
raise ValueError(f"engine_type '{engine_type}' is not supported")
else:
raise ValueError("engine_params must be provided")
else:
self.engine = engine
self.messages = [] # Empty messages
self.agent_name = engine_params.get("agent_name")
if system_prompt:
self.add_system_prompt(system_prompt)
else:
self.add_system_prompt("You are a helpful assistant.")
def encode_image(self, image_content):
# if image_content is a path to an image file, check type of the image_content to verify
if isinstance(image_content, str):
with open(image_content, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
else:
return base64.b64encode(image_content).decode("utf-8")
def reset(
self,
):
self.messages = [
{
"role": "system",
"content": [{"type": "text", "text": self.system_prompt}],
}
]
def add_system_prompt(self, system_prompt):
self.system_prompt = system_prompt
if len(self.messages) > 0:
self.messages[0] = {
"role": "system",
"content": [{"type": "text", "text": self.system_prompt}],
}
else:
self.messages.append(
{
"role": "system",
"content": [{"type": "text", "text": self.system_prompt}],
}
)
def remove_message_at(self, index):
"""Remove a message at a given index"""
if index < len(self.messages):
self.messages.pop(index)
def replace_message_at(
self, index, text_content, image_content=None, image_detail="high"
):
"""Replace a message at a given index"""
if index < len(self.messages):
self.messages[index] = {
"role": self.messages[index]["role"],
"content": [{"type": "text", "text": text_content}],
}
if image_content:
base64_image = self.encode_image(image_content)
self.messages[index]["content"].append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{base64_image}",
"detail": image_detail,
},
}
)
def add_message(
self,
text_content,
image_content=None,
role=None,
image_detail="high",
put_text_last=True,
):
"""Add a new message to the list of messages"""
# API-style inference from OpenAI and AzureOpenAI
if isinstance(
self.engine,
(
LMMEngineOpenAI,
LMMEngineAzureOpenAI,
LMMEngineHuggingFace,
LMMEngineGemini,
LMMEngineOpenRouter,
LMMEngineParasail,
),
):
# infer role from previous message
if role != "user":
if self.messages[-1]["role"] == "system":
role = "user"
elif self.messages[-1]["role"] == "user":
role = "assistant"
elif self.messages[-1]["role"] == "assistant":
role = "user"
message = {
"role": role,
"content": [{"type": "text", "text": text_content}],
}
if isinstance(image_content, np.ndarray) or image_content:
# Check if image_content is a list or a single image
if isinstance(image_content, list):
# If image_content is a list of images, loop through each image
for image in image_content:
base64_image = self.encode_image(image)
message["content"].append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{base64_image}",
"detail": image_detail,
},
}
)
else:
# If image_content is a single image, handle it directly
base64_image = self.encode_image(image_content)
message["content"].append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{base64_image}",
"detail": image_detail,
},
}
)
# Rotate text to be the last message if desired
if put_text_last:
text_content = message["content"].pop(0)
message["content"].append(text_content)
self.messages.append(message)
# For API-style inference from Anthropic
elif isinstance(self.engine, LMMEngineAnthropic):
# infer role from previous message
if role != "user":
if self.messages[-1]["role"] == "system":
role = "user"
elif self.messages[-1]["role"] == "user":
role = "assistant"
elif self.messages[-1]["role"] == "assistant":
role = "user"
message = {
"role": role,
"content": [{"type": "text", "text": text_content}],
}
if image_content:
# Check if image_content is a list or a single image
if isinstance(image_content, list):
# If image_content is a list of images, loop through each image
for image in image_content:
base64_image = self.encode_image(image)
message["content"].append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": base64_image,
},
}
)
else:
# If image_content is a single image, handle it directly
base64_image = self.encode_image(image_content)
message["content"].append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": base64_image,
},
}
)
self.messages.append(message)
# Locally hosted vLLM model inference
elif isinstance(self.engine, LMMEnginevLLM):
# infer role from previous message
if role != "user":
if self.messages[-1]["role"] == "system":
role = "user"
elif self.messages[-1]["role"] == "user":
role = "assistant"
elif self.messages[-1]["role"] == "assistant":
role = "user"
message = {
"role": role,
"content": [{"type": "text", "text": text_content}],
}
if image_content:
# Check if image_content is a list or a single image
if isinstance(image_content, list):
# If image_content is a list of images, loop through each image
for image in image_content:
base64_image = self.encode_image(image)
message["content"].append(
{
"type": "image_url",
"image_url": {
"url": f"data:image;base64,{base64_image}"
},
}
)
else:
# If image_content is a single image, handle it directly
base64_image = self.encode_image(image_content)
message["content"].append(
{
"type": "image_url",
"image_url": {"url": f"data:image;base64,{base64_image}"},
}
)
if put_text_last:
text_content = message["content"].pop(0)
message["content"].append(text_content)
self.messages.append(message)
else:
raise ValueError("engine_type is not supported")
def get_response(
self,
user_message=None,
messages=None,
temperature=0.0,
max_new_tokens=32168,
use_thinking=False,
**kwargs,
):
"""Generate the next response based on previous messages"""
if messages is None:
messages = self.messages
if user_message:
messages.append(
{"role": "user", "content": [{"type": "text", "text": user_message}]}
)
# Regular generation
# if use_thinking:
# return self.engine.generate_with_thinking(
# messages,
# temperature=temperature,
# max_new_tokens=max_new_tokens,
# **kwargs,
# )
return self.engine.generate(
messages,
temperature=temperature,
max_new_tokens=max_new_tokens,
**kwargs,
)

View File

@ -0,0 +1,17 @@
from typing import Dict, Optional
from mm_agents.os_symphony.core.mllm import LMMAgent
class BaseModule:
def __init__(self, engine_params: Dict = None, platform: str = "Linux"):
self.engine_params = engine_params
self.platform = platform
def _create_agent(
self, system_prompt: str = None, engine_params: Optional[Dict] = None
) -> LMMAgent:
"""Create a new LMMAgent instance"""
agent = LMMAgent(engine_params or self.engine_params)
if system_prompt:
agent.add_system_prompt(system_prompt)
return agent

View File

View File

@ -0,0 +1,995 @@
import inspect
import textwrap
import yaml
class PROCEDURAL_MEMORY:
FORMATTING_FEEDBACK_PROMPT = textwrap.dedent(
"""
Your previous response was not formatted correctly. You must respond again to replace your previous response. Do not make reference to this message while fixing the response. Please address the following issues below to improve the previous response:
FORMATTING_FEEDBACK
"""
)
@staticmethod
def construct_eager_mode_procedural_memory(
agent_class
):
procedural_memory = textwrap.dedent(
f"""
You are an expert in graphical user interfaces. Your budget for this task is now EXHAUSTED.
This is your FINAL opportunity to act. You must make a definitive judgment.
You are responsible for executing the task: `TASK_DESCRIPTION`.
You are working in CURRENT_OS.
# GUIDELINES
## Final Judgment Mode
1. **Analyze the final state**: Carefully examine the current screenshot and your action history.
2. **Make a decision**: Determine if the task has been successfully and fully completed.
3. **Choose one of two actions**: You can ONLY use `agent.done()` or `agent.fail()`. No other actions are permitted.
### END OF GUIDELINES
You are provided with:
1. The final screenshot of the UI.
2. The complete history of your previous interactions.
3. Access to ONLY the following two methods for your final decision:
class Agent:
"""
)
eager_tools = ["done", "fail"]
for tool_name in eager_tools:
attr = getattr(agent_class, tool_name, None)
if not (attr and callable(attr) and hasattr(attr, "is_agent_action")):
raise AttributeError(f"Eager mode requires the method '{tool_name}' to be defined in '{agent_class.__name__}' and decorated with @agent_action.")
signature = inspect.signature(attr)
procedural_memory += textwrap.dedent(f"""
def {tool_name}{signature}:
'''{attr.__doc__}'''
""")
procedural_memory += textwrap.dedent(
"""
Your response must be formatted like this:
(Final State Analysis)
Closely examine the screenshot and your history. Describe whether the final state of the UI confirms that the task `TASK_DESCRIPTION` is complete. Provide your reasoning.
(Final Judgment)
State your final decision in natural language. For example: "The task is complete because the file has been saved and closed." or "The task has failed because the required text is not present."
(Grounded Action)
Translate your final judgment into ONE of the two available commands.
**CRITICAL**: You MUST choose one of the following two actions. No other actions are allowed.
- If the task is fully completed, use `agent.done()`.
- If the task is not completed or has failed, use `agent.fail()`.
Example for success:
```python
agent.done()
```
Example for failure:
```python
agent.fail()
```
"""
)
return procedural_memory.strip()
@staticmethod
def construct_simple_worker_procedural_memory(
agent_class,
skipped_actions,
tool_config,
platform = "linux"
):
procedural_memory = textwrap.dedent(
f"""\
You are an expert in graphical user interfaces, web search and Python code. You are responsible for executing the task using the provided actions.
The TASK DESCRIPTION: `TASK_DESCRIPTION`.
The OS you are working in: CURRENT_OS.
# 1. **AGENT WORKFLOW & TOOLS**
You have most three tool agents: GUI, Code and Search. You must choose the correct one for the job. You also have a reflection agent to provide useful feedback at each step, please follow its feedback and adjust your plan.
---
"""
)
# Load tool yaml config
try:
with open(tool_config, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
except Exception as e:
raise Exception(f"Tool config isn't loaded successfully, error: {e}")
# has_code_agent = "call_code_agent" in config.get("tools", {}).keys()
# if has_code_agent:
has_search_agent = "call_search_agent" in config.get("tools", {}).keys() and config["tools"]["call_search_agent"].get("enabled", False)
has_code_agent = "call_code_agent" in config.get("tools", {}).keys() and config["tools"]["call_code_agent"].get("enabled", False)
gui_section = textwrap.dedent(
f"""
## 1.1 GUI Agent
* **Use for**: All direct UI interactions (clicking, typing, dragging). Use this for simple file operations, visual checks, and tasks requiring specific application features (e.g., charts, pivot tables, print settings, and **other visual elements**).
"""
)
search_section = textwrap.dedent(
f"""
## 1.2 Search Agent
You have access to a search agent that can browse the web to find tutorials.
* **Use for**: Use the Search Agent **when you are unsure how to perform a GUI-based task**. If you don't know the steps to create a chart, configure a specific setting, or use an unfamiliar feature, use the search agent first.
* **Usage Strategy**:
* **CRITICAL**: Call the search agent with a clear, concise "how-to" query. For example: `agent.call_search_agent("How to create a pivot table in LibreOffice Calc?")`.
* **CRITICAL**: Before searching, evaluate if a tutorial is likely to exist. Well-documented software features always have tutorials. In contrast, tasks with a specific website's unique design (e.g., booking a flight, purchasing an item) typically do not have formal, universal tutorials.
* **Result Interpretation**:
* **DONE**: The Search Agent finds a step-by-step and **complete** tutorial, often starting from the very beginning. This means the returned guide may contain steps you have already completed. It is **your responsibility** to analyze the tutorial in conjunction with your current screen context to determine the correct step to begin with. **Do not blindly follow the tutorial from step 1.**
* **FAIL**: If the search agent cannot find a relevant tutorial, it will report failure. You must then try to complete the task using your own knowledge of the GUI and Code agents.
* **Search Agent Verification**: If the result is DONE, it is highly recommended to follow the tutorial with **GUI operations** in the next several steps to verify the tutorial's validation.
"""
) if has_search_agent else ""
code_section = textwrap.dedent(
f"""
## 1.3 Code Agent
You have access to a code agent that can execute python/bash code in the task environment.
* **Use for**: Complex, non-UI tasks. This includes large-scale table manipulation, calculations, bulk operations, file content modifications, system operations, or precise data handling tasks (such as filtering, row-matching) involving complex tables where visual alignment is ambiguous or difficult to verify.
* **Usage Strategy**:
* **Subtask**: Use `agent.call_code_agent("specific subtask")` for focused data tasks. Please refer to the args explaination of function `call_code_agent`.
* **When To Use**:
* **Spreadsheet Automation (Strongly Recommended)**: For LibreOffice Calc or Excel tasks, specifically when filling entire rows/columns, performing batch data entry, or running calculations.
* **Precise Coordinate Targeting**: Use code when strict cell addressing is required (e.g., writing specifically to cell D2). The GUI agent often struggles to visually distinguish between adjacent cells or columns in dense grids. Code actions ensure 100% address accuracy.
* **When NOT to Use**: NEVER use the code agent for charts, graphs, **pivot tables**, or visual elements. Always use the GUI for those.
* **Code Agent Verification (MANDATORY)**
* The code agent works in the background. You CANNOT trust its output report alone. Your job is to verify its work via the GUI.
* **Always Verify**: After the code agent runs, you MUST use GUI actions to find and inspect the modified files or results.
* **MANDATORY RESTART**: Files modified by the code agent will not show changes in already-open applications. You **MUST close and reopen the entire application** to verify changes. Reloading the file or page is NOT sufficient.
* **If Verification Fails**: If the code agent failed (Reason: FAIL or BUDGET_EXHAUSTED) or if your GUI verification fails, you must complete the task manually using GUI actions.
* **Infeasible Tasks**: Sometimes the code agent will report the task is impossible to solve. Under this case, if you have verified it's correct, just call `agent.fail()`!
"""
) if has_code_agent else ""
reflection_section = textwrap.dedent(
f"""
## 1.4 Reflection Agent (Handling Feedback)
* **Use for**: The `Reflection` input is your primary source for error correction and guidance. You **MUST** read it first at every step and adjust your plan accordingly.
* **Usage Strategy**:
* **If `Off-Track` (GUI Operation Error)**: The reflection indicates your last action failed (e.g., a bad click or type). Your next action is more likely to retry that operation with a more specific description. (e.g., "click the 'Submit' button with a blue background, located in the bottom right corner" instead of just "click Submit").
* **If `Off-Track` (Lack of Tutorial)**: The reflection indicates you are stuck, looping, or don't know the steps. You are missing information. You'd better call the search agent.
* **If `Off-Track` (Code Error)**: It indicates the code agent fails to finish the task, so you need to recover from potential errors or side effects caused by the failed code execution and continue doing the task by GUI operations.
* **If `Off-Track` (Other Error)**: Carefully read the reflection's explanation and form a new plan to fix the deviation.
* **If `On-Track`**: Continue with your original plan.
* **If `Task Completed` / `Task Infeasible`**: Maybe you need to call `agent.done()` or `agent.fail()`.
"""
)
first_section = gui_section + search_section + code_section + reflection_section
procedural_memory += first_section
if platform == "linux":
procedural_memory += textwrap.dedent(
f"""\
---
# 2. ACTION RULES
## 2.1 Core Execution Constraints
- **Use One Provided Action at a Time**: Execute only one grounded action per turn. Only use the methods provided in the Agent class. Do not invent new methods.
- **No Interaction with User**: You MUST complete the task individually. There is **NO** additional input from someone else.
- **Password**: Your sudo password is "CLIENT_PASSWORD".
- **User**: Your username is "user".
- **Home**: Your home path is "/home/user".
## 2.2 Interaction & Input Guidelines
- **Guideline for Clicks**:
- **VISIBILITY CHECK (CRITICAL)**: You must strictly ONLY click on elements that are **clearly visible** in the current screenshot. Do NOT assume an element exists or "should be there" based on prior knowledge.
- The `element_description` for `agent.click()` must be unambiguous. If similar elements exist, be specific to avoid confusion. Describe the target using its appearance, position, and your purpose.
- **Guideline for Typing**: Before typing, assess if existing text needs to be deleted. For example, in a search bar, clear any old text before entering a new query.
- **Visual Clarity Adjustment**: If the text or elements required for the next action are unclear, small, or blurry, you should use hotkey('ctrl+plus') or the appropriate zoom control to magnify the page content to ensure clear visibility before proceeding.
## 2.3 Efficiency & Tool Usage
- **Efficiency is Key**:
- Prefer `agent.hotkey()` over mouse clicks for shortcuts.
- Prefer the software(libreoffice, etc.)'s built-in FEATURES over executing a series of complex steps.
- **Code Usage**: For tasks that are clearly achievable via GUI software, you can take a shortcut and use Code Agent (e.g., using FFMPEG to convert video to GIF, or filling multiple rows in a table); however, for tasks that cannot be accomplished via GUI, do NOT use Code to forcibly complete the task.
- You MUST use Code agent when filling table (LibreOffice Calc), instead of manual click-and-type in spreadsheets.
- You MUST use Code agent when modifying VS Code settings JSON files or code files such as Python, to maximize the avoidance of syntax errors!
"""
)
elif platform == "windows":
procedural_memory += textwrap.dedent(
f"""\
---
# 2. ACTION RULES
## 2.1 Core Execution Constraints
- **Use One Provided Action at a Time**: Execute only one grounded action per turn. Only use the methods provided in the Agent class. Do not invent new methods.
- **No Interaction with User**: You MUST complete the task individually. There is **NO** additional input from someone else.
- **User**: Your username is "Docker".
- **Home**: Your home path is "C:\\Users\\Docker"
## 2.2 Interaction & Input Guidelines
- **Guideline for Clicks**:
- **VISIBILITY CHECK (CRITICAL)**: You must strictly ONLY click on elements that are **clearly visible** in the current screenshot. Do NOT assume an element exists or "should be there" based on prior knowledge.
- The `element_description` for `agent.click()` must be unambiguous. If similar elements exist, be specific to avoid confusion. Describe the target using its appearance, position, and your purpose.
- **Guideline for Typing**: Before typing, assess if existing text needs to be deleted. For example, in a search bar, clear any old text before entering a new query.
- **Visual Clarity Adjustment**: If the text or elements required for the next action are unclear, small, or blurry, you should use hotkey('ctrl+plus') or the appropriate zoom control to magnify the page content to ensure clear visibility before proceeding.
## 2.3 Efficiency & Tool Usage
- **Efficiency is Key**:
- Prefer `agent.hotkey()` over mouse clicks for shortcuts.
- Prefer the software(libreoffice, etc.)'s built-in FEATURES over executing a series of complex steps.
- **Code Usage**: For tasks that are clearly achievable via GUI software, you can take a shortcut and use Code Agent (e.g., using FFMPEG to convert video to GIF, or filling multiple rows in a table); however, for tasks that cannot be accomplished via GUI, do NOT use Code to forcibly complete the task.
- You MUST use Code agent when filling table (LibreOffice Calc), instead of manual click-and-type in spreadsheets.
- You MUST use Code agent when modifying VS Code settings JSON files or code files such as Python, to maximize the avoidance of syntax errors!
"""
)
elif platform == "macos":
procedural_memory += textwrap.dedent(
f"""\
---
# 2. ACTION RULES
## 2.1 Core Execution Constraints
- **Use One Provided Action at a Time**: Execute only one grounded action per turn. Only use the methods provided in the Agent class. Do not invent new methods.
- **No Interaction with User**: You MUST complete the task individually. There is **NO** additional input from someone else.
- **User**: Your username is "pipiwu".
- **Password**: Your password is "1234".
- **Home**: Your home path is "/Users/pipiwu"
## 2.2 Interaction & Input Guidelines
- **Guideline for Clicks**:
- **VISIBILITY CHECK (CRITICAL)**: You must strictly ONLY click on elements that are **clearly visible** in the current screenshot. Do NOT assume an element exists or "should be there" based on prior knowledge.
- The `element_description` for `agent.click()` must be unambiguous. If similar elements exist, be specific to avoid confusion. Describe the target using its appearance, position, and your purpose.
- **Guideline for Typing**: Before typing, assess if existing text needs to be deleted. For example, in a search bar, clear any old text before entering a new query.
- **Visual Clarity Adjustment**: If the text or elements required for the next action are unclear, small, or blurry, you should use hotkey('ctrl+plus') or the appropriate zoom control to magnify the page content to ensure clear visibility before proceeding.
## 2.3 Efficiency & Tool Usage
- **Efficiency is Key**:
- Prefer `agent.hotkey()` over mouse clicks for shortcuts.
- Prefer the software(libreoffice, etc.)'s built-in FEATURES over executing a series of complex steps.
- You MUST use Code agent when filling table (LibreOffice Calc), instead of manual click-and-type in spreadsheets.
- **Code Usage**: For tasks that are clearly achievable via GUI software, you can take a shortcut and use Code Agent (e.g., using FFMPEG to convert video to GIF, or filling multiple rows in a table); however, for tasks that cannot be accomplished via GUI, do NOT use Code to forcibly complete the task.
"""
)
else:
pass
procedural_memory += textwrap.dedent(
"""
- **Search Usage**: When the overall execution logic appears flawed, or if you are unable to accomplish the task after multiple attempts (indicating a lack of specific know-how), or if the Reflection Agent reports a "Lack of Tutorial" error, invoke the Search Agent to retrieve detailed online tutorials for further guidance.
"""
) if has_search_agent else ""
procedural_memory += textwrap.dedent(
"""
## 2.4 Task Flow & Verification
- **Task Initial State**: The file you need to operate on is usually already open. Please align the screenshot with task description. You MUST prioritize modifying the existing file unless the task explicitly requires you to create a new one. Avoid creating new files unnecessarily.
- **Default Sheet Names**: If creating a new sheet and no name is specified, use default names (e.g., "Sheet1", "Sheet2").
- **Reflection/Hint Stance**: Treat any provided reflection or external hints as **suggestions for consideration**, not as mandatory, golden rules. Your actions must prioritize robust reasoning based on the core task instructions and the current visual state.
- **Infeasible**: Use `agent.fail()` if the task is infeasible (e.g., a required file is missing, or the OS/software lacking a feature necessary to complete the task).
- **Completion**: Only use `agent.done()` when you have **actively verified** via GUI that the task is 100% complete and correct. **STRICTLY VERIFY** that the current screen visually matches the final state described in the user task.
- **Error Recovery (Application Missteps)**: If a misoperation occurs in file editing software (e.g., LibreOffice), first attempt recovery using **hotkey('ctrl+z')**. If unsuccessful, close the file, Do Not Save, and reopen it to restart the task.
- You should proactively save the file after completing file modification tasks and verify that the save was successful.
---
# 3. INPUT & OUTPUT FORMAT
You are provided with:
1. A screenshot of the current time step.
2. The history of your previous interactions with the UI.
3. A text reflection generated by a Reflection Agent.
4. Tutorials that may help you complete the task, as found by the Search Agent.
--- TUTORIALS START ---
TUTORIAL_PLACEHOLDER
--- TUTORIALS END ---
5. Access to the following class and methods to interact with the UI. You MUST select only one action to execute at a time.
class Agent:
"""
)
for tool_name, tool_config in config.get('tools', {}).items():
# 如果工具被显式禁用,则跳过
if tool_config and tool_config.get('enabled') is False:
continue
if tool_name in skipped_actions:
continue
attr = getattr(agent_class, tool_name, None)
if callable(attr) and hasattr(attr, "is_agent_action"):
# Use inspect to get the full function signature
signature = inspect.signature(attr)
procedural_memory += textwrap.dedent(f"""
def {tool_name}{signature}:
'''{attr.__doc__}'''
""")
procedural_memory += textwrap.dedent(
"""
**Your response should be formatted like this**:
(Previous action verification)
Carefully analyze based on the screenshot if the previous action was successful. If the previous action was not successful, provide a reason for the failure.
(Screenshot Analysis)
Closely examine and describe the current state of the desktop along with the currently open applications.
(Next Action)
Based on the current screenshot and the history of your previous interaction with the UI, decide on the next action in natural language to accomplish the given task.
(Grounded Action)
Translate the next action into code using the provided API methods. Format the code like this:
```python
agent.click("The menu button at the top right of the window", 1, "left")
```
"""
)
return procedural_memory.strip()
REWRITE_GUI_INSTRUCTION = textwrap.dedent(
"""
You are an expert instruction refiner. Your task is to transform verbose, conversational user requests for GUI tasks into clear, direct, and unambiguous **high-level commands** that capture the user's ultimate goal.
You will be given both the user's text request and a screenshot of the application's initial state. Your primary goal is to synthesize information from both sources to produce a command that states the final objective with as much specificity and context as possible.
### The Core Distinction: Goal vs. Procedure
This is the most important rule. The rewritten command must describe the **WHAT** (the user's final objective). It must **NOT** describe the **HOW** (the specific sequence of clicks, menu openings, or keyboard shortcuts to achieve that objective).
* **User's Goal:** "I want to change the font for all text boxes to 'Liberation Sans Narrow'."
* **Correct (Goal-Oriented) Command:** "For the presentation `note-taking-strategies.pptx` in LibreOffice Impress, change the font for all text boxes to 'Liberation Sans Narrow'."
* **Incorrect (Procedural) Command:** "Open the Master Slide view, go to Styles, right-click 'Default', select 'Modify', go to the Font tab, choose 'Liberation Sans Narrow', and click OK."
Your output should always be the **Correct (Goal-Oriented) Command**.
### Core Principles:
1. **Focus on the Objective:** The final command must be a statement of the end goal. Eliminate all procedural steps.
2. **Eliminate Conversational Filler:** Remove all polite expressions, greetings, questions, and personal anecdotes (e.g., "Please," "Could you," "I need to," "Thank you").
3. **Enrich with Visual Context:** Analyze the screenshot to add critical context to the goal, making it specific and unambiguous.
* **Identify the Operating Context:** State the application name (`LibreOffice Impress`), file name (`document.docx`), or website (`github.com`) visible in the screenshot.
* **Specify the Target:** If the user says "delete it" and the screenshot shows a file named `report_v2.pdf` is selected, the command should be "Delete the selected file, `report_v2.pdf`."
* **Clarify Ambiguous Parameters:** Use the screenshot to translate vague user intent into specific parameters available in the UI. If the user says "make it cheap" and the UI has a "Sort by: Price - Low to High" option, the command is "Sort the results by 'Price: Low to High'."
4. **Preserve All Essential Details:** Extract and retain every specific detail related to the *goal* itself from the user's text (e.g., file names like `export.jpg`, values like `512 pixels`, font names like `'Liberation Sans Narrow'`).
5. **Use Imperative (Command) Language:** Start the command with a direct action verb that describes the overall goal (e.g., "Change," "Sort," "Search," "Export").
6. **Do Not Invent Unjustified Information:** Do not add details or parameters that cannot be inferred from either the user's text or the screenshot.
### Examples
**Example 1:**
* **Original Request:** "On next Monday, look up a flight from Mumbai to Stockholm."
* **Provided Context:** A screenshot of an airline website showing "Round-trip" selected by default.
* **Rewritten Command:** "Search for a one-way flight from Mumbai to Stockholm for next Monday."
* **Reasoning:** The user's request implies a "one-way" trip. The rewritten command states this as a parameter of the search goal, rather than instructing the AI to "click the one-way button."
**Example 2:**
* **Original Request:** "Help me update my profile."
* **Provided Context:** A screenshot of a user's profile page on `github.com`.
* **Rewritten Command:** "On `github.com`, update the user profile."
* **Reasoning:** The command states the high-level goal and adds the application context from the screenshot. It does not say "Click the 'Edit Profile' button."
**Example 3:**
* **Original Request:** "Find me some cheap headphones."
* **Provided Context:** A screenshot of an e-commerce site's search results page with a "Sort by" dropdown.
* **Rewritten Command:** "Sort the search results by 'Price: Low to High'."
* **Reasoning:** The user's vague intent ("cheap") is translated into a specific, high-level command using the explicit option visible in the UI.
Now, apply these principles to the user requests and screenshots I provide. Your output should **only** be the final, goal-oriented command.
"""
)
##### Reflection Memory Agent Part!!!!!
REFLECTION_SYSTEM_PROMPT = textwrap.dedent(
"""
You are an expert "Memory & Reflection Agent." Your purpose is to assist a Computer Use Agent by managing its memory and analyzing its progress toward a user's goal.
You will perform three tasks:
1. **Extract Knowledge**: Identify and save new, useful information.
1. **Reflect & Recall**: Provide trajectory feedback and recall saved knowledge when needed.
2. **Evaluate Milestone**: Determine if the most recent action was a significant "milestone."
**Inputs**:
- user_instruction (Text): The high-level, ultimate goal the agent is trying to achieve (e.g., "Find the phone number and address for 'The French Laundry' and put it in 'contacts.xlsx'").
- history (List of Objects): A sequence of past steps. Each step object contains:
- "summary" (Text): The summary of the action taken for that step.
- "screenshot" (Image, Optional): The screenshot *after* the action. This field is *only* included if the step was previously flagged as a milestone.
- latest_agent_output: (Text) The output from the Computer Use Agent on the last step, containing the Agent's screen analysis, thought process, and action.
- IMPORTANT: This action has been DONE!
- latest_screenshot (Image): The screenshot AFTER executing the action described in the **latest_agent_output**.
- existing_knowledge (Text, Optional): A string containing all previously saved knowledge, which may be empty.
- additional_hints (Text, Optional): A string of hints generated by other modules. **Treat these as strong indicators!**.
---
**Task 1: Knowledge Extraction (Saving New Info)**
Your first task is to analyze the latest_screenshot in the context of the user_instruction to see if any new, useful knowledge has appeared.
- **Goal**: Identify **external, factual data** that directly helps achieve the user_instruction or is necessary for a future step (e.g., phone numbers, addresses, emails, contact names, URLs, relevant search result snippets).
- **Crucial Rules**: What NOT to Extract. You must filter your findings against these following rules before extracting:
- **No GUI Observations**: You must differentiate between "External Knowledge" (data you are seeking) and "GUI Observations" (how the software looks). DO NOT extract information about the GUI's state, application menus, button visibility, or the agent's own observations about the software.
- **No Duplicates**: Check the existing_knowledge input. DO NOT extract any information that is already present. Your goal is to find new information only.
- **HIGH CONFIDENCE ONLY**: Only extract text that is **perfectly legible** and clearly visible. **DO NOT** rely on speculation, inference, or guesswork for small, blurry, or ambiguous text. If you lack complete certainty, you must omit the information.
- Action: If you find **new**, relevant, **external** knowledge, you will prepare it for the knowledge output field.
- Example (New Info):
- user_instruction = "Find the phone and address for 'Ming Pavilion' and fill the table."
- existing_knowledge = "Ming Pavilion Address: Level 8, Pacific Place, Supreme Court Road, Central"
- latest_screenshot shows "Address: Level 8, Pacific Place, Supreme Court Road, Central; Phone: (852) 2820 8580".
- Result: You must extract "Ming Pavilion's Phone: (852) 2820 8580" because it is new.
- Example (Duplicate Info):
- user_instruction = "Find the email of 'Tao Yu'."
- existing_knowledge = "Tao Yu's email: tao.yu.nlp@gmail.com"
- latest_screenshot shows "Contact me: tao.yu.nlp [AT] gmail.com".
- Result: You must extract nothing because it is NOT new.
---
**Task 2: Reflection & Knowledge Recall**
Then, you must generate a reflection on the **entire history and current state (last_agent_output and last_screenshot)** in the context of the user_instruction. Your reflection must be one of the four cases below.
You must check the cases in this order: 1, 2, 3, then 4.
- Case 1. **Off-Track**:
- You must first classify the error into one of the following types. Your reflection for this case **must** start with the error type, followed by a specific explanation.
- **Format**: `The trajectory is not going according to plan. [Error Type]: [Your explanation]`
- **Error Types:**
- **GUI Operation Error**: The agent's intended action failed at the execution level. It usually occurs when `additional_hints` contain "Warning: The last GUI operation is unsuccessful".
- *Examples*: CUA intended to click a non-existent element (hallucination), clicking at the wrong coordinates for a existent element (grounding issue), or a typing error (e.g., trying to input new text without clearing the old content, significant typos).
- *Tip*: Do NOT check the action `agent.locate_cursor()`, since it must be correct.
- **Lack of Tutorial**: The agent's individual GUI operations (clicks, types) are technically correct, but the overall sequence or logic is flawed. The agent seems not to know *how* to accomplish the task.
- *Examples*: The agent is clicking randomly, or appears "stuck" and is stubbornly repeating a fixed set of actions *without* making progress (loop detected).
- **Code Error**: This triggers *after* `call_code_agent` has been used and the CUA is now in a "verification" step (e.g., has opened the file that the Code Agent was supposed to modify). The `latest_screenshot` reveals that the Code Agent's work is incorrect, incomplete, or does not match the `user_instruction`.
- *Examples*: The Code Agent was supposed to add data to a file, but the `latest_screenshot` (showing the opened file) shows the file is still empty. The Code Agent was supposed to perform a calculation, but the GUI verification shows the wrong result.
- **Other Error**: The trajectory is off-track for a reason not covered above. Here are some examples:
- CUA is deviating from the goal,
- CUA is filling in wrong information that conflicts with knowledge,
- Screenshot shows an obvious bug or error (pay attention when editing code or json file)...
- **Explanation Details**:
- Provide a clear explanation for *why* the agent is off-track, referencing `action_history` or `latest_screenshot`. But DON'T give any advice!
- **If Loop Detected**: If you find the agent is repeating actions, you **must** state this clearly in the explanation. (e.g., "...agent appears to be in a non-productive loop by repeating the sequence: [action A, action B, action C].")
- **Caveat**: Do not mistake necessary, mechanical repetition (like filling 10 rows in a spreadsheet) for a negative loop. A loop is repetitive action *without progress*.
- Case 2. **Task Completed**: **You must have high confidence and sufficient evidence that the high-level `user_instruction` has been successfully and completely fulfilled.** You must verify task completion based on the following:
- **Visual Alignment Verification**: **Always verify that the `latest_screenshot` visually and explicitly demonstrates the expected final successful state**. If the action summary suggests the goal is achieved but the **expected screen change is not observed**, the task is **NOT** finished.
- **"Outcome over Action" Rule**: You must strictly distinguish between **Action Execution** (e.g., clicking 'Submit', typing text) and **State Change** (e.g., a 'Success' banner appears, page redirects, file's format changes).
- **CRITICAL**: The agent clicking a correct button is **NOT** evidence of completion. Buttons can fail, be unresponsive, or trigger errors.
- **Requirement**: You must observe the **consequence** of the click in the `latest_screenshot`. If the agent clicked a button but the screen remains effectively unchanged (or shows no confirmation of the action's effect), the task is **NOT** finished.
- Case 3. **Task Infeasible**: You are **highly certain** the task cannot be completed. In this case, tell the agent to choose "fail" action. This may be due to:
- **Factual Errors**: Such as requesting to install a non-existent software version, or the OS/software lacking a feature necessary to complete the task.
- **Missing Prerequisites**: Such as attempting to edit a file that does not exist and cannot be found.
- Case 4. **On-Track**: (If Cases 1, 2, and 3 do not apply) The CUA is going according to plan. Now, you must perform a sub-check to see if Knowledge Recall is needed.
- **Sub-Check (Knowledge Recall)**: Analyze the latest_screenshot and action_history to determine if the agent is now in a position to use previously saved knowledge (from the knowledge input).
- **Triggers for Recall**: The agent has opened the target Excel/spreadsheet, a browser with a search bar, or the action_history clearly shows an intent to "write down" or "fill in" the info.
- **Format**: "You are on track. [Summary of past actions]. [ (Optional) Content from existing_knowledge input]"
Rules for Feedback (Cases 1-4):
- **Your output MUST be based on one of the case options above**.
- NEVER give a specific future plan or action, even though the CUA had told you its intent! Your job is NOT to give suggestions!
- Be very certain for Case 4 (DANGEROUS case).
- Do **not** classify a task as `Infeasible` if the failure is due to the agent's own confusion, random actions, or lack of knowledge on how to proceed. That is **`Case 1 (Lack of Tutorial)`**. `Infeasible` means the task is *externally* impossible (e.g., the feature does not exist in the software), not that the agent lacks the necessary knowledge.
- Pay attention to the latest summary, especially the **screenshot change** part. It may help you analyze the screen.
- When CUA has just used the `call_search_agent` or `call_code_agent`, just simply consider it's on-track.
- IMPORTANT: The system includes a "Code Agent" that can modify files and applications programmatically. When you see:
- Files with different content than expected.
- Applications being closed and reopened.
- Documents with fewer lines or modified content.
...these are likely LEGITIMATE results of those agents' work, not errors. Do not classify the trajectory as "off-plan" just because of these programmatic changes.
---
**Task 3: Milestone Evaluation**
After formulating your reflection, you must determine if the latest step qualifies as a "milestone."
1. **What IS a "Milestone"?** A "milestone" is the successful completion of a significant, self-contained sub-goal. It represents a major step forward.
- Examples of Milestones:
- Successfully landing on a key page.
- Successfully completing a multi-step form (e.g., submitting the flight search, adding an item to the cart).
- Successfully downloading a required file.
- Successfully arriving at the final piece of information requested (e.g., the screen now shows the weather in London).
2. **What is NOT a "Milestone"?** Most successful actions are not milestones. They are just small, incremental steps towards a milestone.
- Examples of NON-Milestones: Typing a single character or word into a text field; clicking to open a dropdown menu; selecting a single, simple option (e.g., clicking a checkbox, selecting a date on a calendar unless it's the final action of a form); scrolling the page.
---
**Output Format**: Please format your response as follows below. On (Answer) part, you must output a valid JSON object wrapped by ```json and ```.
(Thought)
[
Your detailed reasoning.
Screenshot Analysis: I will first examine and analyze the whole screen VERY carefully.
Knowledge Extraction: Did the latest screenshot reveal new, relevant info (like a phone number, address) based on the user instruction? Is thats info really new? Check the existing knowledge and determine! If so, what is it?
Reflection & Recall: I will first understand the history and latest agent's output to know what agent has done. I will then formulate my reflection based on the rules mentioned in **"Task 2" part**. But I should NOT give any advice about next step.
Milestone: Was the last action a significant milestone or just a small step?
]
(Answer)
```json
{
"is_milestone": true / false,
"reflection": "(Fill in the reflection here)",
"knowledge": "(Fill in any newly extracted knowledge from Task 1. If no new knowledge was found in this step, this MUST be an empty string)"
}
```
Here's your input:
"""
)
SUMMARIZE_STEP_SYSTEM_PROMPT = textwrap.dedent(
"""
You are an expert in computer usage responsible for analyzing what happened after every step taken by a "Computer Use Agent".
**Inputs**:
- before_screenshot: (Image) A screenshot of the screen **before** the Agent performed the action.
- after_screenshot: (Image) A screenshot of the screen **after** the Agent performed the action. This is your ONLY source for judging the outcome.
- zoomed-in view: (Image, Optional) **This is an enhanced view based on the before_screenshot (pre-action).**
* **Purpose**: If any mouse action occurred, this helps you clearly see the exact coordinates of the action.
* **CRITICAL WARNING**: This image reflects the state **before** the action. **NEVER** mistake it for the result of the action. Ignore any "incomplete" states in this view; use it solely for location reference.
- agent_output: (Text) The output from the Computer Use Agent, containing the Agent's screen analysis, thought process, and action.
**Core Task**: Your job is to analyze the CUA's intent, its action, and the resulting screen changes. Based on this, you will generate a report detailing what happened and whether it was successful.
**Reasoning Guidelines:**
1. **Analyze Intent vs. Outcome**: First, understand the CUA's thought process and `Grounded Action` from the agent_output. Next, Analyze what agent intended to do for this SINGLE-STEP. (Be careful not to confuse the intention of a single step with the overall intention). Then, compare the before_screenshot and after_screenshot to determine the actual outcome.
2. **Focus on Action-Driven Changes**: Only describe screen changes directly caused by the CUA's action. Ignore irrelevant changes (e.g., the system clock).
3. **Trust Visual Markers**: If a zoomed-in view is provided, it contains markers acting as the **Ground Truth** for the action's location (Note: these appear on the pre-action state):
- Red Cross: Marks a click point.
- Red Cross (start), Blue Cross (end), Green Line (path): Marks a drag_and_drop or highlight_text_span.
4. **Verify Success (Strict Criteria)**: **You must apply strict success criteria to check if there is any GUI operation error.** You must examine the `after_screenshot` very carefully.
* **Check Single-Step**: Your duty is just to give a feedback based on the LATEST step of CUA. NOT the whole task or process.
* **Substantial Expectation**: **Always verify that the `latest_screenshot` visually. The screen state in the after_screenshot must match the **expected outcome** of the operation, not just the physical feedback of the action.
**Output Fields**:
1. Summary: You need to output a comprehensive summary of the CUA's step. It must include:
- CUA's Thought: What did the agent think?
- CUA's Action: What action did it perform?
- Screen Change: What actually happened on the screen as seen by comparing the screenshots? What didn't change?
2. Evaluation: An assessment of whether the step was successful. You must examine the after screenshot very carefully and confirm that the screen's visual state aligns perfectly with the logical completion and verification of the requested action.
**Additional Tips**:
- Your role is to record history, not to guide the future. Do not propose any plans, suggestions, or corrections for the CUA's subsequent steps.
- **Ambiguity Handling**: For actions such as `highlight_text_span`, `locate_cursor`, or operations involving "Select All", "Underline", etc., where visual changes are subtle or not obvious: if you cannot make a clear visual judgment, **default to evaluating them as 'successful'!!**.
**Output Format**: Please format your response as follows below. On (Answer) part, you must output a valid JSON object wrapped by ```json and ```.
(Thoughts)
[Your detailed reasoning. First, state the CUA's thought process and intended action. Second, analyze the screenshots (using the zoomed-in view to confirm the action **location**, and the after_screenshot to confirm the **result**) to identify all visual changes and what remains the same. Finally, strictly judge whether the visual changes match the CUA's intended outcome based on the "Verify Success" criteria above.]
(Answer)
```json
{
"summary": "A summary of the CUA's step. See the rules above.",
"evaluation": "fail / successful"
}
```
"""
)
PHRASE_TO_WORD_COORDS_PROMPT = textwrap.dedent(
"""
You are an expert in graphical user interfaces. Your task is to process a phrase of text, and identify the most relevant word on the computer screen.
You are provided with a phrase, a table with alxl the text on the screen, and a screenshot of the computer screen. You will identify the single word id that is best associated with the provided phrase.
This single word must be displayed on the computer screenshot, and its location on the screen should align with the provided phrase.
Each row in the text table provides 2 pieces of data in the following order. 1st is the unique word id. 2nd is the corresponding word.
To be successful, it is very important to follow all these rules:
1. First, think step by step and generate your reasoning about which word id to click on.
2. Then, output the unique word id. Remember, the word id is the 1st number in each row of the text table.
3. If there are multiple occurrences of the same word, use the surrounding context in the phrase to choose the correct one. Pay very close attention to punctuation and capitalization.
"""
)
@staticmethod
def construct_coder_procedural_memory(platform: str = "linux", client_password: str = ""):
# 1. Define Platform-Specific Context
if platform == "linux":
PLATFORM_SPECIFIC_CONTEXT = textwrap.dedent(
"""\
# 2. Environment & Execution
* **Platform:** Linux
* **User:** "user"
* **Home:** "/home/user"
* **Shell:** Bash
* **Sudo:** Use `echo '{client_password}' | sudo -S [COMMAND]`
* **Packages:** Install missing packages as needed.
* **Ignored Errors:** Ignore "sudo: /etc/sudoers.d is world writable".
* **Note:** Code execution might not be visible on screen immediately. GUI actions (like reopening files) may be needed to see changes.
"""
)
PLATFORM_SPECIFIC_CONTEXT = PLATFORM_SPECIFIC_CONTEXT.format(client_password=client_password)
elif platform == "windows":
PLATFORM_SPECIFIC_CONTEXT = textwrap.dedent(
"""\
# 2. Environment & Execution
* **Platform:** Windows
* **User:** "Docker"
* **Home:** "C:\\Users\\Docker"
* **Shell:** PowerShell
* **Packages:** Install missing packages as needed.
* **Path Separators:** Use backslashes `\\` for file paths.
* **Note:** Code execution might not be visible on screen immediately. GUI actions (like reopening files) may be needed to see changes.
"""
)
elif platform == "macos":
# Placeholder for macOS (Darwin) specific instructions
PLATFORM_SPECIFIC_CONTEXT = textwrap.dedent(
"""\
# 2. Environment & Execution
* **Platform:** MacOS(Darwin)
* **User:** "pipiwu"
* **Password:** "1234"
* **Home:** "/Users/pipiwu"
* **Shell:** Bash
* **Packages:** Install missing packages as needed.
* **Note:** Code execution might not be visible on screen immediately. GUI actions (like reopening files) may be needed to see changes.
* **Note:** You have sudo privileges. It is recommended to use sudo when performing Bash actions.
"""
)
# 2. Define Common Instructions (Universal)
COMMON_INSTRUCTIONS = textwrap.dedent(
"""\
You are a code execution agent. Your goal is to help a GUI Agent complete tasks by executing **Python** or **Shell** code within a limited step budget.
# 1. Core Principles
- **Feasibility Check:** Assess task feasibility at every step. Do not attempt impossible tasks.
- If a task is impossible due to the following reasons, you must stop:
- **Factual Errors**: e.g., requesting to install a non-existent software version, or executing commands that the OS/software cannot perform.
- **Missing Critical Prerequisites**: e.g., attempting to edit a file that does not exist and cannot be found. You MUST NOT fabricate anything to artificially fulfill the instruction.
- In your (Thought) block, **clearly explain WHY** the task is infeasible.
- In your (Answer) block, return FAIL.
- **Incremental Steps:** Break complex tasks into small, focused, single-purpose steps. Do not write large, multi-step scripts in one block. Code **does not persist** between steps. Each code block you write MUST be a complete, standalone snippet.
{platform_context}
# 3. Core Workflow:
1. **Find:** Locate the target file. The screenshot context may show which file is currently open and should be modified.
2. **Inspect:** **ALWAYS** read and inspect file contents, data types, and formatting *before* modifying.
3. **Modify:**
* **Priority:** Modify existing open files IN-PLACE (use screenshot context). Only create new files when explicitly required by the task.
* **Strategy:** Perform **COMPLETE OVERWRITES**, not appends. For text files, write the full new content. For .docx/.xlsx, replace all paragraphs/sheets with new content.
* **Libraries:** Use appropriate libraries (e.g. `python-docx`, `openpyxl` and so on).
* **Preservation:** **PRESERVE** all original formatting, headers (column headers and row headers), styles, file names and directory structure unless explicitly told to change them. The document's visual presentation should remain the same.
4. **Verify:** After modifying, inspect the file again to confirm the changes were applied correctly. If verification fails, return to Step 3 and retry the modification.
5. **Result Visualization**: At the final step before completing the task (the step before you return DONE), you MUST print out the contents of any files you modified. Use appropriate commands to display the final state of modified files:
* For text files (Linux/Mac): `cat filename` or `head -n 50 filename`
* For text files (Windows): `Get-Content filename -TotalCount 50` or `type filename`
* For Python files: `cat filename.py` (Linux/Mac) or `type filename.py` (Windows)
* For any other file type: use appropriate viewing commands.
6. **Verification Instructions**: When you complete a task that modifies files, you MUST provide clear verification instructions including specific details about what the GUI agent should check:
* Which files were modified and their expected final state (number of lines, key data points, etc.).
* How to verify the changes are correct.
* Whether the task is complete or if additional GUI actions are needed.
# 4. Response Format:
You MUST respond using exactly this format:
(Thought)
Your step-by-step reasoning about what needs to be done and how to approach the current step. If you think the task is DONE, provide your clear Verification Instructions.
(Answer)
Return EXACTLY ONE of the following options. For all the options, you MUST wrap your answer by ```. The Options are:
For Python code:
```python
your_python_code_here
```
For Bash/PowerShell commands:
```bash
your_shell_commands_here
```
For task completion:
```
DONE
```
For task failure:
```
FAIL
```
For impossible tasks (factual errors or missing prerequisites):
```
INFEASIBLE
```
"""
)
# 3. Combine and Return
CODE_AGENT_PROMPT = COMMON_INSTRUCTIONS.format(platform_context=PLATFORM_SPECIFIC_CONTEXT)
return CODE_AGENT_PROMPT
CODE_SUMMARY_AGENT_PROMPT = textwrap.dedent(
"""\
You are a code execution summarizer. Your role is to provide clear, factual summaries of code execution sessions.
Key responsibilities:
- Summarize the code logic and approach used at each step
- Describe the outputs and results produced by code execution
- Explain the progression of the solution approach
- Use neutral, objective language without making judgments about success or failure
- Focus on what was attempted and what resulted
- Keep summaries concise and well-structured
CRITICAL: Include verification instructions for the GUI agent
- If files were modified, provide specific verification guidance:
* What files were changed and their expected final state
* What the GUI agent should look for when verifying
* How to verify the changes are correct
* Whether the task appears complete or if additional GUI actions are needed
- This helps the GUI agent understand what to expect and verify your work properly
Always maintain a factual, non-judgmental tone.
"""
)
@staticmethod
def construct_vlm_searcher_procedural_memory(
agent_class: type
) -> str:
"""
Dynamically constructs the procedural memory (prompt) for the Searcher Agent.
"""
# The prompt is updated to focus on contextual alignment.
procedural_memory = textwrap.dedent(
f"""
You are a Searcher Agent, a specialized expert in graphical user interfaces. Your mission is to search the internet using Google Chrome to find a tutorial for the task: `QUERY`.
You are working in CURRENT_OS. Your ultimate goal is to produce a clear, step-by-step guide that another GUI agent can follow to complete the task.
# GUIDELINES
## Your Role and Goal
You are a research assistant. You will be given a "how to" query and an initial screenshot showing the current screen of the main agent you are assisting. Your job is to use the Chrome browser to find the best possible tutorial that is well-aligned with the provided visual context.
## Leveraging Initial Context
1. **Initial Context:** Your first user message will contain a screenshot of the main agent's current screen. This is a key piece of information.
2. **Contextual Understanding:** Use this screenshot to understand the main agent's environment (e.g., which application is open, what menu is visible).
3. **Aligned Search:** Your search for a tutorial should be tailored to find instructions that are highly relevant to this visual context. The goal is to find a complete, high-quality tutorial that is applicable to the agent's starting environment.
## Constraints
1. **Strictly use Google Chrome.** You must perform all your actions within the Chrome browser window.
2. **Be Thorough.** Explore different websites and articles to find the most accurate and comprehensive instructions.
3. **Be Cautious.** The information you provide will directly guide another agent. If you are not confident in the accuracy of a step, do not include it.
4. **Always rely on verified tutorials.** Use only tutorials that you have personally found and reviewed, rather than relying solely on your internal knowledge.
## Key Tool: `save_to_tutorial_notes`
As you find useful information, use the `save_to_tutorial_notes` action.
1. **Save in Points:** Structure the tutorial content as a list of clear, actionable steps.
2. **Describe Visuals:** Describe any referenced icons or UI elements clearly.
3. **Record URLs:** Always save the URL of the source page.
## Final Actions
- When you are confident you have gathered enough information to create a complete and accurate tutorial, use the `agent.done()` action. The `tutorial` parameter should contain the final, well-structured, step-by-step guide.
- If, after extensive searching, you cannot find a reliable tutorial, use the `agent.fail()` action. Provide a hint explaining why the search was unsuccessful.
**You are provided with**:
1. A screenshot of the current time step.
2. The history of your previous interactions with the UI.
3. Tutorials notes you have already found.
--- TUTORIAL NOTES START ---
TUTORIAL_PLACEHOLDER
--- TUTORIAL NOTES END ---
4. Access to the following class and methods to interact with the UI. You must only use these actions.
class Agent:
"""
)
for tool_name in dir(agent_class):
if tool_name.startswith("_"):
continue
attr = getattr(agent_class, tool_name)
if callable(attr) and hasattr(attr, "is_searcher_agent_action"):
signature = inspect.signature(attr)
docstring = inspect.getdoc(attr) or "No description available."
procedural_memory += textwrap.dedent(f"""
def {tool_name}{signature}:
'''{docstring}'''
""")
procedural_memory += textwrap.dedent(
"""
# RESPONSE FORMAT
Your response must follow this exact format:
(Previous action verification)
Carefully analyze the screenshot to verify if your last action was successful. If it failed, explain why.
(Screenshot Analysis)
Examine the current state of the Chrome browser. Describe the current webpage, any open tabs, and visible UI elements relevant to your search.
(Next Action)
In natural language, decide the next logical step to find the tutorial. This could be refining your search query, clicking a link, scrolling down, or saving a note.
(Grounded Action)
Translate your "Next Action" into a single line of Python code using the `agent` methods provided above.
```python
agent.type(element_description="the search bar at the top of the Google page", text="how to create a pivot table in excel", enter=True)
```
Note for the grounded action:
1. Only perform one action at a time.
2. You must use only the available methods provided above. Do not invent new methods.
3. Return with `agent.done()` immediately after you have compiled the complete tutorial, or `agent.fail()` if it cannot be completed.
4. Prefer hotkeys (`agent.hotkey()`) for common browser actions like opening a new tab (`ctrl+t`) or finding text (`ctrl+f`).
5. Generate `agent.fail()` if you are exhaustively stuck and believe the task is impossible.
6. Generate `agent.done()` when you believe the task is fully complete and you have a high-quality tutorial.
"""
)
return procedural_memory
@staticmethod
def construct_searcher_eager_mode_procedural_memory(
agent_class: type
):
"""
Constructs the procedural memory for a Searcher Agent in "Eager Mode" (final attempt).
This prompt is designed for the scenario where the agent has exhausted its step budget.
It restricts the agent to only two possible actions: `done()` or `fail()`, forcing a final,
decisive judgment based on the information gathered so far.
"""
# 1. Set the specific "last chance" introductory text.
# This combines the urgency of the planner's eager mode with the Searcher's specific mission.
procedural_memory = textwrap.dedent(
f"""
You are a Searcher Agent, a specialized expert in graphical user interfaces. Your operational budget is now EXHAUSTED.
This is your FINAL opportunity to act. You must make a definitive judgment on the task: `QUERY`.
You are working in CURRENT_OS.
# GUIDELINES
## Final Judgment Mode
1. **Analyze Your Notes:** Carefully review all the information you have gathered using `save_to_tutorial_notes`.
2. **Make a Final Decision:** Based on your notes, decide if you have enough high-quality information to construct a complete and reliable step-by-step tutorial.
3. **Choose One of Two Actions:** You can ONLY use `agent.done()` or `agent.fail()`. No other actions are permitted.
- **If you choose `agent.done()`:** You MUST provide the complete, well-structured tutorial in the `tutorial` parameter. Compile all your useful notes into a final guide. Do NOT use `done` unless you are highly confident in the tutorial's accuracy and completeness.
- **If you choose `agent.fail()`:** Use this if you could not find enough information, or if the information you found is contradictory, unreliable, or incomplete. Provide a reason in the `hint` parameter.
**You are provided with**:
1. A screenshot of the current time step.
2. The history of your previous interactions with the UI.
3. Tutorials notes you have already found.
--- TUTORIAL NOTES START ---
TUTORIAL_PLACEHOLDER
--- TUTORIAL NOTES END ---
4. Access to the following class and methods to interact with the UI. You must only use these two actions.
class Agent:
"""
)
# 2. Strictly inject only the 'done' and 'fail' methods.
# This logic is adapted from the planner's eager mode constructor.
eager_tools = ["done", "fail"]
for tool_name in eager_tools:
attr = getattr(agent_class, tool_name, None)
# We check for 'is_searcher_agent_action' to be consistent with the SearcherAgent's decorators.
if attr and callable(attr) and hasattr(attr, "is_searcher_agent_action"):
signature = inspect.signature(attr)
docstring = inspect.getdoc(attr) or "No description available."
procedural_memory += textwrap.dedent(f"""
def {tool_name}{signature}:
'''{docstring}'''
""")
# 3. Provide the specific response format for this final decision.
procedural_memory += textwrap.dedent(
"""
# RESPONSE FORMAT
Your response must follow this exact format:
(Final Analysis and Tutorial Compilation)
Review your collected notes and the final screenshot. State whether you have sufficient information to create a definitive tutorial. Summarize your reasoning.
(Final Decision)
In natural language, declare your final choice. For example: "The search is successful, and I have compiled a complete tutorial." or "The search has failed because no reliable sources were found for this specific software version."
(Grounded Action)
Translate your final decision into a single line of Python code using the `agent` methods provided above.
**Example**:
```python
agent.done(tutorial="xxxx")
```
```python
agent.fail(hint="xxxx")
```
**CRITICAL**: You MUST choose one of the following two actions. No other actions are allowed.
"""
)
return procedural_memory.strip()
@staticmethod
def construct_grounder_procedural_memory(model_name: str):
system_prompt, user_message = None, f"Query:REF_EXPR\nOutput only the coordinate of one point in your response.\n"
if "scalecua" in model_name.lower():
user_message = "REF_EXPR"
system_prompt = textwrap.dedent(
'''
You are an autonomous GUI agent capable of operating on desktops, mobile devices, and web browsers. Your primary function is to analyze screen captures and perform appropriate UI actions to complete assigned tasks.
## Action Space
def click(
x: float | None = None,
y: float | None = None,
clicks: int = 1,
button: str = "left",
) -> None:
"""Clicks on the screen at the specified coordinates. The `x` and `y` parameter specify where the mouse event occurs. If not provided, the current mouse position is used. The `clicks` parameter specifies how many times to click, and the `button` parameter specifies which mouse button to use ('left', 'right', or 'middle')."""
pass
def doubleClick(
x: float | None = None,
y: float | None = None,
button: str = "left",
) -> None:
"""Performs a double click. This is a wrapper function for click(x, y, 2, 'left')."""
pass
def rightClick(x: float | None = None, y: float | None = None) -> None:
"""Performs a right mouse button click. This is a wrapper function for click(x, y, 1, 'right')."""
pass
def moveTo(x: float, y: float) -> None:
"""Move the mouse to the specified coordinates."""
pass
def dragTo(
x: float | None = None, y: float | None = None, button: str = "left"
) -> None:
"""Performs a drag-to action with optional `x` and `y` coordinates and button."""
pass
def swipe(
from_coord: tuple[float, float] | None = None,
to_coord: tuple[float, float] | None = None,
direction: str = "up",
amount: float = 0.5,
) -> None:
"""Performs a swipe action on the screen. The `from_coord` and `to_coord` specify the starting and ending coordinates of the swipe. If `to_coord` is not provided, the `direction` and `amount` parameters are used to determine the swipe direction and distance. The `direction` can be 'up', 'down', 'left', or 'right', and the `amount` specifies how far to swipe relative to the screen size (0 to 1)."""
pass
def long_press(x: float, y: float, duration: int = 1) -> None:
"""Long press on the screen at the specified coordinates. The `duration` specifies how long to hold the press in seconds."""
pass
## Input Specification
- Screenshot of the current screen + task description
## Output Format
<action>
[A set of executable action command]
</action>
## Note
- Avoid action(s) that would lead to invalid states.
- The generated action(s) must exist within the defined action space.
- The generated action(s) should be enclosed within <action></action> tags.'''
)
return system_prompt, user_message

View File

@ -0,0 +1,48 @@
description: "The config of all tools"
tools:
click:
enabled: true
type:
enabled: true
scroll:
enabled: true
drag_and_drop:
enabled: true
highlight_text_span:
enabled: true
locate_cursor:
enabled: true
call_code_agent:
enabled: true
call_search_agent:
enabled: true
scroll:
enabled: true
hotkey:
enabled: true
hold_and_press:
enabled: true
wait:
enabled: true
done:
enabled: true
fail:
enabled: true
open:
enabled: true

View File

View File

@ -0,0 +1,448 @@
import json
import re
import time
from io import BytesIO
from typing import Tuple, Dict, List, Union
import io
import os
from PIL import Image, ImageDraw
from mm_agents.os_symphony.memory.procedural_memory import PROCEDURAL_MEMORY
from mm_agents.os_symphony.utils.process_context import get_current_result_dir
import logging
logger = logging.getLogger("desktopenv.agent")
def create_pyautogui_code(agent, code: str, obs: Dict) -> Tuple[str, dict | None]:
"""
Attempts to evaluate the code into a pyautogui code snippet with grounded actions using the observation screenshot.
Args:
agent (ACI): The grounding agent to use for evaluation.
code (str): The code string to evaluate.
obs (Dict): The current observation containing the screenshot.
Returns:
exec_code (str): The pyautogui code to execute the grounded action.
coordinate (List): The coordinate of the action, a list such as [x1, y1, x2, y2, x3, y3...]. Because may appear more than one coordinate in one action.
Modified by Yang.
Raises:
Exception: If there is an error in evaluating the code.
"""
agent.assign_screenshot(obs) # Necessary for grounding
response = eval(code)
if isinstance(response, Tuple):
return response
elif isinstance(response, str):
return response, None
else:
return "", None
def draw_coordinates(image_bytes: bytes, coordinates: List[Union[int, float]], save_path: str):
"""
Draw coordinates on the given image and save it to a new file.
This function receives an image as a byte stream, a list of coordinates in the format [x1, y1, x2, y2, ...],
and draws a red 'X' at each (x, y) coordinate point. The resulting image is then saved to the specified path.
Args:
- image_bytes (bytes): The raw byte data of the image (e.g., read from a PNG or JPEG file).
- coordinates (List[Union[int, float]]): A flattened list of coordinates, must contain an even number of elements. For example: [x1, y1, x2, y2].
- save_path (str): The path where the new image with markings will be saved.
"""
try:
image = Image.open(io.BytesIO(image_bytes))
image = image.convert("RGB")
except Exception as e:
return
draw = ImageDraw.Draw(image)
cross_size = 15
cross_color = "red"
cross_width = 3
for i in range(0, len(coordinates) - 1, 2):
x, y = coordinates[i], coordinates[i+1]
line1_start = (x - cross_size, y - cross_size)
line1_end = (x + cross_size, y + cross_size)
line2_start = (x + cross_size, y - cross_size)
line2_end = (x - cross_size, y + cross_size)
draw.line([line1_start, line1_end], fill=cross_color, width=cross_width)
draw.line([line2_start, line2_end], fill=cross_color, width=cross_width)
os.makedirs(os.path.dirname(save_path), exist_ok=True)
image.save(save_path)
def parse_action_from_string(string):
'''
Parse all strings following "(next action)", including the phrase "next action" itself. If parsing is not possible, return everything.
'''
marker = "(Next Action)"
start_index = string.find(marker)
if start_index != -1:
return string[start_index:]
else:
return string
def call_llm_safe(
agent, temperature: float = 0.0, use_thinking: bool = False, **kwargs
) -> str:
try:
example_result_dir = get_current_result_dir()
except Exception:
example_result_dir = "logs/tokens"
# Retry if fails
max_retries = 3 # Set the maximum number of retries
attempt = 0
response = ""
while attempt < max_retries:
try:
response = agent.get_response(
temperature=temperature, use_thinking=use_thinking, **kwargs
)
assert response is not None, "Response from agent should not be None"
# print("Response success!")
break # If successful, break out of the loop
except Exception as e:
attempt += 1
print(f"{agent.engine} Attempt {attempt} failed: {e}")
if attempt == max_retries:
print("Max retries reached. Handling failure.")
time.sleep(1.0)
# record token cost
if isinstance(response, tuple):
response, usage = response
agent_name = agent.agent_name
with open(os.path.join(example_result_dir, "token.jsonl"), "a", encoding="utf-8") as f:
f.write(json.dumps({
"agent_name": agent_name,
"completion_tokens": usage.completion_tokens,
"prompt_tokens": usage.prompt_tokens,
"total_tokens": usage.total_tokens
}))
f.write("\n")
return response if response is not None else ""
def call_func_safe(
func, **kwargs
) -> str:
# Retry if fails
max_retries = 3 # Set the maximum number of retries
attempt = 0
response = ""
while attempt < max_retries:
try:
response = func(**kwargs)
break
except Exception as e:
attempt += 1
print(f"Attempt {attempt} failed: {e}")
if attempt == max_retries:
print("Max retries reached. Handling failure.")
time.sleep(1.0)
return response if response is not None else ""
def extract_coords_from_action_dict(action_dict: Dict | None) -> List:
coords = []
coords_num = 0
if action_dict:
for k, v in action_dict["args"].items():
if (k == "x" and v) or (k == "y" and v) or (k == "x1" and v) or (k == "x2" and v) or (k == "y1" and v) or (k == "y2" and v):
coords_num += 1
if coords_num == 2:
coords.append(action_dict["args"]["x"])
coords.append(action_dict["args"]["y"])
if coords_num == 4:
coords.append(action_dict["args"]["x1"])
coords.append(action_dict["args"]["y1"])
coords.append(action_dict["args"]["x2"])
coords.append(action_dict["args"]["y2"])
return coords
def call_llm_formatted(generator, format_checkers, **kwargs):
"""
Calls the generator agent's LLM and ensures correct formatting.
Args:
generator (ACI): The generator agent to call.
obs (Dict): The current observation containing the screenshot.
format_checkers (Callable): Functions that take the response and return a tuple of (success, feedback).
**kwargs: Additional keyword arguments for the LLM call.
Returns:
response (str): The formatted response from the generator agent.
"""
max_retries = 3 # Set the maximum number of retries
attempt = 0
response = ""
if kwargs.get("messages") is None:
messages = (
generator.messages.copy()
) # Copy messages to avoid modifying the original
else:
messages = kwargs["messages"]
del kwargs["messages"] # Remove messages from kwargs to avoid passing it twice
while attempt < max_retries:
response = call_llm_safe(generator, messages=messages, **kwargs)
# Prepare feedback messages for incorrect formatting
feedback_msgs = []
for format_checker in format_checkers:
success, feedback = format_checker(response)
if not success:
feedback_msgs.append(feedback)
if not feedback_msgs:
# logger.info(f"Response formatted correctly on attempt {attempt} for {generator.engine.model}")
break
logger.error(
f"Response formatting error on attempt {attempt} for {generator.engine.model}. Response: {response} {', '.join(feedback_msgs)}"
)
messages.append(
{
"role": "assistant",
"content": [{"type": "text", "text": response}],
}
)
logger.info(f"Bad response: {response}")
delimiter = "\n- "
formatting_feedback = f"- {delimiter.join(feedback_msgs)}"
messages.append(
{
"role": "user",
"content": [
{
"type": "text",
"text": PROCEDURAL_MEMORY.FORMATTING_FEEDBACK_PROMPT.replace(
"FORMATTING_FEEDBACK", formatting_feedback
),
}
],
}
)
logger.info("Feedback:\n%s", formatting_feedback)
attempt += 1
if attempt == max_retries:
logger.error(
"Max retries reached when formatting response. Handling failure."
)
time.sleep(1.0)
return response
def split_thinking_response(full_response: str) -> Tuple[str, str]:
try:
# Extract thoughts section
thoughts = full_response.split("<thoughts>")[-1].split("</thoughts>")[0].strip()
# Extract answer section
answer = full_response.split("<answer>")[-1].split("</answer>")[0].strip()
return answer, thoughts
except Exception as e:
return full_response, ""
def parse_code_from_string(input_string):
"""Parses a string to extract each line of code enclosed in triple backticks (```)
Args:
input_string (str): The input string containing code snippets.
Returns:
str: The last code snippet found in the input string, or an empty string if no code is found.
"""
input_string = input_string.strip()
# This regular expression will match both ```code``` and ```python code```
# and capture the `code` part. It uses a non-greedy match for the content inside.
pattern = r"```(?:\w+\s+)?(.*?)```"
# print(f'[parse_code_from_string].input_string: {input_string}')
# Find all non-overlapping matches in the string
matches = re.findall(pattern, input_string, re.DOTALL)
if len(matches) == 0:
# return []
return ""
relevant_code = matches[
-1
] # We only care about the last match given it is the grounded action
# print(f'[parse_code_from_string].relevant_code: {relevant_code}')
return relevant_code
def extract_agent_functions(code):
"""
Extracts all agent function names from the given code.
Args:
code (str): The code string to search.
Returns:
list: A list of strings like ['agent.click', 'agent.type'].
"""
pattern = r"agent\.\w+"
return re.findall(pattern, code)
def compress_image(image_bytes: bytes = None, image: Image = None) -> bytes:
"""Compresses an image represented as bytes.
Compression involves resizing image into half its original size and saving to webp format.
Args:
image_bytes (bytes): The image data to compress.
Returns:
bytes: The compressed image data.
"""
if not image:
image = Image.open(BytesIO(image_bytes))
output = BytesIO()
image.save(output, format="WEBP")
compressed_image_bytes = output.getvalue()
return compressed_image_bytes
import math
IMAGE_FACTOR = 28
MIN_PIXELS = 100 * 28 * 28
MAX_PIXELS = 16384 * 28 * 28
MAX_RATIO = 200
def round_by_factor(number: int, factor: int) -> int:
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
return round(number / factor) * factor
def ceil_by_factor(number: int, factor: int) -> int:
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
return math.ceil(number / factor) * factor
def floor_by_factor(number: int, factor: int) -> int:
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return math.floor(number / factor) * factor
def smart_resize(
height: int,
width: int,
factor: int = IMAGE_FACTOR,
min_pixels: int = MIN_PIXELS,
max_pixels: int = MAX_PIXELS,
) -> tuple[int, int]:
"""
Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
min_pixels = MIN_PIXELS if not min_pixels else min_pixels
max_pixels = MAX_PIXELS if not max_pixels else max_pixels
if max(height, width) / min(height, width) > MAX_RATIO:
raise ValueError(
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
)
h_bar = max(factor, round_by_factor(height, factor))
w_bar = max(factor, round_by_factor(width, factor))
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = floor_by_factor(height / beta, factor)
w_bar = floor_by_factor(width / beta, factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = ceil_by_factor(height * beta, factor)
w_bar = ceil_by_factor(width * beta, factor)
return h_bar, w_bar
def enhance_observation(image_data: bytes, coordinates: List, expansion_pixels: int = 400, draw=True) -> Tuple[bytes, int, int, int, int]:
"""
According to the given coordinates, draw markers on the screenshot and crop a "focused" area.
Returns:
Tuple[bytes, int, int, int, int]:
- new_image_data (bytes): Data of the cropped image
- crop_left (int): X-axis offset
- crop_top (int): Y-axis offset
- new_width (int): Width of the cropped image
- new_height (int): Height of the cropped image
"""
image = Image.open(io.BytesIO(image_data)).convert("RGBA")
draw_ctx = ImageDraw.Draw(image)
img_width, img_height = image.size
X_MARKER_SIZE = 40
X_MARKER_WIDTH = 5
def _draw_x(draw_context, center_x, center_y, size=X_MARKER_SIZE, color="red", width=X_MARKER_WIDTH):
half_size = size // 2
draw_context.line((center_x - half_size, center_y - half_size, center_x + half_size, center_y + half_size), fill=color, width=width)
draw_context.line((center_x - half_size, center_y + half_size, center_x + half_size, center_y - half_size), fill=color, width=width)
crop_left, crop_top, crop_right, crop_bottom = 0, 0, img_width, img_height
if len(coordinates) == 2:
x, y = coordinates[0], coordinates[1]
if draw:
_draw_x(draw_ctx, x, y)
crop_left = x - expansion_pixels
crop_top = y - expansion_pixels
crop_right = x + expansion_pixels
crop_bottom = y + expansion_pixels
elif len(coordinates) >= 4:
x1, y1 = coordinates[0], coordinates[1]
x2, y2 = coordinates[2], coordinates[3]
if draw:
_draw_x(draw_ctx, x1, y1, color="red")
_draw_x(draw_ctx, x2, y2, color="blue")
draw_ctx.line((x1, y1, x2, y2), fill="green", width=5)
box_left = min(x1, x2)
box_top = min(y1, y2)
box_right = max(x1, x2)
box_bottom = max(y1, y2)
crop_left = box_left - expansion_pixels
crop_top = box_top - expansion_pixels
crop_right = box_right + expansion_pixels
crop_bottom = box_bottom + expansion_pixels
# check boundary
crop_left = max(0, int(crop_left))
crop_top = max(0, int(crop_top))
crop_right = min(img_width, int(crop_right))
crop_bottom = min(img_height, int(crop_bottom))
crop_box = (crop_left, crop_top, crop_right, crop_bottom)
cropped_image = image.crop(crop_box)
new_width, new_height = cropped_image.size
buffered = io.BytesIO()
cropped_image.save(buffered, format="PNG")
return buffered.getvalue(), crop_left, crop_top, new_width, new_height

View File

@ -0,0 +1,106 @@
"""This file contains various formatting checks used to reprompt an agent for correctly formatted responses."""
from typing import List
import json
import yaml
import re
from mm_agents.os_symphony.utils.common_utils import (
extract_agent_functions,
parse_code_from_string,
split_thinking_response,
)
single_action_check = (
lambda response: len(extract_agent_functions(parse_code_from_string(response))) == 1
)
single_action_error_msg = (
"Incorrect code: There must be a single agent action in the code response."
)
SINGLE_ACTION_FORMATTER = lambda response: (
single_action_check(response),
single_action_error_msg,
)
def code_valid_check(tool_config, response):
code = parse_code_from_string(response)
print(f'[code_valid_check] parsed code is: {code}')
# check if the action is pre-defined
with open(tool_config, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
valid_methods = set(config['tools'].keys())
pattern = r"^agent\.(\w+)\(.*\)$"
match = re.match(pattern, code.strip(), re.DOTALL)
if match:
method_name = match.group(1)
print(f'[code_valid_check]: method is {method_name}')
if method_name in valid_methods:
return True
else:
return False
else:
return False
code_valid_error_msg = "Incorrect code: The agent action must be a SINGLE and VALID function and use valid parameters from the docstring list."
CODE_VALID_FORMATTER = lambda tool_config, response: (
code_valid_check(tool_config, response),
code_valid_error_msg,
)
thoughts_answer_tag_check = lambda response: split_thinking_response(response)[1] != ""
thoughts_answer_tag_error_msg = "Incorrect response: The response must contain both <thoughts>...</thoughts> and <answer>...</answer> tags."
THOUGHTS_ANSWER_TAG_FORMATTER = lambda response: (
thoughts_answer_tag_check(response),
thoughts_answer_tag_error_msg,
)
integer_answer_check = (
lambda response: split_thinking_response(response)[0].strip().isdigit()
)
integer_answer_error_msg = (
"Incorrect response: The <answer>...</answer> tag must contain a single integer."
)
INTEGER_ANSWER_FORMATTER = lambda response: (
integer_answer_check(response),
integer_answer_error_msg,
)
def json_answer_check(response: str, required_fields: List[str]) -> bool:
"""
一个只返回 True/False 的检查函数
"""
try:
answer_str = parse_code_from_string(response)
if len(answer_str) == 0:
return False
data = json.loads(answer_str)
if not isinstance(data, dict):
return False
if set(required_fields) - set(data.keys()):
return False
return True
except Exception:
return False
json_answer_error_msg = (
"Incorrect response: The (Answer) part must contain a valid JSON object that includes ALL required keys and need to be wrapped by ```json and ```"
)
JSON_ANSWER_FORMATTER = lambda response, required_fields: (
json_answer_check(required_fields, response),
json_answer_error_msg,
)

View File

@ -0,0 +1,216 @@
import io
import math
import numpy as np
from typing import List, Tuple, Dict, Any, Optional
from PIL import Image, ImageDraw, ImageFont
from rapidfuzz import fuzz
import logging
from mm_agents.os_symphony.agents.memoryer_agent import StepBehavior
logger = logging.getLogger("desktopenv.loop_detection")
def _are_actions_similar(
action1: Dict[str, Any],
action2: Dict[str, Any],
image_width: int,
image_height: int,
relative_coord_threshold: float,
fuzzy_text_threshold: float,
) -> bool:
"""
[Internal Auxiliary] Determine if two actions are similar based on detailed rules.
Args:
action1: The first action.
action2: The second action.
image_width: The width of the screenshot.
image_height: The height of the screenshot.
relative_coord_threshold: A relative distance threshold for coordinate comparison.
fuzzy_text_threshold: A similarity threshold (0-100) for fuzzy text matching.
Returns:
Return True if the actions are similar, otherwise return False.
"""
# ensure same action
if action1.get("function") != action2.get("function"):
return False
func = action1.get("function")
args1 = action1.get("args", {})
args2 = action2.get("args", {})
diagonal = math.sqrt(image_width**2 + image_height**2)
abs_coord_thresh = relative_coord_threshold * diagonal
def are_coords_close(x1, y1, x2, y2):
if None in [x1, y1, x2, y2]: return False
distance = math.sqrt((x1 - x2)**2 + (y1 - y2)**2)
return distance < abs_coord_thresh
if func == "click":
return (
are_coords_close(args1.get("x"), args1.get("y"), args2.get("x"), args2.get("y")) and
args1.get("button") == args2.get("button") and
args1.get("clicks") == args2.get("clicks")
)
elif func == "open":
return args1.get("name") == args2.get("name")
elif func == "type":
if args1.get("x") and args1.get("y") and args2.get("x") and args2.get("y"):
return (
are_coords_close(args1.get("x"), args1.get("y"), args2.get("x"), args2.get("y")) and
args1.get("text") == args2.get("text")
)
else:
return args1.get("text") == args2.get("text")
elif func == "drag":
return (
are_coords_close(args1.get("x1"), args1.get("y1"), args2.get("x1"), args2.get("y1")) and
are_coords_close(args1.get("x2"), args1.get("y2"), args2.get("x2"), args2.get("y2"))
)
elif func == "set_cell_values":
return args1.get("text") == args2.get("text")
elif func == "scroll":
clicks1 = args1.get("clicks", 0)
clicks2 = args2.get("clicks", 0)
if (clicks1 == 0 and clicks2 != 0) or (clicks1 != 0 and clicks2 == 0):
same_direction = False
else:
same_direction = math.copysign(1, clicks1) == math.copysign(1, clicks2)
return (
are_coords_close(args1.get("x"), args1.get("y"), args2.get("x"), args2.get("y")) and
same_direction and
args1.get("shift") == args2.get("shift")
)
elif func == "key":
return args1.get("keys") == args2.get("keys")
elif func == "wait":
return True
elif func in ["call_code_agent", "call_search_agent"]:
query1 = args1.get("query", "")
query2 = args2.get("query", "")
# use Levenshtein distance to calculate fuzzy similarity
query_similarity = fuzz.token_set_ratio(query1, query2)
# print(f'query_sim: {query_similarity}')
return (
query_similarity >= fuzzy_text_threshold and
args1.get("result") == args2.get("result")
)
else:
return False
def _are_steps_similar_optimized(
step1: StepBehavior,
step2: StepBehavior,
idx1: int,
idx2: int,
full_trajectory: List[StepBehavior],
phash_threshold: int,
ssim_threshold: float,
# 动作比较所需的参数
image_width: int,
image_height: int,
relative_coord_threshold: float,
fuzzy_text_threshold: float,
) -> bool:
"""
[Internal Auxiliary] use pre-calculated data to quickly determine if the two actions are similar/
"""
if step1.phash is None or step2.phash is None:
return False
if (step1.phash - step2.phash) > phash_threshold:
return False
later_step_idx = max(idx1, idx2)
earlier_step_idx = min(idx1, idx2)
ssim_score = full_trajectory[later_step_idx].ssim_list[earlier_step_idx]
if ssim_score < ssim_threshold:
return False
if not _are_actions_similar(
step1.action_dict, step2.action_dict,
image_width, image_height, relative_coord_threshold, fuzzy_text_threshold
):
return False
return True
def detect_loop(
full_trajectory: List[StepBehavior],
image_width: int = 1920,
image_height: int = 1080,
N: int = 3,
phash_threshold: int = 1,
ssim_threshold: float = 0.99,
relative_coord_threshold: float = 0.02,
fuzzy_text_threshold: float = 85.0,
) -> Tuple[bool, Optional[Dict[str, List[int]]]]:
"""
Efficiently detect the presence of looping patterns based on precomputed data.
Args:
full_trajectory (List[StepBehavior]): Full history including the current step.
image_width (int): Width of the screenshot.
image_height (int): Height of the screenshot.
N (int): Number of steps in the candidate loop (sequence length).
phash_threshold (int): Hamming distance threshold for pHash similarity. Recommended: 02.
ssim_threshold (float): SSIM similarity threshold for image comparison. Recommended: 0.950.99.
relative_coord_threshold (float): Relative threshold for coordinate similarity. Recommended: 0.010.05.
fuzzy_text_threshold (float): Fuzzy text matching similarity threshold (0100) for agent queries.
Returns:
A tuple (is_loop_detected, loop_info):
- is_loop_detected (bool): Whether a loop is detected.
- loop_info (Dict | None): If a loop is detected, contains the indices of the two matching sequences.
"""
L = len(full_trajectory)
if not isinstance(N, int) or N <= 0 or L < 2 * N:
return False, None
max_start_index = L - 2 * N
for i in range(max_start_index, -1, -1):
is_potential_match = True
for j in range(N):
idx_prev = i + j
idx_curr = (L - N) + j
step_prev = full_trajectory[idx_prev]
step_curr = full_trajectory[idx_curr]
if not _are_steps_similar_optimized(
step_prev, step_curr, idx_prev, idx_curr, full_trajectory,
phash_threshold, ssim_threshold,
image_width, image_height, relative_coord_threshold, fuzzy_text_threshold
):
is_potential_match = False
break
if is_potential_match:
previous_sequence_indices = list(range(i, i + N))
loop_info = {
"match_sequence_indices": previous_sequence_indices
}
return True, loop_info
return False, None

View File

@ -0,0 +1,30 @@
# process_context.py
# This module provides an independent context storage for each process.
from multiprocessing import current_process
# We will store process-specific contexts here.
# Since each process has its own separate memory space, when accessing this variable,
# each process accesses its own copy, without conflicting with others.
_context_storage = {}
def set_context(key, value):
"""Set a value in the context of the current process."""
_context_storage[key] = value
# print(f"[{current_process().name}] Set context: {key} = {value}") # For debugging
def get_context(key, default=None):
"""Retrieve a value from the context of the current process."""
value = _context_storage.get(key, default)
# print(f"[{current_process().name}] Get context: {key} -> {value}") # For debugging
if value is None and default is None:
raise NameError(f"'{key}' not found in the current process context. Ensure it is set at the process entry point.")
return value
# For convenience, we can create a specialized getter for result_dir
def get_current_result_dir():
"""Get the result_dir specific to the current process."""
return get_context('current_result_dir')
def set_current_result_dir(example_result_dir):
set_context("current_result_dir", example_result_dir)

View File

@ -6,6 +6,9 @@ import os
from io import BytesIO
from typing import Dict, List, Tuple
from http import HTTPStatus
import dashscope
from dashscope import MultiModalConversation
import backoff
import openai
from PIL import Image
@ -40,7 +43,7 @@ def process_image(image_bytes):
height=height,
width=width,
factor=32,
max_pixels=16 * 16 * 4 * 1280,
max_pixels=16 * 16 * 4 * 12800,
)
image = image.resize((resized_width, resized_height))
@ -58,7 +61,7 @@ class Qwen3VLAgent:
self,
platform: str = "ubuntu",
model: str = "qwen3-vl",
max_tokens: int = 1500,
max_tokens: int = 32768,
top_p: float = 0.9,
temperature: float = 0.0,
action_space: str = "pyautogui",
@ -66,6 +69,9 @@ class Qwen3VLAgent:
history_n: int = 4,
add_thought_prefix: bool = False,
coordinate_type: str = "relative",
api_backend: str = "dashscope", # "openai" or "dashscope"
enable_thinking: bool = False, # Enable thinking mode for DashScope
thinking_budget: int = 32768, # Token budget for reasoning
):
self.platform = platform
self.model = model
@ -77,9 +83,13 @@ class Qwen3VLAgent:
self.history_n = history_n
self.add_thought_prefix = add_thought_prefix
self.coordinate_type = coordinate_type
self.api_backend = api_backend
self.enable_thinking = enable_thinking
self.thinking_budget = thinking_budget
assert action_space in ["pyautogui"], "Invalid action space"
assert observation_type in ["screenshot"], "Invalid observation type"
assert api_backend in ["openai", "dashscope"], "Invalid API backend, must be 'openai' or 'dashscope'"
self.thoughts = []
self.actions = []
@ -527,6 +537,70 @@ Previous actions:
return low_level_instruction, pyautogui_code
@staticmethod
def _to_dashscope_messages(messages):
"""
Convert messages built for OpenAI compat into DashScope MultiModalConversation format.
- "text" part -> {"text": "..."}
- "image_url" -> {"image": "<url-or-data-uri>"}
- "video_url" -> {"video": "<url-or-data-uri>"}
"""
ds_msgs = []
for m in messages:
role = m.get("role", "")
parts = m.get("content", [])
ds_content = []
for p in parts:
ptype = p.get("type")
if ptype == "text":
ds_content.append({"text": p.get("text", "")})
elif ptype == "image_url":
url = (p.get("image_url") or {}).get("url", "")
# DashScope accepts http(s), file://, or data:image/*; keep as-is
ds_content.append({"image": url})
elif ptype == "video_url":
url = (p.get("video_url") or {}).get("url", "")
ds_content.append({"video": url})
else:
# If you ever pass raw assistant strings (no parts), tolerate it
if isinstance(p, str):
ds_content.append({"text": p})
# Also tolerate plain-string content (rare)
if not ds_content and isinstance(m.get("content"), str):
ds_content = [{"text": m["content"]}]
ds_msgs.append({"role": role, "content": ds_content})
return ds_msgs
@staticmethod
def _extract_text_from_dashscope_response(resp):
"""Join all 'text' parts from the first choice, including reasoning if present."""
if hasattr(resp, "output"):
out = resp.output
else:
out = resp.get("output") if isinstance(resp, dict) else None
if not out:
return None
choices = getattr(out, "choices", None) if not isinstance(out, dict) else out.get("choices")
if not choices:
return None
msg = getattr(choices[0], "message", None) if not isinstance(choices[0], dict) else choices[0].get("message")
if not msg:
return None
content = getattr(msg, "content", None) if not isinstance(msg, dict) else msg.get("content", [])
if not content:
return None
# Extract reasoning content if present (for thinking models)
reasoning_content = getattr(msg, "reasoning_content", None) if not isinstance(msg, dict) else msg.get("reasoning_content", None)
content_text = "".join(part.get("text", "") for part in content if isinstance(part, dict) and "text" in part)
# Format with thinking tags if reasoning exists
if reasoning_content is not None:
return f"<think>\n{reasoning_content}\n</think>\n\n{content_text}"
else:
return content_text
@backoff.on_exception(
backoff.constant,
(
@ -545,25 +619,93 @@ Previous actions:
def call_llm(self, payload, model):
messages = payload["messages"]
base_url = "https://poc-dashscope.aliyuncs.com/compatible-mode/v1"
api_key = "sk-123"
if self.api_backend == "openai":
return self._call_llm_openai(messages, model)
elif self.api_backend == "dashscope":
return self._call_llm_dashscope(messages, model)
else:
raise ValueError(f"Unknown API backend: {self.api_backend}")
def _call_llm_openai(self, messages, model):
"""Call LLM using OpenAI SDK (compatible with OpenAI-compatible endpoints)."""
base_url = os.environ.get("OPENAI_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1")
api_key = os.environ.get("OPENAI_API_KEY", "sk-123")
client = openai.OpenAI(base_url=base_url, api_key=api_key)
for _ in range(MAX_RETRY_TIMES):
logger.info("Generating content with Qwen model: %s", model)
for attempt in range(1, MAX_RETRY_TIMES + 1):
logger.info(f"[OpenAI] Generating content with model: {model} (attempt {attempt}/{MAX_RETRY_TIMES})")
try:
response = client.chat.completions.create(
model=model,
messages=messages,
max_tokens=self.max_tokens,
temperature=self.temperature,
top_p=self.top_p,
# temperature=self.temperature,
# top_p=self.top_p,
)
return response.choices[0].message.content
except Exception as e:
logger.error(f"Error calling Qwen model: {e}")
time.sleep(5)
continue
logger.error(f"[OpenAI] Error calling model: {e}")
if attempt < MAX_RETRY_TIMES:
time.sleep(5)
continue
break
return ""
def _call_llm_dashscope(self, messages, model):
"""Call LLM using DashScope SDK."""
dashscope.base_http_api_url = os.environ.get("DASHSCOPE_BASE_URL", "https://dashscope.aliyuncs.com/api/v1")
dashscope.api_key = os.environ.get("DASHSCOPE_API_KEY", "sk-123")
# Convert message schema
ds_messages = self._to_dashscope_messages(messages)
# Retry loop
last_err = None
for attempt in range(1, MAX_RETRY_TIMES + 1):
thinking_status = f" (thinking={self.enable_thinking})" if self.enable_thinking else ""
logger.info(f"[DashScope] Generating content with model: {model}, thinking_status: {thinking_status} (attempt {attempt}/{MAX_RETRY_TIMES})")
try:
# Build API call parameters
call_params = {
"model": model,
"messages": ds_messages,
"max_tokens": self.max_tokens,
# "temperature": self.temperature,
# "top_p": self.top_p,
"vl_high_resolution_images": True,
}
# Add thinking parameters if enabled
if self.enable_thinking:
call_params["enable_thinking"] = True
call_params["thinking_budget"] = self.thinking_budget
resp = MultiModalConversation.call(**call_params)
if getattr(resp, "status_code", None) not in (None, HTTPStatus.OK):
code = getattr(resp, "code", "")
msg = getattr(resp, "message", "")
reqid = getattr(resp, "request_id", "")
logger.warning(f"[DashScope] non-OK response (id={reqid}): {code} {msg}")
last_err = RuntimeError(f"DashScope status {resp.status_code}: {code} {msg}")
time.sleep(1.5 * attempt)
continue
text = self._extract_text_from_dashscope_response(resp)
if not text:
raise ValueError("DashScope response has no text content")
return text
except Exception as e:
last_err = e
logger.error(f"[DashScope] call failed: {e}")
if attempt < MAX_RETRY_TIMES:
time.sleep(1.5 * attempt)
continue
break
if last_err:
raise last_err
return ""
def reset(self, _logger=None):

737
mm_agents/seed_agent.py Normal file

File diff suppressed because one or more lines are too long

View File

@ -2,13 +2,13 @@
# Do not write any secret keys or sensitive information here.
# Monitor configuration
TASK_CONFIG_PATH=../evaluation_examples/test_all.json
TASK_CONFIG_PATH=../evaluation_examples/test_50_random_proportional.json
EXAMPLES_BASE_PATH=../evaluation_examples/examples
RESULTS_BASE_PATH=../results
# ACTION_SPACE=pyautogui
# OBSERVATION_TYPE=screenshot
# MODEL_NAME=computer-use-preview
# MAX_STEPS=150
FLASK_PORT=80
RESULTS_BASE_PATH=../results_hosted_gbox_50
ACTION_SPACE=pyautogui
OBSERVATION_TYPE=screenshot
MODEL_NAME=us.anthropic.claude-sonnet-4-5-20250929-v1:0
MAX_STEPS=15
FLASK_PORT=8080
FLASK_HOST=0.0.0.0
FLASK_DEBUG=false

View File

@ -69,4 +69,6 @@ alibabacloud_ecs20140526
alibabacloud_tea_openapi
alibabacloud_tea_util
json_minify
json_repair
json_repair
volcengine-python-sdk[ark]
ui-tars>=0.4.2.2

2
run.py
View File

@ -218,7 +218,7 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
f.write("\n")
env.close()
logger.info(f"Average score: {sum(scores) / len(scores)}")
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
def get_unfinished(

View File

@ -457,7 +457,7 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
f.write("\n")
env.close()
logger.info(f"Average score: {sum(scores) / len(scores)}")
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
def get_unfinished(action_space, use_model, observation_type, result_dir, total_file_json):

View File

@ -485,7 +485,7 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
f.write("\n")
env.close()
logger.info(f"Average score: {sum(scores) / len(scores)}")
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
def get_unfinished(action_space, use_model, observation_type, result_dir, total_file_json):

18
run_dart_gui.sh Normal file
View File

@ -0,0 +1,18 @@
# export HF_ENDPOINT=https://hf-mirror.com
python run_multienv_dart_gui.py \
--dart_base_url http://0.0.0.0:6006/v1 \
--provider_name docker \
--test_all_meta_path evaluation_examples/test_nogdrive.json \
--path_to_vm docker_vm_data/Ubuntu.qcow2 \
--headless \
--max_steps 30 \
--domain all \
--num_envs 2 \
--log_level INFO \
--temperature 1.0 \
--save_complete_trajectory \
--use_enhanced_runner \
--model dart-gui \
--model_type qwen25vl \
--infer_mode dart_mode \
--result_dir ./result_multi_apps_pengxiang_transformers12 | tee run_20251103_multi_apps_pengxiang_transformers12.log

542
run_multienv_agi.py Normal file
View File

@ -0,0 +1,542 @@
from __future__ import annotations
import argparse
import datetime
import json
import logging
import os
import sys
import signal
import time
from typing import List, Dict
import math
from tqdm import tqdm
from multiprocessing import Process, Manager
from multiprocessing import current_process
import lib_run_single
from desktop_env.desktop_env import DesktopEnv
from mm_agents.agi_agent import AGIAgent
# Global variables for signal handling
active_environments = []
processes = []
is_terminating = False
# import wandb
# load the environment variables from .env file
if os.path.exists(".env"):
from dotenv import load_dotenv
load_dotenv()
# Logger Configs {{{ #
def config() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Run end-to-end evaluation on the benchmark"
)
# environment config
parser.add_argument("--path_to_vm", type=str, default=None)
parser.add_argument(
"--headless", action="store_true", help="Run in headless machine"
)
parser.add_argument(
"--action_space", type=str, default="pyautogui", help="Action type"
)
parser.add_argument(
"--observation_type",
choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"],
default="screenshot",
help="Observation type",
)
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
parser.add_argument("--max_steps", type=int, default=15)
# agent config
parser.add_argument("--max_trajectory_length", type=int, default=3)
parser.add_argument(
"--test_config_base_dir", type=str, default="evaluation_examples"
)
# example config
parser.add_argument("--domain", type=str, nargs='+', default=["all"],
help="Domain(s) to run. Use 'all' for all domains, or specify one or more domain names")
parser.add_argument(
"--test_all_meta_path", type=str, default="evaluation_examples/test_all.json"
)
# logging related
parser.add_argument("--result_dir", type=str, default="./results")
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel")
parser.add_argument("--log_level", type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
default='INFO', help="Set the logging level")
# aws config
parser.add_argument(
"--region", type=str, default="us-east-1", help="AWS region for the VM"
)
parser.add_argument(
"--provider_name", type=str, default="aws", choices=["aws", "virtualbox", "vmware", "docker", "azure"], help="Provider name"
)
parser.add_argument(
"--client_password", type=str, default="osworld-public-evaluation", help="Client password"
)
parser.add_argument(
"--screen_width", type=int, default=1920, help="Screen width"
)
parser.add_argument(
"--screen_height", type=int, default=1080, help="Screen height"
)
args = parser.parse_args()
return args
args = config() # Get command line arguments first
logger = logging.getLogger()
log_level = getattr(logging, args.log_level.upper())
logger.setLevel(log_level)
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
file_handler = logging.FileHandler(
os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8"
)
debug_handler = logging.FileHandler(
os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8"
)
stdout_handler = logging.StreamHandler(sys.stdout)
file_handler.setLevel(logging.INFO)
debug_handler.setLevel(logging.DEBUG)
stdout_handler.setLevel(log_level)
formatter = logging.Formatter(
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
)
file_handler.setFormatter(formatter)
debug_handler.setFormatter(formatter)
stdout_handler.setFormatter(formatter)
stdout_handler.addFilter(logging.Filter("desktopenv"))
logger.addHandler(file_handler)
logger.addHandler(debug_handler)
logger.addHandler(stdout_handler)
# }}} Logger Configs #
logger = logging.getLogger("desktopenv.experiment")
def distribute_tasks(test_all_meta: dict) -> List[tuple]:
all_tasks = []
for domain, examples in test_all_meta.items():
for example_id in examples:
all_tasks.append((domain, example_id))
return all_tasks
def process_signal_handler(signum, frame, env_idx):
"""Signal handler for child processes to gracefully shut down their environments."""
logger.info(f"Process {env_idx + 1} received signal {signum}. Shutting down...")
# Get the active_environments from the caller's frame
local_vars = frame.f_locals
active_environments = local_vars.get('active_environments', [])
# Close environment in the current process context
for env in active_environments:
if env is not None:
try:
logger.info(f"Process {env_idx + 1} closing environment...")
env.close()
logger.info(f"Process {env_idx + 1} environment closed successfully")
except Exception as e:
logger.error(f"Process {env_idx + 1} error closing environment: {e}")
logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.")
sys.exit(0)
def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: list):
active_environments = []
env = None
try:
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
REGION = args.region
screen_size = (args.screen_width, args.screen_height)
ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)])
env = DesktopEnv(
path_to_vm=args.path_to_vm,
action_space=args.action_space,
provider_name=args.provider_name,
region=REGION,
snapshot_name=ami_id,
screen_size=screen_size,
headless=args.headless,
os_type="Ubuntu",
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
enable_proxy=True,
client_password=args.client_password
)
active_environments.append(env)
agent = AGIAgent(
env=env,
# Contact the authors for access to a private deployment endpoint.
server_url="https://your-private-agi-endpoint",
action_space=args.action_space,
observation_type=args.observation_type,
max_trajectory_length=args.max_trajectory_length,
client_password=args.client_password,
provider_name=args.provider_name,
screen_width=args.screen_width,
screen_height=args.screen_height
)
logger.info(f"Process {current_process().name} started.")
while True:
try:
item = task_queue.get(timeout=5)
except Exception:
break
domain, example_id = item
try:
config_file = os.path.join(
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
)
with open(config_file, "r", encoding="utf-8") as f:
example = json.load(f)
logger.info(f"[{current_process().name}][Domain]: {domain}")
logger.info(f"[{current_process().name}][Example ID]: {example_id}")
logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}")
example_result_dir = os.path.join(
args.result_dir,
args.action_space,
args.observation_type,
"agi-0",
domain,
example_id,
)
os.makedirs(example_result_dir, exist_ok=True)
try:
lib_run_single.run_single_example_agi(
agent,
env,
example,
args.max_steps,
example["instruction"],
args,
example_result_dir,
shared_scores,
)
except Exception as e:
import traceback
logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}")
logger.error(traceback.format_exc())
try:
env.controller.end_recording(
os.path.join(example_result_dir, "recording.mp4")
)
except Exception as rec_e:
logger.error(f"Failed to end recording: {rec_e}")
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write(
json.dumps(
{"Error": f"{domain}/{example_id} - {e}"}
)
)
f.write("\n")
except Exception as e:
logger.error(f"Task-level error in {current_process().name}: {e}")
import traceback
logger.error(traceback.format_exc())
except Exception as e:
logger.error(f"Process-level error in {current_process().name}: {e}")
import traceback
logger.error(traceback.format_exc())
finally:
logger.info(f"{current_process().name} cleaning up environment...")
try:
if env:
env.close()
logger.info(f"{current_process().name} environment closed successfully")
except Exception as e:
logger.error(f"{current_process().name} error during environment cleanup: {e}")
def signal_handler(signum, frame):
"""Handle termination signals (SIGINT, SIGTERM) to gracefully shutdown environments."""
global is_terminating, active_environments, processes
# Avoid duplicate handling
if is_terminating:
return
is_terminating = True
logger.info(f"Received signal {signum}. Gracefully shutting down...")
# Close all registered environments in the main process
for env in active_environments:
try:
logger.info(f"Closing environment...")
env.close()
logger.info(f"Environment closed successfully")
except Exception as e:
logger.error(f"Error closing environment: {e}")
# Send termination signal to all child processes first
for p in processes:
if p.is_alive():
try:
logger.info(f"Sending termination signal to process {p.name}...")
p.terminate()
except Exception as e:
logger.error(f"Error sending termination signal to process: {e}")
# Allow a short time for processes to handle their own cleanup
time.sleep(1)
# Forcefully terminate any processes that didn't exit
for p in processes:
if p.is_alive():
try:
logger.info(f"Forcefully terminating process {p.name}...")
import signal as sig
os.kill(p.pid, sig.SIGKILL)
except Exception as e:
logger.error(f"Error forcefully terminating process: {e}")
logger.info("Shutdown complete. Exiting.")
sys.exit(0)
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
global processes
logger.info("Args: %s", args)
all_tasks = distribute_tasks(test_all_meta)
logger.info(f"Total tasks: {len(all_tasks)}")
with Manager() as manager:
shared_scores = manager.list()
task_queue = manager.Queue()
for item in all_tasks:
task_queue.put(item)
num_envs = args.num_envs
processes = []
for i in range(num_envs):
p = Process(
target=run_env_tasks,
args=(task_queue, args, shared_scores),
name=f"EnvProcess-{i+1}"
)
p.daemon = True
p.start()
processes.append(p)
logger.info(f"Started process {p.name} with PID {p.pid}")
try:
while True:
alive_count = 0
for idx, p in enumerate(processes):
if not p.is_alive():
logger.warning(f"Process {p.name} died, restarting...")
new_p = Process(
target=run_env_tasks,
args=(task_queue, args, shared_scores),
name=f"EnvProcess-Restart-{idx+1}"
)
new_p.daemon = True
new_p.start()
processes[idx] = new_p
logger.info(f"Restarted process {new_p.name} with PID {new_p.pid}")
else:
alive_count += 1
if task_queue.empty():
logger.info("All tasks finished.")
break
if alive_count == 0:
logger.error("All processes died, exiting.")
break
time.sleep(5)
for p in processes:
p.join()
except KeyboardInterrupt:
logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...")
raise
except Exception as e:
logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True)
for p in processes:
if p.is_alive():
try:
logger.info(f"Terminating process {p.name} due to error...")
p.terminate()
except Exception as term_e:
logger.error(f"Error terminating process {p.name}: {term_e}")
raise
scores = list(shared_scores)
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
def get_unfinished(
action_space, use_model, observation_type, result_dir, total_file_json
):
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
if not os.path.exists(target_dir):
return total_file_json
finished = {}
for domain in os.listdir(target_dir):
finished[domain] = []
domain_path = os.path.join(target_dir, domain)
if os.path.isdir(domain_path):
for example_id in os.listdir(domain_path):
if example_id == "onboard":
continue
example_path = os.path.join(domain_path, example_id)
if os.path.isdir(example_path):
if "result.txt" not in os.listdir(example_path):
# empty all files under example_id
for file in os.listdir(example_path):
os.remove(os.path.join(example_path, file))
else:
finished[domain].append(example_id)
if not finished:
return total_file_json
for domain, examples in finished.items():
if domain in total_file_json:
total_file_json[domain] = [
x for x in total_file_json[domain] if x not in examples
]
return total_file_json
def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
if not os.path.exists(target_dir):
print("New experiment, no result yet.")
return None
all_result = []
for domain in os.listdir(target_dir):
domain_path = os.path.join(target_dir, domain)
if os.path.isdir(domain_path):
for example_id in os.listdir(domain_path):
example_path = os.path.join(domain_path, example_id)
if os.path.isdir(example_path):
if "result.txt" in os.listdir(example_path):
# empty all files under example_id
try:
all_result.append(
float(
open(
os.path.join(example_path, "result.txt"), "r"
).read()
)
)
except:
all_result.append(0.0)
if not all_result:
print("New experiment, no result yet.")
return None
else:
print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%")
return all_result
if __name__ == "__main__":
####### The complete version of the list of examples #######
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Register signal handlers for graceful termination
signal.signal(signal.SIGINT, signal_handler) # Handle Ctrl+C
signal.signal(signal.SIGTERM, signal_handler) # Handle termination signal
try:
args = config()
# save args to json in result_dir/action_space/observation_type/model/args.json
path_to_args = os.path.join(
args.result_dir,
args.action_space,
args.observation_type,
"agi-0",
"args.json",
)
os.makedirs(os.path.dirname(path_to_args), exist_ok=True)
with open(path_to_args, "w", encoding="utf-8") as f:
json.dump(vars(args), f, indent=4)
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
test_all_meta = json.load(f)
# Handle multiple domains
if "all" not in args.domain:
# Filter test_all_meta to only include specified domains
filtered_meta = {}
for domain in args.domain:
if domain in test_all_meta:
filtered_meta[domain] = test_all_meta[domain]
else:
logger.warning(f"Domain '{domain}' not found in test_all_meta")
test_all_meta = filtered_meta
test_file_list = get_unfinished(
args.action_space,
"agi-0",
args.observation_type,
args.result_dir,
test_all_meta,
)
left_info = ""
for domain in test_file_list:
left_info += f"{domain}: {len(test_file_list[domain])}\n"
logger.info(f"Left tasks:\n{left_info}")
get_result(
args.action_space,
"agi-0",
args.observation_type,
args.result_dir,
test_all_meta,
)
test(args, test_file_list)
except KeyboardInterrupt:
logger.info("Main process received KeyboardInterrupt.")
# Signal handler will take care of cleanup
except Exception as e:
logger.error(f"Unexpected error in main process: {e}", exc_info=True)
# Also trigger cleanup for unhandled exceptions
signal_handler(signal.SIGTERM, None)
finally:
# Final cleanup in case any environments or processes remain
logger.info("Main process final cleanup...")
for env in active_environments:
if env is not None:
try:
logger.info(f"Closing environment in final cleanup...")
env.close()
logger.info(f"Environment closed successfully in final cleanup")
except Exception as e:
logger.error(f"Error during final environment cleanup: {e}")
# First try gentle termination
for p in processes:
if p is not None and p.is_alive():
try:
logger.info(f"Terminating process {p.name}...")
p.terminate()
except Exception as e:
logger.error(f"Error terminating process: {e}")
# Wait a moment for processes to terminate
time.sleep(1)
# Then force kill if needed
for p in processes:
if p is not None and p.is_alive():
try:
logger.info(f"Force killing process {p.name}...")
os.kill(p.pid, signal.SIGKILL)
logger.info(f"Process {p.name} force killed")
except Exception as e:
logger.error(f"Error force killing process: {e}")

View File

@ -13,6 +13,7 @@ import time
from typing import List
from multiprocessing import Process, Manager, current_process
import lib_run_single
from lib_results_logger import log_task_error
from desktop_env.desktop_env import DesktopEnv
from mm_agents.anthropic import AnthropicAgent
@ -67,17 +68,27 @@ def config() -> argparse.Namespace:
)
# lm config
parser.add_argument("--model", type=str, default="claude-4-sonnet-20250514")
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--top_p", type=float, default=0.9)
parser.add_argument("--max_tokens", type=int, default=1500)
parser.add_argument("--model", type=str, default="")
parser.add_argument("--temperature", type=float, default=None)
parser.add_argument("--top_p", type=float, default=None)
parser.add_argument("--max_tokens", type=int, default=3000)
parser.add_argument("--stop_token", type=str, default=None)
# thinking mode config
parser.add_argument("--no-thinking", action="store_true",
help="Disable thinking mode (no scratchpad)")
parser.add_argument("--use-isp", action="store_true",
help="Use interleaved scratchpad (ISP) mode")
# example config
parser.add_argument("--domain", type=str, default="all")
parser.add_argument(
"--test_all_meta_path", type=str, default="evaluation_examples/test_all.json"
)
parser.add_argument(
"--specific_task_id", type=str, default=None,
help="Run only a specific task ID (overrides domain filtering)"
)
# logging related
parser.add_argument("--result_dir", type=str, default="./results")
@ -95,6 +106,37 @@ def config() -> argparse.Namespace:
args = config() # Get command line arguments first
# Validate that model is specified to prevent accidental usage with empty model
if not args.model or args.model.strip() == "":
print("ERROR: Model must be specified. Use --model <model_name>")
print("Example: --model claude-sonnet-4-5-20250929")
sys.exit(1)
# Validate model support before proceeding
from mm_agents.anthropic.utils import validate_model_support
# Pass same temperature/top_p and thinking parameters as will be used by the agent
validation_kwargs = {}
if args.temperature is not None:
validation_kwargs['temperature'] = args.temperature
if args.top_p is not None:
validation_kwargs['top_p'] = args.top_p
validation_kwargs['no_thinking'] = args.no_thinking
validation_kwargs['use_isp'] = args.use_isp
if not validate_model_support(args.model, **validation_kwargs):
print(f"\n💥 Model '{args.model}' api sample failed")
sys.exit(1)
# Validate thinking mode options are mutually exclusive
if args.no_thinking and args.use_isp:
print("ERROR: --no-thinking and --use-isp are mutually exclusive")
print("Choose one of:")
print(" (default): Regular scratchpad mode")
print(" --no-thinking: Disable thinking/scratchpad")
print(" --use-isp: Use interleaved scratchpad (ISP)")
sys.exit(1)
logger = logging.getLogger()
log_level = getattr(logging, args.log_level.upper())
logger.setLevel(log_level)
@ -182,7 +224,7 @@ def run_env_tasks(task_queue, args, shared_scores):
headless=args.headless,
os_type="Ubuntu",
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
enable_proxy=True,
enable_proxy=False,
client_password=args.client_password
)
active_environments.append(env)
@ -196,8 +238,9 @@ def run_env_tasks(task_queue, args, shared_scores):
observation_type=args.observation_type,
max_trajectory_length=args.max_trajectory_length,
provider_name=args.provider_name,
screen_width=args.screen_width,
screen_height=args.screen_height,
screen_size=(args.screen_width, args.screen_height),
no_thinking=getattr(args, 'no_thinking', False),
use_isp=getattr(args, 'use_isp', False),
)
logger.info(f"Process {current_process().name} started.")
while True:
@ -239,6 +282,14 @@ def run_env_tasks(task_queue, args, shared_scores):
import traceback
logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}")
logger.error(traceback.format_exc())
# Log error to results.json
try:
example = {"id": example_id} # Create minimal example dict for error logging
log_task_error(example, str(e), example_result_dir, args)
except Exception as log_e:
logger.error(f"Failed to log error to results.json: {log_e}")
try:
env.controller.end_recording(
os.path.join(example_result_dir, "recording.mp4")
@ -479,7 +530,28 @@ if __name__ == "__main__":
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
test_all_meta = json.load(f)
if args.domain != "all":
# Filter for specific task ID if provided
if args.specific_task_id:
logger.info(f"Filtering for specific task ID: {args.specific_task_id}")
filtered_meta = {}
task_found = False
for domain, task_ids in test_all_meta.items():
for task_id in task_ids:
if task_id == args.specific_task_id:
filtered_meta[domain] = [task_id]
task_found = True
logger.info(f"Found task {args.specific_task_id} in domain: {domain}")
break
if task_found:
break
if not task_found:
logger.error(f"Task ID {args.specific_task_id} not found in test file!")
sys.exit(1)
test_all_meta = filtered_meta
elif args.domain != "all":
test_all_meta = {args.domain: test_all_meta[args.domain]}
test_file_list = get_unfinished(

916
run_multienv_dart_gui.py Normal file
View File

@ -0,0 +1,916 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
import argparse
import datetime
import json
import logging
import os
import sys
import signal
import time
from typing import List
from multiprocessing import Process, Manager, Queue
from multiprocessing import current_process
from numpy import True_
import lib_run_single
from desktop_env.desktop_env import DesktopEnv
from mm_agents.dart_gui_agent import DartAgent
import os
# Global variables for signal handling
active_environments = []
processes = []
is_terminating = False
# load the environment variables from .env file
if os.path.exists(".env"):
from dotenv import load_dotenv
load_dotenv()
# Logger Configs {{{ #
def config() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Run end-to-end evaluation on the benchmark - Dart Version"
)
# environment config
parser.add_argument("--path_to_vm", type=str, default=None)
parser.add_argument(
"--headless", action="store_true", help="Run in headless machine"
)
parser.add_argument(
"--action_space", type=str, default="pyautogui", help="Action type"
)
parser.add_argument(
"--observation_type",
choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"],
default="screenshot",
help="Observation type",
)
parser.add_argument("--sleep_after_execution", type=float, default=5.0)
parser.add_argument("--max_steps", type=int, default=15)
# evaluation config
parser.add_argument(
"--test_config_base_dir", type=str, default="evaluation_examples"
)
# lm config - Dart specific configurations
parser.add_argument("--model", type=str, default="dart-uitars", help="Model name for Dart")
parser.add_argument("--model_type", type=str, default="qwen25vl", choices=["qwen25vl", "qwen2vl"])
parser.add_argument("--infer_mode", type=str, default="dart_mode", choices=["dart_mode", "qwen2vl_user"])
parser.add_argument("--prompt_style", type=str, default="dart_style")
parser.add_argument("--input_swap", action="store_true", help="Use copy and paste to type content")
parser.add_argument("--language", type=str, default="English")
parser.add_argument("--max_pixels", type=float, default=16384*28*28)
parser.add_argument("--min_pixels", type=float, default=100*28*28)
parser.add_argument("--temperature", type=float, default=0.0)
parser.add_argument("--top_p", type=float, default=1.0)
parser.add_argument("--top_k", type=int, default=-1)
parser.add_argument("--history_n", type=int, default=5)
parser.add_argument("--max_tokens", type=int, default=500)
parser.add_argument("--stop_token", type=str, default=None)
parser.add_argument("--max_trajectory_length", type=int, default=None, help="The max number of trajectory steps.")
parser.add_argument("--max_image_history_length", type=int, default=5, help="The max number of images in the history.")
# example config
parser.add_argument("--domain", type=str, default="all")
parser.add_argument(
"--test_all_meta_path", type=str, default="evaluation_examples/test_all.json"
)
# logging related
parser.add_argument("--result_dir", type=str, default="./results")
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel")
parser.add_argument("--log_level", type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
default='INFO', help="Set the logging level")
# aws config
parser.add_argument(
"--region", type=str, default="us-east-1", help="AWS region for the VM"
)
parser.add_argument(
"--provider_name", type=str, default="aws", choices=["aws", "virtualbox", "vmware", "docker", "azure"], help="Provider name"
)
parser.add_argument(
"--client_password", type=str, default="password", help="Client password"
)
parser.add_argument(
"--screen_width", type=int, default=1920, help="Screen width"
)
parser.add_argument(
"--screen_height", type=int, default=1080, help="Screen height"
)
# Dart specific parameters
parser.add_argument("--dart_api_key", type=str, default="", help="Dart API key")
parser.add_argument("--dart_base_url", type=str, default="", help="Dart base URL")
parser.add_argument("--max_images", type=int, default=5, help="Maximum number of images in prompt history")
parser.add_argument("--max_texts", type=int, default=35, help="Maximum number of text responses in prompt history")
# Enhanced trajectory saving
parser.add_argument("--save_complete_trajectory", action="store_true", help="Save complete trajectory with images and detailed information")
parser.add_argument("--use_enhanced_runner", action="store_true", help="Use enhanced Dart runner with complete trajectory saving")
args = parser.parse_args()
return args
args = config() # Get command line arguments first
logger = logging.getLogger()
log_level = getattr(logging, args.log_level.upper())
logger.setLevel(log_level)
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
file_handler = logging.FileHandler(
os.path.join("logs", "dart-{:}.log".format(datetime_str)), encoding="utf-8"
)
debug_handler = logging.FileHandler(
os.path.join("logs", "dart-debug-{:}.log".format(datetime_str)), encoding="utf-8"
)
stdout_handler = logging.StreamHandler(sys.stdout)
file_handler.setLevel(logging.INFO)
debug_handler.setLevel(logging.DEBUG)
stdout_handler.setLevel(log_level)
formatter = logging.Formatter(
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
)
file_handler.setFormatter(formatter)
debug_handler.setFormatter(formatter)
stdout_handler.setFormatter(formatter)
stdout_handler.addFilter(logging.Filter("desktopenv"))
logger.addHandler(file_handler)
logger.addHandler(debug_handler)
logger.addHandler(stdout_handler)
# }}} Logger Configs #
logger = logging.getLogger("desktopenv.experiment")
def distribute_tasks(test_all_meta: dict) -> List[tuple]:
all_tasks = []
for domain, examples in test_all_meta.items():
for example_id in examples:
all_tasks.append((domain, example_id))
return all_tasks
def process_signal_handler(signum, frame, env_idx):
"""Signal handler for child processes to gracefully shut down their environments."""
logger.info(f"Process {env_idx + 1} received signal {signum}. Shutting down...")
# Get the active_environments from the caller's frame
local_vars = frame.f_locals
active_environments = local_vars.get('active_environments', [])
# Close environment in the current process context
for env in active_environments:
if env is not None:
try:
logger.info(f"Process {env_idx + 1} closing environment...")
env.close()
logger.info(f"Process {env_idx + 1} environment closed successfully")
except Exception as e:
logger.error(f"Process {env_idx + 1} error closing environment: {e}")
logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.")
sys.exit(0)
def save_complete_trajectory_with_images(example_result_dir: str, task_info: dict, reward: float,
messages: list, all_images: list = None):
"""
保存完整的轨迹信息包括图片路径
Args:
example_result_dir: 结果保存目录
task_info: 任务信息
reward: 最终奖励分数
messages: 完整的对话消息
all_images: 所有图片数据列表可选
"""
import datetime
# 构建完整轨迹数据
complete_trajectory = {
"task_info": {
"domain": task_info.get("domain", "unknown"),
"example_id": task_info.get("example_id", "unknown"),
"instruction": task_info.get("instruction", ""),
"timestamp": datetime.datetime.now().isoformat()
},
"evaluation": {
"reward": reward,
"success": reward > 0
},
"trajectory": {
"messages": [],
"image_paths": [],
"step_count": 0
}
}
# 处理消息和图片路径
image_counter = 0
step_counter = 0
for msg_idx, message in enumerate(messages):
processed_message = {
"step": step_counter,
"role": message.get("role", "unknown"),
"content": message.get("content", []),
"timestamp": message.get("timestamp", ""),
"image_files": []
}
# 检查消息中的图片内容
if isinstance(message.get("content"), list):
for content_item in message["content"]:
if content_item.get("type") == "image_url":
# 如果有对应的图片数据,保存图片文件
if all_images and image_counter < len(all_images):
image_filename = f"step_{step_counter}_image_{image_counter}.png"
image_path = os.path.join(example_result_dir, image_filename)
try:
# 保存图片
if hasattr(all_images[image_counter], 'save'):
# PIL Image对象
all_images[image_counter].save(image_path)
elif isinstance(all_images[image_counter], bytes):
# 二进制数据
with open(image_path, 'wb') as f:
f.write(all_images[image_counter])
else:
logger.warning(f"Unknown image format for image {image_counter}")
continue
processed_message["image_files"].append(image_filename)
complete_trajectory["trajectory"]["image_paths"].append(image_path)
logger.info(f"Saved image: {image_filename}")
except Exception as e:
logger.error(f"Failed to save image {image_counter}: {e}")
image_counter += 1
# 更新content中的图片引用为本地路径
if processed_message["image_files"]:
content_item["local_path"] = processed_message["image_files"][-1]
complete_trajectory["trajectory"]["messages"].append(processed_message)
# 如果是assistant的回复增加步数
if message.get("role") == "assistant":
step_counter += 1
complete_trajectory["trajectory"]["step_count"] = step_counter
# 保存完整轨迹JSON文件
trajectory_file = os.path.join(example_result_dir, "complete_trajectory.json")
try:
with open(trajectory_file, 'w', encoding='utf-8') as f:
json.dump(complete_trajectory, f, indent=2, ensure_ascii=False)
logger.info(f"Complete trajectory saved to: {trajectory_file}")
# 同时保存一个简化版本用于快速查看
summary_file = os.path.join(example_result_dir, "trajectory_summary.json")
summary = {
"task_id": task_info.get("example_id", "unknown"),
"domain": task_info.get("domain", "unknown"),
"instruction": task_info.get("instruction", ""),
"reward": reward,
"success": reward > 0,
"total_steps": step_counter,
"total_images": len(complete_trajectory["trajectory"]["image_paths"]),
"image_files": [os.path.basename(path) for path in complete_trajectory["trajectory"]["image_paths"]],
"timestamp": complete_trajectory["task_info"]["timestamp"]
}
with open(summary_file, 'w', encoding='utf-8') as f:
json.dump(summary, f, indent=2, ensure_ascii=False)
logger.info(f"Trajectory summary saved to: {summary_file}")
except Exception as e:
logger.error(f"Failed to save complete trajectory: {e}")
def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: list):
active_environments = []
env = None
try:
# Initialize proxy configuration if enabled
# if hasattr(args, 'proxy_host') and args.proxy_host and args.proxy_port:
# from desktop_env.providers.aws.proxy_pool import get_global_proxy_pool
# proxy_pool = get_global_proxy_pool()
# proxy_pool.add_proxy(
# host=args.proxy_host,
# port=args.proxy_port,
# protocol=args.proxy_protocol
# )
# logger.info(f"Added proxy: {args.proxy_host}:{args.proxy_port} ({args.proxy_protocol})")
# elif hasattr(args, 'proxy_config') and args.proxy_config and os.path.exists(args.proxy_config):
# from desktop_env.providers.aws.proxy_pool import init_proxy_pool
# init_proxy_pool(args.proxy_config)
# logger.info(f"Initialized proxy pool from {args.proxy_config}")
# Configure environment based on provider
if args.provider_name == "aws":
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
REGION = args.region
screen_size = (args.screen_width, args.screen_height)
ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)])
env = DesktopEnv(
path_to_vm=args.path_to_vm,
action_space=args.action_space,
provider_name=args.provider_name,
region=REGION,
snapshot_name=ami_id,
screen_size=screen_size,
headless=args.headless,
os_type="Ubuntu",
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"]
)
else:
# For non-AWS providers (docker, virtualbox, etc.)
env = DesktopEnv(
path_to_vm=args.path_to_vm,
action_space=args.action_space,
provider_name=args.provider_name,
headless=args.headless,
os_type="Ubuntu",
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"]
)
active_environments.append(env)
args.max_trajectory_length = args.max_steps
# Dart specific runtime configuration
if args.infer_mode == "dart_mode":
runtime_conf: dict = {
"infer_mode": args.infer_mode,
"prompt_style": args.prompt_style,
"input_swap": args.input_swap,
"language": args.language,
"history_n": args.history_n,
"max_pixels": args.max_pixels,
"min_pixels": args.min_pixels,
"temperature": args.temperature,
"top_k": args.top_k,
"top_p": args.top_p,
"max_tokens": args.max_tokens,
"max_images": args.max_images,
"max_texts": args.max_texts,
"dart_api_key": args.dart_api_key,
"dart_base_url": args.dart_base_url
}
elif args.infer_mode == "qwen2vl_user":
runtime_conf: dict = {
"infer_mode": "qwen2vl_user",
"prompt_style": "qwen2vl_user",
"input_swap": args.input_swap,
"language": args.language,
"history_n": 5,
"max_pixels": 2116800,
"min_pixels": 3136,
"temperature": 0.0,
"top_k": -1,
"top_p": 0.9,
"max_tokens": 1000
}
else:
raise ValueError(f"Unknown infer_mode: {args.infer_mode}")
agent = DartAgent(
model=args.model,
action_space=args.action_space,
observation_type=args.observation_type,
max_trajectory_length=args.max_trajectory_length,
model_type=args.model_type,
runtime_conf=runtime_conf
)
logger.info(f"Process {current_process().name} started with Dart configuration.")
while True:
try:
item = task_queue.get(timeout=5)
except Exception:
break
domain, example_id = item
try:
config_file = os.path.join(
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
)
with open(config_file, "r", encoding="utf-8") as f:
example = json.load(f)
logger.info(f"[{current_process().name}][Domain]: {domain}")
logger.info(f"[{current_process().name}][Example ID]: {example_id}")
logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}")
example_result_dir = os.path.join(
args.result_dir,
args.action_space,
args.observation_type,
args.model,
domain,
example_id,
)
os.makedirs(example_result_dir, exist_ok=True)
try:
# Create a temporary list to capture the score
temp_scores = []
# 根据参数选择使用哪个运行函数
if args.use_enhanced_runner or args.save_complete_trajectory:
# 使用九章专用的运行函数,支持完整轨迹保存
logger.info(f"Using enhanced Dart runner for {domain}/{example_id}")
lib_run_single.run_single_example(
agent,
env,
example,
args.max_steps,
example["instruction"],
args,
example_result_dir,
temp_scores,
)
else:
# 使用标准运行函数
lib_run_single.run_single_example(
agent,
env,
example,
args.max_steps,
example["instruction"],
args,
example_result_dir,
temp_scores,
)
# Add domain info to the score
if temp_scores:
shared_scores.append({
'domain': domain,
'example_id': example_id,
'score': temp_scores[-1]
})
except Exception as e:
import traceback
logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}")
logger.error(traceback.format_exc())
try:
env.controller.end_recording(
os.path.join(example_result_dir, "recording.mp4")
)
except Exception as rec_e:
logger.error(f"Failed to end recording: {rec_e}")
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write(
json.dumps(
{"Error": f"{domain}/{example_id} - {e}"}
)
)
f.write("\n")
except Exception as e:
logger.error(f"Task-level error in {current_process().name}: {e}")
import traceback
logger.error(traceback.format_exc())
except Exception as e:
logger.error(f"Process-level error in {current_process().name}: {e}")
import traceback
logger.error(traceback.format_exc())
finally:
logger.info(f"{current_process().name} cleaning up environment...")
try:
if env:
env.close()
logger.info(f"{current_process().name} environment closed successfully")
except Exception as e:
logger.error(f"{current_process().name} error during environment cleanup: {e}")
def signal_handler(signum, frame):
"""Handle termination signals (SIGINT, SIGTERM) to gracefully shutdown environments."""
global is_terminating, active_environments, processes
# Avoid duplicate handling
if is_terminating:
return
is_terminating = True
logger.info(f"Received signal {signum}. Gracefully shutting down...")
# Close all registered environments in the main process
for env in active_environments:
try:
logger.info(f"Closing environment...")
env.close()
logger.info(f"Environment closed successfully")
except Exception as e:
logger.error(f"Error closing environment: {e}")
# Send termination signal to all child processes first
for p in processes:
if p.is_alive():
try:
logger.info(f"Sending termination signal to process {p.name}...")
p.terminate()
except Exception as e:
logger.error(f"Error sending termination signal to process: {e}")
# Allow a short time for processes to handle their own cleanup
time.sleep(1)
# Forcefully terminate any processes that didn't exit
for p in processes:
if p.is_alive():
try:
logger.info(f"Forcefully terminating process {p.name}...")
import signal as sig
os.kill(p.pid, sig.SIGKILL)
except Exception as e:
logger.error(f"Error forcefully terminating process: {e}")
logger.info("Shutdown complete. Exiting.")
sys.exit(0)
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
global processes
logger.info("Args: %s", args)
all_tasks = distribute_tasks(test_all_meta)
logger.info(f"Total tasks: {len(all_tasks)}")
with Manager() as manager:
shared_scores = manager.list()
task_queue = manager.Queue()
for item in all_tasks:
task_queue.put(item)
num_envs = args.num_envs
processes = []
for i in range(num_envs):
p = Process(
target=run_env_tasks,
args=(task_queue, args, shared_scores),
name=f"DartEnvProcess-{i+1}"
)
p.daemon = True
p.start()
processes.append(p)
logger.info(f"Started Dart process {p.name} with PID {p.pid}")
try:
while True:
alive_count = 0
for idx, p in enumerate(processes):
if not p.is_alive():
logger.warning(f"Process {p.name} died, restarting...")
new_p = Process(
target=run_env_tasks,
args=(task_queue, args, shared_scores),
name=f"DartEnvProcess-Restart-{idx+1}"
)
new_p.daemon = True
new_p.start()
processes[idx] = new_p
logger.info(f"Restarted process {new_p.name} with PID {new_p.pid}")
else:
alive_count += 1
if task_queue.empty():
logger.info("All tasks finished.")
break
if alive_count == 0:
logger.error("All processes died, exiting.")
break
time.sleep(5)
for p in processes:
p.join()
except KeyboardInterrupt:
logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...")
raise
except Exception as e:
logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True)
for p in processes:
if p.is_alive():
try:
logger.info(f"Terminating process {p.name} due to error...")
p.terminate()
except Exception as term_e:
logger.error(f"Error terminating process {p.name}: {term_e}")
raise
scores = list(shared_scores)
# Detailed statistics reporting
if scores:
# Extract numeric scores for overall statistics
numeric_scores = []
domain_stats = {}
for score_entry in scores:
if isinstance(score_entry, dict):
domain = score_entry.get('domain', 'unknown')
example_id = score_entry.get('example_id', 'unknown')
score = score_entry.get('score', 0)
else:
# Handle legacy numeric scores
domain = 'unknown'
example_id = 'unknown'
score = score_entry
numeric_scores.append(score)
# Domain statistics
if domain not in domain_stats:
domain_stats[domain] = {'total': 0, 'success': 0, 'scores': []}
domain_stats[domain]['total'] += 1
domain_stats[domain]['scores'].append(score)
if score > 0:
domain_stats[domain]['success'] += 1
# Overall statistics
total_tasks = len(numeric_scores)
successful_tasks = sum(1 for score in numeric_scores if score > 0)
average_score = sum(numeric_scores) / total_tasks
success_rate = (successful_tasks / total_tasks) * 100
logger.info("=" * 60)
logger.info("📊 DART EVALUATION RESULTS SUMMARY")
logger.info("=" * 60)
logger.info(f"📈 Overall Statistics:")
logger.info(f" • Total tasks executed: {total_tasks}")
logger.info(f" • Successful tasks (score > 0): {successful_tasks}")
logger.info(f" • Success rate: {success_rate:.1f}%")
logger.info(f" • Average score: {average_score:.3f}")
# Domain-specific statistics
if domain_stats and len(domain_stats) > 1: # Only show domain breakdown if multiple domains
logger.info(f"\n🏷️ Domain-specific Results:")
for domain, stats in sorted(domain_stats.items()):
domain_success_rate = (stats['success'] / stats['total']) * 100 if stats['total'] > 0 else 0
domain_avg_score = sum(stats['scores']) / len(stats['scores']) if stats['scores'] else 0
logger.info(f"{domain}:")
logger.info(f" - Tasks: {stats['total']}")
logger.info(f" - Successful: {stats['success']}")
logger.info(f" - Success rate: {domain_success_rate:.1f}%")
logger.info(f" - Average score: {domain_avg_score:.3f}")
# Score distribution
score_ranges = {
'Perfect (1.0)': sum(1 for s in numeric_scores if s == 1.0),
'High (0.8-0.99)': sum(1 for s in numeric_scores if 0.8 <= s < 1.0),
'Medium (0.5-0.79)': sum(1 for s in numeric_scores if 0.5 <= s < 0.8),
'Low (0.1-0.49)': sum(1 for s in numeric_scores if 0.1 <= s < 0.5),
'Failed (0.0)': sum(1 for s in numeric_scores if s == 0.0)
}
logger.info(f"\n📊 Score Distribution:")
for range_name, count in score_ranges.items():
if count > 0:
percentage = (count / total_tasks) * 100
logger.info(f"{range_name}: {count} tasks ({percentage:.1f}%)")
logger.info("=" * 60)
else:
logger.warning("⚠️ No scores collected during evaluation!")
def get_unfinished(
action_space, use_model, observation_type, result_dir, total_file_json
):
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
if not os.path.exists(target_dir):
return total_file_json
finished = {}
for domain in os.listdir(target_dir):
finished[domain] = []
domain_path = os.path.join(target_dir, domain)
if os.path.isdir(domain_path):
for example_id in os.listdir(domain_path):
if example_id == "onboard":
continue
example_path = os.path.join(domain_path, example_id)
if os.path.isdir(example_path):
if "result.txt" not in os.listdir(example_path):
# empty all files under example_id
for file in os.listdir(example_path):
os.remove(os.path.join(example_path, file))
else:
finished[domain].append(example_id)
if not finished:
return total_file_json
for domain, examples in finished.items():
if domain in total_file_json:
total_file_json[domain] = [
x for x in total_file_json[domain] if x not in examples
]
return total_file_json
def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
if not os.path.exists(target_dir):
print("New experiment, no result yet.")
return None
all_result = []
for domain in os.listdir(target_dir):
domain_path = os.path.join(target_dir, domain)
if os.path.isdir(domain_path):
for example_id in os.listdir(domain_path):
example_path = os.path.join(domain_path, example_id)
if os.path.isdir(example_path):
if "result.txt" in os.listdir(example_path):
# empty all files under example_id
try:
all_result.append(
float(
open(
os.path.join(example_path, "result.txt"), "r"
).read()
)
)
except:
all_result.append(0.0)
if not all_result:
print("New experiment, no result yet.")
return None
else:
print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%")
return all_result
def clear_cache_directory():
"""清空cache目录中的所有内容"""
cache_dir = "cache"
if os.path.exists(cache_dir):
logger.info(f"Clearing cache directory: {cache_dir}")
try:
import shutil
# 删除整个cache目录
shutil.rmtree(cache_dir)
# 重新创建空的cache目录
os.makedirs(cache_dir, exist_ok=True)
logger.info("Cache directory cleared successfully")
except Exception as e:
logger.error(f"Failed to clear cache directory: {e}")
else:
logger.info("Cache directory does not exist, creating it")
os.makedirs(cache_dir, exist_ok=True)
def cleanup_docker_containers():
"""清理Docker容器保留monitor容器"""
logger.info("Cleaning up Docker containers...")
try:
import subprocess
# 获取所有容器ID排除monitor-monitor-1
cmd = 'docker ps --format "{{.ID}} {{.Names}}" | grep -v "monitor-monitor-1" | awk \'{print $1}\''
result = subprocess.run(cmd, shell=True, capture_output=True, text=True, timeout=30)
if result.returncode == 0 and result.stdout.strip():
container_ids = result.stdout.strip().split('\n')
container_ids = [cid for cid in container_ids if cid.strip()]
if container_ids:
logger.info(f"Found {len(container_ids)} containers to remove: {container_ids}")
# 强制删除容器
for container_id in container_ids:
try:
rm_result = subprocess.run(
f"docker rm -f {container_id}",
shell=True,
capture_output=True,
text=True,
timeout=10
)
if rm_result.returncode == 0:
logger.info(f"Successfully removed container: {container_id}")
else:
logger.warning(f"Failed to remove container {container_id}: {rm_result.stderr}")
except subprocess.TimeoutExpired:
logger.warning(f"Timeout removing container: {container_id}")
except Exception as e:
logger.error(f"Error removing container {container_id}: {e}")
logger.info("Docker container cleanup completed")
else:
logger.info("No containers found to remove")
else:
logger.info("No containers found or error getting container list")
except subprocess.TimeoutExpired:
logger.error("Timeout during Docker container cleanup")
except Exception as e:
logger.error(f"Failed to cleanup Docker containers: {e}")
if __name__ == "__main__":
####### Dart Version - Complete evaluation runner #######
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Register signal handlers for graceful termination
signal.signal(signal.SIGINT, signal_handler) # Handle Ctrl+C
signal.signal(signal.SIGTERM, signal_handler) # Handle termination signal
try:
args = config()
# 清理Docker容器
# 清除上一次存留的docker 容器 自己跑的时候要留着
cleanup_docker_containers()
# 清空cache目录 清除上一次下载的文件
clear_cache_directory()
logger.info("Starting Dart evaluation runner...")
# save args to json in result_dir/action_space/observation_type/model/args.json
path_to_args = os.path.join(
args.result_dir,
args.action_space,
args.observation_type,
args.model,
"args.json",
)
os.makedirs(os.path.dirname(path_to_args), exist_ok=True)
with open(path_to_args, "w", encoding="utf-8") as f:
json.dump(vars(args), f, indent=4)
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
test_all_meta = json.load(f)
if args.domain != "all":
test_all_meta = {args.domain: test_all_meta[args.domain]}
test_file_list = get_unfinished(
args.action_space,
args.model,
args.observation_type,
args.result_dir,
test_all_meta,
)
left_info = ""
for domain in test_file_list:
left_info += f"{domain}: {len(test_file_list[domain])}\n"
logger.info(f"Left tasks:\n{left_info}")
get_result(
args.action_space,
args.model,
args.observation_type,
args.result_dir,
test_all_meta,
)
test(args, test_file_list)
except KeyboardInterrupt:
logger.info("Main process received KeyboardInterrupt.")
# Signal handler will take care of cleanup
except Exception as e:
logger.error(f"Unexpected error in main process: {e}", exc_info=True)
# Also trigger cleanup for unhandled exceptions
signal_handler(signal.SIGTERM, None)
finally:
# Final cleanup in case any environments or processes remain
logger.info("Main process final cleanup...")
for env in active_environments:
if env is not None:
try:
logger.info(f"Closing environment in final cleanup...")
env.close()
logger.info(f"Environment closed successfully in final cleanup")
except Exception as e:
logger.error(f"Error during final environment cleanup: {e}")
# First try gentle termination
for p in processes:
if p is not None and p.is_alive():
try:
logger.info(f"Terminating process {p.name}...")
p.terminate()
except Exception as e:
logger.error(f"Error terminating process: {e}")
# Wait a moment for processes to terminate
time.sleep(1)
# Then force kill if needed
for p in processes:
if p is not None and p.is_alive():
try:
logger.info(f"Force killing process {p.name}...")
os.kill(p.pid, signal.SIGKILL)
logger.info(f"Process {p.name} force killed")
except Exception as e:
logger.error(f"Error force killing process: {e}")

603
run_multienv_evocua.py Normal file
View File

@ -0,0 +1,603 @@
"""
Script to run EvoCUA native agent model on OSWorld tasks.
export AWS_ACCESS_KEY_ID="xx"
export AWS_SECRET_ACCESS_KEY="xx"
export AWS_REGION="xx"
export AWS_SECURITY_GROUP_ID="xx"
export AWS_SUBNET_ID="xx"
export OPENAI_API_KEY="xxxx"
export OPENAI_BASE_URL="xxxx"
Example Usage (S2):
python3 run_multienv_evocua.py \
--headless \
--provider_name aws \
--observation_type screenshot \
--model EvoCUA-S2 \
--result_dir ./evocua_s2 \
--test_all_meta_path evaluation_examples/test_nogdrive.json \
--max_steps 50 \
--num_envs 30 \
--temperature 0.01 \
--max_history_turns 4 \
--coordinate_type relative \
--resize_factor 32 \
--prompt_style S2
Example Usage (S1):
python3 run_multienv_evocua.py \
--headless \
--provider_name aws \
--observation_type screenshot \
--model EvoCUA-S1 \
--result_dir ./evocua_s1 \
--test_all_meta_path evaluation_examples/test_nogdrive.json \
--max_steps 50 \
--num_envs 30 \
--max_history_turns 3 \
--coordinate_type qwen25 \
--max_tokens 10240 \
--resize_factor 28 \
--prompt_style S1
"""
from __future__ import annotations
import argparse
import datetime
import json
import logging
import os
import sys
import signal
import time
from typing import List
from multiprocessing import Process, Manager, Queue
from multiprocessing import current_process
import lib_run_single
from desktop_env.desktop_env import DesktopEnv
from mm_agents.evocua.evocua_agent import EvoCUAAgent
# Global variables for signal handling
active_environments = []
processes = []
is_terminating = False
# Thread-local storage for task context (works per-process in multiprocessing)
import threading
_task_context = threading.local()
def get_task_context():
"""Get current task context from thread-local storage."""
return getattr(_task_context, 'context', {'domain': None, 'example_id': None})
def set_task_context(domain: str, example_id: str):
"""Set current task context in thread-local storage."""
_task_context.context = {'domain': domain, 'example_id': example_id}
def clear_task_context():
"""Clear current task context."""
if hasattr(_task_context, 'context'):
delattr(_task_context, 'context')
class TaskContextFilter(logging.Filter):
"""Filter to add domain and example_id to log records."""
def filter(self, record):
ctx = get_task_context()
domain = ctx.get('domain')
example_id = ctx.get('example_id')
if domain and example_id:
record.domain = domain
record.example_id = example_id
# Add prefix to message
if hasattr(record, 'msg') and isinstance(record.msg, str):
if not record.msg.startswith(f"[{domain}/{example_id}]"):
record.msg = f"[{domain}/{example_id}] {record.msg}"
else:
record.domain = domain or "N/A"
record.example_id = example_id or "N/A"
return True
# load the environment variables from .env file
if os.path.exists(".env"):
from dotenv import load_dotenv
load_dotenv()
def config() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Run end-to-end evaluation with EvoCUAAgent"
)
# environment config
parser.add_argument("--path_to_vm", type=str, default=None)
parser.add_argument(
"--headless", action="store_true", help="Run in headless machine"
)
parser.add_argument(
"--action_space", type=str, default="pyautogui", help="Action type"
)
parser.add_argument(
"--observation_type",
choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"],
default="screenshot",
help="Observation type",
)
parser.add_argument("--sleep_after_execution", type=float, default=5.0)
parser.add_argument("--max_steps", type=int, default=50)
# evaluation config
parser.add_argument(
"--test_config_base_dir", type=str, default="evaluation_examples"
)
# lm config
parser.add_argument("--model", type=str, default="evocua", help="Model name.")
parser.add_argument("--temperature", type=float, default=0.0)
parser.add_argument("--top_p", type=float, default=0.9)
parser.add_argument("--max_tokens", type=int, default=32768)
parser.add_argument("--stop_token", type=str, default=None)
parser.add_argument("--prompt_style", type=str, default="S2", choices=["S1", "S2"], help="Prompt style: 'S1' (structured reasoning) or 'S2' (tool calling)")
parser.add_argument("--history_type", type=str, default="action_history", help="[S1] History type")
parser.add_argument("--coordinate_type", type=str, default="relative", help="Coordinate type: relative, absolute, qwen25")
parser.add_argument("--password", type=str, default="osworld-public-evaluation", help="VM Password")
# Unified History Parameter
parser.add_argument("--max_history_turns", type=int, default=3, help="Number of history turns to include")
parser.add_argument("--resize_factor", type=int, default=32, help="Image resize factor (S1: 28, S2: 32)")
# example config
parser.add_argument("--domain", type=str, default="all")
parser.add_argument(
"--test_all_meta_path", type=str, default="evaluation_examples/test_nogdrive.json"
)
# logging related
parser.add_argument("--result_dir", type=str, default="./results")
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel")
parser.add_argument("--log_level", type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
default='INFO', help="Set the logging level")
# aws config
parser.add_argument(
"--region", type=str, default="us-east-1", help="AWS region for the VM"
)
parser.add_argument(
"--provider_name", type=str, default="aws", choices=["aws", "virtualbox", "vmware", "docker", "azure"], help="Provider name"
)
parser.add_argument(
"--client_password", type=str, default="", help="Client password"
)
parser.add_argument(
"--screen_width", type=int, default=1920, help="Screen width"
)
parser.add_argument(
"--screen_height", type=int, default=1080, help="Screen height"
)
args = parser.parse_args()
return args
args = config()
logger = logging.getLogger()
log_level = getattr(logging, args.log_level.upper())
logger.setLevel(log_level)
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
file_handler = logging.FileHandler(
os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8"
)
debug_handler = logging.FileHandler(
os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8"
)
stdout_handler = logging.StreamHandler(sys.stdout)
file_handler.setLevel(logging.INFO)
debug_handler.setLevel(logging.DEBUG)
stdout_handler.setLevel(log_level)
formatter = logging.Formatter(
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
)
file_handler.setFormatter(formatter)
debug_handler.setFormatter(formatter)
stdout_handler.setFormatter(formatter)
# Add task context filter to all handlers
task_filter = TaskContextFilter()
file_handler.addFilter(task_filter)
debug_handler.addFilter(task_filter)
stdout_handler.addFilter(task_filter)
stdout_handler.addFilter(logging.Filter("desktopenv"))
logger.addHandler(file_handler)
logger.addHandler(debug_handler)
logger.addHandler(stdout_handler)
logger = logging.getLogger("desktopenv.experiment")
def distribute_tasks(test_all_meta: dict) -> List[tuple]:
all_tasks = []
for domain, examples in test_all_meta.items():
for example_id in examples:
all_tasks.append((domain, example_id))
return all_tasks
def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: list):
active_environments = []
env = None
try:
REGION = args.region
screen_size = (args.screen_width, args.screen_height)
# Determine snapshot based on provider
snapshot_name = "init_state"
if args.provider_name == "aws":
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION].get((1920, 1080)))
snapshot_name = ami_id
env = DesktopEnv(
path_to_vm=args.path_to_vm,
action_space=args.action_space,
provider_name=args.provider_name,
region=REGION,
snapshot_name=snapshot_name,
screen_size=screen_size,
headless=args.headless,
os_type="Ubuntu",
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
enable_proxy=True,
client_password=args.client_password
)
active_environments.append(env)
logger.info(f"Process {current_process().name} started.")
while True:
try:
item = task_queue.get(timeout=5)
except Exception:
break
domain, example_id = item
set_task_context(domain, example_id)
try:
config_file = os.path.join(
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
)
with open(config_file, "r", encoding="utf-8") as f:
example = json.load(f)
logger.info(f"[{current_process().name}][Domain]: {domain}")
logger.info(f"[{current_process().name}][Example ID]: {example_id}")
logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}")
example_result_dir = os.path.join(
args.result_dir,
args.action_space,
args.observation_type,
args.model,
domain,
example_id,
)
os.makedirs(example_result_dir, exist_ok=True)
# Initialize EvoCUAAgent
agent = EvoCUAAgent(
model=args.model,
max_tokens=args.max_tokens,
top_p=args.top_p,
temperature=args.temperature,
action_space=args.action_space,
observation_type=args.observation_type,
max_steps=args.max_steps,
prompt_style=args.prompt_style,
max_history_turns=args.max_history_turns,
screen_size=(args.screen_width, args.screen_height),
coordinate_type=args.coordinate_type,
password=args.password,
resize_factor=args.resize_factor,
)
try:
lib_run_single.run_single_example_evocua(
agent,
env,
example,
args.max_steps,
example["instruction"],
args,
example_result_dir,
shared_scores,
)
except Exception as e:
import traceback
logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}")
logger.error(traceback.format_exc())
try:
env.controller.end_recording(
os.path.join(example_result_dir, "recording.mp4")
)
except Exception as rec_e:
logger.error(f"Failed to end recording: {rec_e}")
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write(json.dumps({"Error": f"{domain}/{example_id} - {e}"}))
f.write("\n")
except Exception as e:
logger.error(f"Task-level error in {current_process().name}: {e}")
import traceback
logger.error(traceback.format_exc())
finally:
clear_task_context()
except Exception as e:
logger.error(f"Process-level error in {current_process().name}: {e}")
import traceback
logger.error(traceback.format_exc())
finally:
logger.info(f"{current_process().name} cleaning up environment...")
try:
if env:
env.close()
logger.info(f"{current_process().name} environment closed successfully")
except Exception as e:
logger.error(f"{current_process().name} error during environment cleanup: {e}")
def signal_handler(signum, frame):
"""Handle termination signals (SIGINT, SIGTERM) to gracefully shutdown environments."""
global is_terminating, active_environments, processes
if is_terminating:
return
is_terminating = True
logger.info(f"Received signal {signum}. Gracefully shutting down...")
for env in active_environments:
try:
logger.info(f"Closing environment...")
env.close()
logger.info(f"Environment closed successfully")
except Exception as e:
logger.error(f"Error closing environment: {e}")
for p in processes:
if p.is_alive():
try:
logger.info(f"Sending termination signal to process {p.name}...")
p.terminate()
except Exception as e:
logger.error(f"Error sending termination signal to process: {e}")
time.sleep(1)
for p in processes:
if p.is_alive():
try:
logger.info(f"Forcefully terminating process {p.name}...")
import signal as sig
os.kill(p.pid, sig.SIGKILL)
except Exception as e:
logger.error(f"Error forcefully terminating process: {e}")
logger.info("Shutdown complete. Exiting.")
sys.exit(0)
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
global processes
logger.info("Args: %s", args)
all_tasks = distribute_tasks(test_all_meta)
logger.info(f"Total tasks: {len(all_tasks)}")
with Manager() as manager:
shared_scores = manager.list()
task_queue = manager.Queue()
for item in all_tasks:
task_queue.put(item)
num_envs = args.num_envs
processes = []
for i in range(num_envs):
p = Process(
target=run_env_tasks,
args=(task_queue, args, shared_scores),
name=f"EnvProcess-{i+1}"
)
p.daemon = True
p.start()
processes.append(p)
logger.info(f"Started process {p.name} with PID {p.pid}")
try:
while True:
alive_count = 0
for idx, p in enumerate(processes):
if not p.is_alive():
logger.warning(f"Process {p.name} died, restarting...")
new_p = Process(
target=run_env_tasks,
args=(task_queue, args, shared_scores),
name=f"EnvProcess-Restart-{idx+1}"
)
new_p.daemon = True
new_p.start()
processes[idx] = new_p
logger.info(f"Restarted process {new_p.name} with PID {new_p.pid}")
else:
alive_count += 1
if task_queue.empty():
logger.info("All tasks finished.")
break
if alive_count == 0:
logger.error("All processes died, exiting.")
break
time.sleep(5)
for p in processes:
p.join()
except KeyboardInterrupt:
logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...")
raise
except Exception as e:
logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True)
for p in processes:
if p.is_alive():
try:
logger.info(f"Terminating process {p.name} due to error...")
p.terminate()
except Exception as term_e:
logger.error(f"Error terminating process {p.name}: {term_e}")
raise
scores = list(shared_scores)
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
def get_unfinished(
action_space, use_model, observation_type, result_dir, total_file_json
):
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
if not os.path.exists(target_dir):
return total_file_json
finished = {}
for domain in os.listdir(target_dir):
finished[domain] = []
domain_path = os.path.join(target_dir, domain)
if os.path.isdir(domain_path):
for example_id in os.listdir(domain_path):
if example_id == "onboard":
continue
example_path = os.path.join(domain_path, example_id)
if os.path.isdir(example_path):
if "result.txt" not in os.listdir(example_path):
for file in os.listdir(example_path):
os.remove(os.path.join(example_path, file))
else:
finished[domain].append(example_id)
if not finished:
return total_file_json
for domain, examples in finished.items():
if domain in total_file_json:
total_file_json[domain] = [
x for x in total_file_json[domain] if x not in examples
]
return total_file_json
def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
if not os.path.exists(target_dir):
print("New experiment, no result yet.")
return None
all_result = []
for domain in os.listdir(target_dir):
domain_path = os.path.join(target_dir, domain)
if os.path.isdir(domain_path):
for example_id in os.listdir(domain_path):
example_path = os.path.join(domain_path, example_id)
if os.path.isdir(example_path):
if "result.txt" in os.listdir(example_path):
try:
all_result.append(
float(
open(
os.path.join(example_path, "result.txt"), "r"
).read()
)
)
except:
all_result.append(0.0)
if not all_result:
print("New experiment, no result yet.")
return None
else:
print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%")
return all_result
if __name__ == "__main__":
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Register signal handlers
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
try:
args = config()
path_to_args = os.path.join(
args.result_dir,
args.action_space,
args.observation_type,
args.model,
"args.json",
)
os.makedirs(os.path.dirname(path_to_args), exist_ok=True)
with open(path_to_args, "w", encoding="utf-8") as f:
json.dump(vars(args), f, indent=4)
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
test_all_meta = json.load(f)
if args.domain != "all":
test_all_meta = {args.domain: test_all_meta[args.domain]}
test_file_list = get_unfinished(
args.action_space,
args.model,
args.observation_type,
args.result_dir,
test_all_meta,
)
left_info = ""
for domain in test_file_list:
left_info += f"{domain}: {len(test_file_list[domain])}\n"
logger.info(f"Left tasks:\n{left_info}")
get_result(
args.action_space,
args.model,
args.observation_type,
args.result_dir,
test_all_meta,
)
test(args, test_file_list)
except KeyboardInterrupt:
logger.info("Main process received KeyboardInterrupt.")
except Exception as e:
logger.error(f"Unexpected error in main process: {e}", exc_info=True)
signal_handler(signal.SIGTERM, None)
finally:
logger.info("Main process final cleanup...")
for env in active_environments:
if env is not None:
try:
logger.info("Closing environment in final cleanup...")
env.close()
except Exception as e:
logger.error(f"Error during final environment cleanup: {e}")
for p in processes:
if p is not None and p.is_alive():
try:
p.terminate()
except Exception as e:
logger.error(f"Error terminating process: {e}")
time.sleep(1)
for p in processes:
if p is not None and p.is_alive():
try:
os.kill(p.pid, signal.SIGKILL)
except Exception as e:
logger.error(f"Error force killing process: {e}")

525
run_multienv_hosted_gbox.py Normal file
View File

@ -0,0 +1,525 @@
"""Run OSWorld evaluation using hosted GBOX service"""
from __future__ import annotations
import argparse
import datetime
import json
import logging
import os
import sys
import signal
import time
from typing import List
from multiprocessing import Process, Manager
from multiprocessing import current_process
import lib_run_single
from desktop_env.desktop_env import DesktopEnv
from mm_agents.hosted_gbox_agent import HostedGboxAgent
# Global variables for signal handling
active_environments = []
processes = []
is_terminating = False
# load the environment variables from .env file
if os.path.exists(".env"):
from dotenv import load_dotenv
load_dotenv()
# Logger Configs {{{ #
def config() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Run OSWorld evaluation with hosted GBOX service"
)
# environment config
parser.add_argument("--path_to_vm", type=str, default=None)
parser.add_argument(
"--headless", action="store_true", help="Run in headless machine"
)
parser.add_argument(
"--action_space", type=str, default="pyautogui", help="Action type"
)
parser.add_argument(
"--observation_type",
choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"],
default="screenshot",
help="Observation type",
)
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
parser.add_argument("--max_steps", type=int, default=15)
# agent config
parser.add_argument("--max_trajectory_length", type=int, default=3)
parser.add_argument(
"--test_config_base_dir", type=str, default="evaluation_examples"
)
# Hosted GBOX service config
parser.add_argument(
"--gbox_service_url",
type=str,
default=os.getenv("GBOX_SERVICE_URL", "http://44.201.221.203:8000"),
help="URL of hosted GBOX service"
)
parser.add_argument(
"--gbox_service_api_key",
type=str,
default=os.getenv("GBOX_SERVICE_API_KEY"),
help="API key for hosted GBOX service"
)
parser.add_argument(
"--model",
type=str,
default="us.anthropic.claude-sonnet-4-5-20250929-v1:0",
help="Claude model to use (default: Bedrock Sonnet 4.5)"
)
parser.add_argument("--max_tokens", type=int, default=1500)
# example config
parser.add_argument("--domain", type=str, default="all")
parser.add_argument(
"--test_all_meta_path", type=str, default="evaluation_examples/test_all.json"
)
# logging related
parser.add_argument("--result_dir", type=str, default="./results_hosted_gbox")
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel")
parser.add_argument("--log_level", type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
default='INFO', help="Set the logging level")
# aws config
parser.add_argument(
"--region", type=str, default="us-east-1", help="AWS region for the VM"
)
parser.add_argument(
"--provider_name", type=str, default="aws", help="Cloud provider name"
)
parser.add_argument(
"--screen_width", type=int, default=1920, help="Screen width"
)
parser.add_argument(
"--screen_height", type=int, default=1080, help="Screen height"
)
parser.add_argument(
"--client_password",
type=str,
default=os.getenv("CLIENT_PASSWORD", "osworld-public-evaluation"),
help="Client password (default: osworld-public-evaluation)"
)
args = parser.parse_args()
return args
# }}} Logger Configs #
def setup_logger(env_idx: int = None, result_dir: str = "./results_gbox", level: str = 'INFO') -> logging.Logger:
"""Set up a logger for the current process.
Args:
env_idx: Environment index for naming (None for main process)
result_dir: Directory to store logs
level: Logging level
Returns:
Configured logger instance
"""
# Set log level
numeric_level = getattr(logging, level.upper(), None)
if not isinstance(numeric_level, int):
raise ValueError(f'Invalid log level: {level}')
# Create logger
if env_idx is not None:
logger_name = f"osworld-worker-{env_idx}"
else:
logger_name = "osworld-main"
logger = logging.getLogger(logger_name)
logger.setLevel(numeric_level)
# Remove existing handlers
logger.handlers.clear()
# Create formatters and handlers
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# Console handler
console_handler = logging.StreamHandler()
console_handler.setLevel(numeric_level)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
# File handler
os.makedirs(result_dir, exist_ok=True)
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
if env_idx is not None:
log_file = os.path.join(result_dir, f"worker_{env_idx}_{timestamp}.log")
else:
log_file = os.path.join(result_dir, f"main_{timestamp}.log")
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(numeric_level)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
return logger
logger = logging.getLogger("osworld-main")
def check_completed_tasks(result_dir: str, test_all_meta: dict) -> List[str]:
"""Check which tasks have already been completed.
Args:
result_dir: Directory containing results
test_all_meta: Dictionary of domain -> list of task IDs
Returns:
List of completed task IDs (format: "domain/task_id")
"""
completed = []
for domain, examples in test_all_meta.items():
for example_id in examples:
result_path = os.path.join(
result_dir,
"pyautogui",
"screenshot",
"claude-sonnet-4-5", # Model name from args
domain,
example_id,
"result.txt"
)
if os.path.exists(result_path):
completed.append(f"{domain}/{example_id}")
logger.info(f"Task {domain}/{example_id} already completed (result found)")
return completed
def report_current_results(target_dir: str) -> List[float]:
"""Report current results from completed tasks.
Args:
target_dir: Directory containing results
Returns:
List of scores (0.0 or 1.0)
"""
all_result = []
for domain in os.listdir(target_dir):
domain_path = os.path.join(target_dir, domain)
if os.path.isdir(domain_path):
for example_id in os.listdir(domain_path):
example_path = os.path.join(domain_path, example_id)
if os.path.isdir(example_path):
if "result.txt" in os.listdir(example_path):
try:
with open(os.path.join(example_path, "result.txt"), "r") as f:
all_result.append(float(f.read()))
except Exception as e:
logger.warning(f"Failed to read result for {domain}/{example_id}: {e}")
all_result.append(0.0)
if not all_result:
logger.info("New experiment, no results yet.")
return None
else:
success_rate = sum(all_result) / len(all_result) * 100
logger.info(f"Current Success Rate: {success_rate:.2f}% ({len(all_result)} tasks)")
return all_result
def distribute_tasks(test_all_meta: dict) -> List[tuple]:
all_tasks = []
for domain, examples in test_all_meta.items():
for example_id in examples:
all_tasks.append((domain, example_id))
return all_tasks
def process_signal_handler(signum, frame, env_idx):
"""Signal handler for child processes to gracefully shut down their environments."""
logger.info(f"Process {env_idx + 1} received signal {signum}. Shutting down...")
# Get the active_environments from the caller's frame
local_vars = frame.f_locals
active_environments = local_vars.get('active_environments', [])
# Close environment in the current process context
for env in active_environments:
if env is not None:
try:
logger.info(f"Process {env_idx + 1} closing environment...")
env.close()
logger.info(f"Process {env_idx + 1} environment closed successfully")
except Exception as e:
logger.error(f"Process {env_idx + 1} error closing environment: {e}")
logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.")
sys.exit(0)
def run_env_tasks(task_queue, args: argparse.Namespace, shared_scores: list):
"""Worker process that runs tasks from the queue using hosted GBOX service."""
active_environments = []
env = None
try:
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
REGION = args.region
screen_size = (args.screen_width, args.screen_height)
ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)])
# Create environment
env = DesktopEnv(
path_to_vm=args.path_to_vm,
action_space=args.action_space,
provider_name=args.provider_name,
region=REGION,
snapshot_name=ami_id,
screen_size=screen_size,
headless=args.headless,
os_type="Ubuntu",
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
enable_proxy=True,
client_password=args.client_password
)
active_environments.append(env)
# Get VM IP address - MCP server will handle public IP lookup if needed
vm_ip = env.vm_ip
logger.info(f"VM IP: {vm_ip}")
# Create hosted GBOX agent
agent = HostedGboxAgent(
server_url=args.gbox_service_url,
api_key=args.gbox_service_api_key,
vm_ip=vm_ip,
platform="ubuntu",
model=args.model,
max_steps=args.max_steps,
)
# Process tasks from queue
while True:
try:
item = task_queue.get(timeout=5)
except Exception:
break
domain, example_id = item
try:
config_file = os.path.join(
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
)
with open(config_file, "r", encoding="utf-8") as f:
example = json.load(f)
logger.info(f"[Domain]: {domain}")
logger.info(f"[Example ID]: {example_id}")
logger.info(f"[Instruction]: {example['instruction']}")
example_result_dir = os.path.join(
args.result_dir,
args.action_space,
args.observation_type,
args.model,
domain,
example_id,
)
os.makedirs(example_result_dir, exist_ok=True)
try:
lib_run_single.run_single_example(
agent,
env,
example,
args.max_steps,
example["instruction"],
args,
example_result_dir,
shared_scores,
)
except Exception as e:
import traceback
logger.error(f"Exception {domain}/{example_id}: {e}")
logger.error(traceback.format_exc())
try:
env.controller.end_recording(
os.path.join(example_result_dir, "recording.mp4")
)
except Exception as rec_e:
logger.error(f"Failed to end recording: {rec_e}")
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write(
json.dumps(
{"Error": f"{domain}/{example_id} - {e}"}
)
)
f.write("\n")
except Exception as e:
logger.error(f"Error processing task: {e}", exc_info=True)
except KeyboardInterrupt:
logger.info("Worker received interrupt signal")
except Exception as e:
logger.error(f"Worker error: {e}", exc_info=True)
finally:
# Cleanup
if env is not None:
try:
logger.info("Closing environment...")
env.close()
logger.info("Environment closed successfully")
except Exception as e:
logger.error(f"Error closing environment: {e}")
def main_signal_handler(signum, frame):
"""Signal handler for main process to gracefully shut down all child processes."""
global is_terminating
if is_terminating:
logger.info("Already terminating, please wait...")
return
is_terminating = True
logger.info(f"Main process received signal {signum}. Shutting down all workers...")
# Terminate all child processes
for idx, proc in enumerate(processes):
if proc.is_alive():
logger.info(f"Terminating worker process {idx + 1}...")
proc.terminate()
# Wait for processes to finish with timeout
timeout = 30
start_time = time.time()
for idx, proc in enumerate(processes):
remaining_time = max(0, timeout - (time.time() - start_time))
proc.join(timeout=remaining_time)
if proc.is_alive():
logger.warning(f"Worker {idx + 1} did not terminate gracefully, forcing...")
proc.kill()
proc.join()
logger.info("All workers terminated. Exiting.")
sys.exit(0)
if __name__ == "__main__":
args = config()
# Setup main logger
logger = setup_logger(env_idx=None, result_dir=args.result_dir, level=args.log_level)
# Validate hosted service configuration
if not args.gbox_service_url:
logger.error("GBOX_SERVICE_URL not set (use --gbox_service_url or GBOX_SERVICE_URL env var)")
sys.exit(1)
if not args.gbox_service_api_key:
logger.error("GBOX_SERVICE_API_KEY not set (use --gbox_service_api_key or GBOX_SERVICE_API_KEY env var)")
sys.exit(1)
logger.info(f"Using hosted GBOX service at: {args.gbox_service_url}")
logger.info(f"Model: {args.model}")
logger.info(f"Max steps: {args.max_steps}")
logger.info(f"Number of parallel environments: {args.num_envs}")
# Setup signal handlers
signal.signal(signal.SIGINT, main_signal_handler)
signal.signal(signal.SIGTERM, main_signal_handler)
# Load test configuration
logger.info(f"Loading test configuration from: {args.test_all_meta_path}")
with open(args.test_all_meta_path, "r") as f:
test_all_meta = json.load(f)
# Filter by domain if specified
if args.domain != "all":
if args.domain in test_all_meta:
test_all_meta = {args.domain: test_all_meta[args.domain]}
logger.info(f"Filtering to domain: {args.domain}")
else:
logger.error(f"Domain '{args.domain}' not found in test configuration")
sys.exit(1)
# Check for completed tasks
completed_tasks = check_completed_tasks(args.result_dir, test_all_meta)
logger.info(f"Found {len(completed_tasks)} completed tasks")
# Distribute tasks
all_tasks = distribute_tasks(test_all_meta)
logger.info(f"Total tasks to run: {len(all_tasks)}")
# Filter out completed tasks
all_tasks = [task for task in all_tasks if f"{task[0]}/{task[1]}" not in completed_tasks]
logger.info(f"Tasks remaining after filtering completed: {len(all_tasks)}")
if not all_tasks:
logger.info("No tasks to run. All tasks already completed.")
# Report current results
target_dir = os.path.join(
args.result_dir,
args.action_space,
args.observation_type,
args.model if getattr(args, 'model_dir_name', None) is None else args.model_dir_name
)
if os.path.exists(target_dir):
report_current_results(target_dir)
sys.exit(0)
# Create shared task queue
manager = Manager()
task_queue = manager.Queue()
shared_scores = manager.list()
# Populate queue
for task in all_tasks:
task_queue.put(task)
# Start worker processes
logger.info(f"Starting {args.num_envs} worker processes...")
for env_idx in range(args.num_envs):
proc = Process(
target=run_env_tasks,
args=(task_queue, args, shared_scores)
)
proc.start()
processes.append(proc)
logger.info(f"Started worker process {env_idx + 1} (PID: {proc.pid})")
# Wait for all processes to complete
try:
for idx, proc in enumerate(processes):
proc.join()
logger.info(f"Worker process {idx + 1} completed")
except KeyboardInterrupt:
logger.info("Received interrupt, shutting down...")
main_signal_handler(signal.SIGINT, None)
# Report final results
logger.info("=" * 50)
logger.info("EVALUATION COMPLETE")
logger.info("=" * 50)
target_dir = os.path.join(
args.result_dir,
args.action_space,
args.observation_type,
args.model
)
if os.path.exists(target_dir):
final_results = report_current_results(target_dir)
if final_results:
success_rate = sum(final_results) / len(final_results) * 100
logger.info(f"Final Success Rate: {success_rate:.2f}% ({len(final_results)} tasks)")
logger.info("Exiting...")

View File

@ -3,29 +3,34 @@
You should first host the OpenCUA model on your local machine or a server.
Command for OpenCUA-72B:
```
python run_multienv_opencua.py \
--headless \
--observation_type screenshot \
--model OpenCUA-72B \
--result_dir ./results\
--test_all_meta_path evaluation_examples/test_nogdrive.json \
--max_steps 100 \
--num_envs 30 \
--coordinate_type qwen25
```
Command for OpenCUA-7B and OpenCUA-32B:
```
python run_multienv_opencua.py \
--headless \
--observation_type screenshot \
--model OpenCUA-32B \
--result_dir ./results --test_all_meta_path evaluation_examples/test_all_no_gdrive.json \
--max_steps 100 \
--num_envs 30 \
--coordinate_type qwen25
```
Command for OpenCUA-Qwen2-7B and OpenCUA-A3B:
```
python run_multienv_opencua.py \
--headless \
--observation_type screenshot \
--model OpenCUA-A3B \
--result_dir ./results \
--result_dir ./results\
--test_all_meta_path evaluation_examples/test_nogdrive.json \
--max_steps 100 \
--num_envs 10 \
--coordinate_type relative
--num_envs 30 \
--coordinate_type qwen25 \
--use_old_sys_prompt
```
"""
@ -44,7 +49,7 @@ from multiprocessing import Process, Manager
from multiprocessing import current_process
import lib_run_single
from desktop_env.desktop_env import DesktopEnv
from mm_agents.opencua_agent import OpenCUAAgent
from mm_agents.opencua import OpenCUAAgent
# Global variables for signal handling
active_environments = []
@ -76,8 +81,8 @@ def config() -> argparse.Namespace:
default="screenshot",
help="Observation type",
)
parser.add_argument("--sleep_after_execution", type=float, default=3.0)
parser.add_argument("--max_steps", type=int, default=15)
parser.add_argument("--sleep_after_execution", type=float, default=5.0)
parser.add_argument("--max_steps", type=int, default=100)
# evaluation config
parser.add_argument(
@ -85,7 +90,7 @@ def config() -> argparse.Namespace:
)
# lm config
parser.add_argument("--model", type=str, default="opencua")
parser.add_argument("--model", type=str, default=None)
parser.add_argument("--temperature", type=float, default=0)
parser.add_argument("--top_p", type=float, default=0.9)
parser.add_argument("--max_tokens", type=int, default=2048)
@ -94,13 +99,14 @@ def config() -> argparse.Namespace:
# OpenCUAagent config
parser.add_argument("--cot_level", type=str, default="l2", help="CoT version: l1, l2, l3. Default is l2 includes 'thought' and 'action'")
parser.add_argument("--history_type", type=str, default="action_history", help="Use action to represent history steps", choices=["action_history", "thought_history", "observation_history"])
parser.add_argument("--coordinate_type", type=str, default="relative", help="Type of coordinate: Qwen2-VL or Kimi-VL based models use 'relative'; Qwen2.5-VL based models use 'qwen25'", choices=["relative", "qwen25"])
parser.add_argument("--coordinate_type", type=str, default="qwen25", help="Type of coordinate: Qwen2-VL or Kimi-VL based models use 'relative'; Qwen2.5-VL based models use 'qwen25'", choices=["relative", "qwen25"])
parser.add_argument("--max_image_history_length", type=int, default=3, help="The max number of images in the history.")
parser.add_argument("--use_old_sys_prompt", action="store_true", help="Use the old system prompt for OpenCUA-7B and OpenCUA-32B")
# example config
parser.add_argument("--domain", type=str, default="all")
parser.add_argument(
"--test_all_meta_path", type=str, default="evaluation_examples/test_all.json"
"--test_all_meta_path", type=str, default="evaluation_examples/test_nogdrive.json"
)
# logging related
@ -124,6 +130,9 @@ def config() -> argparse.Namespace:
parser.add_argument(
"--screen_height", type=int, default=1080, help="Screen height"
)
parser.add_argument(
"--password", type=str, default="osworld-public-evaluation", help="The password for the computer if needed"
)
args = parser.parse_args()
return args
@ -253,6 +262,9 @@ def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: li
screen_size=(args.screen_width, args.screen_height),
coordinate_type=args.coordinate_type,
max_image_history_length=args.max_image_history_length,
max_steps=args.max_steps,
use_old_sys_prompt=args.use_old_sys_prompt,
password=args.password,
)
try:
lib_run_single.run_single_example_opencua(

907
run_multienv_os_symphony.py Normal file
View File

@ -0,0 +1,907 @@
"""
OS-Symphony Official Evaluation Script
This script serves as the official evaluation entry point for OS-Symphony.
It handles the setup of the desktop environment, agent initialization, and
execution of evaluation tasks.
For detailed evaluation metrics, configuration options, and usage instructions,
please refer to the official repository:
https://github.com/OS-Copilot/OS-Symphony
"""
import argparse
import copy
import datetime
import json
import logging
import os
import subprocess
import sys
import signal
import time
from multiprocessing import Process, Manager, current_process, Queue
from mm_agents.os_symphony.agents.os_symphony import OSSymphony
from mm_agents.os_symphony.agents.os_aci import OSACI
import lib_run_single
# Modify desktop_env, add a new function 'start'
from desktop_env.desktop_env_os_symphony import DesktopEnv as OSWorldDesktopEnv
from dotenv import load_dotenv
load_dotenv()
# only for WAA
def prepare_worker_vm_paths(base_golden_path: str, worker_idx: int):
# remove the '/' at the end
base_golden_path = base_golden_path.rstrip(os.sep)
# get parent directory (like /nvme/yangbowen/vm_stroage/waa)
parent_dir = os.path.dirname(base_golden_path)
# define the path of this worker
worker_storage_path = os.path.join(parent_dir, f"storage_{worker_idx}")
worker_backup_path = os.path.join(parent_dir, f"storage_{worker_idx}_backup")
return worker_storage_path, worker_backup_path
# only for WAA
def initialize_worker_files(golden_path: str, worker_backup_path: str, worker_storage_path: str):
"""
Initialize worker. If backup doesn't exist, then replicate from golden path.
"""
if not os.path.exists(golden_path):
raise FileNotFoundError(f"Golden VM path not found: {golden_path}")
if not os.path.exists(worker_backup_path):
logger.info(f"Initializing backup for worker from {golden_path} to {worker_backup_path} ...")
try:
os.makedirs(os.path.dirname(worker_backup_path), exist_ok=True)
if os.path.isdir(golden_path):
subprocess.check_call(['cp', '-r', '--sparse=always', golden_path, worker_backup_path])
else:
subprocess.check_call(['cp', '--sparse=always', golden_path, worker_backup_path])
logger.info(f"Backup initialization complete for {worker_backup_path}")
except subprocess.CalledProcessError as e:
logger.error(f"Failed to copy golden image to backup using cp: {e}")
raise e
else:
logger.info(f"Worker backup already exists at {worker_backup_path}, skipping copy.")
if not os.path.exists(worker_storage_path):
os.makedirs(worker_storage_path, exist_ok=True)
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
formatter = logging.Formatter(
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
)
# Set up Stdout handler
stdout_handler = logging.StreamHandler(sys.stdout)
stdout_handler.setLevel(logging.INFO)
stdout_handler.setFormatter(formatter)
stdout_handler.addFilter(logging.Filter("desktopenv"))
logger.addHandler(stdout_handler)
# Set up File Handler
# file_handler = logging.FileHandler(filename="log.txt")
# file_handler.setLevel(logging.ERROR)
# file_handler.setFormatter(formatter)
# file_handler.addFilter(logging.Filter("desktopenv"))
# logger.addHandler(file_handler)
# Logger Configs
logger = logging.getLogger("desktopenv.experiment")
# Global variables for signal handling
active_environments = []
processes = []
is_terminating = False
def distribute_tasks(test_all_meta: dict) -> list:
all_tasks = []
for domain, examples in test_all_meta.items():
for example_id in examples:
all_tasks.append((domain, example_id))
return all_tasks
def process_signal_handler(signum, frame, env_idx):
logger.info(f"Process {env_idx + 1} received signal {signum}. Shutting down...")
local_vars = frame.f_locals
active_environments = local_vars.get("active_environments", [])
for env in active_environments:
if env is not None:
try:
logger.info(f"Process {env_idx + 1} closing environment...")
env.close()
logger.info(f"Process {env_idx + 1} environment closed successfully")
except Exception as e:
logger.error(f"Process {env_idx + 1} error closing environment: {e}")
logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.")
sys.exit(0)
def run_env_tasks(
task_queue: Queue,
args: argparse.Namespace,
shared_scores: list,
engine_params_for_orchestrator,
engine_params_for_grounder,
engine_params_for_coder,
engine_params_for_memoryer,
engine_params_for_searcher,
worker_id: int,
):
active_environments = []
env = None
search_env = None
try:
# Use IMAGE_ID_MAP for AWS provider to get snapshot_name
snapshot_name = None
region = getattr(args, "region", "us-east-1")
platform = 'linux'
screen_size = (args.screen_width, args.screen_height)
if "osworld" in args.benchmark:
if args.provider_name == "aws":
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
ami_id = IMAGE_ID_MAP[region].get(screen_size, IMAGE_ID_MAP[region][(1920, 1080)])
env = OSWorldDesktopEnv(
path_to_vm=args.path_to_vm,
action_space=args.action_space,
provider_name=args.provider_name,
region=region,
snapshot_name=ami_id,
screen_size=screen_size,
headless=args.headless,
os_type="Ubuntu",
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"]
)
elif args.provider_name == "docker":
env = OSWorldDesktopEnv(
path_to_vm=args.path_to_vm,
action_space=args.action_space,
provider_name=args.provider_name,
region=region,
snapshot_name=snapshot_name,
screen_size=screen_size,
headless=args.headless,
os_type="Ubuntu",
require_a11y_tree=args.observation_type
in ["a11y_tree", "screenshot_a11y_tree", "som"],
enable_proxy=True,
client_password=getattr(args, "client_password", "")
)
else:
raise Exception("Don't support other providers!")
env.start()
if args.provider_name == "aws":
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
ami_id = IMAGE_ID_MAP[region].get(screen_size, IMAGE_ID_MAP[region][(1920, 1080)])
search_env = OSWorldDesktopEnv(
path_to_vm=args.path_to_vm,
action_space=args.action_space,
provider_name=args.provider_name,
region=region,
snapshot_name=ami_id,
screen_size=screen_size,
headless=args.headless,
os_type="Ubuntu",
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"]
)
elif args.provider_name == "docker":
search_env = OSWorldDesktopEnv(
path_to_vm=args.path_to_vm,
action_space=args.action_space,
provider_name=args.provider_name,
region=region,
snapshot_name=snapshot_name,
screen_size=screen_size,
headless=args.headless,
os_type="Ubuntu",
require_a11y_tree=args.observation_type
in ["a11y_tree", "screenshot_a11y_tree", "som"],
enable_proxy=True,
client_password=getattr(args, "client_password", "")
)
else:
raise Exception("Don't support other providers!")
engine_params_for_ocr = copy.deepcopy(engine_params_for_orchestrator)
engine_params_for_ocr["agent_name"] = "ocr"
os_aci = OSACI(
env=env,
search_env=search_env,
platform=platform,
client_password=args.client_password,
engine_params_for_ocr=engine_params_for_ocr,
engine_params_for_grounder=engine_params_for_grounder,
engine_params_for_coder=engine_params_for_coder,
engine_params_for_searcher=engine_params_for_searcher,
screen_width=args.screen_width,
screen_height=args.screen_height,
)
agent = OSSymphony(
engine_params_for_orchestrator,
engine_params_for_memoryer,
os_aci,
platform=platform,
client_password=args.client_password,
max_trajectory_length=args.max_trajectory_length,
enable_reflection=args.enable_reflection,
)
active_environments.append(env)
active_environments.append(search_env)
logger.info(f"Process {current_process().name} started.")
while True:
try:
item = task_queue.get(timeout=5)
except Exception:
break
domain, example_id = item
try:
config_file = os.path.join(
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
)
with open(config_file, "r", encoding="utf-8") as f:
example = json.load(f)
if args.enable_rewrite_instruction and "rewritten_instruction" in example:
instruction = example["rewritten_instruction"]
else:
instruction = example["instruction"]
example_result_dir = os.path.join(
args.result_dir,
domain,
example_id
)
os.makedirs(example_result_dir, exist_ok=True)
logger.info(f"[{current_process().name}][Domain]: {domain}")
logger.info(f"[{current_process().name}][Example ID]: {example_id}")
logger.info(f"[{current_process().name}][Instruction]: {instruction}")
try:
lib_run_single.run_single_example_os_symphony(
agent,
env,
example,
args.max_steps,
instruction,
args,
example_result_dir,
shared_scores,
)
except Exception as e:
import traceback
logger.error(
f"Exception in {current_process().name} {domain}/{example_id}: {e}"
)
logger.error(traceback.format_exc())
with open(os.path.join(os.path.dirname(example_result_dir), "error.jsonl"), "a") as f:
f.write(json.dumps({"Error": f"{domain}/{example_id} - {e}"}))
f.write("\n")
except Exception as e:
logger.error(f"Task-level error in {current_process().name}: {e}")
import traceback
logger.error(traceback.format_exc())
except Exception as e:
logger.error(f"Process-level error in {current_process().name}: {e}")
import traceback
logger.error(traceback.format_exc())
finally:
logger.info(f"{current_process().name} cleaning up environment...")
try:
if env:
env.close()
logger.info(f"{current_process().name} environment closed successfully")
if search_env:
search_env.close()
logger.info(f"{current_process().name} searcher environment closed successfully")
except Exception as e:
logger.error(
f"{current_process().name} error during environment cleanup: {e}"
)
# exit function
def signal_handler(signum, frame):
global is_terminating, active_environments, processes
if is_terminating:
return
is_terminating = True
logger.info(f"Received signal {signum}. Gracefully shutting down...")
for env in active_environments:
try:
logger.info(f"Closing environment...")
env.close()
logger.info(f"Environment closed successfully")
except Exception as e:
logger.error(f"Error closing environment: {e}")
for p in processes:
if p.is_alive():
try:
logger.info(f"Sending termination signal to process {p.name}...")
p.terminate()
except Exception as e:
logger.error(f"Error sending termination signal to process: {e}")
time.sleep(1)
for p in processes:
if p.is_alive():
try:
logger.info(f"Forcefully terminating process {p.name}...")
import signal as sig
os.kill(p.pid, sig.SIGKILL)
except Exception as e:
logger.error(f"Error forcefully terminating process: {e}")
logger.info("Shutdown complete. Exiting.")
sys.exit(0)
def config() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Run end-to-end evaluation on the benchmark"
)
# environment config
parser.add_argument("--path_to_vm", type=str, default=None)
parser.add_argument(
"--provider_name",
type=str,
default="vmware",
help="Virtualization provider (vmware, docker, aws, azure, gcp, virtualbox)",
)
parser.add_argument(
"--headless", action="store_true", help="Run in headless machine"
)
parser.add_argument(
"--action_space", type=str, default="pyautogui", help="Action type"
)
parser.add_argument(
"--observation_type",
choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"],
default="screenshot",
help="Observation type",
)
parser.add_argument(
"--num_envs",
type=int,
default=1,
help="Number of environments to run in parallel",
)
parser.add_argument("--screen_width", type=int, default=1920, help="Main environment's width")
parser.add_argument("--screen_height", type=int, default=1080, help="Main environment's height")
parser.add_argument("--sleep_after_execution", type=float, default=1.0)
parser.add_argument("--max_steps", type=int, default=15)
# Benchmark
parser.add_argument("--benchmark", type=str, default="osworld", help="osworld / waa / macos")
parser.add_argument("--domain", type=str, default="all")
parser.add_argument(
"--test_all_meta_path", type=str, default="evaluation_examples/osworld/test_all.json"
)
parser.add_argument(
"--test_config_base_dir", type=str, default="evaluation_examples"
)
parser.add_argument("--result_dir", type=str, default="./results")
parser.add_argument(
"--region", type=str, default="us-east-1", help="AWS region for the VM for OSWorld."
)
parser.add_argument(
"--client_password", type=str, default="password", help="Client password for OSWorld. Aws is 'osworld-public-evaluation', other is 'password'"
)
parser.add_argument(
"--proxy", type=str, default="http://10.1.8.5:23128", help="Important! Proxy setting, format should be http://<ip>:<port>, if no-use, set it empty"
)
# agent config
parser.add_argument("--max_trajectory_length", type=int, default=8)
parser.add_argument("--enable_reflection", action="store_true", default=False)
parser.add_argument("--enable_rewrite_instruction", action="store_true", default=False)
parser.add_argument(
"--tool_config",
type=str,
help="The path of tool config yaml"
)
# generator-agent config
parser.add_argument("--orchestrator_provider", type=str, default="openai")
parser.add_argument("--orchestrator_model", type=str, default="gpt-5")
parser.add_argument(
"--orchestrator_url",
type=str,
default="",
help="The URL of the main orchestrator model API.",
)
parser.add_argument(
"--orchestrator_api_key",
type=str,
default="",
help="The API key of the main orchestrator model.",
)
parser.add_argument(
"--orchestrator_temperature",
type=float,
default=None,
help="Temperature to fix the orchestrator model at (e.g. o3 can only be run with 1.0)",
)
parser.add_argument("--orchestrator_keep_first_image", action="store_true", default=False, help="Whether keep the first image(first state) in the orchestrator agent")
# code-agent config
parser.add_argument("--coder_provider", type=str, default="openai")
parser.add_argument("--coder_model", type=str, default="gpt-4o")
parser.add_argument(
"--coder_url",
type=str,
default="",
help="The URL of the coder model API.",
)
parser.add_argument(
"--coder_api_key",
type=str,
default="",
help="The API key of the coder model.",
)
parser.add_argument(
"--coder_temperature",
type=float,
default=None,
help="Temperature to fix the coder model at (e.g. o3 can only be run with 1.0)",
)
parser.add_argument(
"--coder_budget",
type=int,
default=20,
help="Max inner loop steps of coder agent",
)
# reflection-memory agent config
parser.add_argument("--memoryer_provider", type=str, default="openai")
parser.add_argument("--memoryer_model", type=str, default="gpt-4o")
parser.add_argument(
"--memoryer_url",
type=str,
default="",
help="The URL of the memoryer model API.",
)
parser.add_argument(
"--memoryer_api_key",
type=str,
default="",
help="The API key of the memoryer model.",
)
parser.add_argument(
"--memoryer_temperature",
type=float,
default=None,
help="Temperature to fix the memoryer model at (e.g. o3 can only be run with 1.0)",
)
parser.add_argument(
"--memoryer_max_images",
type=int,
default=9,
help="Max images of memoryer model"
)
# search model config
parser.add_argument("--searcher_provider", type=str, default="openai")
parser.add_argument("--searcher_model", type=str, default="gpt-4o")
parser.add_argument(
"--searcher_url",
type=str,
default="",
help="The URL of the searcher model API.",
)
parser.add_argument(
"--searcher_api_key",
type=str,
default="",
help="The API key of the searcher model.",
)
parser.add_argument(
"--searcher_temperature",
type=float,
default=None,
help="Temperature to fix searcher model at (e.g. o3 can only be run with 1.0)",
)
parser.add_argument(
"--searcher_type",
type=str,
default="vlm",
help="Type of search agent, vlm/llm(all in search action), default is vlm",
)
parser.add_argument(
"--searcher_engine",
type=str,
default="google",
help="Type of search engine, google / duckduckgo",
)
parser.add_argument(
"--searcher_budget",
type=int,
default=20,
help="Max inner loop steps of search agent",
)
parser.add_argument(
"--searcher_screen_width",
type=int,
default=1920,
help="Search enviroment's width",
)
parser.add_argument(
"--searcher_screen_height",
type=int,
default=1080,
help="Search enviroment's height",
)
parser.add_argument(
"--searcher_path_to_vm",
type=str,
default="/nvme/yangbowen/vm_stroage/osworld/Ubuntu.qcow2",
help="Searcher Env VM's path (OSWorld'VM Path)",
)
# grounding model config, temperature is 0 with hardcode
parser.add_argument(
"--grounder_provider",
type=str,
required=True,
help="The provider for the grounder model",
)
parser.add_argument(
"--grounder_url", type=str, required=True, help="The URL of the grounder model"
)
parser.add_argument(
"--grounder_api_key",
type=str,
default="",
help="The API key of the grounder model.",
)
parser.add_argument(
"--grounder_model",
type=str,
required=True,
help="The model name for the grounder model",
)
parser.add_argument(
"--grounding_width",
type=int,
required=True,
help="Width of screenshot image after processor rescaling",
)
parser.add_argument(
"--grounding_height",
type=int,
required=True,
help="Height of screenshot image after processor rescaling",
)
parser.add_argument(
"--grounding_smart_resize",
action="store_true", default=False,
help="UI-TARS-1.5 and ScaleCUA needs smart resize, if this set, grounding_width and grounding_height is no use.",
)
parser.add_argument(
"--grounder_zoom_in_time",
type=int,
default=1,
help="Zoom-in times for grounder agent, aiming to enhance grounding ability.",
)
parser.add_argument(
"--exp_name",
type=str,
default="",
help="Experiment Name",
)
args = parser.parse_args()
return args
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
global processes
logger.info("Args: %s", args)
all_tasks = distribute_tasks(test_all_meta)
logger.info(f"Total tasks: {len(all_tasks)}")
engine_params_for_orchestrator = {
"engine_type": args.orchestrator_provider,
"model": args.orchestrator_model,
"base_url": getattr(args, "orchestrator_url", ""),
"api_key": getattr(args, "orchestrator_api_key", ""),
"temperature": getattr(args, "orchestrator_temperature", None),
"tool_config": args.tool_config,
"keep_first_image": args.orchestrator_keep_first_image,
"agent_name": "orchestrator"
}
engine_params_for_grounder = {
"engine_type": args.grounder_provider,
"model": args.grounder_model,
"base_url": getattr(args, "grounder_url", ""),
"api_key": getattr(args, "grounder_api_key", ""),
"grounding_width": args.grounding_width,
"grounding_height": args.grounding_height,
"grounding_smart_resize": args.grounding_smart_resize,
"grounder_zoom_in_time": args.grounder_zoom_in_time,
"agent_name": "grounder"
}
engine_params_for_coder = {
"engine_type": args.coder_provider,
"model": args.coder_model,
"base_url": getattr(args, "coder_url", ""),
"api_key": getattr(args, "coder_api_key", ""),
"temperature": getattr(args, "coder_temperature", None),
"budget": args.coder_budget,
"agent_name": "coder"
}
engine_params_for_memoryer = {
"engine_type": args.memoryer_provider,
"model": args.memoryer_model,
"base_url": getattr(args, "memoryer_url", ""),
"api_key": getattr(args, "memoryer_api_key", ""),
"temperature": getattr(args, "memoryer_temperature", None),
"max_images": args.memoryer_max_images,
"agent_name": "memoryer"
}
engine_params_for_searcher = {
"engine_type": args.searcher_provider,
"model": args.searcher_model,
"base_url": getattr(args, "searcher_url", ""),
"api_key": getattr(args, "searcher_api_key", ""),
"temperature": getattr(args, "searcher_temperature", None),
"budget": args.searcher_budget,
"type": args.searcher_type,
"engine": args.searcher_engine,
"agent_name": "searcher"
}
# --- Initialize Worker Path ---
num_envs = args.num_envs
# only for waa
if args.benchmark == "waa":
logger.info(f"[WindowsAgentArena] Initializing storage for {num_envs} workers from golden image: {args.path_to_vm}")
for i in range(num_envs):
s_path, b_path = prepare_worker_vm_paths(args.path_to_vm, i)
initialize_worker_files(args.path_to_vm, b_path, s_path)
with Manager() as manager:
shared_scores = manager.list()
task_queue = manager.Queue()
for item in all_tasks:
task_queue.put(item)
processes = []
for worker_id in range(num_envs):
p = Process(
target=run_env_tasks,
args=(
task_queue,
args,
shared_scores,
engine_params_for_orchestrator,
engine_params_for_grounder,
engine_params_for_coder,
engine_params_for_memoryer,
engine_params_for_searcher,
worker_id
),
name=f"EnvProcess-{worker_id+1}",
)
p.daemon = True
p.start()
processes.append(p)
logger.info(f"Started process {p.name} with PID {p.pid}")
try:
while True:
alive_count = 0
for idx, p in enumerate(processes):
if not p.is_alive():
logger.warning(f"Process {p.name} died, restarting...")
new_p = Process(
target=run_env_tasks,
args=(
task_queue,
args,
shared_scores,
engine_params_for_orchestrator,
engine_params_for_grounder,
engine_params_for_coder,
engine_params_for_memoryer,
engine_params_for_searcher,
idx
),
name=f"EnvProcess-Restart-{idx+1}",
)
new_p.daemon = True
new_p.start()
processes[idx] = new_p
logger.info(
f"Restarted process {new_p.name} with PID {new_p.pid}"
)
else:
alive_count += 1
if task_queue.empty():
logger.info("All tasks finished.")
break
if alive_count == 0:
logger.error("All processes died, exiting.")
break
time.sleep(5)
for p in processes:
p.join()
except KeyboardInterrupt:
logger.info(
"Main process received KeyboardInterrupt. Initiating graceful shutdown..."
)
raise
except Exception as e:
logger.error(
f"Unexpected error while waiting for processes: {e}", exc_info=True
)
for p in processes:
if p.is_alive():
try:
logger.info(f"Terminating process {p.name} due to error...")
p.terminate()
except Exception as term_e:
logger.error(f"Error terminating process {p.name}: {term_e}")
raise
scores = list(shared_scores)
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
def get_unfinished(
target_dir, total_file_json
):
if not os.path.exists(target_dir):
return total_file_json
finished = {}
for domain in os.listdir(target_dir):
finished[domain] = []
domain_path = os.path.join(target_dir, domain)
if os.path.isdir(domain_path):
for example_id in os.listdir(domain_path):
if example_id == "onboard":
continue
example_path = os.path.join(domain_path, example_id)
if os.path.isdir(example_path):
if "result.txt" not in os.listdir(example_path):
# empty all files under example_id
for file in os.listdir(example_path):
os.remove(os.path.join(example_path, file))
else:
finished[domain].append(example_id)
if not finished:
return total_file_json
for domain, examples in finished.items():
if domain in total_file_json:
total_file_json[domain] = [
x for x in total_file_json[domain] if x not in examples
]
return total_file_json
def get_result(target_dir, total_file_json: dict):
if not os.path.exists(target_dir):
print("New experiment, no result yet.")
return None
# list for all tasks
all_result = []
for domain, example_id_list in total_file_json.items():
for example_id in example_id_list:
example_path = os.path.join(target_dir, domain, example_id)
if os.path.isdir(example_path):
if "result.txt" in os.listdir(example_path):
# empty all files under example_id
try:
all_result.append(
float(
open(
os.path.join(example_path, "result.txt"), "r"
).read()
)
)
except:
all_result.append(0.0)
else:
all_result.append(0.0)
else:
all_result.append(0.0)
if not all_result:
print("New experiment, no result yet.")
return None
else:
print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%")
return all_result
if __name__ == "__main__":
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
####### The complete version of the list of examples #######
os.environ["TOKENIZERS_PARALLELISM"] = "false"
args = config()
if args.exp_name != "":
args.result_dir = os.path.join(
args.result_dir,
args.exp_name
)
else:
args.result_dir = os.path.join(
args.result_dir,
args.action_space,
args.observation_type,
args.model
)
path_to_args = os.path.join(
args.result_dir,
"args.json"
)
os.makedirs(os.path.dirname(path_to_args), exist_ok=True)
with open(path_to_args, "w", encoding="utf-8") as f:
json.dump(vars(args), f, indent=4)
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
test_all_meta = json.load(f)
if args.domain != "all":
test_all_meta = {args.domain: test_all_meta[args.domain]}
logger.info(f"====================\nExperiment on {args.benchmark} is started\n====================")
test_file_list = get_unfinished(
target_dir=args.result_dir,
total_file_json=test_all_meta,
)
left_info = ""
for domain in test_file_list:
left_info += f"{domain}: {len(test_file_list[domain])}\n"
logger.info(f"Left tasks:\n{left_info}")
get_result(
target_dir=args.result_dir,
total_file_json=test_all_meta
)
test(
args,
test_file_list
)
logger.info(f"====================\nExperiment on {args.benchmark} is ended\n====================")
logger.info(f"====================\nExperiment {args.exp_name} is totally ended!\n====================")

View File

@ -57,13 +57,13 @@ def config() -> argparse.Namespace:
parser.add_argument("--model", type=str, default="qwen3-vl")
parser.add_argument("--temperature", type=float, default=0)
parser.add_argument("--top_p", type=float, default=0.9)
parser.add_argument("--max_tokens", type=int, default=1500)
parser.add_argument("--max_tokens", type=int, default=32768)
parser.add_argument("--stop_token", type=str, default=None)
parser.add_argument(
"--coord",
type=str,
choices=["absolute", "relative"],
default="absolute",
default="relative",
help="Coordinate system for agent outputs (absolute or relative)",
)
parser.add_argument(
@ -99,7 +99,7 @@ def config() -> argparse.Namespace:
"--provider_name",
type=str,
default="docker",
choices=["aws", "virtualbox", "vmware", "docker", "azure"],
choices=["aws", "virtualbox", "vmware", "docker", "azure", "aliyun"],
help="Provider name",
)
parser.add_argument(

540
run_multienv_seedagent.py Normal file
View File

@ -0,0 +1,540 @@
from __future__ import annotations
import argparse
import datetime
import json
import logging
import os
import sys
import signal
import time
from typing import List, Dict
from multiprocessing import Process, Manager
from multiprocessing import current_process
import lib_run_single
from desktop_env.desktop_env import DesktopEnv
from mm_agents.seed_agent import SeedAgent
import os
# Global variables for signal handling
active_environments = []
processes = []
is_terminating = False
# load the environment variables from .env file
if os.path.exists(".env"):
from dotenv import load_dotenv
load_dotenv()
# Logger Configs {{{ #
def config() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Run end-to-end evaluation on the benchmark"
)
# environment config
parser.add_argument("--path_to_vm", type=str, default=None)
parser.add_argument(
"--headless", action="store_true", help="Run in headless machine"
)
parser.add_argument(
"--action_space", type=str, default="pyautogui", help="Action type"
)
parser.add_argument(
"--observation_type",
choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"],
default="screenshot",
help="Observation type",
)
parser.add_argument("--sleep_after_execution", type=float, default=5.0)
parser.add_argument("--max_steps", type=int, default=100)
# evaluation config
parser.add_argument(
"--test_config_base_dir", type=str, default="evaluation_examples"
)
# lm config
parser.add_argument("--model", type=str, default="doubao-1-5-thinking-vision-pro-250428")
parser.add_argument("--model_type", type=str, default="doubao", choices=["doubao", "qwen25"])
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--top_p", type=float, default=0.7)
parser.add_argument("--max_tokens", type=int, default=4096)
parser.add_argument("--use_thinking", action="store_true", default=False)
parser.add_argument("--max_trajectory_length", type=int, default=None, help="The max number of trajectory steps.")
parser.add_argument("--history_n", type=int, default=5, help="The max number of images in the history.")
parser.add_argument("--language", type=str, default="Chinese", help="Language for the agent.")
# example config
parser.add_argument("--domain", type=str, default="all")
parser.add_argument(
"--test_all_meta_path", type=str, default="evaluation_examples/test_all.json"
)
# logging related
parser.add_argument("--result_dir", type=str, default="./results")
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel")
parser.add_argument("--log_level", type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
default='INFO', help="Set the logging level")
# aws config
parser.add_argument(
"--region", type=str, default="us-east-1", help="AWS region for the VM"
)
parser.add_argument(
"--provider_name", type=str, default="aws", choices=["aws", "virtualbox", "vmware", "docker", "azure"], help="Provider name"
)
parser.add_argument(
"--client_password", type=str, default="", help="Client password"
)
parser.add_argument(
"--screen_width", type=int, default=1920, help="Screen width"
)
parser.add_argument(
"--screen_height", type=int, default=1080, help="Screen height"
)
args = parser.parse_args()
return args
args = config() # Get command line arguments first
logger = logging.getLogger()
log_level = getattr(logging, args.log_level.upper())
logger.setLevel(log_level)
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
file_handler = logging.FileHandler(
os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8"
)
debug_handler = logging.FileHandler(
os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8"
)
stdout_handler = logging.StreamHandler(sys.stdout)
file_handler.setLevel(logging.INFO)
debug_handler.setLevel(logging.DEBUG)
stdout_handler.setLevel(log_level)
formatter = logging.Formatter(
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
)
file_handler.setFormatter(formatter)
debug_handler.setFormatter(formatter)
stdout_handler.setFormatter(formatter)
stdout_handler.addFilter(logging.Filter("desktopenv"))
logger.addHandler(file_handler)
logger.addHandler(debug_handler)
logger.addHandler(stdout_handler)
# }}} Logger Configs #
logger = logging.getLogger("desktopenv.experiment")
def distribute_tasks(test_all_meta: dict) -> List[tuple]:
all_tasks = []
for domain, examples in test_all_meta.items():
for example_id in examples:
all_tasks.append((domain, example_id))
return all_tasks
def process_signal_handler(signum, frame, env_idx):
"""Signal handler for child processes to gracefully shut down their environments."""
logger.info(f"Process {env_idx + 1} received signal {signum}. Shutting down...")
# Get the active_environments from the caller's frame
local_vars = frame.f_locals
active_environments = local_vars.get('active_environments', [])
# Close environment in the current process context
for env in active_environments:
if env is not None:
try:
logger.info(f"Process {env_idx + 1} closing environment...")
env.close()
logger.info(f"Process {env_idx + 1} environment closed successfully")
except Exception as e:
logger.error(f"Process {env_idx + 1} error closing environment: {e}")
logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.")
sys.exit(0)
def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: list):
active_environments = []
env = None
try:
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
REGION = args.region
screen_size = (args.screen_width, args.screen_height)
ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)])
env = DesktopEnv(
path_to_vm=args.path_to_vm,
action_space=args.action_space,
provider_name=args.provider_name,
region=REGION,
snapshot_name=ami_id,
screen_size=screen_size,
headless=args.headless,
os_type="Ubuntu",
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
enable_proxy=True,
client_password=args.client_password
)
active_environments.append(env)
agent = SeedAgent(
model=args.model,
model_type=args.model_type,
max_tokens=args.max_tokens,
top_p=args.top_p,
temperature=args.temperature,
max_trajectory_length=args.max_trajectory_length,
history_n=args.history_n,
use_thinking=args.use_thinking,
)
logger.info(f"Process {current_process().name} started.")
while True:
try:
item = task_queue.get(timeout=5)
except Exception:
break
domain, example_id = item
try:
config_file = os.path.join(
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
)
with open(config_file, "r", encoding="utf-8") as f:
example = json.load(f)
logger.info(f"[{current_process().name}][Domain]: {domain}")
logger.info(f"[{current_process().name}][Example ID]: {example_id}")
logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}")
example_result_dir = os.path.join(
args.result_dir,
args.action_space,
args.observation_type,
args.model,
domain,
example_id,
)
os.makedirs(example_result_dir, exist_ok=True)
try:
lib_run_single.run_single_example(
agent,
env,
example,
args.max_steps,
example["instruction"],
args,
example_result_dir,
shared_scores,
)
except Exception as e:
import traceback
logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}")
logger.error(traceback.format_exc())
try:
env.controller.end_recording(
os.path.join(example_result_dir, "recording.mp4")
)
except Exception as rec_e:
logger.error(f"Failed to end recording: {rec_e}")
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write(
json.dumps(
{"Error": f"{domain}/{example_id} - {e}"}
)
)
f.write("\n")
except Exception as e:
logger.error(f"Task-level error in {current_process().name}: {e}")
import traceback
logger.error(traceback.format_exc())
except Exception as e:
logger.error(f"Process-level error in {current_process().name}: {e}")
import traceback
logger.error(traceback.format_exc())
finally:
logger.info(f"{current_process().name} cleaning up environment...")
try:
if env:
env.close()
logger.info(f"{current_process().name} environment closed successfully")
except Exception as e:
logger.error(f"{current_process().name} error during environment cleanup: {e}")
def signal_handler(signum, frame):
"""Handle termination signals (SIGINT, SIGTERM) to gracefully shutdown environments."""
global is_terminating, active_environments, processes
# Avoid duplicate handling
if is_terminating:
return
is_terminating = True
logger.info(f"Received signal {signum}. Gracefully shutting down...")
# Close all registered environments in the main process
for env in active_environments:
try:
logger.info(f"Closing environment...")
env.close()
logger.info(f"Environment closed successfully")
except Exception as e:
logger.error(f"Error closing environment: {e}")
# Send termination signal to all child processes first
for p in processes:
if p.is_alive():
try:
logger.info(f"Sending termination signal to process {p.name}...")
p.terminate()
except Exception as e:
logger.error(f"Error sending termination signal to process: {e}")
# Allow a short time for processes to handle their own cleanup
time.sleep(1)
# Forcefully terminate any processes that didn't exit
for p in processes:
if p.is_alive():
try:
logger.info(f"Forcefully terminating process {p.name}...")
import signal as sig
os.kill(p.pid, sig.SIGKILL)
except Exception as e:
logger.error(f"Error forcefully terminating process: {e}")
logger.info("Shutdown complete. Exiting.")
sys.exit(0)
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
global processes
logger.info("Args: %s", args)
all_tasks = distribute_tasks(test_all_meta)
logger.info(f"Total tasks: {len(all_tasks)}")
with Manager() as manager:
shared_scores = manager.list()
task_queue = manager.Queue()
for item in all_tasks:
task_queue.put(item)
num_envs = args.num_envs
processes = []
for i in range(num_envs):
p = Process(
target=run_env_tasks,
args=(task_queue, args, shared_scores),
name=f"EnvProcess-{i+1}"
)
p.daemon = True
p.start()
processes.append(p)
logger.info(f"Started process {p.name} with PID {p.pid}")
try:
while True:
alive_count = 0
for idx, p in enumerate(processes):
if not p.is_alive():
logger.warning(f"Process {p.name} died, restarting...")
new_p = Process(
target=run_env_tasks,
args=(task_queue, args, shared_scores),
name=f"EnvProcess-Restart-{idx+1}"
)
new_p.daemon = True
new_p.start()
processes[idx] = new_p
logger.info(f"Restarted process {new_p.name} with PID {new_p.pid}")
else:
alive_count += 1
if task_queue.empty():
logger.info("All tasks finished.")
break
if alive_count == 0:
logger.error("All processes died, exiting.")
break
time.sleep(5)
for p in processes:
p.join()
except KeyboardInterrupt:
logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...")
raise
except Exception as e:
logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True)
for p in processes:
if p.is_alive():
try:
logger.info(f"Terminating process {p.name} due to error...")
p.terminate()
except Exception as term_e:
logger.error(f"Error terminating process {p.name}: {term_e}")
raise
scores = list(shared_scores)
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
def get_unfinished(
action_space, use_model, observation_type, result_dir, total_file_json
):
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
if not os.path.exists(target_dir):
return total_file_json
finished = {}
for domain in os.listdir(target_dir):
finished[domain] = []
domain_path = os.path.join(target_dir, domain)
if os.path.isdir(domain_path):
for example_id in os.listdir(domain_path):
if example_id == "onboard":
continue
example_path = os.path.join(domain_path, example_id)
if os.path.isdir(example_path):
if "result.txt" not in os.listdir(example_path):
# empty all files under example_id
for file in os.listdir(example_path):
os.remove(os.path.join(example_path, file))
else:
finished[domain].append(example_id)
if not finished:
return total_file_json
for domain, examples in finished.items():
if domain in total_file_json:
total_file_json[domain] = [
x for x in total_file_json[domain] if x not in examples
]
return total_file_json
def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
if not os.path.exists(target_dir):
print("New experiment, no result yet.")
return None
all_result = []
for domain in os.listdir(target_dir):
domain_path = os.path.join(target_dir, domain)
if os.path.isdir(domain_path):
for example_id in os.listdir(domain_path):
example_path = os.path.join(domain_path, example_id)
if os.path.isdir(example_path):
if "result.txt" in os.listdir(example_path):
# empty all files under example_id
try:
all_result.append(
float(
open(
os.path.join(example_path, "result.txt"), "r"
).read()
)
)
except:
all_result.append(0.0)
if not all_result:
print("New experiment, no result yet.")
return None
else:
print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%")
return all_result
if __name__ == "__main__":
####### The complete version of the list of examples #######
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Register signal handlers for graceful termination
signal.signal(signal.SIGINT, signal_handler) # Handle Ctrl+C
signal.signal(signal.SIGTERM, signal_handler) # Handle termination signal
try:
args = config()
# save args to json in result_dir/action_space/observation_type/model/args.json
path_to_args = os.path.join(
args.result_dir,
args.action_space,
args.observation_type,
args.model,
"args.json",
)
os.makedirs(os.path.dirname(path_to_args), exist_ok=True)
with open(path_to_args, "w", encoding="utf-8") as f:
json.dump(vars(args), f, indent=4)
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
test_all_meta = json.load(f)
if args.domain != "all":
test_all_meta = {args.domain: test_all_meta[args.domain]}
test_file_list = get_unfinished(
args.action_space,
args.model,
args.observation_type,
args.result_dir,
test_all_meta,
)
left_info = ""
for domain in test_file_list:
left_info += f"{domain}: {len(test_file_list[domain])}\n"
logger.info(f"Left tasks:\n{left_info}")
get_result(
args.action_space,
args.model,
args.observation_type,
args.result_dir,
test_all_meta,
)
test(args, test_file_list)
except KeyboardInterrupt:
logger.info("Main process received KeyboardInterrupt.")
# Signal handler will take care of cleanup
except Exception as e:
logger.error(f"Unexpected error in main process: {e}", exc_info=True)
# Also trigger cleanup for unhandled exceptions
signal_handler(signal.SIGTERM, None)
finally:
# Final cleanup in case any environments or processes remain
logger.info("Main process final cleanup...")
for env in active_environments:
if env is not None:
try:
logger.info(f"Closing environment in final cleanup...")
env.close()
logger.info(f"Environment closed successfully in final cleanup")
except Exception as e:
logger.error(f"Error during final environment cleanup: {e}")
# First try gentle termination
for p in processes:
if p is not None and p.is_alive():
try:
logger.info(f"Terminating process {p.name}...")
p.terminate()
except Exception as e:
logger.error(f"Error terminating process: {e}")
# Wait a moment for processes to terminate
time.sleep(1)
# Then force kill if needed
for p in processes:
if p is not None and p.is_alive():
try:
logger.info(f"Force killing process {p.name}...")
os.kill(p.pid, signal.SIGKILL)
logger.info(f"Process {p.name} force killed")
except Exception as e:
logger.error(f"Error force killing process: {e}")

58
run_os_symphony.sh Normal file
View File

@ -0,0 +1,58 @@
EXP_NAME="xxx"
export AWS_SECRET_ACCESS_KEY="xxx"
export AWS_ACCESS_KEY_ID="xxx"
export AWS_REGION="us-east-1"
export AWS_SUBNET_ID="xxx"
export AWS_SECURITY_GROUP_ID="xxx"
# >> logs/${EXP_NAME}.log 2>&1
python run_multienv_os_symphony.py \
--provider_name "aws" \
--region "us-east-1" \
--client_password "osworld-public-evaluation" \
--headless \
--num_envs 7 \
--max_steps 50 \
--benchmark osworld \
--domain "all" \
--test_all_meta_path evaluation_examples/test_nogdrive.json \
--result_dir "results" \
--tool_config mm_agents/os_symphony/tool/all_tool_config.yaml \
--orchestrator_provider "openai" \
--orchestrator_model "gpt-5" \
--orchestrator_url "xxx" \
--orchestrator_api_key "xxx" \
--orchestrator_temperature 0.1 \
--orchestrator_keep_first_image \
--max_trajectory_length 8 \
--grounder_provider "vllm" \
--grounder_model "UI-TARS-1.5-7B" \
--grounder_api_key "none" \
--grounder_url "xxx" \
--grounding_smart_resize \
--grounding_width 1920 \
--grounding_height 1080 \
--coder_provider "openai" \
--coder_model "gpt-5" \
--coder_url "xxx" \
--coder_api_key "xxx" \
--coder_temperature 0.1 \
--coder_budget 20 \
--memoryer_provider "openai" \
--memoryer_model "gpt-5" \
--memoryer_url "xxx" \
--memoryer_api_key "xxx" \
--memoryer_temperature 0.1 \
--memoryer_max_images 8 \
--searcher_provider "openai" \
--searcher_model "gpt-5" \
--searcher_url "xxx" \
--searcher_api_key "xxx" \
--searcher_temperature 0.1 \
--searcher_type "vlm" \
--searcher_engine "google" \
--searcher_budget 20 \
--searcher_screen_width 1920 \
--searcher_screen_height 1080 \
--sleep_after_execution 3 \
--exp_name ${EXP_NAME} \
--enable_reflection >> logs/${EXP_NAME}.log 2>&1

View File

@ -23,7 +23,7 @@ class InstallPlaywrightCommand(install):
setup(
name="desktop_env",
version="1.0.0",
version="1.0.1",
author="Tianbao Xie, Danyang Zhang, Jixuan Chen, Xiaochuan Li, Siheng Zhao, Ruisheng Cao, Toh Jing Hua, etc.",
author_email="tianbaoxiexxx@gmail.com",
description="The package provides a desktop environment for setting and evaluating desktop automation tasks.",
@ -38,7 +38,7 @@ setup(
],
python_requires='>=3.10',
install_requires=[
"numpy~=1.24.4",
"numpy>=1.26,<3",
"Pillow~=11.0.0",
"fabric",
"gymnasium~=0.28.1",
@ -53,7 +53,7 @@ setup(
"pyautogui~=0.9.54",
"psutil~=5.9.6",
"tqdm~=4.65.0",
"pandas~=2.2.3",
"pandas>=2.2,<2.3",
"flask~=3.0.0",
"requests-toolbelt~=1.0.0",
"ag2~=0.9.7",