Compare commits
13 Commits
feat/seeda
...
main
| Author | SHA1 | Date |
|---|---|---|
|
|
5ef8bdfa35 | |
|
|
439e178a2e | |
|
|
951e1928c8 | |
|
|
02a35be067 | |
|
|
662826f57e | |
|
|
410ec63a89 | |
|
|
031696e83c | |
|
|
f593f35b1c | |
|
|
ac31778ee3 | |
|
|
60caa52fc4 | |
|
|
41477a9c40 | |
|
|
78433ecfcf | |
|
|
9540454b0a |
|
|
@ -0,0 +1,499 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import re
|
||||
from typing import Callable, Any, Optional, Tuple
|
||||
from typing import List, Dict, Union
|
||||
|
||||
import gymnasium as gym
|
||||
|
||||
from desktop_env.controllers.python import PythonController
|
||||
from desktop_env.controllers.setup import SetupController
|
||||
from desktop_env.evaluators import metrics, getters
|
||||
from desktop_env.providers import create_vm_manager_and_provider
|
||||
|
||||
logger = logging.getLogger("desktopenv.env")
|
||||
|
||||
Metric = Callable[[Any, Any], float]
|
||||
Getter = Callable[[gym.Env, Dict[str, Any]], Any]
|
||||
|
||||
MAX_RETRIES = 5 # Maximum retries for environment setup
|
||||
|
||||
|
||||
|
||||
def _fix_pyautogui_less_than_bug(command: str) -> str:
|
||||
"""
|
||||
Fix PyAutoGUI '<' character bug by converting it to hotkey("shift", ',') calls.
|
||||
|
||||
This fixes the known PyAutoGUI issue where typing '<' produces '>' instead.
|
||||
References:
|
||||
- https://github.com/asweigart/pyautogui/issues/198
|
||||
- https://github.com/xlang-ai/OSWorld/issues/257
|
||||
|
||||
Args:
|
||||
command (str): The original pyautogui command
|
||||
|
||||
Returns:
|
||||
str: The fixed command with '<' characters handled properly
|
||||
"""
|
||||
# Pattern to match press('<') or press('\u003c') calls
|
||||
press_pattern = r'pyautogui\.press\(["\'](?:<|\\u003c)["\']\)'
|
||||
|
||||
# Handle press('<') calls
|
||||
def replace_press_less_than(match):
|
||||
return 'pyautogui.hotkey("shift", ",")'
|
||||
|
||||
# First handle press('<') calls
|
||||
command = re.sub(press_pattern, replace_press_less_than, command)
|
||||
|
||||
# Pattern to match typewrite calls with quoted strings
|
||||
typewrite_pattern = r'pyautogui\.typewrite\((["\'])(.*?)\1\)'
|
||||
|
||||
# Then handle typewrite calls
|
||||
def process_typewrite_match(match):
|
||||
quote_char = match.group(1)
|
||||
content = match.group(2)
|
||||
|
||||
# Preprocess: Try to decode Unicode escapes like \u003c to actual '<'
|
||||
# This handles cases where '<' is represented as escaped Unicode
|
||||
try:
|
||||
# Attempt to decode unicode escapes
|
||||
decoded_content = content.encode('utf-8').decode('unicode_escape')
|
||||
content = decoded_content
|
||||
except UnicodeDecodeError:
|
||||
# If decoding fails, proceed with original content to avoid breaking existing logic
|
||||
pass # English comment: Graceful degradation - fall back to original content if decoding fails
|
||||
|
||||
# Check if content contains '<'
|
||||
if '<' not in content:
|
||||
return match.group(0)
|
||||
|
||||
# Split by '<' and rebuild
|
||||
parts = content.split('<')
|
||||
result_parts = []
|
||||
|
||||
for i, part in enumerate(parts):
|
||||
if i == 0:
|
||||
# First part
|
||||
if part:
|
||||
result_parts.append(f"pyautogui.typewrite({quote_char}{part}{quote_char})")
|
||||
else:
|
||||
# Add hotkey for '<' and then typewrite for the rest
|
||||
result_parts.append('pyautogui.hotkey("shift", ",")')
|
||||
if part:
|
||||
result_parts.append(f"pyautogui.typewrite({quote_char}{part}{quote_char})")
|
||||
|
||||
return '; '.join(result_parts)
|
||||
|
||||
command = re.sub(typewrite_pattern, process_typewrite_match, command)
|
||||
|
||||
return command
|
||||
|
||||
|
||||
class DesktopEnv(gym.Env):
|
||||
"""
|
||||
DesktopEnv with OpenAI Gym interface. It provides a desktop environment for setting and evaluating desktop automation tasks.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
provider_name: str = "vmware",
|
||||
region: str = None,
|
||||
path_to_vm: str = None,
|
||||
snapshot_name: str = "init_state",
|
||||
action_space: str = "pyautogui",
|
||||
cache_dir: str = "cache",
|
||||
screen_size: Tuple[int] = (int(os.environ.get("SCREEN_WIDTH", 1920)), int(os.environ.get("SCREEN_HEIGHT", 1080))),
|
||||
headless: bool = False,
|
||||
require_a11y_tree: bool = True,
|
||||
require_terminal: bool = False,
|
||||
os_type: str = "Ubuntu",
|
||||
enable_proxy: bool = False,
|
||||
client_password: str = "",
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
provider_name (str): virtualization provider name, default to "vmware"
|
||||
region (str): the region for allocate machines, work for cloud services, default to "us-east-1"
|
||||
path_to_vm (str): path to .vmx file
|
||||
snapshot_name (str): snapshot name to revert to, default to "init_state"
|
||||
action_space (str): "computer_13" | "pyautogui"
|
||||
cache_dir (str): cache directory to cache task-related stuffs like
|
||||
reference file for evaluation
|
||||
screen_size (Tuple[int]): screen size of the VM
|
||||
headless (bool): whether to run the VM in headless mode
|
||||
require_a11y_tree (bool): whether to require accessibility tree
|
||||
require_terminal (bool): whether to require terminal output
|
||||
os_type (str): operating system type, default to "Ubuntu"
|
||||
enable_proxy (bool): whether to enable proxy support, default to False
|
||||
"""
|
||||
# Initialize VM manager and vitualization provider
|
||||
self.region = region
|
||||
self.provider_name = provider_name
|
||||
self.enable_proxy = enable_proxy # Store proxy enablement setting
|
||||
if client_password == "":
|
||||
if self.provider_name == "aws":
|
||||
self.client_password = "osworld-public-evaluation"
|
||||
else:
|
||||
self.client_password = "password"
|
||||
else:
|
||||
self.client_password = client_password
|
||||
|
||||
self.screen_width = screen_size[0]
|
||||
self.screen_height = screen_size[1]
|
||||
|
||||
# Default
|
||||
self.server_port = 5000
|
||||
self.chromium_port = 9222
|
||||
self.vnc_port = 8006
|
||||
self.vlc_port = 8080
|
||||
|
||||
# Initialize with default (no proxy) provider
|
||||
self.current_use_proxy = False
|
||||
self.manager, self.provider = None, None
|
||||
self.os_type = os_type
|
||||
self.path_to_vm = path_to_vm
|
||||
# Track whether environment has been used (step/setup) to optimize snapshot revert
|
||||
# docker, aws, gcp, azure are always unused as the emulator starts from a clean state
|
||||
# vmware, virtualbox are always used as the emulator starts from a dirty state
|
||||
if self.provider_name in {"docker", "aws", "gcp", "azure", "aliyun", "volcengine"}:
|
||||
self.is_environment_used = False
|
||||
elif self.provider_name in {"vmware", "virtualbox"}:
|
||||
self.is_environment_used = True
|
||||
else:
|
||||
raise ValueError(f"Invalid provider name: {self.provider_name}")
|
||||
|
||||
self.snapshot_name = snapshot_name
|
||||
self.cache_dir_base: str = cache_dir
|
||||
self.headless = headless
|
||||
self.require_a11y_tree = require_a11y_tree
|
||||
self.require_terminal = require_terminal
|
||||
|
||||
# mode: human or machine
|
||||
self.instruction = None
|
||||
assert action_space in ["computer_13", "pyautogui", "claude_computer_use", "autoglm_computer_use"]
|
||||
self.action_space = action_space # todo: refactor it to the ActType
|
||||
|
||||
# episodic stuffs, like counters, will be updated or reset
|
||||
# when calling self.reset()
|
||||
self._traj_no: int = -1
|
||||
self._step_no: int = 0
|
||||
self.action_history: List[Dict[str, any]] = []
|
||||
|
||||
def start(self):
|
||||
# Initialize emulator and controller
|
||||
if not self.manager and not self.provider:
|
||||
logger.info("Initializing...")
|
||||
self.manager, self.provider = create_vm_manager_and_provider(self.provider_name, self.region, use_proxy=False)
|
||||
|
||||
if self.path_to_vm:
|
||||
self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(self.path_to_vm))) \
|
||||
if self.provider_name in {"vmware", "virtualbox"} else self.path_to_vm
|
||||
else:
|
||||
self.path_to_vm = self.manager.get_vm_path(os_type=self.os_type, region=self.region, screen_size=(self.screen_width, self.screen_height))
|
||||
|
||||
self._start_emulator()
|
||||
|
||||
def _start_emulator(self):
|
||||
try:
|
||||
# Power on the virtual machine
|
||||
self.provider.start_emulator(self.path_to_vm, self.headless, self.os_type)
|
||||
|
||||
# Get the ip from the virtual machine, and setup the controller
|
||||
vm_ip_ports = self.provider.get_ip_address(self.path_to_vm).split(':')
|
||||
self.vm_ip = vm_ip_ports[0]
|
||||
# Get the ports from the virtual machine (for Docker provider only)
|
||||
if len(vm_ip_ports) > 1:
|
||||
self.server_port = int(vm_ip_ports[1])
|
||||
self.chromium_port = int(vm_ip_ports[2])
|
||||
self.vnc_port = int(vm_ip_ports[3])
|
||||
self.vlc_port = int(vm_ip_ports[4])
|
||||
self.controller = PythonController(vm_ip=self.vm_ip, server_port=self.server_port)
|
||||
self.setup_controller = SetupController(vm_ip=self.vm_ip, server_port=self.server_port, chromium_port=self.chromium_port, vlc_port=self.vlc_port, cache_dir=self.cache_dir_base, client_password=self.client_password, screen_width=self.screen_width, screen_height=self.screen_height)
|
||||
|
||||
except Exception as e:
|
||||
try:
|
||||
self.provider.stop_emulator(self.path_to_vm)
|
||||
except Exception as stop_err:
|
||||
logger.warning(f"Cleanup after interrupt failed: {stop_err}")
|
||||
raise
|
||||
|
||||
def _revert_to_snapshot(self):
|
||||
# Revert to certain snapshot of the virtual machine, and refresh the path to vm and ip of vm
|
||||
# due to the fact it could be changed when implemented by cloud services
|
||||
path_to_vm = self.provider.revert_to_snapshot(self.path_to_vm, self.snapshot_name)
|
||||
if path_to_vm and not path_to_vm == self.path_to_vm:
|
||||
# path_to_vm has to be a new path
|
||||
|
||||
self.manager.delete_vm(self.path_to_vm, self.region)
|
||||
self.manager.add_vm(path_to_vm, self.region)
|
||||
self.manager.occupy_vm(path_to_vm, os.getpid(), self.region)
|
||||
self.path_to_vm = path_to_vm
|
||||
|
||||
def _save_state(self, snapshot_name=None):
|
||||
# Save the current virtual machine state to a certain snapshot name
|
||||
self.provider.save_state(self.path_to_vm, snapshot_name)
|
||||
|
||||
def close(self):
|
||||
# Close (release) the virtual machine
|
||||
self.provider.stop_emulator(self.path_to_vm)
|
||||
|
||||
def reset(self, task_config: Optional[Dict[str, Any]] = None, seed=None, options=None) -> Dict[str, Any]:
|
||||
|
||||
# Reset to certain task in OSWorld
|
||||
logger.info("Resetting environment...")
|
||||
logger.info("Switching task...")
|
||||
logger.info("Setting counters...")
|
||||
self._traj_no += 1
|
||||
self._step_no = 0
|
||||
self.action_history.clear()
|
||||
|
||||
for attempt in range(MAX_RETRIES):
|
||||
# Only revert to snapshot if environment has been used (step/setup)
|
||||
# This optimization is especially important for cloud providers like AWS
|
||||
# where unnecessary snapshot operations are costly and time-consuming
|
||||
|
||||
if task_config is not None:
|
||||
# Only consider task proxy requirement if proxy is enabled at system level
|
||||
task_use_proxy = task_config.get("proxy", False) and self.enable_proxy
|
||||
if not self.enable_proxy and task_config.get("proxy", False):
|
||||
logger.info("Task requires proxy but proxy is disabled at system level, ignoring proxy requirement.")
|
||||
|
||||
if task_use_proxy != self.current_use_proxy:
|
||||
# keep because get_info_from_website depend on this
|
||||
self.current_use_proxy = task_use_proxy
|
||||
|
||||
if self.is_environment_used:
|
||||
logger.info("Environment has been used, reverting to snapshot {}...".format(self.snapshot_name))
|
||||
self._revert_to_snapshot()
|
||||
logger.info("Starting emulator...")
|
||||
self._start_emulator()
|
||||
logger.info("Emulator started.")
|
||||
# Reset the usage flag after reverting
|
||||
self.is_environment_used = False
|
||||
else:
|
||||
logger.info("Environment is clean, skipping snapshot revert (provider: {}).".format(self.provider_name))
|
||||
|
||||
if task_config is not None:
|
||||
if task_config.get("proxy", False) and self.enable_proxy:
|
||||
# If using proxy and proxy is enabled, set up the proxy configuration
|
||||
self.setup_controller._proxy_setup(self.client_password)
|
||||
self._set_task_info(task_config)
|
||||
self.setup_controller.reset_cache_dir(self.cache_dir)
|
||||
logger.info("Setting up environment...")
|
||||
success = self.setup_controller.setup(self.config, task_config.get("proxy", False) and self.enable_proxy)
|
||||
if success:
|
||||
# Mark environment as used when setup is successfully executed
|
||||
if self.config: # Only mark as used if there were actual setup operations
|
||||
self.is_environment_used = True
|
||||
break
|
||||
else:
|
||||
logger.error(
|
||||
"Environment setup failed, retrying (%d/%d)...",
|
||||
attempt + 1,
|
||||
MAX_RETRIES,
|
||||
)
|
||||
time.sleep(5)
|
||||
else:
|
||||
break
|
||||
|
||||
logger.info("Environment setup complete.")
|
||||
|
||||
observation = self._get_obs()
|
||||
return observation
|
||||
|
||||
def _get_obs(self):
|
||||
# We provide screenshot, accessibility_tree (optional), terminal (optional), and instruction.
|
||||
# can be customized and scaled
|
||||
return {
|
||||
"screenshot": self.controller.get_screenshot(),
|
||||
"accessibility_tree": self.controller.get_accessibility_tree() if self.require_a11y_tree else None,
|
||||
"terminal": self.controller.get_terminal_output() if self.require_terminal else None,
|
||||
"instruction": self.instruction
|
||||
}
|
||||
|
||||
@property
|
||||
def vm_platform(self):
|
||||
return self.controller.get_vm_platform()
|
||||
|
||||
@property
|
||||
def vm_screen_size(self):
|
||||
return self.controller.get_vm_screen_size()
|
||||
|
||||
def _set_task_info(self, task_config: Dict[str, Any]):
|
||||
"""Set task info (proxy logic is handled in reset method)"""
|
||||
self.task_id: str = task_config["id"]
|
||||
self.cache_dir: str = os.path.join(self.cache_dir_base, self.task_id)
|
||||
os.makedirs(self.cache_dir, exist_ok=True)
|
||||
self.instruction = task_config["instruction"]
|
||||
self.config = task_config["config"] if "config" in task_config else []
|
||||
|
||||
self._set_evaluator_info(task_config)
|
||||
|
||||
def _set_evaluator_info(self, task_config: Dict[str, Any]):
|
||||
"""Set evaluator information from task config"""
|
||||
if "evaluator" not in task_config:
|
||||
return
|
||||
# evaluator dict
|
||||
# func -> metric function string, or list of metric function strings
|
||||
# conj -> conjunction of multiple metrics if func is a list with length > 1, "and"/"or"
|
||||
# result -> result getter config, or list of result getter configs
|
||||
# expected (optional) -> expected getter config, or list of expected getter configs
|
||||
# options (optional) -> metric options, or list of metric options
|
||||
# if func is a str list, then result, expected (if exists), options (if exists) should also be lists of the same length
|
||||
# even if one of the metrics does not need expected or options field, it should be included in the list with None
|
||||
self.evaluator = task_config["evaluator"]
|
||||
self.metric: Metric = [getattr(metrics, func) for func in self.evaluator["func"]] \
|
||||
if isinstance(self.evaluator["func"], list) \
|
||||
else getattr(metrics, self.evaluator["func"])
|
||||
self.metric_conj: str = self.evaluator.get("conj", "and") # take conjunction of multiple metrics
|
||||
if "result" in self.evaluator and len(self.evaluator["result"]) > 0:
|
||||
self.result_getter: Getter = [getattr(getters, "get_{:}".format(res["type"])) for res in
|
||||
self.evaluator["result"]] \
|
||||
if isinstance(self.evaluator["result"], list) \
|
||||
else getattr(getters, "get_{:}".format(self.evaluator["result"]["type"]))
|
||||
else:
|
||||
self.result_getter = [None] * len(self.metric) \
|
||||
if isinstance(self.metric, list) \
|
||||
else None
|
||||
|
||||
if "expected" in self.evaluator and len(self.evaluator["expected"]) > 0:
|
||||
self.expected_getter: Getter = [getattr(getters, "get_{:}".format(exp["type"])) if exp else None for exp in
|
||||
self.evaluator["expected"]] \
|
||||
if isinstance(self.evaluator["expected"], list) \
|
||||
else getattr(getters, "get_{:}".format(self.evaluator["expected"]["type"]))
|
||||
else:
|
||||
self.expected_getter = [None] * len(self.metric) \
|
||||
if isinstance(self.metric, list) \
|
||||
else None
|
||||
self.metric_options: Union[List[Dict[str, Any]], Dict[str, Any]] = [opt if opt else {} for opt in
|
||||
self.evaluator["options"]] \
|
||||
if isinstance(self.evaluator.get("options", {}), list) \
|
||||
else self.evaluator["options"] \
|
||||
if "options" in self.evaluator \
|
||||
else [{}] * len(self.metric) \
|
||||
if isinstance(self.metric, list) \
|
||||
else {}
|
||||
|
||||
assert (not isinstance(self.evaluator["func"], list)
|
||||
or (len(self.metric) == len(self.result_getter) == len(self.expected_getter) == len(
|
||||
self.metric_options)))
|
||||
|
||||
def step(self, action, pause=2):
|
||||
self._step_no += 1
|
||||
self.action_history.append(action)
|
||||
|
||||
# Mark environment as used when step is called
|
||||
self.is_environment_used = True
|
||||
|
||||
reward = 0 # todo: Define reward calculation for each example
|
||||
done = False # todo: Define episode termination condition for each example
|
||||
info = {}
|
||||
logger.info(f"Step {self._step_no} in trajectory {self._traj_no} with action: {action}")
|
||||
# handle the special actions
|
||||
if action in ['WAIT', 'FAIL', 'DONE'] or (type(action) == dict and action['action_type'] in ['WAIT', 'FAIL', 'DONE']):
|
||||
if action == 'WAIT' or (type(action) == dict and action.get('action_type') == 'WAIT'):
|
||||
time.sleep(pause)
|
||||
elif action == 'FAIL' or (type(action) == dict and action.get('action_type') == 'FAIL'):
|
||||
done = True
|
||||
info = {"fail": True}
|
||||
elif action == 'DONE' or (type(action) == dict and action.get('action_type') == 'DONE'):
|
||||
done = True
|
||||
info = {"done": True}
|
||||
|
||||
if self.action_space == "computer_13":
|
||||
# the set of all possible actions defined in the action representation
|
||||
self.controller.execute_action(action)
|
||||
elif self.action_space == "pyautogui" or self.action_space == "claude_computer_use":
|
||||
if action in ['WAIT', 'FAIL', 'DONE'] or (type(action) == dict and action.get('action_type') in ['WAIT', 'FAIL', 'DONE']):
|
||||
self.controller.execute_action(action)
|
||||
else:
|
||||
# the set of all possible python commands insides `pyautogui`
|
||||
if type(action) == str:
|
||||
# Fix PyAutoGUI '<' character bug before execution
|
||||
fixed_command = _fix_pyautogui_less_than_bug(action)
|
||||
self.controller.execute_python_command(fixed_command)
|
||||
elif type(action) == dict:
|
||||
# Fix PyAutoGUI '<' character bug before execution
|
||||
fixed_command = _fix_pyautogui_less_than_bug(action['command'])
|
||||
self.controller.execute_python_command(fixed_command)
|
||||
|
||||
time.sleep(pause)
|
||||
observation = self._get_obs()
|
||||
|
||||
return observation, reward, done, info
|
||||
|
||||
def evaluate(self):
|
||||
"""
|
||||
Evaluate whether the task is successfully completed.
|
||||
"""
|
||||
|
||||
postconfig = self.evaluator.get("postconfig", [])
|
||||
self.setup_controller.setup(postconfig, self.enable_proxy)
|
||||
# Mark environment as used if there were postconfig setup operations
|
||||
if postconfig:
|
||||
self.is_environment_used = True
|
||||
|
||||
if self.evaluator['func'] == "infeasible":
|
||||
if len(self.action_history) > 0:
|
||||
last_action = self.action_history[-1]
|
||||
if last_action == "FAIL" or (type(last_action) == dict and last_action.get('action_type') == 'FAIL'):
|
||||
return 1
|
||||
return 0
|
||||
else:
|
||||
if len(self.action_history) > 0:
|
||||
last_action = self.action_history[-1]
|
||||
if last_action == "FAIL" or (type(last_action) == dict and last_action.get('action_type') == 'FAIL'):
|
||||
return 0
|
||||
|
||||
if type(self.metric) == list:
|
||||
# Multiple metrics to evaluate whether the task is successfully completed
|
||||
results = []
|
||||
assert len(self.metric) == len(self.result_getter), "The number of metrics and result getters must be the same"
|
||||
if "expected" in self.evaluator:
|
||||
assert len(self.metric) == len(self.expected_getter), "The number of metrics and expected getters must be the same"
|
||||
for idx, metric in enumerate(self.metric):
|
||||
try:
|
||||
config = self.evaluator["result"][idx]
|
||||
result_state = self.result_getter[idx](self, config)
|
||||
except FileNotFoundError:
|
||||
logger.error("File not found!")
|
||||
if self.metric_conj == 'and':
|
||||
return 0
|
||||
|
||||
if "expected" in self.evaluator and self.expected_getter and self.evaluator["expected"]:
|
||||
expected_state = self.expected_getter[idx](self, self.evaluator["expected"][idx])
|
||||
metric: int = metric(result_state, expected_state, **self.metric_options[idx])
|
||||
else:
|
||||
metric: int = metric(result_state, **self.metric_options[idx])
|
||||
|
||||
if self.metric_conj == 'and' and float(metric) == 0.0:
|
||||
return 0
|
||||
elif self.metric_conj == 'or' and float(metric) == 1.0:
|
||||
return 1
|
||||
else:
|
||||
results.append(metric)
|
||||
|
||||
return sum(results) / len(results) if self.metric_conj == 'and' else max(results)
|
||||
else:
|
||||
# Single metric to evaluate whether the task is successfully completed
|
||||
try:
|
||||
result_state = self.result_getter(self, self.evaluator["result"])
|
||||
except FileNotFoundError:
|
||||
logger.error("File not found!")
|
||||
return 0
|
||||
|
||||
if "expected" in self.evaluator and self.expected_getter and self.evaluator["expected"]:
|
||||
expected_state = self.expected_getter(self, self.evaluator["expected"])
|
||||
metric: float = self.metric(result_state, expected_state, **self.metric_options)
|
||||
else:
|
||||
metric: float = self.metric(result_state, **self.metric_options)
|
||||
|
||||
return metric
|
||||
|
||||
def render(self, mode='rgb_array'):
|
||||
if mode == 'rgb_array':
|
||||
return self.controller.get_screenshot()
|
||||
else:
|
||||
raise ValueError('Unsupported render mode: {}'.format(mode))
|
||||
|
|
@ -461,3 +461,185 @@ def run_single_example_uipath(agent, env, example, max_steps, instruction, args,
|
|||
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{result}\n")
|
||||
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
|
||||
|
||||
|
||||
from mm_agents.os_symphony.utils.common_utils import draw_coordinates
|
||||
from mm_agents.os_symphony.utils.process_context import set_current_result_dir
|
||||
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
def run_single_example_os_symphony(agent, env, example, max_steps, instruction, args, example_result_dir, scores):
|
||||
set_current_result_dir(example_result_dir)
|
||||
|
||||
agent.reset(result_dir=example_result_dir)
|
||||
env.reset(task_config=example)
|
||||
time.sleep(30) # Wait for the environment to be ready
|
||||
obs = env._get_obs() # Get the initial observation
|
||||
done = False
|
||||
step_idx = 0
|
||||
# env.controller.start_recording()
|
||||
start_time = time.time()
|
||||
|
||||
while not done and step_idx < max_steps:
|
||||
response, actions = agent.predict(
|
||||
instruction,
|
||||
obs,
|
||||
step_idx == max_steps - 1
|
||||
)
|
||||
for action in actions:
|
||||
# Save screenshot and trajectory information
|
||||
if "reflection" in response and response["reflection"].get("is_milestone"):
|
||||
img_name = f"step_{step_idx + 1}_milestone.png"
|
||||
else:
|
||||
img_name = f"step_{step_idx + 1}.png"
|
||||
|
||||
with open(os.path.join(example_result_dir, img_name),
|
||||
"wb") as _f:
|
||||
_f.write(obs['screenshot'])
|
||||
if "coordinates" in response and response["coordinates"]:
|
||||
draw_coordinates(
|
||||
image_bytes=obs['screenshot'],
|
||||
coordinates=response["coordinates"],
|
||||
save_path=os.path.join(example_result_dir, img_name[:-4] + "_draw.png")
|
||||
)
|
||||
|
||||
logger.info("Step %d: %s", step_idx + 1, action)
|
||||
obs, reward, done, info = env.step(action, args.sleep_after_execution)
|
||||
logger.info("Done: %s", done)
|
||||
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps({
|
||||
"instruction": instruction,
|
||||
"step_num": step_idx + 1,
|
||||
"action": action,
|
||||
"response": response,
|
||||
"done": done,
|
||||
"info": info,
|
||||
"screenshot_file": img_name
|
||||
}))
|
||||
f.write("\n")
|
||||
with open(os.path.join(example_result_dir, f"traj_{step_idx+1}.json"), "w", encoding="utf-8") as f:
|
||||
json.dump({
|
||||
"step_num": step_idx + 1,
|
||||
"action": action,
|
||||
"response": response,
|
||||
"done": done,
|
||||
"info": info,
|
||||
"screenshot_file": img_name
|
||||
}, f, indent=4, ensure_ascii=False)
|
||||
if done:
|
||||
logger.info("The episode is done.")
|
||||
time.sleep(60)
|
||||
break
|
||||
step_idx += 1
|
||||
end_time = time.time()
|
||||
result = float(env.evaluate())
|
||||
logger.info("Result: %.2f", result)
|
||||
scores.append(result)
|
||||
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{result}\n")
|
||||
|
||||
with open(os.path.join(example_result_dir, "time.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{end_time-start_time:.2f}\n")
|
||||
|
||||
|
||||
def run_single_example_evocua(agent, env, example, max_steps, instruction, args, example_result_dir, scores):
|
||||
"""
|
||||
Unified run function for EvoCUAAgent (supporting both S1 and S2 modes).
|
||||
"""
|
||||
runtime_logger = setup_logger(example, example_result_dir)
|
||||
|
||||
# Reset Environment
|
||||
env.reset(task_config=example)
|
||||
|
||||
# Reset Agent
|
||||
# Handle agent reset signature differences if any
|
||||
try:
|
||||
agent.reset(runtime_logger, vm_ip=env.vm_ip)
|
||||
except Exception:
|
||||
try:
|
||||
agent.reset(runtime_logger)
|
||||
except Exception:
|
||||
agent.reset()
|
||||
|
||||
time.sleep(60) # Wait for the environment to be ready
|
||||
obs = env._get_obs() # Get the initial observation
|
||||
done = False
|
||||
step_idx = 0
|
||||
|
||||
env.controller.start_recording()
|
||||
while not done and step_idx < max_steps:
|
||||
# EvoCUAAgent.predict unified signature: returns (response, actions)
|
||||
# It handles both modes internally.
|
||||
predict_res = agent.predict(instruction, obs)
|
||||
|
||||
# Check return signature logic
|
||||
if len(predict_res) == 3:
|
||||
# Compatibility with S1 original signature if agent was updated to match
|
||||
response, actions, info_dict = predict_res
|
||||
else:
|
||||
response, actions = predict_res
|
||||
info_dict = {}
|
||||
|
||||
logger.info(f"Step {step_idx + 1} Actions: {actions}")
|
||||
|
||||
# Break if no actions (fail-safe)
|
||||
if not actions or (len(actions) == 1 and (actions[0] == "" or "error" in actions[0].lower())):
|
||||
# Allow "FAIL" or "DONE" to process through execution loop if agent outputs them as actions
|
||||
if not (actions and actions[0] in ["FAIL", "DONE"]):
|
||||
logger.warning("No valid actions returned. Breaking loop.")
|
||||
break
|
||||
|
||||
for action in actions:
|
||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S%f")
|
||||
logger.info("Executing action: %s", action)
|
||||
|
||||
# Execute
|
||||
obs, reward, done, info = env.step(action, args.sleep_after_execution)
|
||||
|
||||
logger.info("Reward: %.2f", reward)
|
||||
logger.info("Done: %s", done)
|
||||
|
||||
# Save screenshot
|
||||
screenshot_file = f"step_{step_idx + 1}_{action_timestamp}.png"
|
||||
with open(os.path.join(example_result_dir, screenshot_file), "wb") as _f:
|
||||
_f.write(obs['screenshot'])
|
||||
|
||||
# Log Trajectory
|
||||
log_entry = {
|
||||
"step_num": step_idx + 1,
|
||||
"action_timestamp": action_timestamp,
|
||||
"action": action,
|
||||
"response": response,
|
||||
"reward": reward,
|
||||
"done": done,
|
||||
"info": info,
|
||||
"screenshot_file": screenshot_file
|
||||
}
|
||||
# Add natural language info if available (S1 style)
|
||||
if info_dict:
|
||||
log_entry["natural_language_action"] = info_dict.get("action")
|
||||
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(log_entry, ensure_ascii=False))
|
||||
f.write("\n")
|
||||
|
||||
if done:
|
||||
logger.info("The episode is done.")
|
||||
break
|
||||
|
||||
step_idx += 1
|
||||
|
||||
time.sleep(20) # Wait for environment to settle
|
||||
result = env.evaluate()
|
||||
logger.info("Result: %.2f", result)
|
||||
scores.append(result)
|
||||
|
||||
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{result}\n")
|
||||
|
||||
log_task_completion(example, result, example_result_dir, args)
|
||||
|
||||
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,84 @@
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from wrapt_timeout_decorator import *
|
||||
from mm_agents.os_symphony.utils.common_utils import draw_coordinates
|
||||
from mm_agents.os_symphony.utils.process_context import set_current_result_dir
|
||||
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
def run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir, scores):
|
||||
set_current_result_dir(example_result_dir)
|
||||
|
||||
agent.reset(result_dir=example_result_dir)
|
||||
env.reset(task_config=example)
|
||||
time.sleep(30) # Wait for the environment to be ready
|
||||
obs = env._get_obs() # Get the initial observation
|
||||
done = False
|
||||
step_idx = 0
|
||||
# env.controller.start_recording()
|
||||
start_time = time.time()
|
||||
|
||||
while not done and step_idx < max_steps:
|
||||
response, actions = agent.predict(
|
||||
instruction,
|
||||
obs,
|
||||
step_idx == max_steps - 1
|
||||
)
|
||||
for action in actions:
|
||||
# Save screenshot and trajectory information
|
||||
if "reflection" in response and response["reflection"].get("is_milestone"):
|
||||
img_name = f"step_{step_idx + 1}_milestone.png"
|
||||
else:
|
||||
img_name = f"step_{step_idx + 1}.png"
|
||||
|
||||
with open(os.path.join(example_result_dir, img_name),
|
||||
"wb") as _f:
|
||||
_f.write(obs['screenshot'])
|
||||
if "coordinates" in response and response["coordinates"]:
|
||||
draw_coordinates(
|
||||
image_bytes=obs['screenshot'],
|
||||
coordinates=response["coordinates"],
|
||||
save_path=os.path.join(example_result_dir, img_name[:-4] + "_draw.png")
|
||||
)
|
||||
|
||||
logger.info("Step %d: %s", step_idx + 1, action)
|
||||
obs, reward, done, info = env.step(action, args.sleep_after_execution)
|
||||
logger.info("Done: %s", done)
|
||||
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps({
|
||||
"instruction": instruction,
|
||||
"step_num": step_idx + 1,
|
||||
"action": action,
|
||||
"response": response,
|
||||
"done": done,
|
||||
"info": info,
|
||||
"screenshot_file": img_name
|
||||
}))
|
||||
f.write("\n")
|
||||
with open(os.path.join(example_result_dir, f"traj_{step_idx+1}.json"), "w", encoding="utf-8") as f:
|
||||
json.dump({
|
||||
"step_num": step_idx + 1,
|
||||
"action": action,
|
||||
"response": response,
|
||||
"done": done,
|
||||
"info": info,
|
||||
"screenshot_file": img_name
|
||||
}, f, indent=4, ensure_ascii=False)
|
||||
if done:
|
||||
logger.info("The episode is done.")
|
||||
time.sleep(60)
|
||||
break
|
||||
step_idx += 1
|
||||
end_time = time.time()
|
||||
result = float(env.evaluate())
|
||||
logger.info("Result: %.2f", result)
|
||||
scores.append(result)
|
||||
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{result}\n")
|
||||
|
||||
with open(os.path.join(example_result_dir, "time.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{end_time-start_time:.2f}\n")
|
||||
|
|
@ -1134,10 +1134,12 @@ class PromptAgent:
|
|||
|
||||
return actions
|
||||
|
||||
def reset(self, _logger=None):
|
||||
def reset(self, _logger=None, vm_ip=None, **kwargs):
|
||||
global logger
|
||||
logger = _logger if _logger is not None else logging.getLogger("desktopenv.agent")
|
||||
|
||||
self.vm_ip = vm_ip
|
||||
|
||||
self.thoughts = []
|
||||
self.actions = []
|
||||
self.observations = []
|
||||
self.observations = []
|
||||
|
|
@ -0,0 +1,653 @@
|
|||
import os
|
||||
import re
|
||||
import json
|
||||
import logging
|
||||
import backoff
|
||||
import openai
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
|
||||
from mm_agents.evocua.utils import (
|
||||
process_image,
|
||||
encode_image,
|
||||
rewrite_pyautogui_text_inputs,
|
||||
project_coordinate_to_absolute_scale,
|
||||
log_messages
|
||||
)
|
||||
|
||||
from mm_agents.evocua.prompts import (
|
||||
S1_SYSTEM_PROMPT,
|
||||
S1_INSTRUTION_TEMPLATE,
|
||||
S1_STEP_TEMPLATE,
|
||||
S1_ACTION_HISTORY_TEMPLATE,
|
||||
S2_ACTION_DESCRIPTION,
|
||||
S2_DESCRIPTION_PROMPT_TEMPLATE,
|
||||
S2_SYSTEM_PROMPT,
|
||||
build_s2_tools_def
|
||||
)
|
||||
|
||||
logger = logging.getLogger("desktopenv.evocua")
|
||||
|
||||
class EvoCUAAgent:
|
||||
"""
|
||||
EvoCUA - A Native GUI agent model for desktop automation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "EvoCUA-S2",
|
||||
max_tokens: int = 32768,
|
||||
top_p: float = 0.9,
|
||||
temperature: float = 0.0,
|
||||
action_space: str = "pyautogui",
|
||||
observation_type: str = "screenshot",
|
||||
max_steps: int = 50,
|
||||
prompt_style: str = "S2", # "S1" or "S2"
|
||||
max_history_turns: int = 4,
|
||||
screen_size: Tuple[int, int] = (1920, 1080),
|
||||
coordinate_type: str = "relative",
|
||||
password: str = "osworld-public-evaluation",
|
||||
resize_factor: int = 32,
|
||||
**kwargs
|
||||
):
|
||||
self.model = model
|
||||
self.max_tokens = max_tokens
|
||||
self.top_p = top_p
|
||||
self.temperature = temperature
|
||||
self.action_space = action_space
|
||||
self.observation_type = observation_type
|
||||
self.max_steps = max_steps
|
||||
|
||||
self.prompt_style = prompt_style
|
||||
assert self.prompt_style in ["S1", "S2"], f"Invalid prompt_style: {self.prompt_style}"
|
||||
|
||||
self.max_history_turns = max_history_turns
|
||||
|
||||
self.screen_size = screen_size
|
||||
self.coordinate_type = coordinate_type
|
||||
self.password = password
|
||||
self.resize_factor = resize_factor
|
||||
|
||||
# Action space assertion
|
||||
assert self.action_space == "pyautogui", f"Invalid action space: {self.action_space}"
|
||||
assert self.observation_type == "screenshot", f"Invalid observation type: {self.observation_type}"
|
||||
|
||||
# State
|
||||
self.thoughts = []
|
||||
self.actions = []
|
||||
self.observations = []
|
||||
self.responses = []
|
||||
self.screenshots = [] # Stores encoded string
|
||||
self.cots = [] # For S1 style history
|
||||
|
||||
def reset(self, _logger=None, vm_ip=None):
|
||||
global logger
|
||||
if _logger:
|
||||
logger = _logger
|
||||
|
||||
self.thoughts = []
|
||||
self.actions = []
|
||||
self.observations = []
|
||||
self.responses = []
|
||||
self.screenshots = []
|
||||
self.cots = []
|
||||
|
||||
def predict(self, instruction: str, obs: Dict) -> List:
|
||||
"""
|
||||
Main prediction loop.
|
||||
"""
|
||||
|
||||
logger.info(f"========================== {self.model} ===================================")
|
||||
logger.info(f"Instruction: \n{instruction}")
|
||||
|
||||
screenshot_bytes = obs["screenshot"]
|
||||
|
||||
try:
|
||||
original_img = Image.open(BytesIO(screenshot_bytes))
|
||||
original_width, original_height = original_img.size
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read screenshot size, falling back to screen_size: {e}")
|
||||
original_width, original_height = self.screen_size
|
||||
|
||||
if self.prompt_style == "S1":
|
||||
raw_b64 = encode_image(screenshot_bytes)
|
||||
self.screenshots.append(raw_b64)
|
||||
return self._predict_s1(instruction, obs, raw_b64)
|
||||
else:
|
||||
processed_b64, p_width, p_height = process_image(screenshot_bytes, factor=self.resize_factor)
|
||||
self.screenshots.append(processed_b64)
|
||||
return self._predict_s2(
|
||||
instruction,
|
||||
obs,
|
||||
processed_b64,
|
||||
p_width,
|
||||
p_height,
|
||||
original_width,
|
||||
original_height,
|
||||
)
|
||||
|
||||
|
||||
def _predict_s2(self, instruction, obs, processed_b64, p_width, p_height, original_width, original_height):
|
||||
current_step = len(self.actions)
|
||||
current_history_n = self.max_history_turns
|
||||
|
||||
response = None
|
||||
|
||||
if self.coordinate_type == "absolute":
|
||||
resolution_info = f"* The screen's resolution is {p_width}x{p_height}."
|
||||
else:
|
||||
resolution_info = "* The screen's resolution is 1000x1000."
|
||||
|
||||
description_prompt = S2_DESCRIPTION_PROMPT_TEMPLATE.format(resolution_info=resolution_info)
|
||||
|
||||
tools_def = build_s2_tools_def(description_prompt)
|
||||
|
||||
system_prompt = S2_SYSTEM_PROMPT.format(tools_xml=json.dumps(tools_def))
|
||||
|
||||
# Retry loop for context length
|
||||
while True:
|
||||
messages = self._build_s2_messages(
|
||||
instruction,
|
||||
processed_b64,
|
||||
current_step,
|
||||
current_history_n,
|
||||
system_prompt
|
||||
)
|
||||
|
||||
try:
|
||||
response = self.call_llm({
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": self.max_tokens,
|
||||
"top_p": self.top_p,
|
||||
"temperature": self.temperature,
|
||||
})
|
||||
break
|
||||
except Exception as e:
|
||||
# Handle Context Too Large
|
||||
if self._should_giveup_on_context_error(e) and current_history_n > 0:
|
||||
current_history_n -= 1
|
||||
logger.warning(f"Context too large, retrying with history_n={current_history_n}")
|
||||
else:
|
||||
logger.error(f"Error in predict: {e}")
|
||||
break
|
||||
|
||||
self.responses.append(response)
|
||||
|
||||
low_level_instruction, pyautogui_code = self._parse_response_s2(
|
||||
response, p_width, p_height, original_width, original_height
|
||||
)
|
||||
|
||||
# new added
|
||||
current_step = len(self.actions) + 1
|
||||
first_action = pyautogui_code[0] if pyautogui_code else ""
|
||||
if current_step >= self.max_steps and str(first_action).upper() not in ("DONE", "FAIL"):
|
||||
logger.warning(f"Reached maximum steps {self.max_steps}. Forcing termination with FAIL.")
|
||||
low_level_instruction = "Fail the task because reaching the maximum step limit."
|
||||
pyautogui_code = ["FAIL"]
|
||||
|
||||
logger.info(f"Low level instruction: {low_level_instruction}")
|
||||
logger.info(f"Pyautogui code: {pyautogui_code}")
|
||||
|
||||
self.actions.append(low_level_instruction)
|
||||
return response, pyautogui_code
|
||||
|
||||
def _build_s2_messages(self, instruction, current_img, step, history_n, system_prompt):
|
||||
messages = [{"role": "system", "content": [{"type": "text", "text": system_prompt}]}]
|
||||
|
||||
previous_actions = []
|
||||
history_start_idx = max(0, step - history_n)
|
||||
for i in range(history_start_idx):
|
||||
if i < len(self.actions):
|
||||
previous_actions.append(f"Step {i+1}: {self.actions[i]}")
|
||||
previous_actions_str = "\n".join(previous_actions) if previous_actions else "None"
|
||||
|
||||
# Add History
|
||||
history_len = min(history_n, len(self.responses))
|
||||
if history_len > 0:
|
||||
hist_responses = self.responses[-history_len:]
|
||||
hist_imgs = self.screenshots[-history_len-1:-1]
|
||||
|
||||
for i in range(history_len):
|
||||
if i < len(hist_imgs):
|
||||
screenshot_b64 = hist_imgs[i]
|
||||
if i == 0:
|
||||
# First history item: Inject Instruction + Previous Actions Context
|
||||
img_url = f"data:image/png;base64,{screenshot_b64}"
|
||||
instruction_prompt = f"""
|
||||
Please generate the next move according to the UI screenshot, instruction and previous actions.
|
||||
|
||||
Instruction: {instruction}
|
||||
|
||||
Previous actions:
|
||||
{previous_actions_str}"""
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": img_url}},
|
||||
{"type": "text", "text": instruction_prompt}
|
||||
]
|
||||
})
|
||||
else:
|
||||
img_url = f"data:image/png;base64,{screenshot_b64}"
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": img_url}},
|
||||
]
|
||||
})
|
||||
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": hist_responses[i]}]
|
||||
})
|
||||
|
||||
# Current Turn
|
||||
# We re-use previous_actions_str logic for the case where history_len == 0
|
||||
|
||||
if history_len == 0:
|
||||
# First turn logic: Include Instruction + Previous Actions
|
||||
instruction_prompt = f"""
|
||||
Please generate the next move according to the UI screenshot, instruction and previous actions.
|
||||
|
||||
Instruction: {instruction}
|
||||
|
||||
Previous actions:
|
||||
{previous_actions_str}"""
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{current_img}"}},
|
||||
{"type": "text", "text": instruction_prompt}
|
||||
]
|
||||
})
|
||||
else:
|
||||
# Subsequent turns logic (context already in first history message): Image Only
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{current_img}"}}
|
||||
]
|
||||
})
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def _parse_response_s2(
|
||||
self,
|
||||
response: str,
|
||||
processed_width: int = None,
|
||||
processed_height: int = None,
|
||||
original_width: Optional[int] = None,
|
||||
original_height: Optional[int] = None,
|
||||
) -> Tuple[str, List[str]]:
|
||||
"""
|
||||
Parse LLM response and convert it to low level action and pyautogui code.
|
||||
"""
|
||||
# Prefer the real screenshot resolution (passed from predict), fallback to configured screen_size.
|
||||
if not (original_width and original_height):
|
||||
original_width, original_height = self.screen_size
|
||||
low_level_instruction = ""
|
||||
pyautogui_code: List[str] = []
|
||||
|
||||
if response is None or not response.strip():
|
||||
return low_level_instruction, pyautogui_code
|
||||
|
||||
def adjust_coordinates(x: float, y: float) -> Tuple[int, int]:
|
||||
if not (original_width and original_height):
|
||||
return int(x), int(y)
|
||||
if self.coordinate_type == "absolute":
|
||||
# scale from processed pixels to original
|
||||
if processed_width and processed_height:
|
||||
x_scale = original_width / processed_width
|
||||
y_scale = original_height / processed_height
|
||||
return int(x * x_scale), int(y * y_scale)
|
||||
return int(x), int(y)
|
||||
# relative: scale from 0..999 grid
|
||||
x_scale = original_width / 999
|
||||
y_scale = original_height / 999
|
||||
return int(x * x_scale), int(y * y_scale)
|
||||
|
||||
def process_tool_call(json_str: str) -> None:
|
||||
try:
|
||||
tool_call = json.loads(json_str)
|
||||
if tool_call.get("name") == "computer_use":
|
||||
args = tool_call["arguments"]
|
||||
action = args["action"]
|
||||
|
||||
def _clean_keys(raw_keys):
|
||||
keys = raw_keys if isinstance(raw_keys, list) else [raw_keys]
|
||||
cleaned_keys = []
|
||||
for key in keys:
|
||||
if isinstance(key, str):
|
||||
if key.startswith("keys=["):
|
||||
key = key[6:]
|
||||
if key.endswith("]"):
|
||||
key = key[:-1]
|
||||
if key.startswith("['") or key.startswith('["'):
|
||||
key = key[2:] if len(key) > 2 else key
|
||||
if key.endswith("']") or key.endswith('"]'):
|
||||
key = key[:-2] if len(key) > 2 else key
|
||||
key = key.strip()
|
||||
cleaned_keys.append(key)
|
||||
else:
|
||||
cleaned_keys.append(key)
|
||||
return cleaned_keys
|
||||
|
||||
if action == "left_click" or action == "click":
|
||||
if "coordinate" in args:
|
||||
x, y = args["coordinate"]
|
||||
adj_x, adj_y = adjust_coordinates(x, y)
|
||||
pyautogui_code.append(f"pyautogui.click({adj_x}, {adj_y})")
|
||||
else:
|
||||
pyautogui_code.append("pyautogui.click()")
|
||||
|
||||
elif action == "right_click":
|
||||
if "coordinate" in args:
|
||||
x, y = args["coordinate"]
|
||||
adj_x, adj_y = adjust_coordinates(x, y)
|
||||
pyautogui_code.append(
|
||||
f"pyautogui.rightClick({adj_x}, {adj_y})"
|
||||
)
|
||||
else:
|
||||
pyautogui_code.append("pyautogui.rightClick()")
|
||||
|
||||
elif action == "middle_click":
|
||||
if "coordinate" in args:
|
||||
x, y = args["coordinate"]
|
||||
adj_x, adj_y = adjust_coordinates(x, y)
|
||||
pyautogui_code.append(
|
||||
f"pyautogui.middleClick({adj_x}, {adj_y})"
|
||||
)
|
||||
else:
|
||||
pyautogui_code.append("pyautogui.middleClick()")
|
||||
|
||||
elif action == "double_click":
|
||||
if "coordinate" in args:
|
||||
x, y = args["coordinate"]
|
||||
adj_x, adj_y = adjust_coordinates(x, y)
|
||||
pyautogui_code.append(
|
||||
f"pyautogui.doubleClick({adj_x}, {adj_y})"
|
||||
)
|
||||
else:
|
||||
pyautogui_code.append("pyautogui.doubleClick()")
|
||||
|
||||
elif action == "triple_click":
|
||||
if "coordinate" in args:
|
||||
x, y = args["coordinate"]
|
||||
adj_x, adj_y = adjust_coordinates(x, y)
|
||||
pyautogui_code.append(
|
||||
f"pyautogui.tripleClick({adj_x}, {adj_y})"
|
||||
)
|
||||
else:
|
||||
pyautogui_code.append("pyautogui.tripleClick()")
|
||||
|
||||
elif action == "type":
|
||||
text = args.get("text", "")
|
||||
|
||||
try:
|
||||
text = text.encode('latin-1', 'backslashreplace').decode('unicode_escape')
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to unescape text: {e}")
|
||||
|
||||
logger.info(f"Pyautogui code[before rewrite]: {text}")
|
||||
|
||||
result = ""
|
||||
for char in text:
|
||||
if char == '\n':
|
||||
result += "pyautogui.press('enter')\n"
|
||||
elif char == "'":
|
||||
result += 'pyautogui.press("\'")\n'
|
||||
elif char == '\\':
|
||||
result += "pyautogui.press('\\\\')\n"
|
||||
elif char == '"':
|
||||
result += "pyautogui.press('\"')\n"
|
||||
else:
|
||||
result += f"pyautogui.press('{char}')\n"
|
||||
|
||||
pyautogui_code.append(result)
|
||||
logger.info(f"Pyautogui code[after rewrite]: {pyautogui_code}")
|
||||
|
||||
|
||||
elif action == "key":
|
||||
keys = _clean_keys(args.get("keys", []))
|
||||
|
||||
keys_str = ", ".join([f"'{key}'" for key in keys])
|
||||
if len(keys) > 1:
|
||||
pyautogui_code.append(f"pyautogui.hotkey({keys_str})")
|
||||
else:
|
||||
pyautogui_code.append(f"pyautogui.press({keys_str})")
|
||||
|
||||
elif action == "key_down":
|
||||
keys = _clean_keys(args.get("keys", []))
|
||||
for k in keys:
|
||||
pyautogui_code.append(f"pyautogui.keyDown('{k}')")
|
||||
|
||||
elif action == "key_up":
|
||||
keys = _clean_keys(args.get("keys", []))
|
||||
for k in reversed(keys):
|
||||
pyautogui_code.append(f"pyautogui.keyUp('{k}')")
|
||||
|
||||
elif action == "scroll":
|
||||
pixels = args.get("pixels", 0)
|
||||
pyautogui_code.append(f"pyautogui.scroll({pixels})")
|
||||
|
||||
elif action == "wait":
|
||||
pyautogui_code.append("WAIT")
|
||||
|
||||
elif action == "terminate":
|
||||
# Termination should respect status:
|
||||
# - success -> DONE
|
||||
# - failure -> FAIL
|
||||
# Backward compatible: missing status defaults to success.
|
||||
status = args.get("status", "success")
|
||||
if str(status).lower() == "failure":
|
||||
pyautogui_code.append("FAIL")
|
||||
else:
|
||||
pyautogui_code.append("DONE")
|
||||
|
||||
elif action == "mouse_move":
|
||||
if "coordinate" in args:
|
||||
x, y = args["coordinate"]
|
||||
adj_x, adj_y = adjust_coordinates(x, y)
|
||||
pyautogui_code.append(
|
||||
f"pyautogui.moveTo({adj_x}, {adj_y})"
|
||||
)
|
||||
else:
|
||||
pyautogui_code.append("pyautogui.moveTo(0, 0)")
|
||||
|
||||
elif action == "left_click_drag":
|
||||
if "coordinate" in args:
|
||||
x, y = args["coordinate"]
|
||||
adj_x, adj_y = adjust_coordinates(x, y)
|
||||
duration = args.get("duration", 0.5)
|
||||
pyautogui_code.append(
|
||||
f"pyautogui.dragTo({adj_x}, {adj_y}, duration={duration})"
|
||||
)
|
||||
else:
|
||||
pyautogui_code.append("pyautogui.dragTo(0, 0)")
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
logger.error(f"Failed to parse tool call: {e}")
|
||||
|
||||
lines = response.split("\n")
|
||||
inside_tool_call = False
|
||||
current_tool_call: List[str] = []
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
if line.lower().startswith(("action:")):
|
||||
if not low_level_instruction:
|
||||
low_level_instruction = line.split("Action:")[-1].strip()
|
||||
continue
|
||||
|
||||
if line.startswith("<tool_call>"):
|
||||
inside_tool_call = True
|
||||
continue
|
||||
elif line.startswith("</tool_call>"):
|
||||
if current_tool_call:
|
||||
process_tool_call("\n".join(current_tool_call))
|
||||
current_tool_call = []
|
||||
inside_tool_call = False
|
||||
continue
|
||||
|
||||
if inside_tool_call:
|
||||
current_tool_call.append(line)
|
||||
continue
|
||||
|
||||
if line.startswith("{") and line.endswith("}"):
|
||||
try:
|
||||
json_obj = json.loads(line)
|
||||
if "name" in json_obj and "arguments" in json_obj:
|
||||
process_tool_call(line)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
if current_tool_call:
|
||||
process_tool_call("\n".join(current_tool_call))
|
||||
|
||||
if not low_level_instruction and len(pyautogui_code) > 0:
|
||||
first_action = pyautogui_code[0]
|
||||
if "." in first_action:
|
||||
action_type = first_action.split(".", 1)[1].split("(", 1)[0]
|
||||
else:
|
||||
action_type = first_action.lower()
|
||||
low_level_instruction = f"Performing {action_type} action"
|
||||
|
||||
return low_level_instruction, pyautogui_code
|
||||
|
||||
|
||||
|
||||
def _predict_s1(self, instruction, obs, processed_b64):
|
||||
messages = [{"role": "system", "content": S1_SYSTEM_PROMPT.format(password=self.password)}]
|
||||
|
||||
# Reconstruct History Logic for S1 mode
|
||||
history_step_texts = []
|
||||
|
||||
for i in range(len(self.actions)):
|
||||
cot = self.cots[i] if i < len(self.cots) else {}
|
||||
|
||||
# Step Content string
|
||||
step_content = S1_STEP_TEMPLATE.format(step_num=i+1) + S1_ACTION_HISTORY_TEMPLATE.format(action=cot.get('action', ''))
|
||||
|
||||
if i > len(self.actions) - self.max_history_turns:
|
||||
# Recent history: Add User(Image) and Assistant(Text)
|
||||
if i < len(self.screenshots) - 1: # Screenshot exists for this step
|
||||
img = self.screenshots[i]
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}}
|
||||
]
|
||||
})
|
||||
messages.append({"role": "assistant", "content": step_content})
|
||||
else:
|
||||
# Old history: Collect text
|
||||
history_step_texts.append(step_content)
|
||||
# If this is the last step before the recent window, flush collected texts
|
||||
if i == len(self.actions) - self.max_history_turns:
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": "\n".join(history_step_texts)
|
||||
})
|
||||
|
||||
# Current
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{processed_b64}"}},
|
||||
{"type": "text", "text": S1_INSTRUTION_TEMPLATE.format(instruction=instruction)}
|
||||
]
|
||||
})
|
||||
|
||||
response = self.call_llm({
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": self.max_tokens
|
||||
})
|
||||
|
||||
low_level, codes, cot_data = self._parse_response_s1(response)
|
||||
|
||||
self.observations.append(obs)
|
||||
self.cots.append(cot_data)
|
||||
self.actions.append(low_level)
|
||||
self.responses.append(response)
|
||||
|
||||
return response, codes
|
||||
|
||||
|
||||
def _parse_response_s1(self, response):
|
||||
sections = {}
|
||||
# Simple Regex Parsing
|
||||
for key, pattern in [
|
||||
('observation', r'#{1,2}\s*Observation\s*:?[\n\r]+(.*?)(?=^#{1,2}\s|$)'),
|
||||
('thought', r'#{1,2}\s*Thought\s*:?[\n\r]+(.*?)(?=^#{1,2}\s|$)'),
|
||||
('action', r'#{1,2}\s*Action\s*:?[\n\r]+(.*?)(?=^#{1,2}\s|$)')
|
||||
]:
|
||||
m = re.search(pattern, response, re.DOTALL | re.MULTILINE)
|
||||
if m: sections[key] = m.group(1).strip()
|
||||
|
||||
code_blocks = re.findall(r'```(?:code|python)?\s*(.*?)\s*```', response, re.DOTALL | re.IGNORECASE)
|
||||
code = code_blocks[-1].strip() if code_blocks else "FAIL"
|
||||
|
||||
sections['code'] = code
|
||||
|
||||
# Post-process code
|
||||
if "computer.terminate" in code:
|
||||
final_code = ["DONE"] if "success" in code.lower() else ["FAIL"]
|
||||
elif "computer.wait" in code:
|
||||
final_code = ["WAIT"]
|
||||
else:
|
||||
# Project coordinates
|
||||
code = project_coordinate_to_absolute_scale(
|
||||
code,
|
||||
self.screen_size[0],
|
||||
self.screen_size[1],
|
||||
self.coordinate_type,
|
||||
self.resize_factor
|
||||
)
|
||||
logger.info(f"[rewrite before]: {code}")
|
||||
final_code = [rewrite_pyautogui_text_inputs(code)]
|
||||
logger.info(f"[rewrite after]: {final_code}")
|
||||
|
||||
return sections.get('action', 'Acting'), final_code, sections
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _should_giveup_on_context_error(e):
|
||||
"""对于 context length 相关的错误,立即放弃重试,交给外层处理"""
|
||||
error_str = str(e)
|
||||
return "Too Large" in error_str or "context_length_exceeded" in error_str or "413" in error_str
|
||||
|
||||
@backoff.on_exception(backoff.constant, Exception, interval=30, max_tries=10, giveup=_should_giveup_on_context_error.__func__)
|
||||
def call_llm(self, payload):
|
||||
"""Unified OpenAI-compatible API call"""
|
||||
# Get env vars
|
||||
base_url = os.environ.get("OPENAI_BASE_URL", "url-xxx")
|
||||
api_key = os.environ.get("OPENAI_API_KEY", "sk-xxx")
|
||||
|
||||
client = openai.OpenAI(base_url=base_url, api_key=api_key)
|
||||
|
||||
messages = payload["messages"]
|
||||
log_messages(messages, "LLM Request")
|
||||
|
||||
params = {
|
||||
"model": payload["model"],
|
||||
"messages": messages,
|
||||
"max_tokens": payload["max_tokens"],
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p
|
||||
}
|
||||
|
||||
try:
|
||||
resp = client.chat.completions.create(**params)
|
||||
content = resp.choices[0].message.content
|
||||
logger.info(f"LLM Response:\n{content}")
|
||||
return content
|
||||
except Exception as e:
|
||||
logger.error(f"LLM Call failed: {e}")
|
||||
raise e
|
||||
|
|
@ -0,0 +1,148 @@
|
|||
S1_SYSTEM_PROMPT = """You are a GUI agent. You are given a task, a screenshot of the screen and your previous interactions with the computer. You need to perform a series of actions to complete the task. The password of the computer is "{password}", use it when you need sudo rights. You need to **wait** explicitly for installation, waiting website loading or running commands to finish. Don't terminate the task unless you are sure the task is finished. If you find that you can't finish the task, or the task is not finished exactly as the instruction indicates (you have made progress but not finished the task completely), or the task is impossible to complete, you must report **failure**.
|
||||
|
||||
For each step, provide your response in this format:
|
||||
# Step: {{step number}}
|
||||
## Thought:
|
||||
{{thought}}
|
||||
## Action:
|
||||
{{action}}
|
||||
## Code:
|
||||
{{code}}
|
||||
|
||||
For the Thought section, you should include the following parts:
|
||||
- Reflection on the task when there is previous action:
|
||||
- Consider the correnctness of previous action and its outcomes
|
||||
- If the previous action was correct, describe the change in the state of the computer and reason
|
||||
- If the previous action was incorrect, reflect on what went wrong and why
|
||||
- Step by Step Progress Assessment:
|
||||
- Add necessary information according to the history screenshots, former actions and current screenshot.
|
||||
- Analyze what parts of the task have already been completed and how they contribute to the overall goal.
|
||||
- Make a plan on how to complete the task based on the history and currect screenshot.
|
||||
- Next Action Prediction:
|
||||
- Propose the most possible next action and state the reason
|
||||
- For Text Input Actions:
|
||||
- Note current cursor position
|
||||
- Consolidate repetitive actions (specify count for multiple keypresses)
|
||||
- Describe expected final text outcome
|
||||
- Use first-person perspective in reasoning
|
||||
|
||||
For the action section, you should provide clear, concise, and actionable instructions in one sentence.
|
||||
- If the action involves interacting with a specific target:
|
||||
- Describe target explicitly (if multiple elements share that name, you should distinguish the target) without using coordinates
|
||||
- Specify element names when possible (use original language if non-English)
|
||||
- Describe features (shape, color, position) if name unavailable
|
||||
- If the action involves keyboard actions like 'press', 'write', 'hotkey':
|
||||
- Consolidate repetitive keypresses with count
|
||||
- Specify expected text outcome for typing actions
|
||||
|
||||
For the code section, you should output the corresponding code for the action. The code should be either PyAutoGUI code or one of the following functions warped in the code block:
|
||||
- {{"name": "computer.wait", "description": "Make the computer wait for 20 seconds for installation, running code, etc.", "parameters": {{"type": "object", "properties": {{}}, "required": []}}}}
|
||||
- {{"name": "computer.terminate", "description": "Terminate the current task and report its completion status", "parameters": {{"type": "object", "properties": {{"status": {{"type": "string", "enum": ["success", "failure"], "description": "The status of the task"}}, {{"answer": {{"type": "string", "description": "The answer of the task"}}}}, "required": ["status"]}}}}
|
||||
Examples for the code section:
|
||||
```python
|
||||
pyautogui.click(x=123, y=456)
|
||||
```
|
||||
```code
|
||||
computer.terminate(status="success")
|
||||
```
|
||||
```code
|
||||
computer.terminate(status="success", answer='''text''')
|
||||
```"""
|
||||
|
||||
|
||||
# S1 prompt templates for generating trajectories
|
||||
S1_STEP_TEMPLATE = "# Step {step_num}:\n"
|
||||
S1_INSTRUTION_TEMPLATE = "# Task Instruction:\n{instruction}\n\nPlease generate the next move according to the screenshot, task instruction and previous steps (if provided).\n"
|
||||
|
||||
S1_ACTION_HISTORY_TEMPLATE = "## Action:\n{action}\n"
|
||||
|
||||
|
||||
# S2 Prompts
|
||||
S2_ACTION_DESCRIPTION = """
|
||||
* `key`: Performs key down presses on the arguments passed in order, then performs key releases in reverse order.
|
||||
* `key_down`: Press and HOLD the specified key(s) down in order (no release). Use this for stateful holds like holding Shift while clicking.
|
||||
* `key_up`: Release the specified key(s) in reverse order.
|
||||
* `type`: Type a string of text on the keyboard.
|
||||
* `mouse_move`: Move the cursor to a specified (x, y) pixel coordinate on the screen.
|
||||
* `left_click`: Click the left mouse button at a specified (x, y) pixel coordinate on the screen.
|
||||
* `left_click_drag`: Click and drag the cursor to a specified (x, y) pixel coordinate on the screen.
|
||||
* `right_click`: Click the right mouse button at a specified (x, y) pixel coordinate on the screen.
|
||||
* `middle_click`: Click the middle mouse button at a specified (x, y) pixel coordinate on the screen.
|
||||
* `double_click`: Double-click the left mouse button at a specified (x, y) pixel coordinate on the screen.
|
||||
* `triple_click`: Triple-click the left mouse button at a specified (x, y) pixel coordinate on the screen.
|
||||
* `scroll`: Performs a scroll of the mouse scroll wheel.
|
||||
* `hscroll`: Performs a horizontal scroll (mapped to regular scroll).
|
||||
* `wait`: Wait specified seconds for the change to happen.
|
||||
* `terminate`: Terminate the current task and report its completion status.
|
||||
* `answer`: Answer a question.
|
||||
"""
|
||||
|
||||
S2_DESCRIPTION_PROMPT_TEMPLATE = """Use a mouse and keyboard to interact with a computer, and take screenshots.
|
||||
* This is an interface to a desktop GUI. You must click on desktop icons to start applications.
|
||||
* Some applications may take time to start or process actions, so you may need to wait and take successive screenshots to see the results of your actions. E.g. if you click on Firefox and a window doesn't open, try wait and taking another screenshot.
|
||||
{resolution_info}
|
||||
* Whenever you intend to move the cursor to click on an element like an icon, you should consult a screenshot to determine the coordinates of the element before moving the cursor.
|
||||
* If you tried clicking on a program or link but it failed to load even after waiting, try adjusting your cursor position so that the tip of the cursor visually falls on the element that you want to click.
|
||||
* Make sure to click any buttons, links, icons, etc with the cursor tip in the center of the element. Don't click boxes on their edges unless asked."""
|
||||
|
||||
S2_SYSTEM_PROMPT = """# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{tools_xml}
|
||||
</tools>
|
||||
|
||||
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
||||
<tool_call>
|
||||
{{"name": <function-name>, "arguments": <args-json-object>}}
|
||||
</tool_call>
|
||||
|
||||
# Response format
|
||||
|
||||
Response format for every step:
|
||||
1) Action: a short imperative describing what to do in the UI.
|
||||
2) A single <tool_call>...</tool_call> block containing only the JSON: {{"name": <function-name>, "arguments": <args-json-object>}}.
|
||||
|
||||
Rules:
|
||||
- Output exactly in the order: Action, <tool_call>.
|
||||
- Be brief: one sentence for Action.
|
||||
- Do not output anything else outside those parts.
|
||||
- If finishing, use action=terminate in the tool call."""
|
||||
|
||||
|
||||
def build_s2_tools_def(description_prompt):
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name_for_human": "computer_use",
|
||||
"name": "computer_use",
|
||||
"description": description_prompt,
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"action": {
|
||||
"description": S2_ACTION_DESCRIPTION,
|
||||
"enum": ["key", "type", "mouse_move", "left_click", "left_click_drag",
|
||||
"right_click", "middle_click", "double_click", "triple_click", "scroll",
|
||||
"wait", "terminate", "key_down", "key_up"],
|
||||
"type": "string"
|
||||
},
|
||||
"keys": {"description": "Required only by `action=key`.", "type": "array"},
|
||||
"text": {"description": "Required only by `action=type`.", "type": "string"},
|
||||
"coordinate": {"description": "The x,y coordinates for mouse actions.", "type": "array"},
|
||||
"pixels": {"description": "The amount of scrolling.", "type": "number"},
|
||||
"time": {"description": "The seconds to wait.", "type": "number"},
|
||||
"status": {
|
||||
"description": "The status of the task.",
|
||||
"type": "string",
|
||||
"enum": ["success", "failure"]
|
||||
}
|
||||
},
|
||||
"required": ["action"],
|
||||
"type": "object"
|
||||
},
|
||||
"args_format": "Format the arguments as a JSON object."
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,302 @@
|
|||
import base64
|
||||
import re
|
||||
import ast
|
||||
import logging
|
||||
from io import BytesIO
|
||||
import json
|
||||
from PIL import Image
|
||||
|
||||
from mm_agents.utils.qwen_vl_utils import smart_resize
|
||||
|
||||
logger = logging.getLogger("desktopenv.evocua.utils")
|
||||
|
||||
def encode_image(image_content):
|
||||
"""Encode image bytes to base64 string."""
|
||||
return base64.b64encode(image_content).decode("utf-8")
|
||||
|
||||
|
||||
def process_image(image_bytes, factor=32):
|
||||
"""
|
||||
Process an image for VL models.
|
||||
factor: 32 for S2 mode, 28 for S1 mode default
|
||||
"""
|
||||
image = Image.open(BytesIO(image_bytes))
|
||||
width, height = image.size
|
||||
|
||||
resized_height, resized_width = smart_resize(
|
||||
height=height,
|
||||
width=width,
|
||||
factor=factor,
|
||||
max_pixels=16 * 16 * 4 * 12800, # Large buffer
|
||||
)
|
||||
|
||||
image = image.resize((resized_width, resized_height))
|
||||
|
||||
buffer = BytesIO()
|
||||
image.save(buffer, format="PNG")
|
||||
processed_bytes = buffer.getvalue()
|
||||
|
||||
return base64.b64encode(processed_bytes).decode("utf-8"), resized_width, resized_height
|
||||
|
||||
def _fallback_rewrite_pyautogui_text_inputs(code: str) -> str:
|
||||
"""
|
||||
Regex-based fallback to handle malformed pyautogui.write/typewrite calls.
|
||||
"""
|
||||
logger.info(f"SyntaxError detected in code, using regex fallback. Original code: {code}")
|
||||
|
||||
def _replacer(match):
|
||||
call_content = match.group(0)
|
||||
m = re.search(r'pyautogui\.(?:write|typewrite)\s*\(', call_content)
|
||||
if not m:
|
||||
return call_content
|
||||
|
||||
args_part = call_content[m.end():].strip()
|
||||
args_part = re.sub(r'^(?:message|text)\s*=\s*', '', args_part)
|
||||
|
||||
text_content = ""
|
||||
if args_part.startswith(("'''", '"""')):
|
||||
quote_type = args_part[:3]
|
||||
content = args_part[3:]
|
||||
end_idx = content.rfind(quote_type)
|
||||
if end_idx != -1:
|
||||
text_content = content[:end_idx]
|
||||
else:
|
||||
text_content = content[:-1] if content.endswith(')') else content
|
||||
elif args_part.startswith(("'", '"')):
|
||||
quote_type = args_part[0]
|
||||
content = args_part[1:]
|
||||
if content.endswith(quote_type + ")"):
|
||||
text_content = content[:-2]
|
||||
elif content.endswith(")"):
|
||||
if len(content) > 1 and content[-2] == quote_type:
|
||||
text_content = content[:-2]
|
||||
else:
|
||||
text_content = content[:-1]
|
||||
elif content.endswith(quote_type):
|
||||
text_content = content[:-1]
|
||||
else:
|
||||
text_content = content
|
||||
else:
|
||||
text_content = args_part[:-1] if args_part.endswith(')') else args_part
|
||||
|
||||
new_cmds = []
|
||||
for char in text_content:
|
||||
p = "enter" if char == "\n" else char
|
||||
p_esc = p.replace("'", "\\'")
|
||||
new_cmds.append(f"pyautogui.press('{p_esc}')")
|
||||
|
||||
return "; ".join(new_cmds)
|
||||
|
||||
pattern = r"pyautogui\.(?:write|typewrite)\s*\(.*?(?=\s*;|\s*$|\n)"
|
||||
new_code = re.sub(pattern, _replacer, code)
|
||||
|
||||
if new_code == code and ("pyautogui.write" in code or "pyautogui.typewrite" in code):
|
||||
new_code = re.sub(r"pyautogui\.(?:write|typewrite)\s*\(.*", _replacer, code)
|
||||
|
||||
return new_code
|
||||
|
||||
def rewrite_pyautogui_text_inputs(code: str) -> str:
|
||||
"""
|
||||
Expand pyautogui.write/typewrite string literals into per-character presses.
|
||||
"""
|
||||
try:
|
||||
tree = ast.parse(code)
|
||||
|
||||
class _TextCallRewriter(ast.NodeTransformer):
|
||||
def _extract_text(self, call: ast.Call):
|
||||
if not (
|
||||
isinstance(call.func, ast.Attribute)
|
||||
and isinstance(call.func.value, ast.Name)
|
||||
and call.func.value.id == "pyautogui"
|
||||
and call.func.attr in ("write", "typewrite")
|
||||
):
|
||||
return None
|
||||
|
||||
message_node = call.args[0] if call.args else None
|
||||
if message_node is None:
|
||||
for kw in call.keywords:
|
||||
if kw.arg in ("message", "text"):
|
||||
message_node = kw.value
|
||||
break
|
||||
|
||||
if isinstance(message_node, ast.Constant) and isinstance(message_node.value, str):
|
||||
return message_node.value
|
||||
return None
|
||||
|
||||
def visit_Expr(self, node):
|
||||
self.generic_visit(node)
|
||||
if isinstance(node.value, ast.Call):
|
||||
text = self._extract_text(node.value)
|
||||
if text is not None:
|
||||
new_nodes = []
|
||||
for char in text:
|
||||
press_value = "enter" if char == "\n" else char
|
||||
press_call = ast.Expr(
|
||||
value=ast.Call(
|
||||
func=ast.Attribute(
|
||||
value=ast.Name(id="pyautogui", ctx=ast.Load()),
|
||||
attr="press",
|
||||
ctx=ast.Load(),
|
||||
),
|
||||
args=[ast.Constant(value=press_value)],
|
||||
keywords=[],
|
||||
)
|
||||
)
|
||||
new_nodes.append(press_call)
|
||||
return new_nodes if new_nodes else node
|
||||
return node
|
||||
|
||||
tree = _TextCallRewriter().visit(tree)
|
||||
tree = ast.fix_missing_locations(tree)
|
||||
new_code = ast.unparse(tree)
|
||||
return new_code
|
||||
|
||||
except (SyntaxError, Exception):
|
||||
return _fallback_rewrite_pyautogui_text_inputs(code)
|
||||
|
||||
|
||||
|
||||
def project_coordinate_to_absolute_scale(pyautogui_code_relative_coordinates, screen_width, screen_height, coordinate_type="relative", resize_factor=28):
|
||||
"""
|
||||
Convert the relative coordinates in the pyautogui code to absolute coordinates based on the logical screen size.
|
||||
"""
|
||||
def _coordinate_projection(x, y, screen_width, screen_height, coordinate_type):
|
||||
if coordinate_type == "qwen25":
|
||||
height, width = smart_resize(
|
||||
height=screen_height,
|
||||
width=screen_width,
|
||||
factor=resize_factor,
|
||||
min_pixels=3136,
|
||||
max_pixels=12845056
|
||||
)
|
||||
if 0 <= x <= 1 and 0 <= y <= 1:
|
||||
# If already normalized, treat like "relative"
|
||||
return int(round(x * width)), int(round(y * height))
|
||||
return int(x / width * screen_width), int(y / height * screen_height)
|
||||
else:
|
||||
raise ValueError(f"Invalid coordinate type: {coordinate_type}. Expected 'qwen25'")
|
||||
|
||||
pattern = r'(pyautogui\.\w+\([^\)]*\))'
|
||||
matches = re.findall(pattern, pyautogui_code_relative_coordinates)
|
||||
|
||||
new_code = pyautogui_code_relative_coordinates
|
||||
|
||||
for full_call in matches:
|
||||
func_name_pattern = r'(pyautogui\.\w+)\((.*)\)'
|
||||
func_match = re.match(func_name_pattern, full_call, re.DOTALL)
|
||||
if not func_match:
|
||||
continue
|
||||
|
||||
func_name = func_match.group(1)
|
||||
args_str = func_match.group(2)
|
||||
|
||||
try:
|
||||
parsed = ast.parse(f"func({args_str})").body[0].value
|
||||
parsed_args = parsed.args
|
||||
parsed_keywords = parsed.keywords
|
||||
|
||||
except SyntaxError:
|
||||
continue
|
||||
|
||||
function_parameters = {
|
||||
'click': ['x', 'y', 'clicks', 'interval', 'button', 'duration', 'pause'],
|
||||
'rightClick': ['x', 'y', 'duration', 'tween', 'pause'],
|
||||
'middleClick': ['x', 'y', 'duration', 'tween', 'pause'],
|
||||
'doubleClick': ['x', 'y', 'interval', 'button', 'duration', 'pause'],
|
||||
'tripleClick': ['x', 'y', 'interval', 'button', 'duration', 'pause'],
|
||||
'moveTo': ['x', 'y', 'duration', 'tween', 'pause'],
|
||||
'dragTo': ['x', 'y', 'duration', 'button', 'mouseDownUp', 'pause'],
|
||||
}
|
||||
|
||||
func_base_name = func_name.split('.')[-1]
|
||||
|
||||
param_names = function_parameters.get(func_base_name, [])
|
||||
|
||||
args = {}
|
||||
for idx, arg in enumerate(parsed_args):
|
||||
if idx < len(param_names):
|
||||
param_name = param_names[idx]
|
||||
try:
|
||||
arg_value = ast.literal_eval(arg)
|
||||
args[param_name] = arg_value
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
for kw in parsed_keywords:
|
||||
param_name = kw.arg
|
||||
arg_value = ast.literal_eval(kw.value)
|
||||
args[param_name] = arg_value
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing keyword arguments: {e}")
|
||||
continue
|
||||
|
||||
updated = False
|
||||
if 'x' in args and 'y' in args:
|
||||
try:
|
||||
x_rel = float(args['x'])
|
||||
y_rel = float(args['y'])
|
||||
# Only project if they look like relative coords (e.g. <= 1.0 or depending on type)
|
||||
# Projection applies unconditionally if type is relative
|
||||
x_abs, y_abs = _coordinate_projection(x_rel, y_rel, screen_width, screen_height, coordinate_type)
|
||||
|
||||
# Apply coordinate transformation
|
||||
args['x'] = x_abs
|
||||
args['y'] = y_abs
|
||||
updated = True
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
if updated:
|
||||
reconstructed_args = []
|
||||
for idx, param_name in enumerate(param_names):
|
||||
if param_name in args:
|
||||
arg_value = args[param_name]
|
||||
if isinstance(arg_value, str):
|
||||
arg_repr = f"'{arg_value}'"
|
||||
else:
|
||||
arg_repr = str(arg_value)
|
||||
reconstructed_args.append(arg_repr)
|
||||
else:
|
||||
break
|
||||
|
||||
used_params = set(param_names[:len(reconstructed_args)])
|
||||
for kw in parsed_keywords:
|
||||
if kw.arg not in used_params:
|
||||
arg_value = args[kw.arg]
|
||||
if isinstance(arg_value, str):
|
||||
arg_repr = f"{kw.arg}='{arg_value}'"
|
||||
else:
|
||||
arg_repr = f"{kw.arg}={arg_value}"
|
||||
reconstructed_args.append(arg_repr)
|
||||
|
||||
new_args_str = ', '.join(reconstructed_args)
|
||||
new_full_call = f"{func_name}({new_args_str})"
|
||||
new_code = new_code.replace(full_call, new_full_call)
|
||||
|
||||
return new_code
|
||||
|
||||
|
||||
def log_messages(messages, prefix="LLM Messages"):
|
||||
"""Log messages with truncated base64 images"""
|
||||
try:
|
||||
log_msgs = []
|
||||
for msg in messages:
|
||||
msg_copy = msg.copy()
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list):
|
||||
new_content = []
|
||||
for item in content:
|
||||
if isinstance(item, dict) and item.get("type") == "image_url":
|
||||
item_copy = item.copy()
|
||||
url = item_copy.get("image_url", {}).get("url", "")
|
||||
if len(url) > 100:
|
||||
item_copy["image_url"] = {"url": url[:30] + "...[base64_truncated]..." + url[-10:]}
|
||||
new_content.append(item_copy)
|
||||
else:
|
||||
new_content.append(item)
|
||||
msg_copy["content"] = new_content
|
||||
log_msgs.append(msg_copy)
|
||||
logger.info(f"{prefix}:\n{json.dumps(log_msgs, indent=2, ensure_ascii=False)}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to log messages: {e}")
|
||||
|
|
@ -0,0 +1,350 @@
|
|||
import logging
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
|
||||
from mm_agents.os_symphony.memory.procedural_memory import PROCEDURAL_MEMORY
|
||||
from mm_agents.os_symphony.utils.common_utils import call_llm_safe, parse_code_from_string
|
||||
from mm_agents.os_symphony.core.mllm import LMMAgent
|
||||
|
||||
logger = logging.getLogger("desktopenv.coder_agent")
|
||||
|
||||
|
||||
def extract_code_block(action: str) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""Extract code and determine type from action string."""
|
||||
if "```python" in action:
|
||||
code_type = "python"
|
||||
code = action.split("```python")[1].split("```")[0].strip()
|
||||
elif "```bash" in action:
|
||||
code_type = "bash"
|
||||
code = action.split("```bash")[1].split("```")[0].strip()
|
||||
elif "```" in action:
|
||||
code_type = None
|
||||
code = action.split("```")[1].split("```")[0].strip()
|
||||
else:
|
||||
code_type = None
|
||||
code = None
|
||||
|
||||
logger.debug(
|
||||
f"Extracted code block: type={code_type}, length={len(code) if code else 0}"
|
||||
)
|
||||
return code_type, code
|
||||
|
||||
|
||||
def execute_code(code_type: str, code: str, env_controller) -> Dict:
|
||||
"""Execute code based on its type."""
|
||||
# Log the full code being executed (untruncated)
|
||||
logger.info(f"CODING_AGENT_CODE_EXECUTION - Type: {code_type}\nCode:\n{code}")
|
||||
|
||||
try:
|
||||
if code_type == "bash":
|
||||
result = env_controller.run_bash_script(code, timeout=30)
|
||||
elif code_type == "python":
|
||||
result = env_controller.run_python_script(code)
|
||||
else:
|
||||
result = {"status": "error", "error": f"Unknown code type: {code_type}"}
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing {code_type} code: {e}")
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def format_result(result: Dict, step_count: int) -> str:
|
||||
"""Format execution result into context string."""
|
||||
if not result:
|
||||
logger.warning(f"Step {step_count + 1}: No result returned from execution")
|
||||
return f"""
|
||||
Step {step_count + 1} Error:
|
||||
Error: No result returned from execution
|
||||
"""
|
||||
|
||||
status = result.get("status", "unknown")
|
||||
return_code = result.get("returncode", result.get("return_code", -1))
|
||||
|
||||
# Handle different response structures for bash vs python
|
||||
if "returncode" in result:
|
||||
# Bash script response
|
||||
output = result.get("output", "") # Contains both stdout and stderr merged
|
||||
error = result.get("error", "") # Always empty for bash
|
||||
else:
|
||||
# Python script response
|
||||
output = result.get("output", "") # stdout only
|
||||
error = result.get("error", "") # stderr only
|
||||
|
||||
logger.debug(f"Step {step_count + 1}: Status={status}, Return Code={return_code}")
|
||||
|
||||
# Format with better structure for multi-line outputs
|
||||
result_text = f"Step {step_count + 1} Result:\n"
|
||||
result_text += f"Status: {status}\n"
|
||||
result_text += f"Return Code: {return_code}\n"
|
||||
|
||||
if output:
|
||||
result_text += f"Output:\n{output}\n"
|
||||
|
||||
if error:
|
||||
result_text += f"Error:\n{error}\n"
|
||||
|
||||
return result_text
|
||||
|
||||
|
||||
class CoderAgent:
|
||||
"""A dedicated agent for executing code with a budget of steps."""
|
||||
|
||||
def __init__(self, engine_params: Dict, client_password: str, platform: str = "linux"):
|
||||
"""Initialize the CodeAgent."""
|
||||
if not engine_params:
|
||||
raise ValueError("engine_params cannot be None or empty")
|
||||
|
||||
self.engine_params = engine_params
|
||||
self.budget = engine_params.get("budget", 20)
|
||||
self.temperature = engine_params.get("temperature", 0.1)
|
||||
self.agent = None
|
||||
self.platform = platform
|
||||
self.client_password = client_password
|
||||
|
||||
logger.info(f"CodeAgent initialized with budget={self.budget} and platform={self.platform}")
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""Reset the code agent state."""
|
||||
logger.debug("Resetting CodeAgent state")
|
||||
self.agent = LMMAgent(
|
||||
engine_params=self.engine_params,
|
||||
system_prompt=PROCEDURAL_MEMORY.construct_coder_procedural_memory(platform=self.platform, client_password=self.client_password)
|
||||
)
|
||||
|
||||
def execute(self, task_instruction: str, screenshot: str, env_controller) -> Dict:
|
||||
"""Execute code for the given task with a budget of steps."""
|
||||
if env_controller is None:
|
||||
raise ValueError("env_controller is required for code execution")
|
||||
|
||||
print(f"\n🚀 STARTING CODE EXECUTION")
|
||||
print("=" * 60)
|
||||
print(f"Task: {task_instruction}")
|
||||
print(f"Budget: {self.budget} steps")
|
||||
print("=" * 60)
|
||||
|
||||
logger.info(f"Starting code execution for task: {task_instruction}")
|
||||
logger.info(f"Budget: {self.budget} steps")
|
||||
|
||||
self.reset()
|
||||
|
||||
|
||||
# Add initial task instruction and screenshot context as user message
|
||||
context = (
|
||||
f"Task: {task_instruction}\n\nCurrent screenshot is provided for context."
|
||||
)
|
||||
self.agent.add_message(context, image_content=screenshot, role="user")
|
||||
|
||||
step_count = 0
|
||||
execution_history = []
|
||||
execution_result_history = []
|
||||
while step_count < self.budget:
|
||||
logger.info(f"Step {step_count + 1}/{self.budget}")
|
||||
|
||||
# Get assistant response (thoughts and code)
|
||||
response = call_llm_safe(self.agent, temperature=self.temperature)
|
||||
|
||||
# Print to terminal for immediate visibility
|
||||
# print(f"\n🤖 CODING AGENT RESPONSE - Step {step_count + 1}/{self.budget}")
|
||||
# print("=" * 60)
|
||||
# print(response)
|
||||
# print("=" * 60)
|
||||
|
||||
# Log the latest message from the coding agent (untruncated)
|
||||
logger.info(
|
||||
f"CODING_AGENT_LATEST_MESSAGE - Step {step_count + 1}:\n{response}"
|
||||
)
|
||||
|
||||
# Check if response is None or empty
|
||||
if not response or response.strip() == "":
|
||||
error_msg = f"Step {step_count + 1}: LLM returned empty response"
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
# Parse the response to extract action
|
||||
action = parse_code_from_string(response)
|
||||
thoughts = response
|
||||
|
||||
execution_history.append(
|
||||
{"step": step_count + 1, "action": action, "thoughts": thoughts}
|
||||
)
|
||||
|
||||
# Check for completion signals
|
||||
action_upper = action.upper().strip()
|
||||
if action_upper == "DONE":
|
||||
print(f"\n✅ TASK COMPLETED - Step {step_count + 1}")
|
||||
print("=" * 60)
|
||||
print("Agent signaled task completion")
|
||||
print("=" * 60)
|
||||
logger.info(f"Step {step_count + 1}: Task completed successfully")
|
||||
completion_reason = "DONE"
|
||||
break
|
||||
elif action_upper == "FAIL":
|
||||
print(f"\n❌ TASK FAILED - Step {step_count + 1}")
|
||||
print("=" * 60)
|
||||
print("Agent signaled task failure")
|
||||
print("=" * 60)
|
||||
logger.info(f"Step {step_count + 1}: Task failed by agent request")
|
||||
completion_reason = "FAIL"
|
||||
break
|
||||
elif action_upper == 'INFEASIBLE':
|
||||
print(f"\n❌ TASK INFEASIBLE - Step {step_count + 1}")
|
||||
print("=" * 60)
|
||||
print("Agent signaled task infeasible")
|
||||
print("=" * 60)
|
||||
logger.info(f"Step {step_count + 1}: Task infeasible by agent request")
|
||||
completion_reason = "INFEASIBLE"
|
||||
break
|
||||
|
||||
# Extract and execute code
|
||||
code_type, code = extract_code_block(response.split("(Answer)")[-1])
|
||||
|
||||
if code:
|
||||
result = execute_code(code_type, code, env_controller)
|
||||
execution_result_history.append(
|
||||
{"step": step_count + 1, "result": result}
|
||||
)
|
||||
# Prepare formatted output and error for logging
|
||||
output = result.get("output", "")
|
||||
error = result.get("error", "")
|
||||
message = result.get("message", "")
|
||||
status = result.get("status", "")
|
||||
|
||||
# Print execution result to terminal for immediate visibility
|
||||
print(f"\n⚡ CODE EXECUTION RESULT - Step {step_count + 1}")
|
||||
print("-" * 50)
|
||||
print(f"Status: {status}")
|
||||
if output:
|
||||
print(f"Output:\n{output}")
|
||||
if error:
|
||||
print(f"Error:\n{error}")
|
||||
if message and not output and not error:
|
||||
print(f"Message:\n{message}")
|
||||
print("-" * 50)
|
||||
|
||||
log_lines = [
|
||||
f"CODING_AGENT_EXECUTION_RESULT - Step {step_count + 1}:",
|
||||
f"Status: {status}" if status else None,
|
||||
]
|
||||
|
||||
if output:
|
||||
log_lines.append(
|
||||
"Output:\n" + ("-" * 40) + f"\n{output}\n" + ("-" * 40)
|
||||
)
|
||||
if error:
|
||||
log_lines.append(
|
||||
"Error:\n" + ("!" * 40) + f"\n{error}\n" + ("!" * 40)
|
||||
)
|
||||
if message and not output and not error:
|
||||
log_lines.append(
|
||||
"Message:\n" + ("-" * 40) + f"\n{message}\n" + ("-" * 40)
|
||||
)
|
||||
|
||||
# Remove None entries and join
|
||||
formatted_log = "\n".join([line for line in log_lines if line])
|
||||
logger.info(formatted_log)
|
||||
else:
|
||||
print(f"\n⚠️ NO CODE BLOCK FOUND - Step {step_count + 1}")
|
||||
print("-" * 50)
|
||||
print("Action did not contain executable code")
|
||||
print("-" * 50)
|
||||
|
||||
logger.warning(f"Step {step_count + 1}: No code block found in action")
|
||||
result = {"status": "skipped", "message": "No code block found"}
|
||||
logger.info(
|
||||
f"CODING_AGENT_EXECUTION_RESULT - Step {step_count + 1}:\n"
|
||||
f"Status: skipped\n"
|
||||
f"Message:\n{'-' * 40}\n{result['message']}\n{'-' * 40}"
|
||||
)
|
||||
# Add assistant's thoughts and code to message history
|
||||
self.agent.add_message(response, role="assistant")
|
||||
|
||||
# Process result and add formatted environment results as user message
|
||||
result_context = format_result(result, step_count)
|
||||
self.agent.add_message(result_context, role="user")
|
||||
|
||||
step_count += 1
|
||||
|
||||
# Handle budget exhaustion
|
||||
if "completion_reason" not in locals():
|
||||
print(f"\n⏰ BUDGET EXHAUSTED - {step_count} steps completed")
|
||||
print("=" * 60)
|
||||
print(f"Maximum budget of {self.budget} steps reached")
|
||||
print("=" * 60)
|
||||
logger.info(f"Budget exhausted after {step_count} steps")
|
||||
completion_reason = f"BUDGET_EXHAUSTED_AFTER_{step_count}_STEPS"
|
||||
|
||||
# Generate final summary
|
||||
logger.info("Generating execution summary")
|
||||
summary = self._generate_summary(execution_history, task_instruction)
|
||||
|
||||
result = {
|
||||
"task_instruction": task_instruction,
|
||||
"completion_reason": completion_reason,
|
||||
"summary": summary,
|
||||
"execution_history": execution_history,
|
||||
"execution_result_history": execution_result_history,
|
||||
"steps_executed": step_count,
|
||||
"budget": self.budget
|
||||
}
|
||||
|
||||
logger.info(f"Code execution completed: steps={step_count}")
|
||||
return result
|
||||
|
||||
def _generate_summary(
|
||||
self, execution_history: List[Dict], task_instruction: str
|
||||
) -> str:
|
||||
"""Generate summary of code execution session."""
|
||||
if not execution_history:
|
||||
logger.info("No execution history to summarize")
|
||||
return "No actions were executed."
|
||||
|
||||
logger.info(f"Generated summary for {len(execution_history)} steps")
|
||||
|
||||
# Build detailed execution context for summary agent
|
||||
execution_context = f"Task: {task_instruction}\n\nExecution Steps:\n"
|
||||
|
||||
for step in execution_history:
|
||||
step_num = step["step"]
|
||||
thoughts = step.get("thoughts", "")
|
||||
action = step.get("action", "")
|
||||
|
||||
execution_context += f"\nStep {step_num}:\n"
|
||||
if thoughts:
|
||||
execution_context += f"Thoughts: {thoughts}\n"
|
||||
execution_context += f"Code: {action}\n"
|
||||
|
||||
# Create summary prompt with same context as coding agent
|
||||
summary_prompt = f"""
|
||||
{execution_context}
|
||||
|
||||
Please provide a concise summary of the code execution session. Focus on:
|
||||
|
||||
1. The code logic implemented at each step
|
||||
2. The outputs and results produced by each code execution
|
||||
3. The progression of the solution approach
|
||||
|
||||
Do not make judgments about success or failure. Simply describe what was attempted and what resulted.
|
||||
|
||||
Keep the summary under 150 words and use clear, factual language.
|
||||
"""
|
||||
|
||||
# Generate summary using LLM with dedicated summary system prompt
|
||||
try:
|
||||
summary_agent = LMMAgent(
|
||||
engine_params=self.engine_params,
|
||||
system_prompt=PROCEDURAL_MEMORY.CODE_SUMMARY_AGENT_PROMPT,
|
||||
)
|
||||
summary_agent.add_message(summary_prompt, role="user")
|
||||
summary = call_llm_safe(summary_agent, temperature=self.temperature)
|
||||
|
||||
if not summary or summary.strip() == "":
|
||||
summary = "Summary generation failed - no response from LLM"
|
||||
logger.warning("Summary generation failed - empty response from LLM")
|
||||
|
||||
except Exception as e:
|
||||
summary = f"Summary generation failed: {str(e)}"
|
||||
logger.error(f"Error generating summary: {e}")
|
||||
|
||||
return summary
|
||||
|
|
@ -0,0 +1,109 @@
|
|||
import re
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytesseract
|
||||
from PIL import Image
|
||||
import io
|
||||
from mm_agents.os_symphony.core.mllm import LMMAgent
|
||||
from mm_agents.os_symphony.utils.common_utils import call_llm_safe, smart_resize
|
||||
from mm_agents.os_symphony.memory.procedural_memory import PROCEDURAL_MEMORY
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("desktopenv.agent")
|
||||
|
||||
class GrounderAgent:
|
||||
"""
|
||||
Class designed for interacting with GUI, serving for Grounding Agent and VLMSearcher
|
||||
"""
|
||||
def __init__(self, engine_params: Dict, screen_width: int, screen_height: int):
|
||||
self.engine_params_for_grounder = engine_params # grounder_params
|
||||
system_prompt, self.user_message = PROCEDURAL_MEMORY.construct_grounder_procedural_memory(model_name=engine_params["model"])
|
||||
self.grounding_model = LMMAgent(engine_params, system_prompt=system_prompt)
|
||||
# Width and height for Grounding Agent!
|
||||
self.width = engine_params['grounding_width']
|
||||
self.height = engine_params['grounding_height']
|
||||
print(f"[Grounder]: initialized width is {self.width}, height is {self.height}")
|
||||
# Width and height for actual screen!
|
||||
self.screen_width = screen_width
|
||||
self.screen_height = screen_height
|
||||
|
||||
# Given the state and worker's referring expression, use the grounding model to generate (x,y)
|
||||
def generate_coords(self, ref_expr: str, obs: Dict, detail=False, expansion_pixels=400, **kwargs) -> List:
|
||||
cur_screenshot = obs["screenshot"]
|
||||
|
||||
# store global offset
|
||||
global_offset_x = 0
|
||||
global_offset_y = 0
|
||||
|
||||
# final coordinates for output
|
||||
final_global_x = 0
|
||||
final_global_y = 0
|
||||
|
||||
cur_width, cur_height = self.screen_width, self.screen_height
|
||||
|
||||
print(f"[Grounder] start to ground!")
|
||||
self.grounding_model.reset()
|
||||
|
||||
# Configure the context
|
||||
prompt = self.user_message.replace("REF_EXPR", ref_expr)
|
||||
|
||||
# cosistent with the system prompt presented in the paper of GTA-1
|
||||
if 'gta' in self.engine_params_for_grounder['model']:
|
||||
self.grounding_model.add_system_prompt("You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task.")
|
||||
|
||||
self.grounding_model.add_message(
|
||||
text_content=prompt, image_content=cur_screenshot, put_text_last=True, role="user"
|
||||
)
|
||||
|
||||
# Generate and parse coordinates
|
||||
response = call_llm_safe(self.grounding_model, temperature=0.05, **kwargs)
|
||||
print(f"[Grounder] prompt: {prompt}\nmodel: {self.engine_params_for_grounder['model']}, \nresponse: {response}")
|
||||
|
||||
|
||||
# 1. highest priority: (x1="...", y1="...", x="...", y="...")
|
||||
numericals = re.findall(r'(?:x1|y1|x|y)=["\']?(\d+)["\']?', response)
|
||||
# 2. second highest priority: just like <points>653 42</points> or [653, 42]
|
||||
if len(numericals) < 2:
|
||||
clean_response = re.sub(r'[xXyY]\d', '', response)
|
||||
numericals = re.findall(r'\d+', clean_response)
|
||||
assert len(numericals) >= 2
|
||||
|
||||
print(f"[Grounder] the parsed coordinates: {numericals}")
|
||||
|
||||
local_x, local_y = self._resize_coordinates([int(numericals[0]), int(numericals[1])], width=cur_width, height=cur_height)
|
||||
|
||||
# current global coordinates = local ordinates + global offset
|
||||
final_global_x = local_x + global_offset_x
|
||||
final_global_y = local_y + global_offset_y
|
||||
|
||||
if detail:
|
||||
return [cur_screenshot, global_offset_x, global_offset_y]
|
||||
else:
|
||||
return [final_global_x, final_global_y]
|
||||
|
||||
def dynamic_set_width_height(self, width: int, height: int):
|
||||
self.width = width
|
||||
self.height = height
|
||||
|
||||
# Resize from grounding model dim into OSWorld dim (1920 * 1080)
|
||||
def _resize_coordinates(self, coordinates: List[int], width:int, height:int) -> List[int]:
|
||||
"""
|
||||
width, height: for current observation
|
||||
grounding_width, grounding_height: width and height for Grounding model 1000x1000 or 1280x800)
|
||||
"""
|
||||
grounding_width = self.engine_params_for_grounder["grounding_width"]
|
||||
grounding_height = self.engine_params_for_grounder["grounding_height"]
|
||||
grounding_smart_resize = self.engine_params_for_grounder["grounding_smart_resize"]
|
||||
|
||||
|
||||
if not grounding_smart_resize:
|
||||
return [
|
||||
round(coordinates[0] * width / grounding_width),
|
||||
round(coordinates[1] * height / grounding_height),
|
||||
]
|
||||
else:
|
||||
smart_height, smart_width = smart_resize(height, width)
|
||||
return [
|
||||
round(coordinates[0] * width / smart_width),
|
||||
round(coordinates[1] * height / smart_height)
|
||||
]
|
||||
|
|
@ -0,0 +1,428 @@
|
|||
from ast import parse
|
||||
import logging
|
||||
import json
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from mm_agents.os_symphony.utils.common_utils import (
|
||||
call_llm_formatted,
|
||||
enhance_observation,
|
||||
parse_code_from_string
|
||||
)
|
||||
from functools import partial
|
||||
from mm_agents.os_symphony.utils.formatters import JSON_ANSWER_FORMATTER
|
||||
from mm_agents.os_symphony.core.mllm import LMMAgent
|
||||
from mm_agents.os_symphony.memory.procedural_memory import PROCEDURAL_MEMORY
|
||||
import imagehash
|
||||
import io
|
||||
import os
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from skimage.metrics import structural_similarity as ssim
|
||||
|
||||
logger = logging.getLogger("desktopenv.agent")
|
||||
|
||||
|
||||
class StepBehavior:
|
||||
"""
|
||||
Narrative Step Behavior.
|
||||
Description of each step, cosists of generative agent (main agent)'s output, screenshot (if this step is milestone), and textual description.
|
||||
The textual description shows that how the agent thought and did, and how the state changes.
|
||||
"""
|
||||
def __init__(self, is_milestone: bool, gen_output: str, summary: str, obs: Dict, action_dict: Dict):
|
||||
self.is_milestone = is_milestone
|
||||
self.gen_output = gen_output
|
||||
self.obs = obs
|
||||
self.summary = summary
|
||||
self.action_dict = action_dict
|
||||
# Variants for opyimizing the time complexity of loop detection
|
||||
# --- 1. pHash ---
|
||||
self.phash = None
|
||||
# --- 2. SSIM ---
|
||||
self.ssim_list = []
|
||||
|
||||
def _update_phash_ssim(self, history: List):
|
||||
# Calculate the ssim_list of current obs
|
||||
# Update pHash
|
||||
cur_img = Image.open(io.BytesIO(self.obs["screenshot"]))
|
||||
cur_img_gray = cur_img.convert('L')
|
||||
cur_img_np = np.array(cur_img_gray)
|
||||
self.phash = imagehash.phash(cur_img)
|
||||
# Update ssim_list
|
||||
for hs in history:
|
||||
compare_img = Image.open(io.BytesIO(hs.obs["screenshot"]))
|
||||
compare_img_gray = compare_img.convert('L')
|
||||
compare_img_np = np.array(compare_img_gray)
|
||||
self.ssim_list.append(ssim(cur_img_np, compare_img_np, data_range=cur_img_np.max() - compare_img_np.min()))
|
||||
|
||||
class ReflectionMemoryAgent:
|
||||
"""
|
||||
Reflection Memory Agent (RMA).
|
||||
Responsible for maintaining long-term memory, extracting narratives from trajectories,
|
||||
providing reflections to the Main Agent, and validating task completion status.
|
||||
"""
|
||||
def __init__(self, engine_params: Dict):
|
||||
"""
|
||||
Initialize the RMA.
|
||||
|
||||
Args:
|
||||
- engine_params:
|
||||
{
|
||||
"engine_type": args.provider,
|
||||
"model": args.model,
|
||||
"base_url": args.model_url,
|
||||
"api_key": args.model_api_key,
|
||||
"temperature": getattr(args, "model_temperature", None),
|
||||
}
|
||||
- max_img_len: max image number to use in reflection process
|
||||
"""
|
||||
|
||||
self.engine_params = engine_params
|
||||
|
||||
self.max_images = engine_params.get('max_images', 8)
|
||||
|
||||
self.memoryer_level = engine_params.get('memoryer_level', 3)
|
||||
|
||||
self.reset()
|
||||
|
||||
logger.info(f"ReflectionMemoryAgent initialized with:\n {self.engine_params}")
|
||||
|
||||
|
||||
def reset(self):
|
||||
"""Reset the code agent state."""
|
||||
logger.debug("Resetting RMA state")
|
||||
|
||||
self.instruction = None
|
||||
|
||||
self.trajectory: List[StepBehavior] = []
|
||||
|
||||
self.knowledge_base: List[str] = []
|
||||
|
||||
self.last_code_step_idx = -1
|
||||
|
||||
'''
|
||||
Control the count of images, we only use the maximum number of max_img_len images.
|
||||
The update logic: the 0-th screenshot is always retained. If the total number of screenshots is less than max_img_len, all are kept; otherwise, starting from index 1, milestone screenshots are managed via FIFO.
|
||||
'''
|
||||
self.active_img_idx = []
|
||||
|
||||
self.reflection_agent = LMMAgent(
|
||||
engine_params=self.engine_params,
|
||||
system_prompt=PROCEDURAL_MEMORY.REFLECTION_SYSTEM_PROMPT,
|
||||
)
|
||||
self.behavior_agent = LMMAgent(
|
||||
engine_params=self.engine_params,
|
||||
system_prompt=PROCEDURAL_MEMORY.SUMMARIZE_STEP_SYSTEM_PROMPT
|
||||
)
|
||||
|
||||
def add_instruction(self, instruction):
|
||||
"""
|
||||
[Interface] Main -> RMA
|
||||
Main agent set the instruction to RMA.
|
||||
"""
|
||||
self.instruction = instruction
|
||||
|
||||
def _update_trajectory(self, step_behavior):
|
||||
self.trajectory.append(step_behavior)
|
||||
if len(self.active_img_idx) >= self.max_images:
|
||||
if step_behavior.is_milestone:
|
||||
self.active_img_idx.append(len(self.trajectory) - 1) # over max_img_len,only milestone image
|
||||
del self.active_img_idx[1] # FIFO starts from index 1
|
||||
else:
|
||||
self.active_img_idx.append(len(self.trajectory) - 1) # less than max_img_len, feed all images
|
||||
|
||||
assert len(self.active_img_idx) <= self.max_images, "[RMA] Errors in updating StepBehavior!!"
|
||||
|
||||
def _summarize_step_behavior(
|
||||
self,
|
||||
generator_output: str,
|
||||
cur_obs: Dict,
|
||||
enhanced_obs: bytes | None,
|
||||
is_milestone: bool,
|
||||
mode: str = "gui",
|
||||
code_exec_summary: str = "",
|
||||
action_dict: Dict = {}
|
||||
) -> Tuple[StepBehavior, bool]:
|
||||
"""
|
||||
[Interface] Main -> RMA
|
||||
The Main Agent (MA) calls this method to "feed" the information of the just-completed step to the RMA.
|
||||
RMA will internally process and store this step.
|
||||
"""
|
||||
|
||||
if mode == "search":
|
||||
is_success = "successful"
|
||||
# summary is fixed
|
||||
step_behavior = StepBehavior(
|
||||
False,
|
||||
generator_output,
|
||||
"Search Agent was called last step, and a tutorial has been generated.",
|
||||
cur_obs,
|
||||
action_dict
|
||||
)
|
||||
elif mode == "code":
|
||||
self.last_code_step_idx = len(self.trajectory)
|
||||
|
||||
is_success = "successful"
|
||||
# the summary returned by the code agent
|
||||
step_behavior = StepBehavior(
|
||||
False,
|
||||
generator_output,
|
||||
f"Code Agent was called last step, and the summary of its trajectory is: \n---\n{code_exec_summary}\n---",
|
||||
cur_obs,
|
||||
action_dict
|
||||
)
|
||||
else: # common gui operation, use LLM to summarize
|
||||
prev_obs = self.trajectory[-1].obs
|
||||
|
||||
text_content = f"""Computer Use Agent's Output: \n{generator_output}"""
|
||||
|
||||
|
||||
self.behavior_agent.reset() # don't need history messages
|
||||
|
||||
updated_sys_prompt = (
|
||||
self.behavior_agent.system_prompt + "\n" + text_content
|
||||
)
|
||||
self.behavior_agent.add_system_prompt(updated_sys_prompt)
|
||||
|
||||
self.behavior_agent.add_message(
|
||||
text_content="This is the observation before executing action (attached below).",
|
||||
image_content=prev_obs['screenshot'],
|
||||
role="user",
|
||||
put_text_last=False
|
||||
)
|
||||
self.behavior_agent.add_message(
|
||||
text_content="This is the zoom-in view, which may help you to identify the operational region (attached below).",
|
||||
image_content=enhanced_obs,
|
||||
role="user",
|
||||
put_text_last=False
|
||||
)
|
||||
self.behavior_agent.add_message(
|
||||
text_content="This is the observation after executing action (attached below).",
|
||||
image_content=cur_obs['screenshot'],
|
||||
role="user",
|
||||
put_text_last=False
|
||||
)
|
||||
|
||||
required_fields = ["summary", "evaluation"]
|
||||
format_checkers = [
|
||||
partial(JSON_ANSWER_FORMATTER, required_fields)
|
||||
]
|
||||
|
||||
full_response = call_llm_formatted(
|
||||
self.behavior_agent,
|
||||
format_checkers,
|
||||
temperature=self.engine_params.get("temperture", 0.1),
|
||||
)
|
||||
|
||||
response = parse_code_from_string(full_response)
|
||||
|
||||
try:
|
||||
data = json.loads(response)
|
||||
behavior_summary = data['summary']
|
||||
is_success = data["evaluation"]
|
||||
except Exception as e:
|
||||
print("[RMA] Errors in generating step summary: ", e)
|
||||
logger.info("Response is not a JSON object or miss required keys!")
|
||||
behavior_summary = response
|
||||
is_success = "successful"
|
||||
|
||||
|
||||
step_behavior = StepBehavior(is_milestone, generator_output, behavior_summary, cur_obs, action_dict)
|
||||
|
||||
return step_behavior, is_success == "successful"
|
||||
|
||||
def get_reflection(
|
||||
self,
|
||||
cur_obs: Dict,
|
||||
generator_output: str,
|
||||
coordinates: List,
|
||||
mode: str="gui",
|
||||
code_exec_summary: str = "",
|
||||
action_dict: Dict = {}
|
||||
) -> Dict:
|
||||
"""
|
||||
[Interface] RMA -> Main
|
||||
The Main Agent (MA) calls this method to get RMA's reflection before deciding the next action.
|
||||
|
||||
Args:
|
||||
- cur_obs (Dict): The Main Agent's current observation (o_k).
|
||||
- generator_output (str): The thoughts, screen analysis and action of Main Agent.
|
||||
- coordinates (List): coordinates in the last operation step of Main Agent.
|
||||
- mode(str): [gui, code, search]. Indicate which agent that main agent called last step.
|
||||
- code_exec_summary: execution summary for code agent.
|
||||
- action_dict: extracted action from generator output.
|
||||
|
||||
Returns:
|
||||
- reflection_info(Dict): all the info related to reflection
|
||||
"""
|
||||
if self.memoryer_level == 0:
|
||||
return {
|
||||
"reflection": None,
|
||||
"reflection_thoughts": None,
|
||||
"existing_knowledge": None,
|
||||
"is_milestone": False,
|
||||
"new_knowledge": None,
|
||||
"step_summary": None,
|
||||
"hint": {
|
||||
"gui_operation_error": False,
|
||||
"lack_of_tutorial": False,
|
||||
"code_error": False,
|
||||
"loop_detection": None,
|
||||
}
|
||||
}
|
||||
|
||||
reflection = None
|
||||
reflection_thought = None
|
||||
if len(self.trajectory) == 0:
|
||||
step_behavior = StepBehavior(
|
||||
True,
|
||||
"The initial screen is provided. No action has been taken yet.",
|
||||
"The initial screen is provided. No action has been taken yet.",
|
||||
cur_obs,
|
||||
action_dict
|
||||
)
|
||||
step_behavior._update_phash_ssim(self.trajectory)
|
||||
self._update_trajectory(step_behavior)
|
||||
reflection_info = {
|
||||
"reflection": reflection,
|
||||
"reflection_thoughts": reflection_thought,
|
||||
"existing_knowledge": "\n".join(self.knowledge_base),
|
||||
"is_milestone": True,
|
||||
"new_knowledge": "",
|
||||
"step_summary": "",
|
||||
"loop_detection": None
|
||||
}
|
||||
else:
|
||||
### Step Summary
|
||||
prev_obs = self.trajectory[-1].obs
|
||||
enhanced_obs = None
|
||||
if coordinates:
|
||||
enhanced_obs, _, _, _, _ = enhance_observation(
|
||||
prev_obs["screenshot"],
|
||||
coordinates,
|
||||
draw=True
|
||||
)
|
||||
|
||||
# generate step behavior
|
||||
step_behavior, last_gui_check = self._summarize_step_behavior(
|
||||
generator_output,
|
||||
cur_obs,
|
||||
enhanced_obs,
|
||||
False,
|
||||
mode,
|
||||
code_exec_summary,
|
||||
action_dict
|
||||
)
|
||||
step_behavior._update_phash_ssim(self.trajectory)
|
||||
|
||||
### make additional hints
|
||||
additional_hints = []
|
||||
if not last_gui_check:
|
||||
additional_hints.append(f"\t- Warning: The last GUI operation is unsuccessful. Careful review is required to avoid GUI Operation Error.")
|
||||
|
||||
code_error_hint = False
|
||||
|
||||
if self.last_code_step_idx != -1 and len(self.trajectory) - self.last_code_step_idx < 0:
|
||||
code_error_hint = True
|
||||
additional_hints.append(f"\t- Warning: The Computer Use Agent might in the verification stage of Code Agent. Careful review is required to avoid Code Error.")
|
||||
|
||||
# loop detection
|
||||
from mm_agents.os_symphony.utils.loop_detection import detect_loop
|
||||
is_loop, loop_details = detect_loop(full_trajectory=self.trajectory, N=3)
|
||||
if is_loop and loop_details:
|
||||
match_sequence_indices = loop_details["match_sequence_indices"]
|
||||
loop_hint_message = f"\t- Warning: A potential LOOP has been detected between Step {match_sequence_indices[0]} and Step {match_sequence_indices[-1]}. Careful review is required to avoid Repetitive Behavior Error."
|
||||
additional_hints.append(loop_hint_message)
|
||||
|
||||
self.reflection_agent.reset()
|
||||
|
||||
updated_sys_prompt = (
|
||||
PROCEDURAL_MEMORY.REFLECTION_SYSTEM_PROMPT + "\n\n" +
|
||||
f"---\n- **user instruction**: {self.instruction}\n" +
|
||||
"- **existing knowledge**: \n" + "\n".join(self.knowledge_base) +
|
||||
"\n- **additional_hints**: " + "\n".join(additional_hints) + "\n---"
|
||||
)
|
||||
|
||||
# update system prompt
|
||||
self.reflection_agent.add_system_prompt(updated_sys_prompt)
|
||||
|
||||
|
||||
for i, step in enumerate(self.trajectory):
|
||||
text_content = f"""### (Step {i}) history:\nsummary: '''\n{step.summary}\n'''"""
|
||||
if i in self.active_img_idx:
|
||||
if i == 0:
|
||||
text_content += f"\ninitial screenshot:"
|
||||
else:
|
||||
text_content += f"\nscreenshot (after executing action): (attached below)"
|
||||
|
||||
self.reflection_agent.add_message(
|
||||
text_content=text_content,
|
||||
image_content=step.obs['screenshot'] if i in self.active_img_idx else None,
|
||||
role="user",
|
||||
put_text_last=False
|
||||
)
|
||||
|
||||
text_content = f"""### (Last Step) CUA's output (has been finished):\n---\n{generator_output}\n---\nStep Summary:\n---\n{step_behavior.summary}\n---\nlatest_screenshot: (attached below)"""
|
||||
self.reflection_agent.add_message(
|
||||
text_content=text_content,
|
||||
image_content=cur_obs['screenshot'],
|
||||
role="user",
|
||||
put_text_last=False
|
||||
)
|
||||
|
||||
required_fields = ["is_milestone", "reflection", "knowledge"]
|
||||
|
||||
format_checkers = [
|
||||
partial(JSON_ANSWER_FORMATTER, required_fields)
|
||||
]
|
||||
|
||||
full_response = call_llm_formatted(
|
||||
self.reflection_agent,
|
||||
format_checkers
|
||||
)
|
||||
|
||||
|
||||
reflection_thought = full_response
|
||||
|
||||
response = parse_code_from_string(full_response)
|
||||
|
||||
try:
|
||||
data = json.loads(response)
|
||||
reflection = data['reflection']
|
||||
is_milestone = data["is_milestone"]
|
||||
knowledge = data['knowledge']
|
||||
except Exception as e:
|
||||
print("[RMA] Errors in dealing with reflection: ", e)
|
||||
logger.info("Response is not a JSON object or miss required keys!")
|
||||
reflection = response
|
||||
is_milestone = False
|
||||
knowledge = ""
|
||||
|
||||
if len(knowledge) > 0:
|
||||
self.knowledge_base.append(knowledge)
|
||||
|
||||
if isinstance(is_milestone, str):
|
||||
is_milestone = True if "true" in is_milestone.lower() else False
|
||||
|
||||
# update trajectory and is_milestone
|
||||
self._update_trajectory(step_behavior)
|
||||
if mode == "gui": # only gui opration can be considered as milestone
|
||||
self.trajectory[-1].is_milestone = is_milestone
|
||||
|
||||
|
||||
reflection_info = {
|
||||
"reflection": reflection,
|
||||
"reflection_thoughts": reflection_thought,
|
||||
"existing_knowledge": "\n".join(self.knowledge_base),
|
||||
"is_milestone": is_milestone,
|
||||
"new_knowledge": knowledge,
|
||||
"step_summary": step_behavior.summary,
|
||||
"hint": {
|
||||
"gui_operation_error": not last_gui_check,
|
||||
"lack_of_tutorial": is_loop,
|
||||
"code_error": code_error_hint,
|
||||
"loop_detection": loop_details,
|
||||
}
|
||||
}
|
||||
|
||||
return reflection_info
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,178 @@
|
|||
import re
|
||||
from io import BytesIO
|
||||
from typing import Tuple, List, Dict
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
import numpy as np
|
||||
import pytesseract
|
||||
from pytesseract import Output
|
||||
import easyocr
|
||||
|
||||
|
||||
class OCRProcessor:
|
||||
"""
|
||||
OCR Processor supports Tesseract and EasyOCR
|
||||
"""
|
||||
def __init__(self, use_gpu: bool = False, languages: List[str] = ['en']):
|
||||
"""
|
||||
Initialize processor
|
||||
|
||||
Args:
|
||||
use_gpu (bool): whether EasyOCR need to use gpu
|
||||
languages (List[str]): language list that EasyOCR, e.g. ['en', 'ch_sim']。
|
||||
"""
|
||||
self.use_gpu = use_gpu
|
||||
self.languages = languages
|
||||
self.reader = None # lazy-load EasyOCR Reader
|
||||
|
||||
def _get_easyocr_reader(self):
|
||||
if self.reader is None:
|
||||
print(f"Loading EasyOCR model (GPU={self.use_gpu})...")
|
||||
self.reader = easyocr.Reader(self.languages, gpu=self.use_gpu)
|
||||
return self.reader
|
||||
|
||||
def get_ocr_elements(self, bytes_image_data: bytes, mode: str = 'tesseract') -> Tuple[str, List[Dict]]:
|
||||
"""
|
||||
Executes OCR recognization.
|
||||
|
||||
Args:
|
||||
bytes_image_data (str): image in Base64
|
||||
mode (str): 'tesseract' (faster) or 'easyocr' (more precise)。
|
||||
|
||||
Returns:
|
||||
Tuple[str, List]: (textual table string, list of element details)
|
||||
"""
|
||||
try:
|
||||
image = Image.open(BytesIO(bytes_image_data))
|
||||
except Exception as e:
|
||||
print(f"Error decoding or opening image: {e}")
|
||||
return "", []
|
||||
|
||||
if mode == 'tesseract':
|
||||
return self._process_tesseract(image)
|
||||
elif mode == 'easyocr':
|
||||
return self._process_easyocr(image)
|
||||
else:
|
||||
raise ValueError(f"Unknown mode: {mode}. Use 'tesseract' or 'easyocr'.")
|
||||
|
||||
def _process_tesseract(self, image: Image.Image) -> Tuple[str, List[Dict]]:
|
||||
"""Tesseract processing"""
|
||||
data = pytesseract.image_to_data(image, output_type=Output.DICT)
|
||||
|
||||
ocr_elements = []
|
||||
ocr_table = "Text Table (Tesseract):\nWord id\tText\n"
|
||||
ocr_id = 0
|
||||
|
||||
num_boxes = len(data['text'])
|
||||
for i in range(num_boxes):
|
||||
# filter text with low confidence
|
||||
if int(data['conf'][i]) > 0 and data['text'][i].strip():
|
||||
clean_text = re.sub(r"^[^a-zA-Z0-9\s.,!?;:\-\+]+|[^a-zA-Z0-9\s.,!?;:\-\+]+$", "", data['text'][i])
|
||||
if not clean_text: continue
|
||||
|
||||
ocr_table += f"{ocr_id}\t{clean_text}\n"
|
||||
|
||||
ocr_elements.append({
|
||||
"id": ocr_id,
|
||||
"text": clean_text,
|
||||
"mode": "tesseract",
|
||||
"left": data["left"][i],
|
||||
"top": data["top"][i],
|
||||
"width": data["width"][i],
|
||||
"height": data["height"][i],
|
||||
"conf": data["conf"][i]
|
||||
})
|
||||
ocr_id += 1
|
||||
|
||||
return ocr_table, ocr_elements
|
||||
|
||||
def _process_easyocr(self, image: Image.Image) -> Tuple[str, List[Dict]]:
|
||||
"""EasyOCR processing"""
|
||||
reader = self._get_easyocr_reader()
|
||||
|
||||
image_np = np.array(image)
|
||||
|
||||
# detail=1 means returning (bbox, text, conf)
|
||||
results = reader.readtext(image_np, detail=1, paragraph=False, width_ths=0.1)
|
||||
|
||||
ocr_elements = []
|
||||
ocr_table = "Text Table (EasyOCR):\nWord id\tText\n"
|
||||
ocr_id = 0
|
||||
|
||||
for (bbox, text, conf) in results:
|
||||
clean_text = re.sub(r"^[^a-zA-Z0-9\s.,!?;:\-\+]+|[^a-zA-Z0-9\s.,!?;:\-\+]+$", "", text)
|
||||
if not clean_text.strip(): continue
|
||||
|
||||
# EasyOCR returns [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
|
||||
# we convert them into left, top, width, height
|
||||
(tl, tr, br, bl) = bbox
|
||||
tl = [int(v) for v in tl]
|
||||
br = [int(v) for v in br]
|
||||
|
||||
left = min(tl[0], bl[0])
|
||||
top = min(tl[1], tr[1])
|
||||
right = max(tr[0], br[0])
|
||||
bottom = max(bl[1], br[1])
|
||||
|
||||
width = right - left
|
||||
height = bottom - top
|
||||
# ---------------
|
||||
|
||||
ocr_table += f"{ocr_id}\t{clean_text}\n"
|
||||
|
||||
ocr_elements.append({
|
||||
"id": ocr_id,
|
||||
"text": clean_text,
|
||||
"mode": "easyocr",
|
||||
"left": left,
|
||||
"top": top,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"conf": float(conf)
|
||||
})
|
||||
ocr_id += 1
|
||||
|
||||
return ocr_table, ocr_elements
|
||||
|
||||
@staticmethod
|
||||
def visualize_ocr_results(image_path: str, ocr_elements: List[Dict], output_path: str):
|
||||
"""
|
||||
Draw bounding boxes and IDs on the original image.
|
||||
"""
|
||||
try:
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
draw = ImageDraw.Draw(image)
|
||||
|
||||
try:
|
||||
font = ImageFont.truetype("arial.ttf", 16)
|
||||
except IOError:
|
||||
font = ImageFont.load_default()
|
||||
|
||||
for element in ocr_elements:
|
||||
left, top = element["left"], element["top"]
|
||||
width, height = element["width"], element["height"]
|
||||
|
||||
color = "green" if element.get("mode") == "easyocr" else "red"
|
||||
|
||||
draw.rectangle([(left, top), (left + width, top + height)], outline=color, width=2)
|
||||
|
||||
text_str = str(element["id"])
|
||||
|
||||
if hasattr(draw, "textbbox"):
|
||||
bbox = draw.textbbox((0, 0), text_str, font=font)
|
||||
text_w, text_h = bbox[2]-bbox[0], bbox[3]-bbox[1]
|
||||
else:
|
||||
text_w, text_h = draw.textsize(text_str, font=font)
|
||||
|
||||
label_bg = [left, top - text_h - 4, left + text_w + 4, top]
|
||||
draw.rectangle(label_bg, fill=color)
|
||||
|
||||
draw.text((left + 2, top - text_h - 4), text_str, fill="white", font=font)
|
||||
|
||||
image.save(output_path)
|
||||
print(f"Visualization saved to: {output_path}")
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"Error: Image {image_path} not found.")
|
||||
except Exception as e:
|
||||
print(f"Visualization error: {e}")
|
||||
|
||||
|
|
@ -0,0 +1,575 @@
|
|||
import re
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from mm_agents.os_symphony.memory.procedural_memory import PROCEDURAL_MEMORY
|
||||
from mm_agents.os_symphony.core.mllm import LMMAgent
|
||||
from mm_agents.os_symphony.utils.common_utils import call_llm_safe
|
||||
from mm_agents.os_symphony.agents.coder_agent import CoderAgent
|
||||
from mm_agents.os_symphony.agents.grounder_agent import GrounderAgent
|
||||
from mm_agents.os_symphony.agents.searcher_agent import SearcherAgent
|
||||
import logging
|
||||
from mm_agents.os_symphony.agents.ocr import OCRProcessor
|
||||
|
||||
|
||||
logger = logging.getLogger("desktopenv.agent")
|
||||
|
||||
# Agent action decorator
|
||||
def agent_action(func):
|
||||
func.is_agent_action = True
|
||||
return func
|
||||
|
||||
# GrounderAgent primitives are parameterized by description, and coordinate generation uses a pretrained grounding model
|
||||
class OSACI:
|
||||
def __init__(
|
||||
self,
|
||||
env,
|
||||
search_env,
|
||||
platform: str,
|
||||
client_password: str,
|
||||
engine_params_for_ocr: Dict,
|
||||
engine_params_for_grounder: Dict,
|
||||
engine_params_for_coder: Dict,
|
||||
engine_params_for_searcher: Dict,
|
||||
screen_width: int = 1920,
|
||||
screen_height: int = 1080
|
||||
):
|
||||
|
||||
self.env = env
|
||||
self.platform = platform
|
||||
self.client_password = client_password
|
||||
|
||||
self.result_dir = ""
|
||||
|
||||
self.grounder_agent = GrounderAgent(engine_params=engine_params_for_grounder, screen_width=screen_width, screen_height=screen_height)
|
||||
|
||||
# Configure text grounding agent
|
||||
self.ocr_processor = OCRProcessor()
|
||||
self.text_span_agent = LMMAgent(
|
||||
engine_params=engine_params_for_ocr,
|
||||
system_prompt=PROCEDURAL_MEMORY.PHRASE_TO_WORD_COORDS_PROMPT,
|
||||
)
|
||||
|
||||
# Configure code agent
|
||||
self.coder_agent = CoderAgent(
|
||||
engine_params=engine_params_for_coder,
|
||||
platform=self.platform,
|
||||
client_password=client_password
|
||||
)
|
||||
|
||||
# Configure search agent
|
||||
self.searcher_agent = SearcherAgent.create(
|
||||
engine_params=engine_params_for_searcher,
|
||||
search_env=search_env,
|
||||
grounder_agent=self.grounder_agent,
|
||||
platform=self.platform,
|
||||
client_password=self.client_password
|
||||
)
|
||||
|
||||
# Store task instruction for code agent
|
||||
self.current_task_instruction = None
|
||||
self.last_code_agent_result = None
|
||||
self.last_search_agent_result = None
|
||||
self.notes: List[str] = []
|
||||
# Tutorial should be a global info, not a local context, so how to add it to the global info
|
||||
self.tutorials = []
|
||||
|
||||
|
||||
def assign_screenshot(self, obs):
|
||||
self.obs = obs
|
||||
|
||||
# Given the state and worker's text phrase, generate the coords of the first/last word in the phrase
|
||||
def generate_text_coords(
|
||||
self, phrase: str, obs: Dict, alignment: str = ""
|
||||
) -> List[int]:
|
||||
|
||||
screenshot, global_offset_x, global_offset_y= obs["screenshot"], 0, 0
|
||||
|
||||
ocr_table, ocr_elements = self.ocr_processor.get_ocr_elements(screenshot, "easyocr")
|
||||
|
||||
alignment_prompt = ""
|
||||
if alignment == "start":
|
||||
alignment_prompt = "**Important**: Output the word id of the FIRST word in the provided phrase.\n"
|
||||
elif alignment == "end":
|
||||
alignment_prompt = "**Important**: Output the word id of the LAST word in the provided phrase.\n"
|
||||
|
||||
# Load LLM prompt
|
||||
self.text_span_agent.reset()
|
||||
self.text_span_agent.add_message(
|
||||
alignment_prompt + "Phrase: " + phrase + "\n" + ocr_table, role="user"
|
||||
)
|
||||
self.text_span_agent.add_message(
|
||||
"Screenshot:\n", image_content=screenshot, role="user"
|
||||
)
|
||||
|
||||
# Obtain the target element
|
||||
response = call_llm_safe(self.text_span_agent)
|
||||
print("TEXT SPAN AGENT RESPONSE:", response)
|
||||
numericals = re.findall(r"\d+", response)
|
||||
if len(numericals) > 0:
|
||||
text_id = int(numericals[-1])
|
||||
else:
|
||||
text_id = 0
|
||||
elem = ocr_elements[text_id]
|
||||
|
||||
# Compute the element coordinates
|
||||
# Note: 0.1 * elem["height"] is used to adjust coordinates to select the last character more precisely.
|
||||
if alignment == "start":
|
||||
coords = [elem["left"], elem["top"] + (elem["height"] // 2)]
|
||||
elif alignment == "end":
|
||||
coords = [elem["left"] + elem["width"] + 0.15 * elem["height"], elem["top"] + (elem["height"] // 2)]
|
||||
|
||||
print(f'[OCR] output coordinates: {[coords[0] + global_offset_x, coords[1] + global_offset_y]}')
|
||||
return [int(coords[0] + global_offset_x), int(coords[1] + global_offset_y)]
|
||||
|
||||
def set_task_instruction(self, task_instruction: str):
|
||||
"""Set the current task instruction for the code agent."""
|
||||
self.current_task_instruction = task_instruction
|
||||
|
||||
@agent_action
|
||||
def click(
|
||||
self,
|
||||
element_description: str,
|
||||
num_clicks: int = 1,
|
||||
button_type: str = "left",
|
||||
hold_keys: List = []
|
||||
):
|
||||
"""Click on the element
|
||||
Args:
|
||||
element_description:str, a detailed descriptions of which element to click on. This description needs to be VERY unambiguous. If the page contains many similar elements, ensure the description uniquely identifies the target element.
|
||||
num_clicks:int, number of times to click the element
|
||||
button_type:str, which mouse button to press can be "left", "middle", or "right"
|
||||
hold_keys:List, list of keys to hold while clicking
|
||||
"""
|
||||
x, y = self.grounder_agent.generate_coords(element_description, self.obs)
|
||||
|
||||
command = "import pyautogui; "
|
||||
|
||||
for k in hold_keys:
|
||||
command += f"pyautogui.keyDown({repr(k)}); "
|
||||
command += f"""import pyautogui; pyautogui.click({x}, {y}, clicks={num_clicks}, button={repr(button_type)}); """
|
||||
for k in hold_keys:
|
||||
command += f"pyautogui.keyUp({repr(k)}); "
|
||||
# Return pyautoguicode to click on the element
|
||||
|
||||
action = {"function": "click", "args": {"x": x, "y": y, "button": button_type, "clicks": num_clicks}}
|
||||
return (command, action)
|
||||
|
||||
@agent_action
|
||||
def open(self, app_or_filename: str):
|
||||
"""Open any application or file with name app_or_filename. Use this action to open applications or files on the desktop, do not open manually.
|
||||
Args:
|
||||
app_or_filename:str, the name of the application or filename to open
|
||||
|
||||
**Important**:
|
||||
Provide only the name of the application or file. Do not include the full path (e.g., "/home/user/Desktop/my_report.docx"). The function works by searching for the name, not by accessing a file path directly.
|
||||
"""
|
||||
action = {"function": "open", "args": {"name": app_or_filename}}
|
||||
if self.platform == "linux":
|
||||
return (f"import pyautogui; pyautogui.hotkey('win'); time.sleep(1.0); pyautogui.write({repr(app_or_filename)}); time.sleep(1.0); pyautogui.hotkey('enter'); time.sleep(1.0)", action)
|
||||
elif self.platform == "macos":
|
||||
return (f"import pyautogui; import time; pyautogui.hotkey('command', 'space', interval=0.5); pyautogui.typewrite({repr(app_or_filename)}); pyautogui.press('enter'); time.sleep(1.0)", action)
|
||||
elif self.platform == "windows":
|
||||
return (f"import pyautogui; import time; pyautogui.hotkey('win'); time.sleep(0.5); pyautogui.write({repr(app_or_filename)}); time.sleep(1.0); pyautogui.press('enter'); time.sleep(0.5)", action)
|
||||
else:
|
||||
assert (
|
||||
False
|
||||
), f"Unsupported platform: {self.platform}. Supported platforms are: darwin, linux, windows."
|
||||
|
||||
def _paste(self, is_terminal):
|
||||
if self.platform == 'macos':
|
||||
return "pyautogui.hotkey('command', 'v');"
|
||||
|
||||
elif self.platform == 'linux':
|
||||
if is_terminal:
|
||||
return "pyautogui.hotkey('ctrl', 'shift', 'v');"
|
||||
else:
|
||||
return "pyautogui.hotkey('ctrl', 'v');"
|
||||
|
||||
elif self.platform == 'windows':
|
||||
return "pyautogui.hotkey('ctrl', 'v');"
|
||||
|
||||
return ""
|
||||
|
||||
def _clear_all(self, is_terminal):
|
||||
"""
|
||||
Clean the content of current line
|
||||
"""
|
||||
# common apps in GUI
|
||||
if not is_terminal:
|
||||
if self.platform == 'macos':
|
||||
# macOS GUI: Command + A -> Backspace
|
||||
return "pyautogui.hotkey('command', 'a'); pyautogui.press('backspace');"
|
||||
else:
|
||||
# Windows/Linux GUI: Ctrl + A -> Backspace
|
||||
return "pyautogui.hotkey('ctrl', 'a'); pyautogui.press('backspace');"
|
||||
|
||||
# terminal
|
||||
else:
|
||||
if self.platform == 'windows':
|
||||
return "pyautogui.press('esc');"
|
||||
else:
|
||||
return "pyautogui.hotkey('ctrl', 'e'); pyautogui.hotkey('ctrl', 'u');"
|
||||
|
||||
def _type(
|
||||
self,
|
||||
text: str,
|
||||
is_terminal: bool
|
||||
):
|
||||
"""
|
||||
use copy and paste to input Chinese, otherwise type normally
|
||||
"""
|
||||
commands = ""
|
||||
has_unicode = any(ord(char) > 127 for char in text)
|
||||
if has_unicode and self.platform != "macos":
|
||||
commands += (
|
||||
"original_clipboard = pyperclip.paste();"
|
||||
f"pyperclip.copy({repr(text)});"
|
||||
"time.sleep(0.1);"
|
||||
)
|
||||
commands += self._paste(is_terminal=is_terminal)
|
||||
commands += "pyperclip.copy(original_clipboard);"
|
||||
else:
|
||||
commands += f"pyautogui.write({repr(text)}, interval=0.1);"
|
||||
|
||||
return commands
|
||||
|
||||
@agent_action
|
||||
def type(
|
||||
self,
|
||||
element_description: str,
|
||||
text: str = "",
|
||||
overwrite: bool = False,
|
||||
enter: bool = False,
|
||||
is_terminal = False
|
||||
):
|
||||
"""Type text/unicode into a specific element
|
||||
Args:
|
||||
element_description: str, a detailed description of which element to enter text in. If provided, the agent will click on this element before typing.
|
||||
text:str, the text to type
|
||||
overwrite:bool, Default is False, assign it to True if the text should overwrite the whole existing text. Using this argument clears all text in an element.
|
||||
enter:bool, Assign it to True if the enter key should be pressed after typing all the text, otherwise assign it to False.
|
||||
is_terminal:bool, (MANDATORY) You MUST set this to True whenever the target you will type into is a terminal.
|
||||
"""
|
||||
commands = (
|
||||
"import os;"
|
||||
"import pyautogui;"
|
||||
"import pyperclip;"
|
||||
"import subprocess;"
|
||||
"import time;"
|
||||
)
|
||||
|
||||
|
||||
if self.platform == "linux":
|
||||
commands += (
|
||||
"p_http = os.environ.get('http_proxy') or os.environ.get('HTTP_PROXY');"
|
||||
"p_https = os.environ.get('https_proxy') or os.environ.get('HTTPS_PROXY');"
|
||||
"proxy_prefix = (f'http_proxy={p_http} ' if p_http else '') + (f'https_proxy={p_https} ' if p_https else '');"
|
||||
f"subprocess.run(f'echo \"{self.client_password}\" | sudo -S {{proxy_prefix}}apt-get install -y xclip xsel', shell=True, check=True);"
|
||||
)
|
||||
|
||||
x, y = None, None
|
||||
if element_description is not None:
|
||||
x, y = self.grounder_agent.generate_coords(element_description, self.obs)
|
||||
commands += (
|
||||
f"pyautogui.click({x}, {y}, clicks=2);"
|
||||
f"time.sleep(1.0);"
|
||||
f"pyautogui.click({x}, {y});"
|
||||
)
|
||||
|
||||
if overwrite:
|
||||
commands += self._clear_all(is_terminal=is_terminal)
|
||||
|
||||
commands += self._type(text=text, is_terminal=is_terminal)
|
||||
|
||||
if enter:
|
||||
commands += "pyautogui.press('enter');"
|
||||
|
||||
if element_description is not None:
|
||||
action = {"function": "type", "args": {"x": x, "y": y, "text": text}}
|
||||
else:
|
||||
action = {"function": "type", "args": {"text": text}}
|
||||
return (commands, action)
|
||||
|
||||
@agent_action
|
||||
def drag_and_drop(
|
||||
self, starting_description: str, ending_description: str, hold_keys: List = []
|
||||
):
|
||||
"""Drag from the starting description to the ending description
|
||||
Args:
|
||||
starting_description:str, a very detailed description of where to start the drag action. This description should be at least a full sentence.
|
||||
ending_description:str, a very detailed description of where to end the drag action. This description should be at least a full sentence.
|
||||
hold_keys:List list of keys to hold while dragging
|
||||
"""
|
||||
x1, y1 = self.grounder_agent.generate_coords(starting_description, self.obs)
|
||||
x2, y2 = self.grounder_agent.generate_coords(ending_description, self.obs)
|
||||
|
||||
command = "import pyautogui; "
|
||||
|
||||
command += f"pyautogui.moveTo({x1}, {y1}); "
|
||||
# TODO: specified duration?
|
||||
for k in hold_keys:
|
||||
command += f"pyautogui.keyDown({repr(k)}); "
|
||||
command += f"pyautogui.dragTo({x2}, {y2}, duration=3., button='left'); pyautogui.mouseUp(); "
|
||||
for k in hold_keys:
|
||||
command += f"pyautogui.keyUp({repr(k)}); "
|
||||
|
||||
# Return pyautoguicode to drag and drop the elements
|
||||
action = {"function": "drag", "args": {"x1": x1, "y1": y1, "x2": x2, "y2": y2}}
|
||||
return (command, action)
|
||||
|
||||
@agent_action
|
||||
def highlight_text_span(
|
||||
self,
|
||||
starting_phrase: str,
|
||||
ending_phrase: str,
|
||||
button: str = "left",
|
||||
text: Optional[str|None] = None
|
||||
):
|
||||
"""Highlight a text span between a provided starting phrase and ending phrase. Use this to highlight words, lines, and paragraphs.
|
||||
Args:
|
||||
starting_phrase: str, the sequence of words that marks the beginning of the text span. Provide a unique sequence of 5 to 10 words.
|
||||
ending_phrase: str, the sequence of words that marks the end of the text span. Provide a unique sequence of 5 to 10 words.
|
||||
button:str, the button to use to highlight the text span. Defaults to "left". Can be "left", "right", or "middle".
|
||||
text: str | None, The text to overwrite the highlighted span with. Providing text here ensures the replacement happens immediately after selection, preventing focus loss.
|
||||
"""
|
||||
x1, y1 = self.generate_text_coords(
|
||||
starting_phrase, self.obs, alignment="start"
|
||||
)
|
||||
x2, y2 = self.generate_text_coords(
|
||||
ending_phrase, self.obs, alignment="end"
|
||||
)
|
||||
|
||||
command = "import pyautogui; import time;"
|
||||
command += f"pyautogui.moveTo({x1}, {y1}); "
|
||||
# Click in advance to simulate selecting the text box.
|
||||
command += (
|
||||
f"pyautogui.click({x1}, {y1}, clicks=2);"
|
||||
f"time.sleep(1.0); pyautogui.click({x1}, {y1}); time.sleep(1.0);"
|
||||
)
|
||||
command += f"pyautogui.dragTo({x2}, {y2}, duration=5., button='{button}'); time.sleep(0.5); pyautogui.mouseUp(); "
|
||||
|
||||
if text:
|
||||
if self.platform == "linux":
|
||||
command += "subprocess.run('echo \"password\" | sudo -S apt-get install -y xclip xsel', shell=True, check=True, env={\"http_proxy\": \"http://10.1.8.5:23128\", \"https_proxy\": \"http://10.1.8.5:23128\"});"
|
||||
|
||||
command += (
|
||||
"original_clipboard = pyperclip.paste();"
|
||||
f"pyperclip.copy({repr(text)});"
|
||||
)
|
||||
command += self._paste(is_terminal=False)
|
||||
command += "pyperclip.copy(original_clipboard);"
|
||||
|
||||
# Return pyautoguicode to drag and drop the elements
|
||||
action = {"function": "drag", "args": {"x1": x1, "y1": y1, "x2": x2, "y2": y2}}
|
||||
return (command, action)
|
||||
|
||||
@agent_action
|
||||
def locate_cursor(
|
||||
self,
|
||||
phrase: str,
|
||||
start_or_end: str="start",
|
||||
text: Optional[str|None] = None
|
||||
):
|
||||
"""Click at the beginning or end of a specific text phrase to precisely control cursor positioning. Please prefer using the "click" action in general situations, and use this action only in text-intensive software such as libreoffice_writer, impress, etc.
|
||||
|
||||
Args:
|
||||
phrase: str, The text phrase where you want to position the cursor. Provide a unique sequence of 5 to 10 words. Do NOT use single words unless the total text is extremely short.
|
||||
start_or_end: str, Whether to click at the "start" (beginning) or "end" (trailing edge) of the identified text phrase. Use "start" to position before the text, "end" to position after it.
|
||||
text: str | None, The text to enter immediately after positioning the cursor. Use this parameter instead of a separate 'type' action to ensure precise input.
|
||||
"""
|
||||
x, y = self.generate_text_coords(
|
||||
phrase, self.obs, alignment=start_or_end
|
||||
)
|
||||
command = (
|
||||
"import pyautogui;"
|
||||
"import time;"
|
||||
"import subprocess;"
|
||||
"import pyperclip;"
|
||||
f"pyautogui.click({x}, {y}, button='left', clicks=2);"
|
||||
"time.sleep(1.0);"
|
||||
f"pyautogui.click({x}, {y}, button='left');"
|
||||
)
|
||||
if text:
|
||||
if self.platform == "linux":
|
||||
command += "subprocess.run('echo \"password\" | sudo -S apt-get install -y xclip xsel', shell=True, check=True, env={\"http_proxy\": \"http://10.1.8.5:23128\", \"https_proxy\": \"http://10.1.8.5:23128\"});"
|
||||
|
||||
command += self._type(text=text, is_terminal=False)
|
||||
|
||||
if text:
|
||||
action = {"function": "type", "args": {"x": x, "y": y, "text": text}}
|
||||
else:
|
||||
action = {"function": "click", "args": {"x": x, "y": y, "clicks": 1, "button": "left"}}
|
||||
return (command, action)
|
||||
|
||||
|
||||
@agent_action
|
||||
def call_code_agent(self, task: str):
|
||||
"""Calls the code agent to execute a well-defined, self-contained goal that can be completed with code.
|
||||
|
||||
Args:
|
||||
task: str, A specific, self-contained goal that the code agent can work on until completion.
|
||||
|
||||
**🚨 CRITICAL GUIDELINES:**
|
||||
|
||||
**Decompose the Main Objective into Logical Goals:**
|
||||
- You **MUST** break down the overall mission into distinct, logical goals or stages.
|
||||
- Your role is to define *what* needs to be done for a specific stage. The code agent's role is to figure out *how* to do it with code.
|
||||
- Pass only one logical goal at a time. The `task` parameter is **REQUIRED**.
|
||||
|
||||
**Define a Self-Contained, Continuous Goal:**
|
||||
- The `task` you provide should be a single, continuous goal. The code agent is capable of handling a multi-step process internally (e.g., opening a file, processing its data, and then saving it) to achieve this one goal.
|
||||
- **Crucially, do not pass a task that combines multiple distinct objectives.** For example, instead of passing "Analyze the sales data, AND email the result," you should first pass the self-contained goal: "Analyze the sales data." After that goal is complete, you can proceed with the next logical goal (e.g., emailing the result) in a subsequent step.
|
||||
- **If unsure, err on the side of caution.** If a task feels like it has two separate parts, break it down and pass only the first part.
|
||||
- Your instruction must describe the desired end-state, NOT the recipe to get there. Do not specify any solution!
|
||||
|
||||
**Goal Purity is Essential:**
|
||||
- **NEVER** rephrase, paraphrase, or modify the subtask instruction you have decided on. Pass the exact, original wording of the subtask to prevent instruction drift and hallucination.
|
||||
|
||||
Use this for tasks that can be fully accomplished through code execution, particularly for:
|
||||
- Spreadsheet applications: data processing, filtering, sorting, calculations, formulas, data analysis
|
||||
- Document editors: text processing, content editing, formatting, document manipulation
|
||||
- Code editors: code editing, file processing, text manipulation, configuration
|
||||
- Data analysis tools: statistical analysis, data transformation, reporting
|
||||
- File management: bulk operations, file processing, content extraction
|
||||
- System utilities: configuration, setup, automation
|
||||
"""
|
||||
logger.info("=" * 50)
|
||||
logger.info("ACI: Calling Code Agent")
|
||||
logger.info("=" * 50)
|
||||
task_to_execute = task
|
||||
logger.info(f"Executing SUBTASK: {task_to_execute}")
|
||||
|
||||
print("obs keys: ", self.obs.keys())
|
||||
screenshot = self.obs.get("screenshot", "") if self.obs else ""
|
||||
logger.info(f"Screenshot available: {'Yes' if screenshot else 'No'}")
|
||||
|
||||
logger.info("Executing code agent...")
|
||||
|
||||
result = self.coder_agent.execute(
|
||||
task_to_execute, screenshot, self.env.controller
|
||||
)
|
||||
|
||||
# Store the result for the worker to access
|
||||
self.last_code_agent_result = result
|
||||
|
||||
logger.info("Code agent execution completed")
|
||||
logger.info(f"Result - Completion reason: {result['completion_reason']}")
|
||||
logger.info(f"Steps executed: {result['steps_executed']}")
|
||||
logger.info(f"Summary: {result['summary']}")
|
||||
|
||||
logger.info("=" * 50)
|
||||
logger.info("GROUNDING AGENT: Code Agent Call Finished")
|
||||
logger.info("=" * 50)
|
||||
|
||||
action = {"function": "call_code_agent", "args": {"query": task, "result": True if result["completion_reason"] == "DONE" else False}}
|
||||
# Return code to be executed in the environment
|
||||
return ("import time; time.sleep(2.222)", action)
|
||||
|
||||
@agent_action
|
||||
def scroll(self, element_description: str, clicks: int, shift: bool = False):
|
||||
"""Scroll the element in the specified direction
|
||||
Args:
|
||||
element_description:str, a very detailed description of which element to enter scroll in. This description should be at least a full sentence.
|
||||
clicks:int, the number of clicks to scroll can be positive (up) or negative (down).
|
||||
shift:bool, whether to use shift+scroll for horizontal scrolling
|
||||
"""
|
||||
x, y = self.grounder_agent.generate_coords(element_description, self.obs)
|
||||
action = {"function": "scroll", "args": {"x": x, "y": y, "clicks": clicks, "shift": shift}}
|
||||
if shift:
|
||||
return (f"import pyautogui; import time; pyautogui.moveTo({x}, {y}); time.sleep(0.5); pyautogui.hscroll({clicks})", action)
|
||||
else:
|
||||
return (f"import pyautogui; import time; pyautogui.moveTo({x}, {y}); time.sleep(0.5); pyautogui.vscroll({clicks})", action)
|
||||
|
||||
@agent_action
|
||||
def hotkey(self, keys: List):
|
||||
"""Press a hotkey combination (can press a single key as well)
|
||||
Args:
|
||||
keys:List the keys to press in combination in a list format (e.g. ['ctrl', 'c'], ['enter'])
|
||||
"""
|
||||
# add quotes around the keys
|
||||
keys = [f"'{key}'" for key in keys]
|
||||
keys_string = " ".join(keys)
|
||||
action = {"function": "key", "args": {"keys": keys_string}}
|
||||
return (f"import pyautogui; pyautogui.hotkey({', '.join(keys)});", action)
|
||||
|
||||
@agent_action
|
||||
def hold_and_press(self, hold_keys: List, press_keys: List):
|
||||
"""Hold a list of keys and press a list of keys
|
||||
Args:
|
||||
hold_keys:List, list of keys to hold
|
||||
press_keys:List, list of keys to press in a sequence
|
||||
"""
|
||||
|
||||
press_keys_str = "[" + ", ".join([f"'{key}'" for key in press_keys]) + "]"
|
||||
command = "import pyautogui; "
|
||||
for k in hold_keys:
|
||||
command += f"pyautogui.keyDown({repr(k)}); "
|
||||
command += f"pyautogui.press({press_keys_str}); "
|
||||
for k in hold_keys:
|
||||
command += f"pyautogui.keyUp({repr(k)}); "
|
||||
|
||||
hold_keys_string = " ".join(hold_keys)
|
||||
press_keys_string = " ".join(press_keys)
|
||||
action = {"function": "key", "args": {"keys": hold_keys_string + ";" + press_keys_string}}
|
||||
return (command, action)
|
||||
|
||||
@agent_action
|
||||
def wait(self, time: float):
|
||||
"""Wait for a specified amount of time
|
||||
Args:
|
||||
time:float, the amount of time to wait in seconds
|
||||
"""
|
||||
return (f"""import time; time.sleep({time});""", {"function": "wait", "args": {}})
|
||||
|
||||
@agent_action
|
||||
def done(
|
||||
self,
|
||||
):
|
||||
"""
|
||||
End the current task with a success. Use this when you believe the entire task has been fully completed. You must ensure all visual information aligns with the user's true intent.
|
||||
"""
|
||||
return ("""DONE""", {"function": "done", "args": {}})
|
||||
|
||||
@agent_action
|
||||
def fail(self):
|
||||
"""End the current task with a failure. Use this when you believe the entire task is impossible to complete."""
|
||||
return ("""FAIL""", {"function": "fail", "args": {}})
|
||||
|
||||
@agent_action
|
||||
def call_search_agent(
|
||||
self,
|
||||
query: str,
|
||||
):
|
||||
"""
|
||||
Calls a specialized 'Searcher Agent' to find a detailed, step-by-step tutorial on the internet for a specific GUI action.
|
||||
Args:
|
||||
query:str, the search phrase or question for the tutorial. The formulation of this query is critical for success and must follow the guidelines below.
|
||||
|
||||
**Query Formulation Guidelines:**
|
||||
|
||||
Your query must be a well-defined question targeting a **single, specific action** within a **specific application**. To get the best results, adhere to these rules:
|
||||
|
||||
1. **Start with "How to":** Your query must begin with the phrase "How to" to frame it as a request for instructions.
|
||||
2. **Include the Application Name:** Always specify the name of the software you are working in (e.g., "GIMP", "Google Chrome", "Libreoffice Writer").
|
||||
3. **Focus on a Single Intent:** The query should represent one clear goal. Do not combine multiple steps or tasks into one query.
|
||||
4. **Be Specific, Not Abstract:** Ask a concrete question. Avoid repeating the user's high-level or abstract instructions.
|
||||
5. **Decompose Complex Tasks:** If the user's overall instruction involves multiple actions (e.g., "download a file and then email it"), and you are stuck on one part, search *only for that specific part*.
|
||||
|
||||
**Examples:**
|
||||
|
||||
* **User's Overall Instruction:** "Please help me download my latest bank statement and then send it to my accountant."
|
||||
* **Correct Query (if stuck on downloading):** "How to download a bank statement from the Bank of America website?"
|
||||
* **Correct Query (if stuck on attaching a file):** "How to attach a file to an email in Gmail?"
|
||||
* **Incorrect Query:** "Download my bank statement and email it to my accountant" *(This query is too broad, contains multiple sub-tasks, and does not start with "How to".)*
|
||||
"""
|
||||
logger.info("=" * 50)
|
||||
logger.info(f"ACI: Calling Search Agent(query={query})")
|
||||
logger.info("=" * 50)
|
||||
self.searcher_agent.result_dir = self.result_dir
|
||||
result = self.searcher_agent.search(query=query, main_obs=self.obs)
|
||||
self.last_search_agent_result = result
|
||||
if result["completion_reason"] == "DONE":
|
||||
self.tutorials.append(result["final_answer"])
|
||||
action = {"function": "call_search_agent", "args": {"query": query, "result": True if result["completion_reason"] == "DONE" else False}}
|
||||
return ("import time; time.sleep(2.222)", action)
|
||||
|
||||
|
|
@ -0,0 +1,61 @@
|
|||
import logging
|
||||
import platform
|
||||
from typing import Dict, List, Tuple
|
||||
from mm_agents.os_symphony.agents.os_aci import OSACI
|
||||
from mm_agents.os_symphony.agents.searcher_agent import VLMSearcherAgent
|
||||
from mm_agents.os_symphony.agents.worker import Worker
|
||||
|
||||
logger = logging.getLogger("desktopenv.agent")
|
||||
|
||||
class OSSymphony:
|
||||
def __init__(
|
||||
self,
|
||||
engine_params_for_orchestrator: Dict,
|
||||
engine_params_for_memoryer: Dict,
|
||||
os_aci: OSACI,
|
||||
platform: str = platform.system().lower(),
|
||||
client_password: str = "",
|
||||
max_trajectory_length: int = 8,
|
||||
enable_reflection: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
worker_engine_params: Configuration parameters for the worker agent.
|
||||
grounding_agent: Instance of ACI class for UI interaction
|
||||
platform: Operating system platform (darwin, linux, windows)
|
||||
max_trajectory_length: Maximum number of image turns to keep
|
||||
enable_reflection: Creates a reflection agent to assist the worker agent
|
||||
"""
|
||||
|
||||
self.engine_params_for_orchestrator = engine_params_for_orchestrator
|
||||
self.engine_params_for_memoryer = engine_params_for_memoryer
|
||||
self.os_aci: OSACI = os_aci
|
||||
self.platform =platform
|
||||
self.client_password = client_password
|
||||
self.max_trajectory_length = max_trajectory_length
|
||||
self.enable_reflection = enable_reflection
|
||||
|
||||
def reset(self, result_dir) -> None:
|
||||
"""Reset agent state and initialize components"""
|
||||
# Reset the search time per task
|
||||
self.os_aci.result_dir = result_dir
|
||||
self.executor = Worker(
|
||||
engine_params_for_orchestrator=self.engine_params_for_orchestrator,
|
||||
engine_params_for_memoryer=self.engine_params_for_memoryer,
|
||||
os_aci=self.os_aci,
|
||||
platform=self.platform,
|
||||
client_password=self.client_password,
|
||||
max_trajectory_length=self.max_trajectory_length,
|
||||
enable_reflection=self.enable_reflection,
|
||||
)
|
||||
|
||||
def predict(self, instruction: str, observation: Dict, is_last_step: bool) -> Tuple[Dict, List[str]]:
|
||||
# Initialize the three info dictionaries
|
||||
executor_info, actions = self.executor.generate_next_action(
|
||||
instruction=instruction, obs=observation, is_last_step=is_last_step
|
||||
)
|
||||
|
||||
# concatenate the three info dictionaries
|
||||
info = {**{k: v for d in [executor_info or {}] for k, v in d.items()}}
|
||||
|
||||
return info, actions
|
||||
|
|
@ -0,0 +1,478 @@
|
|||
import logging
|
||||
import urllib.parse
|
||||
from typing import Any, Dict, List, Optional
|
||||
from mm_agents.os_symphony.memory.procedural_memory import PROCEDURAL_MEMORY
|
||||
from mm_agents.os_symphony.utils.common_utils import (
|
||||
draw_coordinates,
|
||||
call_llm_formatted,
|
||||
parse_code_from_string,
|
||||
create_pyautogui_code
|
||||
)
|
||||
from mm_agents.os_symphony.core.mllm import LMMAgent
|
||||
from mm_agents.os_symphony.agents.grounder_agent import GrounderAgent
|
||||
import os
|
||||
import time
|
||||
import json
|
||||
|
||||
|
||||
logger = logging.getLogger("desktopenv.searcher_agent")
|
||||
|
||||
# Agent action decorator
|
||||
def searcher_agent_action(func):
|
||||
func.is_searcher_agent_action = True
|
||||
return func
|
||||
|
||||
|
||||
# --- Abstract Base Class and Factory ---
|
||||
class SearcherAgent:
|
||||
def __init__(self, engine_params: Dict, platform: str):
|
||||
self.engine_params = engine_params
|
||||
self.result_dir = ""
|
||||
self.tutorial_or_hint = ""
|
||||
self.tutorial_notes = []
|
||||
self.max_trajectory_length = 8
|
||||
self.platform = platform
|
||||
self.budget = engine_params.get("budget", 20)
|
||||
|
||||
@staticmethod
|
||||
def create(engine_params: Dict, search_env, grounder_agent: GrounderAgent, platform: str, client_password: str="password"):
|
||||
searcher_type = engine_params.get("type", "vlm")
|
||||
if searcher_type == "vlm":
|
||||
return VLMSearcherAgent(engine_params=engine_params, search_env=search_env, grounder_agent=grounder_agent, platform=platform, client_password=client_password)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_search_time(self) -> int:
|
||||
"""for the name of result directory"""
|
||||
if not self.result_dir: return 1
|
||||
search_times: list[int] = []
|
||||
try:
|
||||
if not os.path.exists(self.result_dir): return 1
|
||||
for item_name in os.listdir(self.result_dir):
|
||||
full_path = os.path.join(self.result_dir, item_name)
|
||||
if os.path.isdir(full_path) and item_name.startswith("search_"):
|
||||
try:
|
||||
time_val = int(item_name.split('_', 1)[1])
|
||||
search_times.append(time_val)
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
except Exception:
|
||||
return 1
|
||||
if not search_times: return 1
|
||||
return max(search_times) + 1
|
||||
|
||||
def search(self, query: str, obs) -> str:
|
||||
"""
|
||||
Args:
|
||||
query: Format like "How to xxxx?", must be a detailed subtask
|
||||
obs: Current screenshot
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement the 'search' method")
|
||||
|
||||
class VLMSearcherAgent(SearcherAgent):
|
||||
"""
|
||||
Start a new, isolated vm, and open chrome in advance
|
||||
"""
|
||||
def __init__(self, engine_params: Dict, search_env, grounder_agent: GrounderAgent, platform: str, client_password: str):
|
||||
SearcherAgent.__init__(self, engine_params=engine_params, platform=platform)
|
||||
|
||||
self.grounder_agent = grounder_agent
|
||||
self.client_password = client_password
|
||||
self.env = search_env
|
||||
|
||||
self.use_thinking = engine_params.get("model", "") in [
|
||||
"claude-opus-4-20250514",
|
||||
"claude-sonnet-4-20250514",
|
||||
"claude-3-7-sonnet-20250219",
|
||||
"claude-sonnet-4-5-20250929",
|
||||
]
|
||||
|
||||
self.engine = engine_params.get("engine", "google")
|
||||
|
||||
# Reuse OSWorld's initialization script to set up Chrome, then directly perform a Google search using the query—currently, the query can be substituted by a placeholder field.
|
||||
self.task_config = {
|
||||
"id": "searcher",
|
||||
"instruction": "searcher",
|
||||
"config": [
|
||||
{
|
||||
"type": "launch",
|
||||
"parameters": {
|
||||
"command": [
|
||||
"google-chrome",
|
||||
"--remote-debugging-port=1337"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "launch",
|
||||
"parameters": {
|
||||
"command": [
|
||||
"socat",
|
||||
"tcp-listen:9222,fork",
|
||||
"tcp:localhost:1337"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "chrome_open_tabs",
|
||||
"parameters": {
|
||||
"urls_to_open": [
|
||||
"GOOGLE_SEARCH_URL"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "activate_window",
|
||||
"parameters": {
|
||||
"window_name": "Google Chrome"
|
||||
}
|
||||
}
|
||||
],
|
||||
"proxy": True
|
||||
}
|
||||
self.obs = None
|
||||
|
||||
def reset(self, query):
|
||||
# When the search function is invoked, a new agent is created; the environment is instantiated only upon the first call, but it must be reset on every invocation.
|
||||
self.tutorial_notes = []
|
||||
self.tutorial_or_hint = ""
|
||||
self.system_prompt = PROCEDURAL_MEMORY.construct_vlm_searcher_procedural_memory(
|
||||
agent_class=type(self)
|
||||
).replace("CURRENT_OS", self.platform).replace("QUERY", query)
|
||||
self.searcher_agent = LMMAgent(
|
||||
engine_params=self.engine_params,
|
||||
system_prompt=self.system_prompt
|
||||
)
|
||||
self.env.start()
|
||||
# config URL and initialize search environment (google/duckduckgo)
|
||||
search_url = f"https://www.google.com/search?q=" + urllib.parse.quote_plus(query) if self.engine == "google" else f"https://www.duckduckgo.com/?q=" + urllib.parse.quote_plus(query)
|
||||
self.task_config["config"][2]["parameters"]["urls_to_open"][0] = search_url
|
||||
|
||||
self.env.reset(task_config=self.task_config)
|
||||
print("[Searcher] sleeping...")
|
||||
time.sleep(5)
|
||||
|
||||
def flush_messages(self):
|
||||
"""Flush messages based on the model's context limits.
|
||||
|
||||
This method ensures that the agent's message history does not exceed the maximum trajectory length.
|
||||
|
||||
Side Effects:
|
||||
- Modifies the messages of generator, reflection, and bon_judge agents to fit within the context limits.
|
||||
"""
|
||||
engine_type = self.engine_params.get("engine_type", "")
|
||||
|
||||
# Flush strategy for long-context models: keep all text, only keep latest images
|
||||
if engine_type in ["anthropic", "openai", "gemini"]:
|
||||
max_images = self.max_trajectory_length
|
||||
for agent in [self.searcher_agent]:
|
||||
if agent is None:
|
||||
continue
|
||||
# keep latest k images
|
||||
# @Yang: keep the first main agent image
|
||||
img_count = 0
|
||||
for i in range(len(agent.messages) - 1, 1, -1):
|
||||
for j in range(len(agent.messages[i]["content"]) - 1, -1, -1):
|
||||
if "image" in agent.messages[i]["content"][j].get("type", ""):
|
||||
img_count += 1
|
||||
if img_count > max_images:
|
||||
del agent.messages[i]["content"][j]
|
||||
|
||||
# Flush strategy for non-long-context models: drop full turns
|
||||
else:
|
||||
# generator msgs are alternating [user, assistant], so 2 per round
|
||||
if len(self.searcher_agent.messages) > 2 * self.max_trajectory_length + 1:
|
||||
self.searcher_agent.messages.pop(1)
|
||||
self.searcher_agent.messages.pop(1)
|
||||
|
||||
def assign_screenshot(self, obs):
|
||||
self.obs = obs
|
||||
|
||||
def search(self, query: str, main_obs):
|
||||
# only create vm when search is called
|
||||
self.reset(query=query) # reset
|
||||
search_result_dir = os.path.join(self.result_dir, f"search_{self._get_search_time()}")
|
||||
os.makedirs(search_result_dir, exist_ok=True)
|
||||
|
||||
obs = self.env._get_obs() # Get the initial observation
|
||||
step_idx = 0
|
||||
initial_state_text = (
|
||||
"This screenshot shows the current visual context of the main GUI Agent you are assisting. "
|
||||
"Use this image to understand the application, the current view, and the overall environment. "
|
||||
"Your primary goal is to find a tutorial that is highly relevant and well-aligned with this specific context, "
|
||||
"ensuring the instructions you find are applicable to what the main agent is currently seeing."
|
||||
)
|
||||
self.searcher_agent.add_message(
|
||||
text_content=initial_state_text,
|
||||
image_content=main_obs["screenshot"],
|
||||
role="user"
|
||||
)
|
||||
execution_history = []
|
||||
completion_reason = ""
|
||||
final_answer = ""
|
||||
|
||||
while step_idx < self.budget:
|
||||
# update system_prompt dynamically
|
||||
tutorial_notes_str = ""
|
||||
if len(self.tutorial_notes) > 0:
|
||||
for i, note in enumerate(self.tutorial_notes, 1):
|
||||
tutorial_notes_str += f"Tutorial Note {i}: {note}\n\n"
|
||||
|
||||
if step_idx == self.budget - 1:
|
||||
# eager mode
|
||||
self.system_prompt = PROCEDURAL_MEMORY.construct_searcher_eager_mode_procedural_memory(
|
||||
agent_class=type(self)
|
||||
).replace("CURRENT_OS", self.platform).replace("QUERY", query)
|
||||
|
||||
system_prompt = self.system_prompt.replace("TUTORIAL_PLACEHOLDER", tutorial_notes_str)
|
||||
self.searcher_agent.add_system_prompt(system_prompt=system_prompt)
|
||||
|
||||
# start a new turn
|
||||
self.assign_screenshot(obs=obs)
|
||||
generator_message = ""
|
||||
|
||||
self.searcher_agent.add_message(
|
||||
generator_message, image_content=obs["screenshot"], role="user"
|
||||
)
|
||||
format_checkers = []
|
||||
|
||||
# predict action
|
||||
plan = call_llm_formatted(
|
||||
self.searcher_agent,
|
||||
format_checkers,
|
||||
temperature=self.engine_params.get("temperture", 0.1),
|
||||
use_thinking=self.use_thinking,
|
||||
)
|
||||
|
||||
self.searcher_agent.add_message(plan, role="assistant")
|
||||
execution_history.append(plan)
|
||||
logger.info("SEARCHER PLAN:\n %s", plan)
|
||||
|
||||
plan_code = parse_code_from_string(plan)
|
||||
try:
|
||||
assert plan_code, "Plan code should not be empty"
|
||||
# exec_code e.g. import pyautogui; pyautogui.click(1, 2);
|
||||
exec_code, coords = create_pyautogui_code(self, plan_code, obs)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Could not evaluate the following plan code:\n{plan_code}\nError: {e}"
|
||||
)
|
||||
exec_code = self.wait(
|
||||
1.333
|
||||
) # Skip a turn if the code cannot be evaluated
|
||||
|
||||
self.flush_messages()
|
||||
|
||||
# execute action
|
||||
action = exec_code
|
||||
logger.info("Step %d: %s", step_idx + 1, action)
|
||||
|
||||
# Save screenshot and trajectory information
|
||||
with open(os.path.join(search_result_dir, f"step_{step_idx + 1}.png"),
|
||||
"wb") as _f:
|
||||
_f.write(obs['screenshot'])
|
||||
|
||||
if coords is not None and isinstance(coords, list):
|
||||
draw_coordinates(
|
||||
image_bytes=obs['screenshot'],
|
||||
coordinates=coords,
|
||||
save_path=os.path.join(search_result_dir, f"step_{step_idx + 1}_draw.png")
|
||||
)
|
||||
|
||||
with open(os.path.join(search_result_dir, "traj.jsonl"), "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps({
|
||||
"query": query,
|
||||
"step_num": step_idx + 1,
|
||||
"action": action,
|
||||
"response": {
|
||||
"plan": plan,
|
||||
"plan_code": plan_code,
|
||||
"coordinates": coords
|
||||
},
|
||||
"screenshot_file": f"step_{step_idx + 1}.png"
|
||||
}, ensure_ascii=False))
|
||||
f.write("\n")
|
||||
|
||||
with open(os.path.join(search_result_dir, f"traj_{step_idx+1}.json"), "w", encoding="utf-8") as f:
|
||||
json.dump({
|
||||
"query": query,
|
||||
"step_num": step_idx + 1,
|
||||
"action": action,
|
||||
"response": {
|
||||
"plan": plan,
|
||||
"plan_code": plan_code,
|
||||
"coordinates": coords
|
||||
},
|
||||
"screenshot_file": f"step_{step_idx + 1}.png"
|
||||
}, f, indent=4, ensure_ascii=False)
|
||||
|
||||
if exec_code in ["DONE", "FAIL"]:
|
||||
# terminate loop
|
||||
completion_reason = exec_code
|
||||
final_answer = self.tutorial_or_hint
|
||||
break
|
||||
else:
|
||||
obs, _, _, _ = self.env.step(action, 5)
|
||||
|
||||
step_idx += 1
|
||||
|
||||
if completion_reason == "":
|
||||
completion_reason = "BUDGET_EXHAUSTED"
|
||||
final_answer = "Sorry, can't get the useful tutorial about the GUI task you provided."
|
||||
|
||||
return {
|
||||
"query": query,
|
||||
"completion_reason": completion_reason,
|
||||
"tutorial_notes": self.tutorial_notes,
|
||||
"execution_history": execution_history,
|
||||
"steps_executed": step_idx,
|
||||
"budget": self.budget,
|
||||
"final_answer": final_answer,
|
||||
}
|
||||
|
||||
@searcher_agent_action
|
||||
def click(
|
||||
self,
|
||||
element_description: str,
|
||||
num_clicks: int = 1,
|
||||
button_type: str = "left",
|
||||
):
|
||||
"""Click on the element
|
||||
Args:
|
||||
element_description:str, a detailed descriptions of which element to click on. This description should be at least a full sentence.
|
||||
num_clicks:int, number of times to click the element
|
||||
button_type:str, which mouse button to press can be "left", "middle", or "right"
|
||||
"""
|
||||
x, y = self.grounder_agent.generate_coords(element_description, self.obs)
|
||||
command = "import pyautogui; "
|
||||
command += f"""import pyautogui; pyautogui.click({x}, {y}, clicks={num_clicks}, button={repr(button_type)}); """
|
||||
|
||||
# Return pyautoguicode to click on the element
|
||||
return (command, [x, y])
|
||||
|
||||
@searcher_agent_action
|
||||
def type(
|
||||
self,
|
||||
element_description: Optional[str] = None,
|
||||
text: str = "",
|
||||
overwrite: bool = True,
|
||||
enter: bool = False
|
||||
):
|
||||
"""Type text/unicode into a specific element
|
||||
Args:
|
||||
element_description:str, a detailed description of which element to enter text in. This description should be at least a full sentence.
|
||||
text:str, the text to type
|
||||
overwrite:bool, Default is True, assign it to False if the text should not overwrite the existing text. Using this argument clears all text in an element.
|
||||
enter:bool, Assign it to True if the enter key should be pressed after typing the text, otherwise assign it to False.
|
||||
"""
|
||||
commands = (
|
||||
"import os;"
|
||||
"import pyautogui;"
|
||||
"import pyperclip;"
|
||||
"import subprocess;"
|
||||
"import time;"
|
||||
"p_http = os.environ.get('http_proxy') or os.environ.get('HTTP_PROXY');"
|
||||
"p_https = os.environ.get('https_proxy') or os.environ.get('HTTPS_PROXY');"
|
||||
"proxy_prefix = (f'http_proxy={p_http} ' if p_http else '') + (f'https_proxy={p_https} ' if p_https else '');"
|
||||
f"subprocess.run(f'echo \"{self.client_password}\" | sudo -S {{proxy_prefix}}apt-get install -y xclip xsel', shell=True, check=True);"
|
||||
)
|
||||
|
||||
|
||||
|
||||
click_coords = None
|
||||
if element_description is not None:
|
||||
x, y = self.grounder_agent.generate_coords(element_description, self.obs)
|
||||
click_coords = [x, y]
|
||||
|
||||
commands += f"pyautogui.click({x}, {y});"
|
||||
|
||||
if overwrite:
|
||||
commands += (
|
||||
f"pyautogui.hotkey('ctrl', 'a');"
|
||||
"pyautogui.press('backspace');"
|
||||
)
|
||||
|
||||
# use paste to input
|
||||
commands += (
|
||||
"original_clipboard = pyperclip.paste();"
|
||||
f"pyperclip.copy({repr(text)});"
|
||||
"pyautogui.hotkey('ctrl', 'v');"
|
||||
"pyperclip.copy(original_clipboard);"
|
||||
)
|
||||
|
||||
if enter:
|
||||
commands += "pyautogui.press('enter');"
|
||||
|
||||
if click_coords is not None:
|
||||
return (commands, click_coords)
|
||||
else:
|
||||
return commands
|
||||
|
||||
@searcher_agent_action
|
||||
def scroll(self, element_description: str, clicks: int, shift: bool = False):
|
||||
"""Scroll the element in the specified direction
|
||||
Args:
|
||||
element_description:str, a very detailed description of which element to enter scroll in. This description should be at least a full sentence.
|
||||
clicks:int, the number of clicks to scroll can be positive (up) or negative (down).
|
||||
shift:bool, whether to use shift+scroll for horizontal scrolling
|
||||
"""
|
||||
x, y = self.grounder_agent.generate_coords(element_description, self.obs)
|
||||
|
||||
if shift:
|
||||
return (f"import pyautogui; import time; pyautogui.moveTo({x}, {y}); time.sleep(0.5); pyautogui.hscroll({clicks})", [x, y])
|
||||
else:
|
||||
return (f"import pyautogui; import time; pyautogui.moveTo({x}, {y}); time.sleep(0.5); pyautogui.vscroll({clicks})", [x, y])
|
||||
|
||||
@searcher_agent_action
|
||||
def hotkey(self, keys: List):
|
||||
"""Press a hotkey combination (can press a single key as well)
|
||||
Args:
|
||||
keys: List the keys to press in combination in a list format (e.g. ['ctrl', 'c'], ['enter'])
|
||||
"""
|
||||
# add quotes around the keys
|
||||
keys = [f"'{key}'" for key in keys]
|
||||
return f"import pyautogui; pyautogui.hotkey({', '.join(keys)})"
|
||||
|
||||
@searcher_agent_action
|
||||
def save_to_tutorial_notes(self, text: str):
|
||||
"""Save high quality and useful information to a long-term knowledge bank for reuse during this search task.
|
||||
Args:
|
||||
text:str, the text to save to the tutorial notes
|
||||
"""
|
||||
self.tutorial_notes.append(text)
|
||||
return """WAIT"""
|
||||
|
||||
@searcher_agent_action
|
||||
def wait(self, time: float):
|
||||
"""Wait for a specified amount of time
|
||||
Args:
|
||||
time:float the amount of time to wait in seconds
|
||||
"""
|
||||
return f"""import time; time.sleep({time})"""
|
||||
|
||||
@searcher_agent_action
|
||||
def done(
|
||||
self,
|
||||
tutorial: str
|
||||
):
|
||||
"""End the current task with a success. Use this when you believe the entire task has been fully completed.
|
||||
Args:
|
||||
tutorial:str, A detailed, step-by-step tutorial compiled from the search results to be passed to the main agent.
|
||||
"""
|
||||
self.tutorial_or_hint = tutorial
|
||||
return """DONE"""
|
||||
|
||||
@searcher_agent_action
|
||||
def fail(
|
||||
self,
|
||||
hint: str
|
||||
):
|
||||
"""End the current task with a failure. Use this when you believe the entire task is impossible to complete.
|
||||
Args:
|
||||
hint:str, A hint or reason explaining why the search failed, or what kind of information was missing.
|
||||
"""
|
||||
self.tutorial_or_hint = hint
|
||||
return """FAIL"""
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,340 @@
|
|||
from functools import partial
|
||||
import logging
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from mm_agents.os_symphony.agents.memoryer_agent import ReflectionMemoryAgent
|
||||
from mm_agents.os_symphony.agents.os_aci import OSACI
|
||||
from mm_agents.os_symphony.core.module import BaseModule
|
||||
from mm_agents.os_symphony.memory.procedural_memory import PROCEDURAL_MEMORY
|
||||
from mm_agents.os_symphony.utils.common_utils import (
|
||||
call_llm_formatted,
|
||||
extract_coords_from_action_dict,
|
||||
parse_action_from_string,
|
||||
parse_code_from_string,
|
||||
create_pyautogui_code,
|
||||
)
|
||||
from mm_agents.os_symphony.utils.formatters import (
|
||||
SINGLE_ACTION_FORMATTER,
|
||||
CODE_VALID_FORMATTER,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger("desktopenv.agent")
|
||||
|
||||
|
||||
class Worker(BaseModule):
|
||||
def __init__(
|
||||
self,
|
||||
engine_params_for_orchestrator: Dict,
|
||||
engine_params_for_memoryer: Dict,
|
||||
os_aci: OSACI,
|
||||
platform: str,
|
||||
client_password: str,
|
||||
max_trajectory_length: int = 8,
|
||||
enable_reflection: bool = True,
|
||||
):
|
||||
"""
|
||||
Worker receives the main task and generates actions, without the need of hierarchical planning
|
||||
Args:
|
||||
worker_engine_params: Dict
|
||||
Parameters for the worker agent
|
||||
os_aci: Agent
|
||||
The grounding agent to use
|
||||
platform: str
|
||||
OS platform the agent runs on (darwin, linux, windows)
|
||||
max_trajectory_length: int
|
||||
The amount of images turns to keep
|
||||
enable_reflection: bool
|
||||
Whether to enable reflection
|
||||
"""
|
||||
super().__init__(platform=platform)
|
||||
self.client_password = client_password
|
||||
|
||||
self.temperature = engine_params_for_orchestrator.get("temperature", 0.0)
|
||||
self.tool_config = engine_params_for_orchestrator.get("tool_config", "")
|
||||
self.use_thinking = engine_params_for_orchestrator.get("model", "") in [
|
||||
"claude-opus-4-20250514",
|
||||
"claude-sonnet-4-20250514",
|
||||
"claude-3-7-sonnet-20250219",
|
||||
"claude-sonnet-4-5-20250929",
|
||||
]
|
||||
self.engine_params_for_orchestrator = engine_params_for_orchestrator
|
||||
self.engine_params_for_memoryer = engine_params_for_memoryer
|
||||
self.os_aci: OSACI = os_aci
|
||||
|
||||
self.max_trajectory_length = max_trajectory_length if not self.engine_params_for_orchestrator.get("keep_first_image", False) else max_trajectory_length - 1
|
||||
self.enable_reflection = enable_reflection
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
# set_cell_values only occurs in linux; meanwhile there is no fail option in the other benchmarks
|
||||
if self.platform in ["windows", "macos"]:
|
||||
skipped_actions = ["set_cell_values", "fail"]
|
||||
else:
|
||||
skipped_actions = []
|
||||
|
||||
# Hide code agent action entirely if no env/controller is available
|
||||
if not getattr(self.os_aci, "env", None) or not getattr(
|
||||
getattr(self.os_aci, "env", None), "controller", None
|
||||
):
|
||||
skipped_actions.append("call_code_agent")
|
||||
|
||||
self.orchestrator_sys_prompt = PROCEDURAL_MEMORY.construct_simple_worker_procedural_memory(
|
||||
agent_class=type(self.os_aci),
|
||||
skipped_actions=skipped_actions,
|
||||
tool_config=self.tool_config,
|
||||
platform=self.platform
|
||||
).replace("CURRENT_OS", self.platform).replace("CLIENT_PASSWORD", self.client_password)
|
||||
|
||||
# Worker contains orchestrator and reflection agent
|
||||
self.orchestrator_agent = self._create_agent(
|
||||
engine_params=self.engine_params_for_orchestrator,
|
||||
system_prompt=self.orchestrator_sys_prompt
|
||||
|
||||
)
|
||||
self.memoryer_agent = ReflectionMemoryAgent(self.engine_params_for_memoryer)
|
||||
|
||||
self.instruction = None
|
||||
self.turn_count = 0
|
||||
self.worker_history = []
|
||||
self.coords_history = []
|
||||
|
||||
# For loop detection
|
||||
self.action_dict_history = []
|
||||
|
||||
def flush_messages(self):
|
||||
"""Flush messages based on the model's context limits.
|
||||
|
||||
This method ensures that the agent's message history does not exceed the maximum trajectory length.
|
||||
|
||||
Side Effects:
|
||||
- Modifies the messages of generator, reflection, and bon_judge agents to fit within the context limits.
|
||||
"""
|
||||
engine_type = self.engine_params_for_orchestrator.get("engine_type", "")
|
||||
|
||||
# Flush strategy for long-context models: keep all text, only keep latest images
|
||||
if engine_type in ["anthropic", "openai", "gemini", "vllm"]:
|
||||
max_images = self.max_trajectory_length
|
||||
# for agent in [self.generator_agent, self.reflection_agent]:
|
||||
for agent in [self.orchestrator_agent]:
|
||||
if agent is None:
|
||||
continue
|
||||
# keep latest k images
|
||||
img_count = 0
|
||||
stop_idx = 1 if self.engine_params_for_orchestrator.get("keep_first_image", False) else -1
|
||||
for i in range(len(agent.messages) - 1, stop_idx, -1):
|
||||
# for j in range(len(agent.messages[i]["content"])):
|
||||
for j in range(len(agent.messages[i]["content"]) - 1, -1, -1):
|
||||
if "image" in agent.messages[i]["content"][j].get("type", ""):
|
||||
img_count += 1
|
||||
if img_count > max_images:
|
||||
del agent.messages[i]["content"][j]
|
||||
|
||||
# Flush strategy for non-long-context models: drop full turns
|
||||
else:
|
||||
# generator msgs are alternating [user, assistant], so 2 per round
|
||||
if len(self.orchestrator_agent.messages) > 2 * self.max_trajectory_length + 1:
|
||||
self.orchestrator_agent.messages.pop(1)
|
||||
self.orchestrator_agent.messages.pop(1)
|
||||
|
||||
|
||||
def generate_next_action(self, instruction: str, obs: Dict, is_last_step: bool) -> Tuple[Dict, List]:
|
||||
"""
|
||||
Predict the next action(s) based on the current observation.
|
||||
"""
|
||||
print("=" * 30, f"Turn {self.turn_count + 1}", "=" * 30)
|
||||
|
||||
print("=" * 10)
|
||||
print(instruction)
|
||||
print("=" * 10)
|
||||
|
||||
self.os_aci.assign_screenshot(obs)
|
||||
self.os_aci.set_task_instruction(instruction)
|
||||
|
||||
|
||||
generator_message = (
|
||||
""
|
||||
if self.turn_count > 0
|
||||
else "The initial screen is provided. No action has been taken yet."
|
||||
)
|
||||
|
||||
|
||||
# Load the task into the system prompt
|
||||
if is_last_step:
|
||||
# Eager mode: must decide done / fail
|
||||
prompt_with_instructions = PROCEDURAL_MEMORY.construct_eager_mode_procedural_memory(agent_class=type(self.os_aci)).replace(
|
||||
"TASK_DESCRIPTION", instruction
|
||||
).replace(
|
||||
"CURRENT_OS", self.platform
|
||||
)
|
||||
print(f'Eager Mode Started, Instruction: {prompt_with_instructions}')
|
||||
self.orchestrator_agent.add_system_prompt(prompt_with_instructions)
|
||||
generator_message += "Note: 'EAGER MODE' is enabled. You must determine whether the task is done or fail in this step!!!"
|
||||
else:
|
||||
tutorials = ""
|
||||
for idx, t in enumerate(self.os_aci.tutorials, start=1):
|
||||
tutorials += f"### Tutorial {idx}:\n {t}\n"
|
||||
|
||||
prompt_with_instructions = self.orchestrator_sys_prompt.replace(
|
||||
"TASK_DESCRIPTION", instruction
|
||||
).replace(
|
||||
"TUTORIAL_PLACEHOLDER", tutorials
|
||||
)
|
||||
|
||||
self.orchestrator_agent.add_system_prompt(prompt_with_instructions)
|
||||
|
||||
# print(self.orchestrator_agent.system_prompt)
|
||||
|
||||
### Reflection Part
|
||||
reflection_info = {}
|
||||
if self.enable_reflection:
|
||||
# set instruction to memory agent
|
||||
self.memoryer_agent.add_instruction(instruction)
|
||||
reflection = None
|
||||
# Differentiate the operation mode of last step
|
||||
last_code_summary = ""
|
||||
mode = "gui"
|
||||
if (
|
||||
hasattr(self.os_aci, "last_code_agent_result")
|
||||
and self.os_aci.last_code_agent_result is not None
|
||||
):
|
||||
# If code agent is called last step, we use its execution result as step behavior.
|
||||
code_result = self.os_aci.last_code_agent_result
|
||||
mode = "code"
|
||||
last_code_summary += f"Subtask Instruction: {code_result['task_instruction']}\nSteps Completed: {code_result['steps_executed']}\nCompletion Reason: {code_result['completion_reason']}\nExec Summary: {code_result['summary']}\n"
|
||||
|
||||
if (
|
||||
hasattr(self.os_aci, "last_search_agent_result")
|
||||
and self.os_aci.last_search_agent_result is not None
|
||||
):
|
||||
mode = "search"
|
||||
# retrieve reflection!!!
|
||||
reflection_info = self.memoryer_agent.get_reflection(
|
||||
cur_obs=obs,
|
||||
# only use the string after "(next action)" in orchestrator's output
|
||||
generator_output=parse_action_from_string(self.worker_history[-1]) if self.turn_count != 0 else "",
|
||||
coordinates=self.coords_history[-1] if self.turn_count != 0 else [],
|
||||
mode=mode,
|
||||
code_exec_summary=last_code_summary,
|
||||
action_dict=self.action_dict_history[-1] if self.turn_count != 0 else {}
|
||||
)
|
||||
reflection = reflection_info['reflection']
|
||||
logger.info(f'[Reflection]: {reflection}')
|
||||
if reflection:
|
||||
generator_message += f"REFLECTION: You MUST use this reflection on the latest action:\n{reflection}\n"
|
||||
else:
|
||||
generator_message += "You should go on with your plan.\n"
|
||||
else:
|
||||
generator_message += "You should go on with your plan.\n"
|
||||
|
||||
|
||||
# Add code agent result from previous step if available (from full task or subtask execution)
|
||||
if (
|
||||
hasattr(self.os_aci, "last_code_agent_result")
|
||||
and self.os_aci.last_code_agent_result is not None
|
||||
):
|
||||
code_result = self.os_aci.last_code_agent_result
|
||||
generator_message += f"\nCODE AGENT RESULT:\n"
|
||||
generator_message += (
|
||||
f"Task/Subtask Instruction: {code_result['task_instruction']}\n"
|
||||
)
|
||||
generator_message += f"Steps Completed: {code_result['steps_executed']}\n"
|
||||
generator_message += f"Max Steps: {code_result['budget']}\n"
|
||||
generator_message += (
|
||||
f"Completion Reason: {code_result['completion_reason']}\n"
|
||||
)
|
||||
generator_message += f"Summary: {code_result['summary']}\n"
|
||||
generator_message += "\n"
|
||||
# Reset the code agent result after adding it to context
|
||||
self.os_aci.last_code_agent_result = None
|
||||
|
||||
if (
|
||||
hasattr(self.os_aci, "last_search_agent_result")
|
||||
and self.os_aci.last_search_agent_result is not None
|
||||
):
|
||||
# Retrieve the result dictionary
|
||||
search_result = self.os_aci.last_search_agent_result
|
||||
|
||||
# Add a clear, distinct header for this section in the prompt
|
||||
generator_message += f"\nSEARCH AGENT RESULT:\n"
|
||||
|
||||
# Add contextual metadata from the search task
|
||||
generator_message += f"Search Query: {search_result['query']}\n"
|
||||
generator_message += f"Search Completion Reason: {search_result['completion_reason']}\n"
|
||||
generator_message += "Search Result: "
|
||||
# Add the most important part: the tutorial found by the agent.
|
||||
# This is given a prominent sub-header so the LLM knows to pay close attention.
|
||||
if search_result["completion_reason"] == "DONE":
|
||||
generator_message += f'Search is completed, the tutorial it found has been already added to your system prompt.\n'
|
||||
elif search_result["completion_reason"] == "FAIL":
|
||||
generator_message += f"Search is fail, the failure reason or the hint is as follow: {search_result['final_answer']}\n"
|
||||
|
||||
|
||||
# CRITICAL: Reset the search agent result after adding it to the context.
|
||||
# This prevents it from being added to the prompt again in the next turn.
|
||||
self.os_aci.last_search_agent_result = None
|
||||
|
||||
|
||||
# Finalize the generator message
|
||||
self.orchestrator_agent.add_message(
|
||||
generator_message, image_content=obs["screenshot"], role="user", put_text_last=True
|
||||
)
|
||||
|
||||
# Generate the plan and next action
|
||||
format_checkers = [
|
||||
SINGLE_ACTION_FORMATTER,
|
||||
partial(CODE_VALID_FORMATTER, self.tool_config),
|
||||
]
|
||||
plan = call_llm_formatted(
|
||||
self.orchestrator_agent,
|
||||
format_checkers,
|
||||
temperature=self.engine_params_for_orchestrator.get("temperture", 0.1),
|
||||
use_thinking=self.use_thinking,
|
||||
)
|
||||
self.worker_history.append(plan)
|
||||
self.orchestrator_agent.add_message(plan, role="assistant")
|
||||
logger.info("PLAN:\n %s", plan)
|
||||
|
||||
# Extract the next action from the plan
|
||||
# 此时的plan code e.g. agent.click('xxxxx', 1)
|
||||
plan_code = parse_code_from_string(plan)
|
||||
action_dict, coordinates = None, None
|
||||
try:
|
||||
assert plan_code, "Plan code should not be empty"
|
||||
# exec_code e.g. import pyautogui; pyautogui.click(1, 2);
|
||||
exec_code, action_dict = create_pyautogui_code(self.os_aci, plan_code, obs)
|
||||
coordinates = extract_coords_from_action_dict(action_dict)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Could not evaluate the following plan code:\n{plan_code}\nError: {e}"
|
||||
)
|
||||
exec_code, action_dict = self.os_aci.wait(
|
||||
1.333
|
||||
) # Skip a turn if the code cannot be evaluated
|
||||
|
||||
self.action_dict_history.append(action_dict)
|
||||
|
||||
executor_info = {
|
||||
"refined_instruction": self.instruction,
|
||||
"plan": plan,
|
||||
"plan_code": plan_code,
|
||||
"exec_code": exec_code,
|
||||
"coordinates": coordinates,
|
||||
"reflection": reflection_info,
|
||||
"code_agent_output": (
|
||||
self.os_aci.last_code_agent_result
|
||||
if hasattr(self.os_aci, "last_code_agent_result")
|
||||
and self.os_aci.last_code_agent_result is not None
|
||||
else None
|
||||
),
|
||||
"search_agent_output": (
|
||||
self.os_aci.last_search_agent_result
|
||||
if hasattr(self.os_aci, "last_search_agent_result")
|
||||
and self.os_aci.last_search_agent_result is not None
|
||||
else None
|
||||
)
|
||||
}
|
||||
self.turn_count += 1
|
||||
self.coords_history.append(coordinates)
|
||||
self.flush_messages()
|
||||
return executor_info, [exec_code]
|
||||
|
|
@ -0,0 +1,480 @@
|
|||
import copy
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import base64
|
||||
import backoff
|
||||
from anthropic import Anthropic
|
||||
from openai import (
|
||||
AzureOpenAI,
|
||||
APIConnectionError,
|
||||
APIError,
|
||||
AzureOpenAI,
|
||||
OpenAI,
|
||||
RateLimitError,
|
||||
)
|
||||
logger = logging.getLogger("desktopenv.agents.engine")
|
||||
|
||||
logger = logging.getLogger("desktopenv.agents.engine")
|
||||
|
||||
|
||||
class LMMEngine:
|
||||
pass
|
||||
|
||||
|
||||
class LMMEngineOpenAI(LMMEngine):
|
||||
def __init__(
|
||||
self,
|
||||
base_url=None,
|
||||
api_key=None,
|
||||
model=None,
|
||||
rate_limit=-1,
|
||||
temperature=None,
|
||||
organization=None,
|
||||
**kwargs,
|
||||
):
|
||||
assert model is not None, "model must be provided"
|
||||
self.model = model
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key
|
||||
self.organization = organization
|
||||
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
|
||||
self.llm_client = None
|
||||
self.temperature = temperature # Can force temperature to be the same (in the case of o3 requiring temperature to be 1)
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
||||
)
|
||||
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
|
||||
api_key = self.api_key or os.getenv("OPENAI_API_KEY")
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"An API Key needs to be provided in either the api_key parameter or as an environment variable named OPENAI_API_KEY"
|
||||
)
|
||||
organization = self.organization or os.getenv("OPENAI_ORG_ID")
|
||||
|
||||
# H 集群认证 最后再删!!!!!!
|
||||
if self.model.lower().startswith("ui") or self.model.lower().startswith("qwen") or self.model.lower().startswith("scale") or self.model.lower().startswith("holo"):
|
||||
custom_headers = {
|
||||
"Authorization": "Basic NWFkMzQxMDBlZTA1NWE0YmFlNjYzNzBhNWU2ODNiYWM6NjA3ZGU4MjQ5NjU3YTNiM2JkMDM2ZGM5NmQ0YzBiMmY="
|
||||
}
|
||||
else:
|
||||
custom_headers = {}
|
||||
if not self.llm_client:
|
||||
if not self.base_url:
|
||||
self.llm_client = OpenAI(
|
||||
api_key=api_key,
|
||||
organization=organization,
|
||||
default_headers=custom_headers
|
||||
)
|
||||
else:
|
||||
self.llm_client = OpenAI(
|
||||
base_url=self.base_url,
|
||||
api_key=api_key,
|
||||
organization=organization,
|
||||
default_headers=custom_headers
|
||||
)
|
||||
|
||||
# print(**kwargs)
|
||||
payload_size = len(json.dumps(messages)) / 1024 / 1024
|
||||
logger.info(f"Payload size: {len(json.dumps(messages)) / 1024 / 1024:.2f} MB")
|
||||
if payload_size > 30:
|
||||
logger.info("Payload size exceeds 30MB!!!")
|
||||
|
||||
result = self.llm_client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
# max_completion_tokens=max_new_tokens if max_new_tokens else 4096,
|
||||
temperature=(
|
||||
temperature if self.temperature is None else self.temperature
|
||||
),
|
||||
**kwargs,
|
||||
)
|
||||
usage = result.usage
|
||||
response = result.choices[0].message.content
|
||||
return (response, usage)
|
||||
|
||||
|
||||
class LMMEngineAnthropic(LMMEngine):
|
||||
def __init__(
|
||||
self,
|
||||
base_url=None,
|
||||
api_key=None,
|
||||
model=None,
|
||||
thinking=False,
|
||||
temperature=None,
|
||||
**kwargs,
|
||||
):
|
||||
assert model is not None, "model must be provided"
|
||||
self.model = model
|
||||
self.thinking = thinking
|
||||
self.api_key = api_key
|
||||
self.llm_client = None
|
||||
self.temperature = temperature
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
||||
)
|
||||
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
|
||||
api_key = self.api_key or os.getenv("ANTHROPIC_API_KEY")
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"An API Key needs to be provided in either the api_key parameter or as an environment variable named ANTHROPIC_API_KEY"
|
||||
)
|
||||
self.llm_client = Anthropic(api_key=api_key)
|
||||
# Use the instance temperature if not specified in the call
|
||||
temp = self.temperature if temperature is None else temperature
|
||||
if self.thinking:
|
||||
full_response = self.llm_client.messages.create(
|
||||
system=messages[0]["content"][0]["text"],
|
||||
model=self.model,
|
||||
messages=messages[1:],
|
||||
max_tokens=8192,
|
||||
thinking={"type": "enabled", "budget_tokens": 4096},
|
||||
**kwargs,
|
||||
)
|
||||
thoughts = full_response.content[0].thinking
|
||||
return full_response.content[1].text
|
||||
return (
|
||||
self.llm_client.messages.create(
|
||||
system=messages[0]["content"][0]["text"],
|
||||
model=self.model,
|
||||
messages=messages[1:],
|
||||
max_tokens=max_new_tokens if max_new_tokens else 4096,
|
||||
temperature=temp,
|
||||
**kwargs,
|
||||
)
|
||||
.content[0]
|
||||
.text
|
||||
)
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
||||
)
|
||||
# Compatible with Claude-3.7 Sonnet thinking mode
|
||||
def generate_with_thinking(
|
||||
self, messages, temperature=0.0, max_new_tokens=None, **kwargs
|
||||
):
|
||||
"""Generate the next message based on previous messages, and keeps the thinking tokens"""
|
||||
api_key = self.api_key or os.getenv("ANTHROPIC_API_KEY")
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"An API Key needs to be provided in either the api_key parameter or as an environment variable named ANTHROPIC_API_KEY"
|
||||
)
|
||||
self.llm_client = Anthropic(api_key=api_key)
|
||||
full_response = self.llm_client.messages.create(
|
||||
system=messages[0]["content"][0]["text"],
|
||||
model=self.model,
|
||||
messages=messages[1:],
|
||||
max_tokens=8192,
|
||||
thinking={"type": "enabled", "budget_tokens": 4096},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
thoughts = full_response.content[0].thinking
|
||||
answer = full_response.content[1].text
|
||||
full_response = (
|
||||
f"<thoughts>\n{thoughts}\n</thoughts>\n\n<answer>\n{answer}\n</answer>\n"
|
||||
)
|
||||
return full_response
|
||||
|
||||
|
||||
class LMMEngineGemini(LMMEngine):
|
||||
def __init__(
|
||||
self,
|
||||
base_url=None,
|
||||
api_key=None,
|
||||
model=None,
|
||||
rate_limit=-1,
|
||||
temperature=None,
|
||||
**kwargs,
|
||||
):
|
||||
assert model is not None, "model must be provided"
|
||||
self.model = model
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key
|
||||
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
|
||||
self.llm_client = None
|
||||
self.temperature = temperature
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
||||
)
|
||||
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
|
||||
api_key = self.api_key or os.getenv("GEMINI_API_KEY")
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"An API Key needs to be provided in either the api_key parameter or as an environment variable named GEMINI_API_KEY"
|
||||
)
|
||||
base_url = self.base_url or os.getenv("GEMINI_ENDPOINT_URL")
|
||||
if base_url is None:
|
||||
raise ValueError(
|
||||
"An endpoint URL needs to be provided in either the endpoint_url parameter or as an environment variable named GEMINI_ENDPOINT_URL"
|
||||
)
|
||||
if not self.llm_client:
|
||||
self.llm_client = OpenAI(base_url=base_url, api_key=api_key)
|
||||
# Use the temperature passed to generate, otherwise use the instance's temperature, otherwise default to 0.0
|
||||
temp = self.temperature if temperature is None else temperature
|
||||
return (
|
||||
self.llm_client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
max_tokens=max_new_tokens if max_new_tokens else 4096,
|
||||
temperature=temp,
|
||||
**kwargs,
|
||||
)
|
||||
.choices[0]
|
||||
.message.content
|
||||
)
|
||||
|
||||
|
||||
class LMMEngineOpenRouter(LMMEngine):
|
||||
def __init__(
|
||||
self,
|
||||
base_url=None,
|
||||
api_key=None,
|
||||
model=None,
|
||||
rate_limit=-1,
|
||||
temperature=None,
|
||||
**kwargs,
|
||||
):
|
||||
assert model is not None, "model must be provided"
|
||||
self.model = model
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key
|
||||
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
|
||||
self.llm_client = None
|
||||
self.temperature = temperature
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
||||
)
|
||||
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
|
||||
api_key = self.api_key or os.getenv("OPENROUTER_API_KEY")
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"An API Key needs to be provided in either the api_key parameter or as an environment variable named OPENROUTER_API_KEY"
|
||||
)
|
||||
base_url = self.base_url or os.getenv("OPEN_ROUTER_ENDPOINT_URL")
|
||||
if base_url is None:
|
||||
raise ValueError(
|
||||
"An endpoint URL needs to be provided in either the endpoint_url parameter or as an environment variable named OPEN_ROUTER_ENDPOINT_URL"
|
||||
)
|
||||
if not self.llm_client:
|
||||
self.llm_client = OpenAI(base_url=base_url, api_key=api_key)
|
||||
# Use self.temperature if set, otherwise use the temperature argument
|
||||
temp = self.temperature if self.temperature is not None else temperature
|
||||
return (
|
||||
self.llm_client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
max_tokens=max_new_tokens if max_new_tokens else 4096,
|
||||
temperature=temp,
|
||||
**kwargs,
|
||||
)
|
||||
.choices[0]
|
||||
.message.content
|
||||
)
|
||||
|
||||
|
||||
class LMMEngineAzureOpenAI(LMMEngine):
|
||||
def __init__(
|
||||
self,
|
||||
base_url=None,
|
||||
api_key=None,
|
||||
azure_endpoint=None,
|
||||
model=None,
|
||||
api_version=None,
|
||||
rate_limit=-1,
|
||||
temperature=None,
|
||||
**kwargs,
|
||||
):
|
||||
assert model is not None, "model must be provided"
|
||||
self.model = model
|
||||
self.api_version = api_version
|
||||
self.api_key = api_key
|
||||
self.azure_endpoint = azure_endpoint
|
||||
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
|
||||
self.llm_client = None
|
||||
self.cost = 0.0
|
||||
self.temperature = temperature
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
||||
)
|
||||
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
|
||||
api_key = self.api_key or os.getenv("AZURE_OPENAI_API_KEY")
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"An API Key needs to be provided in either the api_key parameter or as an environment variable named AZURE_OPENAI_API_KEY"
|
||||
)
|
||||
api_version = self.api_version or os.getenv("OPENAI_API_VERSION")
|
||||
if api_version is None:
|
||||
raise ValueError(
|
||||
"api_version must be provided either as a parameter or as an environment variable named OPENAI_API_VERSION"
|
||||
)
|
||||
azure_endpoint = self.azure_endpoint or os.getenv("AZURE_OPENAI_ENDPOINT")
|
||||
if azure_endpoint is None:
|
||||
raise ValueError(
|
||||
"An Azure API endpoint needs to be provided in either the azure_endpoint parameter or as an environment variable named AZURE_OPENAI_ENDPOINT"
|
||||
)
|
||||
if not self.llm_client:
|
||||
self.llm_client = AzureOpenAI(
|
||||
azure_endpoint=azure_endpoint,
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
)
|
||||
# Use self.temperature if set, otherwise use the temperature argument
|
||||
temp = self.temperature if self.temperature is not None else temperature
|
||||
completion = self.llm_client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
max_tokens=max_new_tokens if max_new_tokens else 4096,
|
||||
temperature=temp,
|
||||
**kwargs,
|
||||
)
|
||||
total_tokens = completion.usage.total_tokens
|
||||
self.cost += 0.02 * ((total_tokens + 500) / 1000)
|
||||
return completion.choices[0].message.content
|
||||
|
||||
|
||||
class LMMEnginevLLM(LMMEngine):
|
||||
def __init__(
|
||||
self,
|
||||
base_url=None,
|
||||
api_key=None,
|
||||
model=None,
|
||||
rate_limit=-1,
|
||||
temperature=None,
|
||||
**kwargs,
|
||||
):
|
||||
assert model is not None, "model must be provided"
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
|
||||
self.llm_client = None
|
||||
self.temperature = temperature
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
||||
)
|
||||
def generate(
|
||||
self,
|
||||
messages,
|
||||
temperature=0.0,
|
||||
top_p=0.8,
|
||||
repetition_penalty=1.05,
|
||||
max_new_tokens=4096,
|
||||
**kwargs,
|
||||
):
|
||||
api_key = self.api_key or os.getenv("vLLM_API_KEY")
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"A vLLM API key needs to be provided in either the api_key parameter or as an environment variable named vLLM_API_KEY"
|
||||
)
|
||||
base_url = self.base_url or os.getenv("vLLM_ENDPOINT_URL")
|
||||
if base_url is None:
|
||||
raise ValueError(
|
||||
"An endpoint URL needs to be provided in either the endpoint_url parameter or as an environment variable named vLLM_ENDPOINT_URL"
|
||||
)
|
||||
if not self.llm_client:
|
||||
USERNAME = "5ad34100ee055a4bae66370a5e683bac"
|
||||
PASSWORD = "607de8249657a3b3bd036dc96d4c0b2f"
|
||||
auth_string = f"{USERNAME}:{PASSWORD}".encode("utf-8")
|
||||
basic_auth_encoded = base64.b64encode(auth_string).decode("utf-8")
|
||||
basic_auth_header = f"Basic {basic_auth_encoded}"
|
||||
self.llm_client = OpenAI(base_url=base_url, api_key=api_key, default_headers={"Authorization": basic_auth_header},)
|
||||
# Use self.temperature if set, otherwise use the temperature argument
|
||||
temp = self.temperature if self.temperature is not None else temperature
|
||||
completion = self.llm_client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
max_tokens=max_new_tokens if max_new_tokens else 4096,
|
||||
temperature=temp,
|
||||
top_p=top_p,
|
||||
extra_body={"repetition_penalty": repetition_penalty},
|
||||
)
|
||||
|
||||
usage = completion.usage
|
||||
response = completion.choices[0].message.content
|
||||
return (response, usage)
|
||||
|
||||
|
||||
class LMMEngineHuggingFace(LMMEngine):
|
||||
def __init__(self, base_url=None, api_key=None, rate_limit=-1, **kwargs):
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key
|
||||
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
|
||||
self.llm_client = None
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
||||
)
|
||||
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
|
||||
api_key = self.api_key or os.getenv("HF_TOKEN")
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"A HuggingFace token needs to be provided in either the api_key parameter or as an environment variable named HF_TOKEN"
|
||||
)
|
||||
base_url = self.base_url or os.getenv("HF_ENDPOINT_URL")
|
||||
if base_url is None:
|
||||
raise ValueError(
|
||||
"HuggingFace endpoint must be provided as base_url parameter or as an environment variable named HF_ENDPOINT_URL."
|
||||
)
|
||||
if not self.llm_client:
|
||||
self.llm_client = OpenAI(base_url=base_url, api_key=api_key)
|
||||
return (
|
||||
self.llm_client.chat.completions.create(
|
||||
model="tgi",
|
||||
messages=messages,
|
||||
max_tokens=max_new_tokens if max_new_tokens else 4096,
|
||||
temperature=temperature,
|
||||
**kwargs,
|
||||
)
|
||||
.choices[0]
|
||||
.message.content
|
||||
)
|
||||
|
||||
|
||||
class LMMEngineParasail(LMMEngine):
|
||||
def __init__(
|
||||
self, base_url=None, api_key=None, model=None, rate_limit=-1, **kwargs
|
||||
):
|
||||
assert model is not None, "Parasail model id must be provided"
|
||||
self.base_url = base_url
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit
|
||||
self.llm_client = None
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60
|
||||
)
|
||||
def generate(self, messages, temperature=0.0, max_new_tokens=None, **kwargs):
|
||||
api_key = self.api_key or os.getenv("PARASAIL_API_KEY")
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"A Parasail API key needs to be provided in either the api_key parameter or as an environment variable named PARASAIL_API_KEY"
|
||||
)
|
||||
base_url = self.base_url
|
||||
if base_url is None:
|
||||
raise ValueError(
|
||||
"Parasail endpoint must be provided as base_url parameter or as an environment variable named PARASAIL_ENDPOINT_URL"
|
||||
)
|
||||
if not self.llm_client:
|
||||
self.llm_client = OpenAI(
|
||||
base_url=base_url if base_url else "https://api.parasail.io/v1",
|
||||
api_key=api_key,
|
||||
)
|
||||
return (
|
||||
self.llm_client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
max_tokens=max_new_tokens if max_new_tokens else 4096,
|
||||
temperature=temperature,
|
||||
**kwargs,
|
||||
)
|
||||
.choices[0]
|
||||
.message.content
|
||||
)
|
||||
|
|
@ -0,0 +1,308 @@
|
|||
import base64
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mm_agents.os_symphony.core.engine import (
|
||||
LMMEngineAnthropic,
|
||||
LMMEngineAzureOpenAI,
|
||||
LMMEngineHuggingFace,
|
||||
LMMEngineOpenAI,
|
||||
LMMEngineOpenRouter,
|
||||
LMMEngineParasail,
|
||||
LMMEnginevLLM,
|
||||
LMMEngineGemini,
|
||||
)
|
||||
|
||||
|
||||
class LMMAgent:
|
||||
def __init__(self, engine_params: dict, system_prompt=None, engine=None):
|
||||
if engine is None:
|
||||
if engine_params is not None:
|
||||
engine_type = engine_params.get("engine_type")
|
||||
if engine_type == "openai":
|
||||
self.engine = LMMEngineOpenAI(**engine_params)
|
||||
elif engine_type == "anthropic":
|
||||
self.engine = LMMEngineAnthropic(**engine_params)
|
||||
elif engine_type == "azure":
|
||||
self.engine = LMMEngineAzureOpenAI(**engine_params)
|
||||
elif engine_type == "vllm":
|
||||
self.engine = LMMEnginevLLM(**engine_params)
|
||||
elif engine_type == "huggingface":
|
||||
self.engine = LMMEngineHuggingFace(**engine_params)
|
||||
elif engine_type == "gemini":
|
||||
self.engine = LMMEngineGemini(**engine_params)
|
||||
elif engine_type == "open_router":
|
||||
self.engine = LMMEngineOpenRouter(**engine_params)
|
||||
elif engine_type == "parasail":
|
||||
self.engine = LMMEngineParasail(**engine_params)
|
||||
else:
|
||||
raise ValueError(f"engine_type '{engine_type}' is not supported")
|
||||
else:
|
||||
raise ValueError("engine_params must be provided")
|
||||
else:
|
||||
self.engine = engine
|
||||
|
||||
self.messages = [] # Empty messages
|
||||
self.agent_name = engine_params.get("agent_name")
|
||||
if system_prompt:
|
||||
self.add_system_prompt(system_prompt)
|
||||
else:
|
||||
self.add_system_prompt("You are a helpful assistant.")
|
||||
|
||||
def encode_image(self, image_content):
|
||||
# if image_content is a path to an image file, check type of the image_content to verify
|
||||
if isinstance(image_content, str):
|
||||
with open(image_content, "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||
else:
|
||||
return base64.b64encode(image_content).decode("utf-8")
|
||||
|
||||
def reset(
|
||||
self,
|
||||
):
|
||||
|
||||
self.messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": self.system_prompt}],
|
||||
}
|
||||
]
|
||||
|
||||
def add_system_prompt(self, system_prompt):
|
||||
self.system_prompt = system_prompt
|
||||
if len(self.messages) > 0:
|
||||
self.messages[0] = {
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": self.system_prompt}],
|
||||
}
|
||||
else:
|
||||
self.messages.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": self.system_prompt}],
|
||||
}
|
||||
)
|
||||
|
||||
def remove_message_at(self, index):
|
||||
"""Remove a message at a given index"""
|
||||
if index < len(self.messages):
|
||||
self.messages.pop(index)
|
||||
|
||||
def replace_message_at(
|
||||
self, index, text_content, image_content=None, image_detail="high"
|
||||
):
|
||||
"""Replace a message at a given index"""
|
||||
if index < len(self.messages):
|
||||
self.messages[index] = {
|
||||
"role": self.messages[index]["role"],
|
||||
"content": [{"type": "text", "text": text_content}],
|
||||
}
|
||||
if image_content:
|
||||
base64_image = self.encode_image(image_content)
|
||||
self.messages[index]["content"].append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{base64_image}",
|
||||
"detail": image_detail,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
def add_message(
|
||||
self,
|
||||
text_content,
|
||||
image_content=None,
|
||||
role=None,
|
||||
image_detail="high",
|
||||
put_text_last=True,
|
||||
):
|
||||
"""Add a new message to the list of messages"""
|
||||
|
||||
# API-style inference from OpenAI and AzureOpenAI
|
||||
if isinstance(
|
||||
self.engine,
|
||||
(
|
||||
LMMEngineOpenAI,
|
||||
LMMEngineAzureOpenAI,
|
||||
LMMEngineHuggingFace,
|
||||
LMMEngineGemini,
|
||||
LMMEngineOpenRouter,
|
||||
LMMEngineParasail,
|
||||
),
|
||||
):
|
||||
# infer role from previous message
|
||||
if role != "user":
|
||||
if self.messages[-1]["role"] == "system":
|
||||
role = "user"
|
||||
elif self.messages[-1]["role"] == "user":
|
||||
role = "assistant"
|
||||
elif self.messages[-1]["role"] == "assistant":
|
||||
role = "user"
|
||||
|
||||
message = {
|
||||
"role": role,
|
||||
"content": [{"type": "text", "text": text_content}],
|
||||
}
|
||||
|
||||
if isinstance(image_content, np.ndarray) or image_content:
|
||||
# Check if image_content is a list or a single image
|
||||
if isinstance(image_content, list):
|
||||
# If image_content is a list of images, loop through each image
|
||||
for image in image_content:
|
||||
base64_image = self.encode_image(image)
|
||||
message["content"].append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{base64_image}",
|
||||
"detail": image_detail,
|
||||
},
|
||||
}
|
||||
)
|
||||
else:
|
||||
# If image_content is a single image, handle it directly
|
||||
base64_image = self.encode_image(image_content)
|
||||
message["content"].append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{base64_image}",
|
||||
"detail": image_detail,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Rotate text to be the last message if desired
|
||||
if put_text_last:
|
||||
text_content = message["content"].pop(0)
|
||||
message["content"].append(text_content)
|
||||
|
||||
self.messages.append(message)
|
||||
|
||||
# For API-style inference from Anthropic
|
||||
elif isinstance(self.engine, LMMEngineAnthropic):
|
||||
# infer role from previous message
|
||||
if role != "user":
|
||||
if self.messages[-1]["role"] == "system":
|
||||
role = "user"
|
||||
elif self.messages[-1]["role"] == "user":
|
||||
role = "assistant"
|
||||
elif self.messages[-1]["role"] == "assistant":
|
||||
role = "user"
|
||||
|
||||
message = {
|
||||
"role": role,
|
||||
"content": [{"type": "text", "text": text_content}],
|
||||
}
|
||||
|
||||
if image_content:
|
||||
# Check if image_content is a list or a single image
|
||||
if isinstance(image_content, list):
|
||||
# If image_content is a list of images, loop through each image
|
||||
for image in image_content:
|
||||
base64_image = self.encode_image(image)
|
||||
message["content"].append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": base64_image,
|
||||
},
|
||||
}
|
||||
)
|
||||
else:
|
||||
# If image_content is a single image, handle it directly
|
||||
base64_image = self.encode_image(image_content)
|
||||
message["content"].append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": base64_image,
|
||||
},
|
||||
}
|
||||
)
|
||||
self.messages.append(message)
|
||||
|
||||
# Locally hosted vLLM model inference
|
||||
elif isinstance(self.engine, LMMEnginevLLM):
|
||||
# infer role from previous message
|
||||
if role != "user":
|
||||
if self.messages[-1]["role"] == "system":
|
||||
role = "user"
|
||||
elif self.messages[-1]["role"] == "user":
|
||||
role = "assistant"
|
||||
elif self.messages[-1]["role"] == "assistant":
|
||||
role = "user"
|
||||
|
||||
message = {
|
||||
"role": role,
|
||||
"content": [{"type": "text", "text": text_content}],
|
||||
}
|
||||
|
||||
if image_content:
|
||||
# Check if image_content is a list or a single image
|
||||
if isinstance(image_content, list):
|
||||
# If image_content is a list of images, loop through each image
|
||||
for image in image_content:
|
||||
base64_image = self.encode_image(image)
|
||||
message["content"].append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image;base64,{base64_image}"
|
||||
},
|
||||
}
|
||||
)
|
||||
else:
|
||||
# If image_content is a single image, handle it directly
|
||||
base64_image = self.encode_image(image_content)
|
||||
message["content"].append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image;base64,{base64_image}"},
|
||||
}
|
||||
)
|
||||
|
||||
if put_text_last:
|
||||
text_content = message["content"].pop(0)
|
||||
message["content"].append(text_content)
|
||||
self.messages.append(message)
|
||||
else:
|
||||
raise ValueError("engine_type is not supported")
|
||||
|
||||
def get_response(
|
||||
self,
|
||||
user_message=None,
|
||||
messages=None,
|
||||
temperature=0.0,
|
||||
max_new_tokens=32168,
|
||||
use_thinking=False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Generate the next response based on previous messages"""
|
||||
if messages is None:
|
||||
messages = self.messages
|
||||
if user_message:
|
||||
messages.append(
|
||||
{"role": "user", "content": [{"type": "text", "text": user_message}]}
|
||||
)
|
||||
|
||||
# Regular generation
|
||||
# if use_thinking:
|
||||
# return self.engine.generate_with_thinking(
|
||||
# messages,
|
||||
# temperature=temperature,
|
||||
# max_new_tokens=max_new_tokens,
|
||||
# **kwargs,
|
||||
# )
|
||||
|
||||
return self.engine.generate(
|
||||
messages,
|
||||
temperature=temperature,
|
||||
max_new_tokens=max_new_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
from typing import Dict, Optional
|
||||
from mm_agents.os_symphony.core.mllm import LMMAgent
|
||||
|
||||
|
||||
class BaseModule:
|
||||
def __init__(self, engine_params: Dict = None, platform: str = "Linux"):
|
||||
self.engine_params = engine_params
|
||||
self.platform = platform
|
||||
|
||||
def _create_agent(
|
||||
self, system_prompt: str = None, engine_params: Optional[Dict] = None
|
||||
) -> LMMAgent:
|
||||
"""Create a new LMMAgent instance"""
|
||||
agent = LMMAgent(engine_params or self.engine_params)
|
||||
if system_prompt:
|
||||
agent.add_system_prompt(system_prompt)
|
||||
return agent
|
||||
|
|
@ -0,0 +1,995 @@
|
|||
import inspect
|
||||
import textwrap
|
||||
import yaml
|
||||
|
||||
class PROCEDURAL_MEMORY:
|
||||
|
||||
FORMATTING_FEEDBACK_PROMPT = textwrap.dedent(
|
||||
"""
|
||||
Your previous response was not formatted correctly. You must respond again to replace your previous response. Do not make reference to this message while fixing the response. Please address the following issues below to improve the previous response:
|
||||
FORMATTING_FEEDBACK
|
||||
"""
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def construct_eager_mode_procedural_memory(
|
||||
agent_class
|
||||
):
|
||||
|
||||
procedural_memory = textwrap.dedent(
|
||||
f"""
|
||||
You are an expert in graphical user interfaces. Your budget for this task is now EXHAUSTED.
|
||||
This is your FINAL opportunity to act. You must make a definitive judgment.
|
||||
|
||||
You are responsible for executing the task: `TASK_DESCRIPTION`.
|
||||
You are working in CURRENT_OS.
|
||||
|
||||
|
||||
# GUIDELINES
|
||||
|
||||
## Final Judgment Mode
|
||||
1. **Analyze the final state**: Carefully examine the current screenshot and your action history.
|
||||
2. **Make a decision**: Determine if the task has been successfully and fully completed.
|
||||
3. **Choose one of two actions**: You can ONLY use `agent.done()` or `agent.fail()`. No other actions are permitted.
|
||||
|
||||
### END OF GUIDELINES
|
||||
|
||||
You are provided with:
|
||||
1. The final screenshot of the UI.
|
||||
2. The complete history of your previous interactions.
|
||||
3. Access to ONLY the following two methods for your final decision:
|
||||
class Agent:
|
||||
"""
|
||||
)
|
||||
|
||||
eager_tools = ["done", "fail"]
|
||||
for tool_name in eager_tools:
|
||||
attr = getattr(agent_class, tool_name, None)
|
||||
|
||||
if not (attr and callable(attr) and hasattr(attr, "is_agent_action")):
|
||||
raise AttributeError(f"Eager mode requires the method '{tool_name}' to be defined in '{agent_class.__name__}' and decorated with @agent_action.")
|
||||
|
||||
signature = inspect.signature(attr)
|
||||
procedural_memory += textwrap.dedent(f"""
|
||||
def {tool_name}{signature}:
|
||||
'''{attr.__doc__}'''
|
||||
""")
|
||||
|
||||
|
||||
procedural_memory += textwrap.dedent(
|
||||
"""
|
||||
Your response must be formatted like this:
|
||||
|
||||
(Final State Analysis)
|
||||
Closely examine the screenshot and your history. Describe whether the final state of the UI confirms that the task `TASK_DESCRIPTION` is complete. Provide your reasoning.
|
||||
|
||||
(Final Judgment)
|
||||
State your final decision in natural language. For example: "The task is complete because the file has been saved and closed." or "The task has failed because the required text is not present."
|
||||
|
||||
(Grounded Action)
|
||||
Translate your final judgment into ONE of the two available commands.
|
||||
|
||||
**CRITICAL**: You MUST choose one of the following two actions. No other actions are allowed.
|
||||
- If the task is fully completed, use `agent.done()`.
|
||||
- If the task is not completed or has failed, use `agent.fail()`.
|
||||
|
||||
Example for success:
|
||||
```python
|
||||
agent.done()
|
||||
```
|
||||
|
||||
Example for failure:
|
||||
```python
|
||||
agent.fail()
|
||||
```
|
||||
"""
|
||||
)
|
||||
|
||||
return procedural_memory.strip()
|
||||
|
||||
@staticmethod
|
||||
def construct_simple_worker_procedural_memory(
|
||||
agent_class,
|
||||
skipped_actions,
|
||||
tool_config,
|
||||
platform = "linux"
|
||||
):
|
||||
|
||||
procedural_memory = textwrap.dedent(
|
||||
f"""\
|
||||
You are an expert in graphical user interfaces, web search and Python code. You are responsible for executing the task using the provided actions.
|
||||
The TASK DESCRIPTION: `TASK_DESCRIPTION`.
|
||||
The OS you are working in: CURRENT_OS.
|
||||
# 1. **AGENT WORKFLOW & TOOLS**
|
||||
You have most three tool agents: GUI, Code and Search. You must choose the correct one for the job. You also have a reflection agent to provide useful feedback at each step, please follow its feedback and adjust your plan.
|
||||
|
||||
---
|
||||
"""
|
||||
)
|
||||
|
||||
# Load tool yaml config
|
||||
try:
|
||||
with open(tool_config, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
except Exception as e:
|
||||
raise Exception(f"Tool config isn't loaded successfully, error: {e}")
|
||||
|
||||
# has_code_agent = "call_code_agent" in config.get("tools", {}).keys()
|
||||
# if has_code_agent:
|
||||
has_search_agent = "call_search_agent" in config.get("tools", {}).keys() and config["tools"]["call_search_agent"].get("enabled", False)
|
||||
has_code_agent = "call_code_agent" in config.get("tools", {}).keys() and config["tools"]["call_code_agent"].get("enabled", False)
|
||||
|
||||
gui_section = textwrap.dedent(
|
||||
f"""
|
||||
## 1.1 GUI Agent
|
||||
* **Use for**: All direct UI interactions (clicking, typing, dragging). Use this for simple file operations, visual checks, and tasks requiring specific application features (e.g., charts, pivot tables, print settings, and **other visual elements**).
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
search_section = textwrap.dedent(
|
||||
f"""
|
||||
## 1.2 Search Agent
|
||||
You have access to a search agent that can browse the web to find tutorials.
|
||||
* **Use for**: Use the Search Agent **when you are unsure how to perform a GUI-based task**. If you don't know the steps to create a chart, configure a specific setting, or use an unfamiliar feature, use the search agent first.
|
||||
* **Usage Strategy**:
|
||||
* **CRITICAL**: Call the search agent with a clear, concise "how-to" query. For example: `agent.call_search_agent("How to create a pivot table in LibreOffice Calc?")`.
|
||||
* **CRITICAL**: Before searching, evaluate if a tutorial is likely to exist. Well-documented software features always have tutorials. In contrast, tasks with a specific website's unique design (e.g., booking a flight, purchasing an item) typically do not have formal, universal tutorials.
|
||||
* **Result Interpretation**:
|
||||
* **DONE**: The Search Agent finds a step-by-step and **complete** tutorial, often starting from the very beginning. This means the returned guide may contain steps you have already completed. It is **your responsibility** to analyze the tutorial in conjunction with your current screen context to determine the correct step to begin with. **Do not blindly follow the tutorial from step 1.**
|
||||
* **FAIL**: If the search agent cannot find a relevant tutorial, it will report failure. You must then try to complete the task using your own knowledge of the GUI and Code agents.
|
||||
* **Search Agent Verification**: If the result is DONE, it is highly recommended to follow the tutorial with **GUI operations** in the next several steps to verify the tutorial's validation.
|
||||
|
||||
"""
|
||||
) if has_search_agent else ""
|
||||
|
||||
code_section = textwrap.dedent(
|
||||
f"""
|
||||
## 1.3 Code Agent
|
||||
You have access to a code agent that can execute python/bash code in the task environment.
|
||||
* **Use for**: Complex, non-UI tasks. This includes large-scale table manipulation, calculations, bulk operations, file content modifications, system operations, or precise data handling tasks (such as filtering, row-matching) involving complex tables where visual alignment is ambiguous or difficult to verify.
|
||||
* **Usage Strategy**:
|
||||
* **Subtask**: Use `agent.call_code_agent("specific subtask")` for focused data tasks. Please refer to the args explaination of function `call_code_agent`.
|
||||
* **When To Use**:
|
||||
* **Spreadsheet Automation (Strongly Recommended)**: For LibreOffice Calc or Excel tasks, specifically when filling entire rows/columns, performing batch data entry, or running calculations.
|
||||
* **Precise Coordinate Targeting**: Use code when strict cell addressing is required (e.g., writing specifically to cell D2). The GUI agent often struggles to visually distinguish between adjacent cells or columns in dense grids. Code actions ensure 100% address accuracy.
|
||||
* **When NOT to Use**: NEVER use the code agent for charts, graphs, **pivot tables**, or visual elements. Always use the GUI for those.
|
||||
|
||||
* **Code Agent Verification (MANDATORY)**
|
||||
* The code agent works in the background. You CANNOT trust its output report alone. Your job is to verify its work via the GUI.
|
||||
* **Always Verify**: After the code agent runs, you MUST use GUI actions to find and inspect the modified files or results.
|
||||
* **MANDATORY RESTART**: Files modified by the code agent will not show changes in already-open applications. You **MUST close and reopen the entire application** to verify changes. Reloading the file or page is NOT sufficient.
|
||||
* **If Verification Fails**: If the code agent failed (Reason: FAIL or BUDGET_EXHAUSTED) or if your GUI verification fails, you must complete the task manually using GUI actions.
|
||||
* **Infeasible Tasks**: Sometimes the code agent will report the task is impossible to solve. Under this case, if you have verified it's correct, just call `agent.fail()`!
|
||||
|
||||
"""
|
||||
) if has_code_agent else ""
|
||||
|
||||
reflection_section = textwrap.dedent(
|
||||
f"""
|
||||
## 1.4 Reflection Agent (Handling Feedback)
|
||||
* **Use for**: The `Reflection` input is your primary source for error correction and guidance. You **MUST** read it first at every step and adjust your plan accordingly.
|
||||
* **Usage Strategy**:
|
||||
* **If `Off-Track` (GUI Operation Error)**: The reflection indicates your last action failed (e.g., a bad click or type). Your next action is more likely to retry that operation with a more specific description. (e.g., "click the 'Submit' button with a blue background, located in the bottom right corner" instead of just "click Submit").
|
||||
* **If `Off-Track` (Lack of Tutorial)**: The reflection indicates you are stuck, looping, or don't know the steps. You are missing information. You'd better call the search agent.
|
||||
* **If `Off-Track` (Code Error)**: It indicates the code agent fails to finish the task, so you need to recover from potential errors or side effects caused by the failed code execution and continue doing the task by GUI operations.
|
||||
* **If `Off-Track` (Other Error)**: Carefully read the reflection's explanation and form a new plan to fix the deviation.
|
||||
* **If `On-Track`**: Continue with your original plan.
|
||||
* **If `Task Completed` / `Task Infeasible`**: Maybe you need to call `agent.done()` or `agent.fail()`.
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
first_section = gui_section + search_section + code_section + reflection_section
|
||||
procedural_memory += first_section
|
||||
|
||||
if platform == "linux":
|
||||
procedural_memory += textwrap.dedent(
|
||||
f"""\
|
||||
---
|
||||
# 2. ACTION RULES
|
||||
## 2.1 Core Execution Constraints
|
||||
- **Use One Provided Action at a Time**: Execute only one grounded action per turn. Only use the methods provided in the Agent class. Do not invent new methods.
|
||||
- **No Interaction with User**: You MUST complete the task individually. There is **NO** additional input from someone else.
|
||||
- **Password**: Your sudo password is "CLIENT_PASSWORD".
|
||||
- **User**: Your username is "user".
|
||||
- **Home**: Your home path is "/home/user".
|
||||
|
||||
## 2.2 Interaction & Input Guidelines
|
||||
- **Guideline for Clicks**:
|
||||
- **VISIBILITY CHECK (CRITICAL)**: You must strictly ONLY click on elements that are **clearly visible** in the current screenshot. Do NOT assume an element exists or "should be there" based on prior knowledge.
|
||||
- The `element_description` for `agent.click()` must be unambiguous. If similar elements exist, be specific to avoid confusion. Describe the target using its appearance, position, and your purpose.
|
||||
- **Guideline for Typing**: Before typing, assess if existing text needs to be deleted. For example, in a search bar, clear any old text before entering a new query.
|
||||
- **Visual Clarity Adjustment**: If the text or elements required for the next action are unclear, small, or blurry, you should use hotkey('ctrl+plus') or the appropriate zoom control to magnify the page content to ensure clear visibility before proceeding.
|
||||
|
||||
## 2.3 Efficiency & Tool Usage
|
||||
- **Efficiency is Key**:
|
||||
- Prefer `agent.hotkey()` over mouse clicks for shortcuts.
|
||||
- Prefer the software(libreoffice, etc.)'s built-in FEATURES over executing a series of complex steps.
|
||||
- **Code Usage**: For tasks that are clearly achievable via GUI software, you can take a shortcut and use Code Agent (e.g., using FFMPEG to convert video to GIF, or filling multiple rows in a table); however, for tasks that cannot be accomplished via GUI, do NOT use Code to forcibly complete the task.
|
||||
- You MUST use Code agent when filling table (LibreOffice Calc), instead of manual click-and-type in spreadsheets.
|
||||
- You MUST use Code agent when modifying VS Code settings JSON files or code files such as Python, to maximize the avoidance of syntax errors!
|
||||
"""
|
||||
)
|
||||
|
||||
elif platform == "windows":
|
||||
procedural_memory += textwrap.dedent(
|
||||
f"""\
|
||||
---
|
||||
# 2. ACTION RULES
|
||||
## 2.1 Core Execution Constraints
|
||||
- **Use One Provided Action at a Time**: Execute only one grounded action per turn. Only use the methods provided in the Agent class. Do not invent new methods.
|
||||
- **No Interaction with User**: You MUST complete the task individually. There is **NO** additional input from someone else.
|
||||
- **User**: Your username is "Docker".
|
||||
- **Home**: Your home path is "C:\\Users\\Docker"
|
||||
|
||||
## 2.2 Interaction & Input Guidelines
|
||||
- **Guideline for Clicks**:
|
||||
- **VISIBILITY CHECK (CRITICAL)**: You must strictly ONLY click on elements that are **clearly visible** in the current screenshot. Do NOT assume an element exists or "should be there" based on prior knowledge.
|
||||
- The `element_description` for `agent.click()` must be unambiguous. If similar elements exist, be specific to avoid confusion. Describe the target using its appearance, position, and your purpose.
|
||||
- **Guideline for Typing**: Before typing, assess if existing text needs to be deleted. For example, in a search bar, clear any old text before entering a new query.
|
||||
- **Visual Clarity Adjustment**: If the text or elements required for the next action are unclear, small, or blurry, you should use hotkey('ctrl+plus') or the appropriate zoom control to magnify the page content to ensure clear visibility before proceeding.
|
||||
|
||||
## 2.3 Efficiency & Tool Usage
|
||||
- **Efficiency is Key**:
|
||||
- Prefer `agent.hotkey()` over mouse clicks for shortcuts.
|
||||
- Prefer the software(libreoffice, etc.)'s built-in FEATURES over executing a series of complex steps.
|
||||
- **Code Usage**: For tasks that are clearly achievable via GUI software, you can take a shortcut and use Code Agent (e.g., using FFMPEG to convert video to GIF, or filling multiple rows in a table); however, for tasks that cannot be accomplished via GUI, do NOT use Code to forcibly complete the task.
|
||||
- You MUST use Code agent when filling table (LibreOffice Calc), instead of manual click-and-type in spreadsheets.
|
||||
- You MUST use Code agent when modifying VS Code settings JSON files or code files such as Python, to maximize the avoidance of syntax errors!
|
||||
"""
|
||||
)
|
||||
elif platform == "macos":
|
||||
procedural_memory += textwrap.dedent(
|
||||
f"""\
|
||||
---
|
||||
# 2. ACTION RULES
|
||||
## 2.1 Core Execution Constraints
|
||||
- **Use One Provided Action at a Time**: Execute only one grounded action per turn. Only use the methods provided in the Agent class. Do not invent new methods.
|
||||
- **No Interaction with User**: You MUST complete the task individually. There is **NO** additional input from someone else.
|
||||
- **User**: Your username is "pipiwu".
|
||||
- **Password**: Your password is "1234".
|
||||
- **Home**: Your home path is "/Users/pipiwu"
|
||||
|
||||
## 2.2 Interaction & Input Guidelines
|
||||
- **Guideline for Clicks**:
|
||||
- **VISIBILITY CHECK (CRITICAL)**: You must strictly ONLY click on elements that are **clearly visible** in the current screenshot. Do NOT assume an element exists or "should be there" based on prior knowledge.
|
||||
- The `element_description` for `agent.click()` must be unambiguous. If similar elements exist, be specific to avoid confusion. Describe the target using its appearance, position, and your purpose.
|
||||
- **Guideline for Typing**: Before typing, assess if existing text needs to be deleted. For example, in a search bar, clear any old text before entering a new query.
|
||||
- **Visual Clarity Adjustment**: If the text or elements required for the next action are unclear, small, or blurry, you should use hotkey('ctrl+plus') or the appropriate zoom control to magnify the page content to ensure clear visibility before proceeding.
|
||||
|
||||
## 2.3 Efficiency & Tool Usage
|
||||
- **Efficiency is Key**:
|
||||
- Prefer `agent.hotkey()` over mouse clicks for shortcuts.
|
||||
- Prefer the software(libreoffice, etc.)'s built-in FEATURES over executing a series of complex steps.
|
||||
- You MUST use Code agent when filling table (LibreOffice Calc), instead of manual click-and-type in spreadsheets.
|
||||
- **Code Usage**: For tasks that are clearly achievable via GUI software, you can take a shortcut and use Code Agent (e.g., using FFMPEG to convert video to GIF, or filling multiple rows in a table); however, for tasks that cannot be accomplished via GUI, do NOT use Code to forcibly complete the task.
|
||||
"""
|
||||
)
|
||||
else:
|
||||
pass
|
||||
|
||||
procedural_memory += textwrap.dedent(
|
||||
"""
|
||||
- **Search Usage**: When the overall execution logic appears flawed, or if you are unable to accomplish the task after multiple attempts (indicating a lack of specific know-how), or if the Reflection Agent reports a "Lack of Tutorial" error, invoke the Search Agent to retrieve detailed online tutorials for further guidance.
|
||||
"""
|
||||
) if has_search_agent else ""
|
||||
|
||||
procedural_memory += textwrap.dedent(
|
||||
"""
|
||||
|
||||
## 2.4 Task Flow & Verification
|
||||
- **Task Initial State**: The file you need to operate on is usually already open. Please align the screenshot with task description. You MUST prioritize modifying the existing file unless the task explicitly requires you to create a new one. Avoid creating new files unnecessarily.
|
||||
- **Default Sheet Names**: If creating a new sheet and no name is specified, use default names (e.g., "Sheet1", "Sheet2").
|
||||
- **Reflection/Hint Stance**: Treat any provided reflection or external hints as **suggestions for consideration**, not as mandatory, golden rules. Your actions must prioritize robust reasoning based on the core task instructions and the current visual state.
|
||||
- **Infeasible**: Use `agent.fail()` if the task is infeasible (e.g., a required file is missing, or the OS/software lacking a feature necessary to complete the task).
|
||||
- **Completion**: Only use `agent.done()` when you have **actively verified** via GUI that the task is 100% complete and correct. **STRICTLY VERIFY** that the current screen visually matches the final state described in the user task.
|
||||
- **Error Recovery (Application Missteps)**: If a misoperation occurs in file editing software (e.g., LibreOffice), first attempt recovery using **hotkey('ctrl+z')**. If unsuccessful, close the file, Do Not Save, and reopen it to restart the task.
|
||||
- You should proactively save the file after completing file modification tasks and verify that the save was successful.
|
||||
|
||||
---
|
||||
# 3. INPUT & OUTPUT FORMAT
|
||||
You are provided with:
|
||||
1. A screenshot of the current time step.
|
||||
2. The history of your previous interactions with the UI.
|
||||
3. A text reflection generated by a Reflection Agent.
|
||||
4. Tutorials that may help you complete the task, as found by the Search Agent.
|
||||
--- TUTORIALS START ---
|
||||
TUTORIAL_PLACEHOLDER
|
||||
--- TUTORIALS END ---
|
||||
5. Access to the following class and methods to interact with the UI. You MUST select only one action to execute at a time.
|
||||
class Agent:
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
for tool_name, tool_config in config.get('tools', {}).items():
|
||||
# 如果工具被显式禁用,则跳过
|
||||
if tool_config and tool_config.get('enabled') is False:
|
||||
continue
|
||||
if tool_name in skipped_actions:
|
||||
continue
|
||||
attr = getattr(agent_class, tool_name, None)
|
||||
|
||||
if callable(attr) and hasattr(attr, "is_agent_action"):
|
||||
# Use inspect to get the full function signature
|
||||
signature = inspect.signature(attr)
|
||||
procedural_memory += textwrap.dedent(f"""
|
||||
def {tool_name}{signature}:
|
||||
'''{attr.__doc__}'''
|
||||
""")
|
||||
|
||||
procedural_memory += textwrap.dedent(
|
||||
"""
|
||||
**Your response should be formatted like this**:
|
||||
(Previous action verification)
|
||||
Carefully analyze based on the screenshot if the previous action was successful. If the previous action was not successful, provide a reason for the failure.
|
||||
|
||||
(Screenshot Analysis)
|
||||
Closely examine and describe the current state of the desktop along with the currently open applications.
|
||||
|
||||
(Next Action)
|
||||
Based on the current screenshot and the history of your previous interaction with the UI, decide on the next action in natural language to accomplish the given task.
|
||||
|
||||
(Grounded Action)
|
||||
Translate the next action into code using the provided API methods. Format the code like this:
|
||||
```python
|
||||
agent.click("The menu button at the top right of the window", 1, "left")
|
||||
```
|
||||
"""
|
||||
)
|
||||
|
||||
return procedural_memory.strip()
|
||||
|
||||
REWRITE_GUI_INSTRUCTION = textwrap.dedent(
|
||||
"""
|
||||
You are an expert instruction refiner. Your task is to transform verbose, conversational user requests for GUI tasks into clear, direct, and unambiguous **high-level commands** that capture the user's ultimate goal.
|
||||
|
||||
You will be given both the user's text request and a screenshot of the application's initial state. Your primary goal is to synthesize information from both sources to produce a command that states the final objective with as much specificity and context as possible.
|
||||
|
||||
### The Core Distinction: Goal vs. Procedure
|
||||
|
||||
This is the most important rule. The rewritten command must describe the **WHAT** (the user's final objective). It must **NOT** describe the **HOW** (the specific sequence of clicks, menu openings, or keyboard shortcuts to achieve that objective).
|
||||
|
||||
* **User's Goal:** "I want to change the font for all text boxes to 'Liberation Sans Narrow'."
|
||||
* **Correct (Goal-Oriented) Command:** "For the presentation `note-taking-strategies.pptx` in LibreOffice Impress, change the font for all text boxes to 'Liberation Sans Narrow'."
|
||||
* **Incorrect (Procedural) Command:** "Open the Master Slide view, go to Styles, right-click 'Default', select 'Modify', go to the Font tab, choose 'Liberation Sans Narrow', and click OK."
|
||||
|
||||
Your output should always be the **Correct (Goal-Oriented) Command**.
|
||||
|
||||
### Core Principles:
|
||||
|
||||
1. **Focus on the Objective:** The final command must be a statement of the end goal. Eliminate all procedural steps.
|
||||
|
||||
2. **Eliminate Conversational Filler:** Remove all polite expressions, greetings, questions, and personal anecdotes (e.g., "Please," "Could you," "I need to," "Thank you").
|
||||
|
||||
3. **Enrich with Visual Context:** Analyze the screenshot to add critical context to the goal, making it specific and unambiguous.
|
||||
* **Identify the Operating Context:** State the application name (`LibreOffice Impress`), file name (`document.docx`), or website (`github.com`) visible in the screenshot.
|
||||
* **Specify the Target:** If the user says "delete it" and the screenshot shows a file named `report_v2.pdf` is selected, the command should be "Delete the selected file, `report_v2.pdf`."
|
||||
* **Clarify Ambiguous Parameters:** Use the screenshot to translate vague user intent into specific parameters available in the UI. If the user says "make it cheap" and the UI has a "Sort by: Price - Low to High" option, the command is "Sort the results by 'Price: Low to High'."
|
||||
|
||||
4. **Preserve All Essential Details:** Extract and retain every specific detail related to the *goal* itself from the user's text (e.g., file names like `export.jpg`, values like `512 pixels`, font names like `'Liberation Sans Narrow'`).
|
||||
|
||||
5. **Use Imperative (Command) Language:** Start the command with a direct action verb that describes the overall goal (e.g., "Change," "Sort," "Search," "Export").
|
||||
|
||||
6. **Do Not Invent Unjustified Information:** Do not add details or parameters that cannot be inferred from either the user's text or the screenshot.
|
||||
|
||||
### Examples
|
||||
|
||||
**Example 1:**
|
||||
* **Original Request:** "On next Monday, look up a flight from Mumbai to Stockholm."
|
||||
* **Provided Context:** A screenshot of an airline website showing "Round-trip" selected by default.
|
||||
* **Rewritten Command:** "Search for a one-way flight from Mumbai to Stockholm for next Monday."
|
||||
* **Reasoning:** The user's request implies a "one-way" trip. The rewritten command states this as a parameter of the search goal, rather than instructing the AI to "click the one-way button."
|
||||
|
||||
**Example 2:**
|
||||
* **Original Request:** "Help me update my profile."
|
||||
* **Provided Context:** A screenshot of a user's profile page on `github.com`.
|
||||
* **Rewritten Command:** "On `github.com`, update the user profile."
|
||||
* **Reasoning:** The command states the high-level goal and adds the application context from the screenshot. It does not say "Click the 'Edit Profile' button."
|
||||
|
||||
**Example 3:**
|
||||
* **Original Request:** "Find me some cheap headphones."
|
||||
* **Provided Context:** A screenshot of an e-commerce site's search results page with a "Sort by" dropdown.
|
||||
* **Rewritten Command:** "Sort the search results by 'Price: Low to High'."
|
||||
* **Reasoning:** The user's vague intent ("cheap") is translated into a specific, high-level command using the explicit option visible in the UI.
|
||||
|
||||
Now, apply these principles to the user requests and screenshots I provide. Your output should **only** be the final, goal-oriented command.
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
##### Reflection Memory Agent Part!!!!!
|
||||
REFLECTION_SYSTEM_PROMPT = textwrap.dedent(
|
||||
"""
|
||||
You are an expert "Memory & Reflection Agent." Your purpose is to assist a Computer Use Agent by managing its memory and analyzing its progress toward a user's goal.
|
||||
You will perform three tasks:
|
||||
1. **Extract Knowledge**: Identify and save new, useful information.
|
||||
1. **Reflect & Recall**: Provide trajectory feedback and recall saved knowledge when needed.
|
||||
2. **Evaluate Milestone**: Determine if the most recent action was a significant "milestone."
|
||||
|
||||
**Inputs**:
|
||||
- user_instruction (Text): The high-level, ultimate goal the agent is trying to achieve (e.g., "Find the phone number and address for 'The French Laundry' and put it in 'contacts.xlsx'").
|
||||
- history (List of Objects): A sequence of past steps. Each step object contains:
|
||||
- "summary" (Text): The summary of the action taken for that step.
|
||||
- "screenshot" (Image, Optional): The screenshot *after* the action. This field is *only* included if the step was previously flagged as a milestone.
|
||||
- latest_agent_output: (Text) The output from the Computer Use Agent on the last step, containing the Agent's screen analysis, thought process, and action.
|
||||
- IMPORTANT: This action has been DONE!
|
||||
- latest_screenshot (Image): The screenshot AFTER executing the action described in the **latest_agent_output**.
|
||||
- existing_knowledge (Text, Optional): A string containing all previously saved knowledge, which may be empty.
|
||||
- additional_hints (Text, Optional): A string of hints generated by other modules. **Treat these as strong indicators!**.
|
||||
|
||||
---
|
||||
**Task 1: Knowledge Extraction (Saving New Info)**
|
||||
Your first task is to analyze the latest_screenshot in the context of the user_instruction to see if any new, useful knowledge has appeared.
|
||||
- **Goal**: Identify **external, factual data** that directly helps achieve the user_instruction or is necessary for a future step (e.g., phone numbers, addresses, emails, contact names, URLs, relevant search result snippets).
|
||||
- **Crucial Rules**: What NOT to Extract. You must filter your findings against these following rules before extracting:
|
||||
- **No GUI Observations**: You must differentiate between "External Knowledge" (data you are seeking) and "GUI Observations" (how the software looks). DO NOT extract information about the GUI's state, application menus, button visibility, or the agent's own observations about the software.
|
||||
- **No Duplicates**: Check the existing_knowledge input. DO NOT extract any information that is already present. Your goal is to find new information only.
|
||||
- **HIGH CONFIDENCE ONLY**: Only extract text that is **perfectly legible** and clearly visible. **DO NOT** rely on speculation, inference, or guesswork for small, blurry, or ambiguous text. If you lack complete certainty, you must omit the information.
|
||||
- Action: If you find **new**, relevant, **external** knowledge, you will prepare it for the knowledge output field.
|
||||
- Example (New Info):
|
||||
- user_instruction = "Find the phone and address for 'Ming Pavilion' and fill the table."
|
||||
- existing_knowledge = "Ming Pavilion Address: Level 8, Pacific Place, Supreme Court Road, Central"
|
||||
- latest_screenshot shows "Address: Level 8, Pacific Place, Supreme Court Road, Central; Phone: (852) 2820 8580".
|
||||
- Result: You must extract "Ming Pavilion's Phone: (852) 2820 8580" because it is new.
|
||||
- Example (Duplicate Info):
|
||||
- user_instruction = "Find the email of 'Tao Yu'."
|
||||
- existing_knowledge = "Tao Yu's email: tao.yu.nlp@gmail.com"
|
||||
- latest_screenshot shows "Contact me: tao.yu.nlp [AT] gmail.com".
|
||||
- Result: You must extract nothing because it is NOT new.
|
||||
|
||||
---
|
||||
**Task 2: Reflection & Knowledge Recall**
|
||||
Then, you must generate a reflection on the **entire history and current state (last_agent_output and last_screenshot)** in the context of the user_instruction. Your reflection must be one of the four cases below.
|
||||
|
||||
You must check the cases in this order: 1, 2, 3, then 4.
|
||||
- Case 1. **Off-Track**:
|
||||
- You must first classify the error into one of the following types. Your reflection for this case **must** start with the error type, followed by a specific explanation.
|
||||
- **Format**: `The trajectory is not going according to plan. [Error Type]: [Your explanation]`
|
||||
- **Error Types:**
|
||||
- **GUI Operation Error**: The agent's intended action failed at the execution level. It usually occurs when `additional_hints` contain "Warning: The last GUI operation is unsuccessful".
|
||||
- *Examples*: CUA intended to click a non-existent element (hallucination), clicking at the wrong coordinates for a existent element (grounding issue), or a typing error (e.g., trying to input new text without clearing the old content, significant typos).
|
||||
- *Tip*: Do NOT check the action `agent.locate_cursor()`, since it must be correct.
|
||||
- **Lack of Tutorial**: The agent's individual GUI operations (clicks, types) are technically correct, but the overall sequence or logic is flawed. The agent seems not to know *how* to accomplish the task.
|
||||
- *Examples*: The agent is clicking randomly, or appears "stuck" and is stubbornly repeating a fixed set of actions *without* making progress (loop detected).
|
||||
- **Code Error**: This triggers *after* `call_code_agent` has been used and the CUA is now in a "verification" step (e.g., has opened the file that the Code Agent was supposed to modify). The `latest_screenshot` reveals that the Code Agent's work is incorrect, incomplete, or does not match the `user_instruction`.
|
||||
- *Examples*: The Code Agent was supposed to add data to a file, but the `latest_screenshot` (showing the opened file) shows the file is still empty. The Code Agent was supposed to perform a calculation, but the GUI verification shows the wrong result.
|
||||
- **Other Error**: The trajectory is off-track for a reason not covered above. Here are some examples:
|
||||
- CUA is deviating from the goal,
|
||||
- CUA is filling in wrong information that conflicts with knowledge,
|
||||
- Screenshot shows an obvious bug or error (pay attention when editing code or json file)...
|
||||
- **Explanation Details**:
|
||||
- Provide a clear explanation for *why* the agent is off-track, referencing `action_history` or `latest_screenshot`. But DON'T give any advice!
|
||||
- **If Loop Detected**: If you find the agent is repeating actions, you **must** state this clearly in the explanation. (e.g., "...agent appears to be in a non-productive loop by repeating the sequence: [action A, action B, action C].")
|
||||
- **Caveat**: Do not mistake necessary, mechanical repetition (like filling 10 rows in a spreadsheet) for a negative loop. A loop is repetitive action *without progress*.
|
||||
- Case 2. **Task Completed**: **You must have high confidence and sufficient evidence that the high-level `user_instruction` has been successfully and completely fulfilled.** You must verify task completion based on the following:
|
||||
- **Visual Alignment Verification**: **Always verify that the `latest_screenshot` visually and explicitly demonstrates the expected final successful state**. If the action summary suggests the goal is achieved but the **expected screen change is not observed**, the task is **NOT** finished.
|
||||
- **"Outcome over Action" Rule**: You must strictly distinguish between **Action Execution** (e.g., clicking 'Submit', typing text) and **State Change** (e.g., a 'Success' banner appears, page redirects, file's format changes).
|
||||
- **CRITICAL**: The agent clicking a correct button is **NOT** evidence of completion. Buttons can fail, be unresponsive, or trigger errors.
|
||||
- **Requirement**: You must observe the **consequence** of the click in the `latest_screenshot`. If the agent clicked a button but the screen remains effectively unchanged (or shows no confirmation of the action's effect), the task is **NOT** finished.
|
||||
- Case 3. **Task Infeasible**: You are **highly certain** the task cannot be completed. In this case, tell the agent to choose "fail" action. This may be due to:
|
||||
- **Factual Errors**: Such as requesting to install a non-existent software version, or the OS/software lacking a feature necessary to complete the task.
|
||||
- **Missing Prerequisites**: Such as attempting to edit a file that does not exist and cannot be found.
|
||||
- Case 4. **On-Track**: (If Cases 1, 2, and 3 do not apply) The CUA is going according to plan. Now, you must perform a sub-check to see if Knowledge Recall is needed.
|
||||
- **Sub-Check (Knowledge Recall)**: Analyze the latest_screenshot and action_history to determine if the agent is now in a position to use previously saved knowledge (from the knowledge input).
|
||||
- **Triggers for Recall**: The agent has opened the target Excel/spreadsheet, a browser with a search bar, or the action_history clearly shows an intent to "write down" or "fill in" the info.
|
||||
- **Format**: "You are on track. [Summary of past actions]. [ (Optional) Content from existing_knowledge input]"
|
||||
|
||||
Rules for Feedback (Cases 1-4):
|
||||
- **Your output MUST be based on one of the case options above**.
|
||||
- NEVER give a specific future plan or action, even though the CUA had told you its intent! Your job is NOT to give suggestions!
|
||||
- Be very certain for Case 4 (DANGEROUS case).
|
||||
- Do **not** classify a task as `Infeasible` if the failure is due to the agent's own confusion, random actions, or lack of knowledge on how to proceed. That is **`Case 1 (Lack of Tutorial)`**. `Infeasible` means the task is *externally* impossible (e.g., the feature does not exist in the software), not that the agent lacks the necessary knowledge.
|
||||
- Pay attention to the latest summary, especially the **screenshot change** part. It may help you analyze the screen.
|
||||
- When CUA has just used the `call_search_agent` or `call_code_agent`, just simply consider it's on-track.
|
||||
- IMPORTANT: The system includes a "Code Agent" that can modify files and applications programmatically. When you see:
|
||||
- Files with different content than expected.
|
||||
- Applications being closed and reopened.
|
||||
- Documents with fewer lines or modified content.
|
||||
...these are likely LEGITIMATE results of those agents' work, not errors. Do not classify the trajectory as "off-plan" just because of these programmatic changes.
|
||||
|
||||
---
|
||||
**Task 3: Milestone Evaluation**
|
||||
After formulating your reflection, you must determine if the latest step qualifies as a "milestone."
|
||||
1. **What IS a "Milestone"?** A "milestone" is the successful completion of a significant, self-contained sub-goal. It represents a major step forward.
|
||||
- Examples of Milestones:
|
||||
- Successfully landing on a key page.
|
||||
- Successfully completing a multi-step form (e.g., submitting the flight search, adding an item to the cart).
|
||||
- Successfully downloading a required file.
|
||||
- Successfully arriving at the final piece of information requested (e.g., the screen now shows the weather in London).
|
||||
|
||||
2. **What is NOT a "Milestone"?** Most successful actions are not milestones. They are just small, incremental steps towards a milestone.
|
||||
- Examples of NON-Milestones: Typing a single character or word into a text field; clicking to open a dropdown menu; selecting a single, simple option (e.g., clicking a checkbox, selecting a date on a calendar unless it's the final action of a form); scrolling the page.
|
||||
|
||||
---
|
||||
**Output Format**: Please format your response as follows below. On (Answer) part, you must output a valid JSON object wrapped by ```json and ```.
|
||||
(Thought)
|
||||
[
|
||||
Your detailed reasoning.
|
||||
Screenshot Analysis: I will first examine and analyze the whole screen VERY carefully.
|
||||
Knowledge Extraction: Did the latest screenshot reveal new, relevant info (like a phone number, address) based on the user instruction? Is thats info really new? Check the existing knowledge and determine! If so, what is it?
|
||||
Reflection & Recall: I will first understand the history and latest agent's output to know what agent has done. I will then formulate my reflection based on the rules mentioned in **"Task 2" part**. But I should NOT give any advice about next step.
|
||||
Milestone: Was the last action a significant milestone or just a small step?
|
||||
]
|
||||
|
||||
(Answer)
|
||||
```json
|
||||
{
|
||||
"is_milestone": true / false,
|
||||
"reflection": "(Fill in the reflection here)",
|
||||
"knowledge": "(Fill in any newly extracted knowledge from Task 1. If no new knowledge was found in this step, this MUST be an empty string)"
|
||||
}
|
||||
```
|
||||
|
||||
Here's your input:
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
SUMMARIZE_STEP_SYSTEM_PROMPT = textwrap.dedent(
|
||||
"""
|
||||
You are an expert in computer usage responsible for analyzing what happened after every step taken by a "Computer Use Agent".
|
||||
|
||||
**Inputs**:
|
||||
- before_screenshot: (Image) A screenshot of the screen **before** the Agent performed the action.
|
||||
- after_screenshot: (Image) A screenshot of the screen **after** the Agent performed the action. This is your ONLY source for judging the outcome.
|
||||
- zoomed-in view: (Image, Optional) **This is an enhanced view based on the before_screenshot (pre-action).**
|
||||
* **Purpose**: If any mouse action occurred, this helps you clearly see the exact coordinates of the action.
|
||||
* **CRITICAL WARNING**: This image reflects the state **before** the action. **NEVER** mistake it for the result of the action. Ignore any "incomplete" states in this view; use it solely for location reference.
|
||||
- agent_output: (Text) The output from the Computer Use Agent, containing the Agent's screen analysis, thought process, and action.
|
||||
|
||||
**Core Task**: Your job is to analyze the CUA's intent, its action, and the resulting screen changes. Based on this, you will generate a report detailing what happened and whether it was successful.
|
||||
|
||||
**Reasoning Guidelines:**
|
||||
1. **Analyze Intent vs. Outcome**: First, understand the CUA's thought process and `Grounded Action` from the agent_output. Next, Analyze what agent intended to do for this SINGLE-STEP. (Be careful not to confuse the intention of a single step with the overall intention). Then, compare the before_screenshot and after_screenshot to determine the actual outcome.
|
||||
2. **Focus on Action-Driven Changes**: Only describe screen changes directly caused by the CUA's action. Ignore irrelevant changes (e.g., the system clock).
|
||||
3. **Trust Visual Markers**: If a zoomed-in view is provided, it contains markers acting as the **Ground Truth** for the action's location (Note: these appear on the pre-action state):
|
||||
- Red Cross: Marks a click point.
|
||||
- Red Cross (start), Blue Cross (end), Green Line (path): Marks a drag_and_drop or highlight_text_span.
|
||||
4. **Verify Success (Strict Criteria)**: **You must apply strict success criteria to check if there is any GUI operation error.** You must examine the `after_screenshot` very carefully.
|
||||
* **Check Single-Step**: Your duty is just to give a feedback based on the LATEST step of CUA. NOT the whole task or process.
|
||||
* **Substantial Expectation**: **Always verify that the `latest_screenshot` visually. The screen state in the after_screenshot must match the **expected outcome** of the operation, not just the physical feedback of the action.
|
||||
|
||||
**Output Fields**:
|
||||
1. Summary: You need to output a comprehensive summary of the CUA's step. It must include:
|
||||
- CUA's Thought: What did the agent think?
|
||||
- CUA's Action: What action did it perform?
|
||||
- Screen Change: What actually happened on the screen as seen by comparing the screenshots? What didn't change?
|
||||
2. Evaluation: An assessment of whether the step was successful. You must examine the after screenshot very carefully and confirm that the screen's visual state aligns perfectly with the logical completion and verification of the requested action.
|
||||
|
||||
**Additional Tips**:
|
||||
- Your role is to record history, not to guide the future. Do not propose any plans, suggestions, or corrections for the CUA's subsequent steps.
|
||||
- **Ambiguity Handling**: For actions such as `highlight_text_span`, `locate_cursor`, or operations involving "Select All", "Underline", etc., where visual changes are subtle or not obvious: if you cannot make a clear visual judgment, **default to evaluating them as 'successful'!!**.
|
||||
|
||||
**Output Format**: Please format your response as follows below. On (Answer) part, you must output a valid JSON object wrapped by ```json and ```.
|
||||
|
||||
(Thoughts)
|
||||
[Your detailed reasoning. First, state the CUA's thought process and intended action. Second, analyze the screenshots (using the zoomed-in view to confirm the action **location**, and the after_screenshot to confirm the **result**) to identify all visual changes and what remains the same. Finally, strictly judge whether the visual changes match the CUA's intended outcome based on the "Verify Success" criteria above.]
|
||||
|
||||
(Answer)
|
||||
```json
|
||||
{
|
||||
"summary": "A summary of the CUA's step. See the rules above.",
|
||||
"evaluation": "fail / successful"
|
||||
}
|
||||
```
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
PHRASE_TO_WORD_COORDS_PROMPT = textwrap.dedent(
|
||||
"""
|
||||
You are an expert in graphical user interfaces. Your task is to process a phrase of text, and identify the most relevant word on the computer screen.
|
||||
You are provided with a phrase, a table with alxl the text on the screen, and a screenshot of the computer screen. You will identify the single word id that is best associated with the provided phrase.
|
||||
This single word must be displayed on the computer screenshot, and its location on the screen should align with the provided phrase.
|
||||
Each row in the text table provides 2 pieces of data in the following order. 1st is the unique word id. 2nd is the corresponding word.
|
||||
|
||||
To be successful, it is very important to follow all these rules:
|
||||
1. First, think step by step and generate your reasoning about which word id to click on.
|
||||
2. Then, output the unique word id. Remember, the word id is the 1st number in each row of the text table.
|
||||
3. If there are multiple occurrences of the same word, use the surrounding context in the phrase to choose the correct one. Pay very close attention to punctuation and capitalization.
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def construct_coder_procedural_memory(platform: str = "linux", client_password: str = ""):
|
||||
# 1. Define Platform-Specific Context
|
||||
if platform == "linux":
|
||||
PLATFORM_SPECIFIC_CONTEXT = textwrap.dedent(
|
||||
"""\
|
||||
# 2. Environment & Execution
|
||||
* **Platform:** Linux
|
||||
* **User:** "user"
|
||||
* **Home:** "/home/user"
|
||||
* **Shell:** Bash
|
||||
* **Sudo:** Use `echo '{client_password}' | sudo -S [COMMAND]`
|
||||
* **Packages:** Install missing packages as needed.
|
||||
* **Ignored Errors:** Ignore "sudo: /etc/sudoers.d is world writable".
|
||||
* **Note:** Code execution might not be visible on screen immediately. GUI actions (like reopening files) may be needed to see changes.
|
||||
"""
|
||||
)
|
||||
PLATFORM_SPECIFIC_CONTEXT = PLATFORM_SPECIFIC_CONTEXT.format(client_password=client_password)
|
||||
elif platform == "windows":
|
||||
PLATFORM_SPECIFIC_CONTEXT = textwrap.dedent(
|
||||
"""\
|
||||
# 2. Environment & Execution
|
||||
* **Platform:** Windows
|
||||
* **User:** "Docker"
|
||||
* **Home:** "C:\\Users\\Docker"
|
||||
* **Shell:** PowerShell
|
||||
* **Packages:** Install missing packages as needed.
|
||||
* **Path Separators:** Use backslashes `\\` for file paths.
|
||||
* **Note:** Code execution might not be visible on screen immediately. GUI actions (like reopening files) may be needed to see changes.
|
||||
"""
|
||||
)
|
||||
elif platform == "macos":
|
||||
# Placeholder for macOS (Darwin) specific instructions
|
||||
PLATFORM_SPECIFIC_CONTEXT = textwrap.dedent(
|
||||
"""\
|
||||
# 2. Environment & Execution
|
||||
* **Platform:** MacOS(Darwin)
|
||||
* **User:** "pipiwu"
|
||||
* **Password:** "1234"
|
||||
* **Home:** "/Users/pipiwu"
|
||||
* **Shell:** Bash
|
||||
* **Packages:** Install missing packages as needed.
|
||||
* **Note:** Code execution might not be visible on screen immediately. GUI actions (like reopening files) may be needed to see changes.
|
||||
* **Note:** You have sudo privileges. It is recommended to use sudo when performing Bash actions.
|
||||
"""
|
||||
)
|
||||
|
||||
# 2. Define Common Instructions (Universal)
|
||||
COMMON_INSTRUCTIONS = textwrap.dedent(
|
||||
"""\
|
||||
You are a code execution agent. Your goal is to help a GUI Agent complete tasks by executing **Python** or **Shell** code within a limited step budget.
|
||||
|
||||
# 1. Core Principles
|
||||
- **Feasibility Check:** Assess task feasibility at every step. Do not attempt impossible tasks.
|
||||
- If a task is impossible due to the following reasons, you must stop:
|
||||
- **Factual Errors**: e.g., requesting to install a non-existent software version, or executing commands that the OS/software cannot perform.
|
||||
- **Missing Critical Prerequisites**: e.g., attempting to edit a file that does not exist and cannot be found. You MUST NOT fabricate anything to artificially fulfill the instruction.
|
||||
- In your (Thought) block, **clearly explain WHY** the task is infeasible.
|
||||
- In your (Answer) block, return FAIL.
|
||||
- **Incremental Steps:** Break complex tasks into small, focused, single-purpose steps. Do not write large, multi-step scripts in one block. Code **does not persist** between steps. Each code block you write MUST be a complete, standalone snippet.
|
||||
|
||||
{platform_context}
|
||||
|
||||
# 3. Core Workflow:
|
||||
1. **Find:** Locate the target file. The screenshot context may show which file is currently open and should be modified.
|
||||
2. **Inspect:** **ALWAYS** read and inspect file contents, data types, and formatting *before* modifying.
|
||||
3. **Modify:**
|
||||
* **Priority:** Modify existing open files IN-PLACE (use screenshot context). Only create new files when explicitly required by the task.
|
||||
* **Strategy:** Perform **COMPLETE OVERWRITES**, not appends. For text files, write the full new content. For .docx/.xlsx, replace all paragraphs/sheets with new content.
|
||||
* **Libraries:** Use appropriate libraries (e.g. `python-docx`, `openpyxl` and so on).
|
||||
* **Preservation:** **PRESERVE** all original formatting, headers (column headers and row headers), styles, file names and directory structure unless explicitly told to change them. The document's visual presentation should remain the same.
|
||||
4. **Verify:** After modifying, inspect the file again to confirm the changes were applied correctly. If verification fails, return to Step 3 and retry the modification.
|
||||
5. **Result Visualization**: At the final step before completing the task (the step before you return DONE), you MUST print out the contents of any files you modified. Use appropriate commands to display the final state of modified files:
|
||||
* For text files (Linux/Mac): `cat filename` or `head -n 50 filename`
|
||||
* For text files (Windows): `Get-Content filename -TotalCount 50` or `type filename`
|
||||
* For Python files: `cat filename.py` (Linux/Mac) or `type filename.py` (Windows)
|
||||
* For any other file type: use appropriate viewing commands.
|
||||
6. **Verification Instructions**: When you complete a task that modifies files, you MUST provide clear verification instructions including specific details about what the GUI agent should check:
|
||||
* Which files were modified and their expected final state (number of lines, key data points, etc.).
|
||||
* How to verify the changes are correct.
|
||||
* Whether the task is complete or if additional GUI actions are needed.
|
||||
|
||||
# 4. Response Format:
|
||||
You MUST respond using exactly this format:
|
||||
|
||||
(Thought)
|
||||
Your step-by-step reasoning about what needs to be done and how to approach the current step. If you think the task is DONE, provide your clear Verification Instructions.
|
||||
|
||||
(Answer)
|
||||
Return EXACTLY ONE of the following options. For all the options, you MUST wrap your answer by ```. The Options are:
|
||||
|
||||
For Python code:
|
||||
```python
|
||||
your_python_code_here
|
||||
```
|
||||
|
||||
For Bash/PowerShell commands:
|
||||
```bash
|
||||
your_shell_commands_here
|
||||
```
|
||||
|
||||
For task completion:
|
||||
```
|
||||
DONE
|
||||
```
|
||||
|
||||
For task failure:
|
||||
```
|
||||
FAIL
|
||||
```
|
||||
|
||||
For impossible tasks (factual errors or missing prerequisites):
|
||||
```
|
||||
INFEASIBLE
|
||||
```
|
||||
"""
|
||||
)
|
||||
|
||||
# 3. Combine and Return
|
||||
CODE_AGENT_PROMPT = COMMON_INSTRUCTIONS.format(platform_context=PLATFORM_SPECIFIC_CONTEXT)
|
||||
|
||||
return CODE_AGENT_PROMPT
|
||||
|
||||
CODE_SUMMARY_AGENT_PROMPT = textwrap.dedent(
|
||||
"""\
|
||||
You are a code execution summarizer. Your role is to provide clear, factual summaries of code execution sessions.
|
||||
|
||||
Key responsibilities:
|
||||
- Summarize the code logic and approach used at each step
|
||||
- Describe the outputs and results produced by code execution
|
||||
- Explain the progression of the solution approach
|
||||
- Use neutral, objective language without making judgments about success or failure
|
||||
- Focus on what was attempted and what resulted
|
||||
- Keep summaries concise and well-structured
|
||||
|
||||
CRITICAL: Include verification instructions for the GUI agent
|
||||
- If files were modified, provide specific verification guidance:
|
||||
* What files were changed and their expected final state
|
||||
* What the GUI agent should look for when verifying
|
||||
* How to verify the changes are correct
|
||||
* Whether the task appears complete or if additional GUI actions are needed
|
||||
- This helps the GUI agent understand what to expect and verify your work properly
|
||||
|
||||
Always maintain a factual, non-judgmental tone.
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def construct_vlm_searcher_procedural_memory(
|
||||
agent_class: type
|
||||
) -> str:
|
||||
"""
|
||||
Dynamically constructs the procedural memory (prompt) for the Searcher Agent.
|
||||
"""
|
||||
# The prompt is updated to focus on contextual alignment.
|
||||
procedural_memory = textwrap.dedent(
|
||||
f"""
|
||||
You are a Searcher Agent, a specialized expert in graphical user interfaces. Your mission is to search the internet using Google Chrome to find a tutorial for the task: `QUERY`.
|
||||
You are working in CURRENT_OS. Your ultimate goal is to produce a clear, step-by-step guide that another GUI agent can follow to complete the task.
|
||||
|
||||
# GUIDELINES
|
||||
|
||||
## Your Role and Goal
|
||||
You are a research assistant. You will be given a "how to" query and an initial screenshot showing the current screen of the main agent you are assisting. Your job is to use the Chrome browser to find the best possible tutorial that is well-aligned with the provided visual context.
|
||||
|
||||
## Leveraging Initial Context
|
||||
1. **Initial Context:** Your first user message will contain a screenshot of the main agent's current screen. This is a key piece of information.
|
||||
2. **Contextual Understanding:** Use this screenshot to understand the main agent's environment (e.g., which application is open, what menu is visible).
|
||||
3. **Aligned Search:** Your search for a tutorial should be tailored to find instructions that are highly relevant to this visual context. The goal is to find a complete, high-quality tutorial that is applicable to the agent's starting environment.
|
||||
|
||||
## Constraints
|
||||
1. **Strictly use Google Chrome.** You must perform all your actions within the Chrome browser window.
|
||||
2. **Be Thorough.** Explore different websites and articles to find the most accurate and comprehensive instructions.
|
||||
3. **Be Cautious.** The information you provide will directly guide another agent. If you are not confident in the accuracy of a step, do not include it.
|
||||
4. **Always rely on verified tutorials.** Use only tutorials that you have personally found and reviewed, rather than relying solely on your internal knowledge.
|
||||
|
||||
## Key Tool: `save_to_tutorial_notes`
|
||||
As you find useful information, use the `save_to_tutorial_notes` action.
|
||||
1. **Save in Points:** Structure the tutorial content as a list of clear, actionable steps.
|
||||
2. **Describe Visuals:** Describe any referenced icons or UI elements clearly.
|
||||
3. **Record URLs:** Always save the URL of the source page.
|
||||
|
||||
## Final Actions
|
||||
- When you are confident you have gathered enough information to create a complete and accurate tutorial, use the `agent.done()` action. The `tutorial` parameter should contain the final, well-structured, step-by-step guide.
|
||||
- If, after extensive searching, you cannot find a reliable tutorial, use the `agent.fail()` action. Provide a hint explaining why the search was unsuccessful.
|
||||
|
||||
**You are provided with**:
|
||||
1. A screenshot of the current time step.
|
||||
2. The history of your previous interactions with the UI.
|
||||
3. Tutorials notes you have already found.
|
||||
--- TUTORIAL NOTES START ---
|
||||
TUTORIAL_PLACEHOLDER
|
||||
--- TUTORIAL NOTES END ---
|
||||
4. Access to the following class and methods to interact with the UI. You must only use these actions.
|
||||
class Agent:
|
||||
"""
|
||||
)
|
||||
|
||||
for tool_name in dir(agent_class):
|
||||
if tool_name.startswith("_"):
|
||||
continue
|
||||
|
||||
attr = getattr(agent_class, tool_name)
|
||||
|
||||
if callable(attr) and hasattr(attr, "is_searcher_agent_action"):
|
||||
signature = inspect.signature(attr)
|
||||
docstring = inspect.getdoc(attr) or "No description available."
|
||||
|
||||
procedural_memory += textwrap.dedent(f"""
|
||||
def {tool_name}{signature}:
|
||||
'''{docstring}'''
|
||||
""")
|
||||
|
||||
procedural_memory += textwrap.dedent(
|
||||
"""
|
||||
# RESPONSE FORMAT
|
||||
Your response must follow this exact format:
|
||||
|
||||
(Previous action verification)
|
||||
Carefully analyze the screenshot to verify if your last action was successful. If it failed, explain why.
|
||||
|
||||
(Screenshot Analysis)
|
||||
Examine the current state of the Chrome browser. Describe the current webpage, any open tabs, and visible UI elements relevant to your search.
|
||||
|
||||
(Next Action)
|
||||
In natural language, decide the next logical step to find the tutorial. This could be refining your search query, clicking a link, scrolling down, or saving a note.
|
||||
|
||||
(Grounded Action)
|
||||
Translate your "Next Action" into a single line of Python code using the `agent` methods provided above.
|
||||
```python
|
||||
agent.type(element_description="the search bar at the top of the Google page", text="how to create a pivot table in excel", enter=True)
|
||||
```
|
||||
|
||||
Note for the grounded action:
|
||||
1. Only perform one action at a time.
|
||||
2. You must use only the available methods provided above. Do not invent new methods.
|
||||
3. Return with `agent.done()` immediately after you have compiled the complete tutorial, or `agent.fail()` if it cannot be completed.
|
||||
4. Prefer hotkeys (`agent.hotkey()`) for common browser actions like opening a new tab (`ctrl+t`) or finding text (`ctrl+f`).
|
||||
5. Generate `agent.fail()` if you are exhaustively stuck and believe the task is impossible.
|
||||
6. Generate `agent.done()` when you believe the task is fully complete and you have a high-quality tutorial.
|
||||
"""
|
||||
)
|
||||
|
||||
return procedural_memory
|
||||
|
||||
@staticmethod
|
||||
def construct_searcher_eager_mode_procedural_memory(
|
||||
agent_class: type
|
||||
):
|
||||
"""
|
||||
Constructs the procedural memory for a Searcher Agent in "Eager Mode" (final attempt).
|
||||
|
||||
This prompt is designed for the scenario where the agent has exhausted its step budget.
|
||||
It restricts the agent to only two possible actions: `done()` or `fail()`, forcing a final,
|
||||
decisive judgment based on the information gathered so far.
|
||||
"""
|
||||
# 1. Set the specific "last chance" introductory text.
|
||||
# This combines the urgency of the planner's eager mode with the Searcher's specific mission.
|
||||
procedural_memory = textwrap.dedent(
|
||||
f"""
|
||||
You are a Searcher Agent, a specialized expert in graphical user interfaces. Your operational budget is now EXHAUSTED.
|
||||
This is your FINAL opportunity to act. You must make a definitive judgment on the task: `QUERY`.
|
||||
You are working in CURRENT_OS.
|
||||
|
||||
# GUIDELINES
|
||||
|
||||
## Final Judgment Mode
|
||||
1. **Analyze Your Notes:** Carefully review all the information you have gathered using `save_to_tutorial_notes`.
|
||||
2. **Make a Final Decision:** Based on your notes, decide if you have enough high-quality information to construct a complete and reliable step-by-step tutorial.
|
||||
3. **Choose One of Two Actions:** You can ONLY use `agent.done()` or `agent.fail()`. No other actions are permitted.
|
||||
|
||||
- **If you choose `agent.done()`:** You MUST provide the complete, well-structured tutorial in the `tutorial` parameter. Compile all your useful notes into a final guide. Do NOT use `done` unless you are highly confident in the tutorial's accuracy and completeness.
|
||||
- **If you choose `agent.fail()`:** Use this if you could not find enough information, or if the information you found is contradictory, unreliable, or incomplete. Provide a reason in the `hint` parameter.
|
||||
|
||||
**You are provided with**:
|
||||
1. A screenshot of the current time step.
|
||||
2. The history of your previous interactions with the UI.
|
||||
3. Tutorials notes you have already found.
|
||||
--- TUTORIAL NOTES START ---
|
||||
TUTORIAL_PLACEHOLDER
|
||||
--- TUTORIAL NOTES END ---
|
||||
4. Access to the following class and methods to interact with the UI. You must only use these two actions.
|
||||
class Agent:
|
||||
"""
|
||||
)
|
||||
|
||||
# 2. Strictly inject only the 'done' and 'fail' methods.
|
||||
# This logic is adapted from the planner's eager mode constructor.
|
||||
eager_tools = ["done", "fail"]
|
||||
for tool_name in eager_tools:
|
||||
attr = getattr(agent_class, tool_name, None)
|
||||
|
||||
# We check for 'is_searcher_agent_action' to be consistent with the SearcherAgent's decorators.
|
||||
if attr and callable(attr) and hasattr(attr, "is_searcher_agent_action"):
|
||||
signature = inspect.signature(attr)
|
||||
docstring = inspect.getdoc(attr) or "No description available."
|
||||
procedural_memory += textwrap.dedent(f"""
|
||||
def {tool_name}{signature}:
|
||||
'''{docstring}'''
|
||||
""")
|
||||
|
||||
# 3. Provide the specific response format for this final decision.
|
||||
procedural_memory += textwrap.dedent(
|
||||
"""
|
||||
# RESPONSE FORMAT
|
||||
Your response must follow this exact format:
|
||||
|
||||
(Final Analysis and Tutorial Compilation)
|
||||
Review your collected notes and the final screenshot. State whether you have sufficient information to create a definitive tutorial. Summarize your reasoning.
|
||||
|
||||
(Final Decision)
|
||||
In natural language, declare your final choice. For example: "The search is successful, and I have compiled a complete tutorial." or "The search has failed because no reliable sources were found for this specific software version."
|
||||
|
||||
(Grounded Action)
|
||||
Translate your final decision into a single line of Python code using the `agent` methods provided above.
|
||||
**Example**:
|
||||
```python
|
||||
agent.done(tutorial="xxxx")
|
||||
```
|
||||
```python
|
||||
agent.fail(hint="xxxx")
|
||||
```
|
||||
**CRITICAL**: You MUST choose one of the following two actions. No other actions are allowed.
|
||||
"""
|
||||
)
|
||||
|
||||
return procedural_memory.strip()
|
||||
|
||||
|
||||
@staticmethod
|
||||
def construct_grounder_procedural_memory(model_name: str):
|
||||
system_prompt, user_message = None, f"Query:REF_EXPR\nOutput only the coordinate of one point in your response.\n"
|
||||
if "scalecua" in model_name.lower():
|
||||
user_message = "REF_EXPR"
|
||||
system_prompt = textwrap.dedent(
|
||||
'''
|
||||
You are an autonomous GUI agent capable of operating on desktops, mobile devices, and web browsers. Your primary function is to analyze screen captures and perform appropriate UI actions to complete assigned tasks.
|
||||
|
||||
## Action Space
|
||||
def click(
|
||||
x: float | None = None,
|
||||
y: float | None = None,
|
||||
clicks: int = 1,
|
||||
button: str = "left",
|
||||
) -> None:
|
||||
"""Clicks on the screen at the specified coordinates. The `x` and `y` parameter specify where the mouse event occurs. If not provided, the current mouse position is used. The `clicks` parameter specifies how many times to click, and the `button` parameter specifies which mouse button to use ('left', 'right', or 'middle')."""
|
||||
pass
|
||||
|
||||
def doubleClick(
|
||||
x: float | None = None,
|
||||
y: float | None = None,
|
||||
button: str = "left",
|
||||
) -> None:
|
||||
"""Performs a double click. This is a wrapper function for click(x, y, 2, 'left')."""
|
||||
pass
|
||||
|
||||
def rightClick(x: float | None = None, y: float | None = None) -> None:
|
||||
"""Performs a right mouse button click. This is a wrapper function for click(x, y, 1, 'right')."""
|
||||
pass
|
||||
|
||||
def moveTo(x: float, y: float) -> None:
|
||||
"""Move the mouse to the specified coordinates."""
|
||||
pass
|
||||
|
||||
def dragTo(
|
||||
x: float | None = None, y: float | None = None, button: str = "left"
|
||||
) -> None:
|
||||
"""Performs a drag-to action with optional `x` and `y` coordinates and button."""
|
||||
pass
|
||||
|
||||
def swipe(
|
||||
from_coord: tuple[float, float] | None = None,
|
||||
to_coord: tuple[float, float] | None = None,
|
||||
direction: str = "up",
|
||||
amount: float = 0.5,
|
||||
) -> None:
|
||||
"""Performs a swipe action on the screen. The `from_coord` and `to_coord` specify the starting and ending coordinates of the swipe. If `to_coord` is not provided, the `direction` and `amount` parameters are used to determine the swipe direction and distance. The `direction` can be 'up', 'down', 'left', or 'right', and the `amount` specifies how far to swipe relative to the screen size (0 to 1)."""
|
||||
pass
|
||||
|
||||
def long_press(x: float, y: float, duration: int = 1) -> None:
|
||||
"""Long press on the screen at the specified coordinates. The `duration` specifies how long to hold the press in seconds."""
|
||||
pass
|
||||
|
||||
## Input Specification
|
||||
- Screenshot of the current screen + task description
|
||||
|
||||
## Output Format
|
||||
<action>
|
||||
[A set of executable action command]
|
||||
</action>
|
||||
|
||||
## Note
|
||||
- Avoid action(s) that would lead to invalid states.
|
||||
- The generated action(s) must exist within the defined action space.
|
||||
- The generated action(s) should be enclosed within <action></action> tags.'''
|
||||
)
|
||||
return system_prompt, user_message
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
description: "The config of all tools"
|
||||
|
||||
tools:
|
||||
click:
|
||||
enabled: true
|
||||
|
||||
type:
|
||||
enabled: true
|
||||
|
||||
scroll:
|
||||
enabled: true
|
||||
|
||||
drag_and_drop:
|
||||
enabled: true
|
||||
|
||||
highlight_text_span:
|
||||
enabled: true
|
||||
|
||||
locate_cursor:
|
||||
enabled: true
|
||||
|
||||
call_code_agent:
|
||||
enabled: true
|
||||
|
||||
call_search_agent:
|
||||
enabled: true
|
||||
|
||||
scroll:
|
||||
enabled: true
|
||||
|
||||
hotkey:
|
||||
enabled: true
|
||||
|
||||
hold_and_press:
|
||||
enabled: true
|
||||
|
||||
wait:
|
||||
enabled: true
|
||||
|
||||
done:
|
||||
enabled: true
|
||||
|
||||
fail:
|
||||
enabled: true
|
||||
|
||||
open:
|
||||
enabled: true
|
||||
|
||||
|
|
@ -0,0 +1,448 @@
|
|||
import json
|
||||
import re
|
||||
import time
|
||||
from io import BytesIO
|
||||
from typing import Tuple, Dict, List, Union
|
||||
import io
|
||||
import os
|
||||
from PIL import Image, ImageDraw
|
||||
from mm_agents.os_symphony.memory.procedural_memory import PROCEDURAL_MEMORY
|
||||
from mm_agents.os_symphony.utils.process_context import get_current_result_dir
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("desktopenv.agent")
|
||||
|
||||
|
||||
def create_pyautogui_code(agent, code: str, obs: Dict) -> Tuple[str, dict | None]:
|
||||
"""
|
||||
Attempts to evaluate the code into a pyautogui code snippet with grounded actions using the observation screenshot.
|
||||
|
||||
Args:
|
||||
agent (ACI): The grounding agent to use for evaluation.
|
||||
code (str): The code string to evaluate.
|
||||
obs (Dict): The current observation containing the screenshot.
|
||||
|
||||
Returns:
|
||||
exec_code (str): The pyautogui code to execute the grounded action.
|
||||
coordinate (List): The coordinate of the action, a list such as [x1, y1, x2, y2, x3, y3...]. Because may appear more than one coordinate in one action.
|
||||
Modified by Yang.
|
||||
Raises:
|
||||
Exception: If there is an error in evaluating the code.
|
||||
"""
|
||||
agent.assign_screenshot(obs) # Necessary for grounding
|
||||
response = eval(code)
|
||||
if isinstance(response, Tuple):
|
||||
return response
|
||||
elif isinstance(response, str):
|
||||
return response, None
|
||||
else:
|
||||
return "", None
|
||||
|
||||
|
||||
def draw_coordinates(image_bytes: bytes, coordinates: List[Union[int, float]], save_path: str):
|
||||
"""
|
||||
Draw coordinates on the given image and save it to a new file.
|
||||
|
||||
This function receives an image as a byte stream, a list of coordinates in the format [x1, y1, x2, y2, ...],
|
||||
and draws a red 'X' at each (x, y) coordinate point. The resulting image is then saved to the specified path.
|
||||
|
||||
Args:
|
||||
- image_bytes (bytes): The raw byte data of the image (e.g., read from a PNG or JPEG file).
|
||||
- coordinates (List[Union[int, float]]): A flattened list of coordinates, must contain an even number of elements. For example: [x1, y1, x2, y2].
|
||||
- save_path (str): The path where the new image with markings will be saved.
|
||||
"""
|
||||
try:
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
image = image.convert("RGB")
|
||||
except Exception as e:
|
||||
return
|
||||
|
||||
draw = ImageDraw.Draw(image)
|
||||
|
||||
cross_size = 15
|
||||
cross_color = "red"
|
||||
cross_width = 3
|
||||
|
||||
for i in range(0, len(coordinates) - 1, 2):
|
||||
x, y = coordinates[i], coordinates[i+1]
|
||||
|
||||
line1_start = (x - cross_size, y - cross_size)
|
||||
line1_end = (x + cross_size, y + cross_size)
|
||||
|
||||
line2_start = (x + cross_size, y - cross_size)
|
||||
line2_end = (x - cross_size, y + cross_size)
|
||||
|
||||
draw.line([line1_start, line1_end], fill=cross_color, width=cross_width)
|
||||
draw.line([line2_start, line2_end], fill=cross_color, width=cross_width)
|
||||
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
image.save(save_path)
|
||||
|
||||
|
||||
def parse_action_from_string(string):
|
||||
'''
|
||||
Parse all strings following "(next action)", including the phrase "next action" itself. If parsing is not possible, return everything.
|
||||
'''
|
||||
marker = "(Next Action)"
|
||||
|
||||
start_index = string.find(marker)
|
||||
|
||||
if start_index != -1:
|
||||
return string[start_index:]
|
||||
else:
|
||||
return string
|
||||
|
||||
|
||||
def call_llm_safe(
|
||||
agent, temperature: float = 0.0, use_thinking: bool = False, **kwargs
|
||||
) -> str:
|
||||
|
||||
try:
|
||||
example_result_dir = get_current_result_dir()
|
||||
except Exception:
|
||||
example_result_dir = "logs/tokens"
|
||||
# Retry if fails
|
||||
max_retries = 3 # Set the maximum number of retries
|
||||
attempt = 0
|
||||
response = ""
|
||||
while attempt < max_retries:
|
||||
try:
|
||||
response = agent.get_response(
|
||||
temperature=temperature, use_thinking=use_thinking, **kwargs
|
||||
)
|
||||
assert response is not None, "Response from agent should not be None"
|
||||
# print("Response success!")
|
||||
break # If successful, break out of the loop
|
||||
except Exception as e:
|
||||
attempt += 1
|
||||
print(f"{agent.engine} Attempt {attempt} failed: {e}")
|
||||
if attempt == max_retries:
|
||||
print("Max retries reached. Handling failure.")
|
||||
time.sleep(1.0)
|
||||
# record token cost
|
||||
if isinstance(response, tuple):
|
||||
response, usage = response
|
||||
agent_name = agent.agent_name
|
||||
with open(os.path.join(example_result_dir, "token.jsonl"), "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps({
|
||||
"agent_name": agent_name,
|
||||
"completion_tokens": usage.completion_tokens,
|
||||
"prompt_tokens": usage.prompt_tokens,
|
||||
"total_tokens": usage.total_tokens
|
||||
}))
|
||||
f.write("\n")
|
||||
|
||||
return response if response is not None else ""
|
||||
|
||||
|
||||
def call_func_safe(
|
||||
func, **kwargs
|
||||
) -> str:
|
||||
# Retry if fails
|
||||
max_retries = 3 # Set the maximum number of retries
|
||||
attempt = 0
|
||||
response = ""
|
||||
while attempt < max_retries:
|
||||
try:
|
||||
response = func(**kwargs)
|
||||
break
|
||||
except Exception as e:
|
||||
attempt += 1
|
||||
print(f"Attempt {attempt} failed: {e}")
|
||||
if attempt == max_retries:
|
||||
print("Max retries reached. Handling failure.")
|
||||
time.sleep(1.0)
|
||||
|
||||
return response if response is not None else ""
|
||||
|
||||
|
||||
def extract_coords_from_action_dict(action_dict: Dict | None) -> List:
|
||||
coords = []
|
||||
coords_num = 0
|
||||
if action_dict:
|
||||
for k, v in action_dict["args"].items():
|
||||
if (k == "x" and v) or (k == "y" and v) or (k == "x1" and v) or (k == "x2" and v) or (k == "y1" and v) or (k == "y2" and v):
|
||||
coords_num += 1
|
||||
if coords_num == 2:
|
||||
coords.append(action_dict["args"]["x"])
|
||||
coords.append(action_dict["args"]["y"])
|
||||
if coords_num == 4:
|
||||
coords.append(action_dict["args"]["x1"])
|
||||
coords.append(action_dict["args"]["y1"])
|
||||
coords.append(action_dict["args"]["x2"])
|
||||
coords.append(action_dict["args"]["y2"])
|
||||
return coords
|
||||
|
||||
|
||||
def call_llm_formatted(generator, format_checkers, **kwargs):
|
||||
"""
|
||||
Calls the generator agent's LLM and ensures correct formatting.
|
||||
|
||||
Args:
|
||||
generator (ACI): The generator agent to call.
|
||||
obs (Dict): The current observation containing the screenshot.
|
||||
format_checkers (Callable): Functions that take the response and return a tuple of (success, feedback).
|
||||
**kwargs: Additional keyword arguments for the LLM call.
|
||||
|
||||
Returns:
|
||||
response (str): The formatted response from the generator agent.
|
||||
"""
|
||||
max_retries = 3 # Set the maximum number of retries
|
||||
attempt = 0
|
||||
response = ""
|
||||
if kwargs.get("messages") is None:
|
||||
messages = (
|
||||
generator.messages.copy()
|
||||
) # Copy messages to avoid modifying the original
|
||||
else:
|
||||
messages = kwargs["messages"]
|
||||
del kwargs["messages"] # Remove messages from kwargs to avoid passing it twice
|
||||
while attempt < max_retries:
|
||||
response = call_llm_safe(generator, messages=messages, **kwargs)
|
||||
# Prepare feedback messages for incorrect formatting
|
||||
feedback_msgs = []
|
||||
for format_checker in format_checkers:
|
||||
success, feedback = format_checker(response)
|
||||
if not success:
|
||||
feedback_msgs.append(feedback)
|
||||
if not feedback_msgs:
|
||||
# logger.info(f"Response formatted correctly on attempt {attempt} for {generator.engine.model}")
|
||||
break
|
||||
logger.error(
|
||||
f"Response formatting error on attempt {attempt} for {generator.engine.model}. Response: {response} {', '.join(feedback_msgs)}"
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": response}],
|
||||
}
|
||||
)
|
||||
logger.info(f"Bad response: {response}")
|
||||
delimiter = "\n- "
|
||||
formatting_feedback = f"- {delimiter.join(feedback_msgs)}"
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": PROCEDURAL_MEMORY.FORMATTING_FEEDBACK_PROMPT.replace(
|
||||
"FORMATTING_FEEDBACK", formatting_feedback
|
||||
),
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
logger.info("Feedback:\n%s", formatting_feedback)
|
||||
|
||||
attempt += 1
|
||||
if attempt == max_retries:
|
||||
logger.error(
|
||||
"Max retries reached when formatting response. Handling failure."
|
||||
)
|
||||
time.sleep(1.0)
|
||||
return response
|
||||
|
||||
|
||||
def split_thinking_response(full_response: str) -> Tuple[str, str]:
|
||||
try:
|
||||
# Extract thoughts section
|
||||
thoughts = full_response.split("<thoughts>")[-1].split("</thoughts>")[0].strip()
|
||||
|
||||
# Extract answer section
|
||||
answer = full_response.split("<answer>")[-1].split("</answer>")[0].strip()
|
||||
|
||||
return answer, thoughts
|
||||
except Exception as e:
|
||||
return full_response, ""
|
||||
|
||||
|
||||
def parse_code_from_string(input_string):
|
||||
"""Parses a string to extract each line of code enclosed in triple backticks (```)
|
||||
|
||||
Args:
|
||||
input_string (str): The input string containing code snippets.
|
||||
|
||||
Returns:
|
||||
str: The last code snippet found in the input string, or an empty string if no code is found.
|
||||
"""
|
||||
input_string = input_string.strip()
|
||||
|
||||
# This regular expression will match both ```code``` and ```python code```
|
||||
# and capture the `code` part. It uses a non-greedy match for the content inside.
|
||||
pattern = r"```(?:\w+\s+)?(.*?)```"
|
||||
# print(f'[parse_code_from_string].input_string: {input_string}')
|
||||
# Find all non-overlapping matches in the string
|
||||
matches = re.findall(pattern, input_string, re.DOTALL)
|
||||
if len(matches) == 0:
|
||||
# return []
|
||||
return ""
|
||||
relevant_code = matches[
|
||||
-1
|
||||
] # We only care about the last match given it is the grounded action
|
||||
# print(f'[parse_code_from_string].relevant_code: {relevant_code}')
|
||||
return relevant_code
|
||||
|
||||
|
||||
def extract_agent_functions(code):
|
||||
"""
|
||||
Extracts all agent function names from the given code.
|
||||
|
||||
Args:
|
||||
code (str): The code string to search.
|
||||
|
||||
Returns:
|
||||
list: A list of strings like ['agent.click', 'agent.type'].
|
||||
"""
|
||||
pattern = r"agent\.\w+"
|
||||
|
||||
return re.findall(pattern, code)
|
||||
|
||||
|
||||
def compress_image(image_bytes: bytes = None, image: Image = None) -> bytes:
|
||||
"""Compresses an image represented as bytes.
|
||||
|
||||
Compression involves resizing image into half its original size and saving to webp format.
|
||||
|
||||
Args:
|
||||
image_bytes (bytes): The image data to compress.
|
||||
|
||||
Returns:
|
||||
bytes: The compressed image data.
|
||||
"""
|
||||
if not image:
|
||||
image = Image.open(BytesIO(image_bytes))
|
||||
output = BytesIO()
|
||||
image.save(output, format="WEBP")
|
||||
compressed_image_bytes = output.getvalue()
|
||||
return compressed_image_bytes
|
||||
|
||||
import math
|
||||
|
||||
IMAGE_FACTOR = 28
|
||||
MIN_PIXELS = 100 * 28 * 28
|
||||
MAX_PIXELS = 16384 * 28 * 28
|
||||
MAX_RATIO = 200
|
||||
|
||||
def round_by_factor(number: int, factor: int) -> int:
|
||||
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
||||
return round(number / factor) * factor
|
||||
|
||||
|
||||
def ceil_by_factor(number: int, factor: int) -> int:
|
||||
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.ceil(number / factor) * factor
|
||||
|
||||
|
||||
def floor_by_factor(number: int, factor: int) -> int:
|
||||
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.floor(number / factor) * factor
|
||||
|
||||
|
||||
def smart_resize(
|
||||
height: int,
|
||||
width: int,
|
||||
factor: int = IMAGE_FACTOR,
|
||||
min_pixels: int = MIN_PIXELS,
|
||||
max_pixels: int = MAX_PIXELS,
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Rescales the image so that the following conditions are met:
|
||||
|
||||
1. Both dimensions (height and width) are divisible by 'factor'.
|
||||
|
||||
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
||||
|
||||
3. The aspect ratio of the image is maintained as closely as possible.
|
||||
"""
|
||||
min_pixels = MIN_PIXELS if not min_pixels else min_pixels
|
||||
max_pixels = MAX_PIXELS if not max_pixels else max_pixels
|
||||
if max(height, width) / min(height, width) > MAX_RATIO:
|
||||
raise ValueError(
|
||||
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
|
||||
)
|
||||
h_bar = max(factor, round_by_factor(height, factor))
|
||||
w_bar = max(factor, round_by_factor(width, factor))
|
||||
if h_bar * w_bar > max_pixels:
|
||||
beta = math.sqrt((height * width) / max_pixels)
|
||||
h_bar = floor_by_factor(height / beta, factor)
|
||||
w_bar = floor_by_factor(width / beta, factor)
|
||||
elif h_bar * w_bar < min_pixels:
|
||||
beta = math.sqrt(min_pixels / (height * width))
|
||||
h_bar = ceil_by_factor(height * beta, factor)
|
||||
w_bar = ceil_by_factor(width * beta, factor)
|
||||
return h_bar, w_bar
|
||||
|
||||
|
||||
def enhance_observation(image_data: bytes, coordinates: List, expansion_pixels: int = 400, draw=True) -> Tuple[bytes, int, int, int, int]:
|
||||
"""
|
||||
According to the given coordinates, draw markers on the screenshot and crop a "focused" area.
|
||||
|
||||
Returns:
|
||||
Tuple[bytes, int, int, int, int]:
|
||||
- new_image_data (bytes): Data of the cropped image
|
||||
- crop_left (int): X-axis offset
|
||||
- crop_top (int): Y-axis offset
|
||||
- new_width (int): Width of the cropped image
|
||||
- new_height (int): Height of the cropped image
|
||||
"""
|
||||
image = Image.open(io.BytesIO(image_data)).convert("RGBA")
|
||||
draw_ctx = ImageDraw.Draw(image)
|
||||
|
||||
img_width, img_height = image.size
|
||||
|
||||
X_MARKER_SIZE = 40
|
||||
X_MARKER_WIDTH = 5
|
||||
|
||||
def _draw_x(draw_context, center_x, center_y, size=X_MARKER_SIZE, color="red", width=X_MARKER_WIDTH):
|
||||
half_size = size // 2
|
||||
draw_context.line((center_x - half_size, center_y - half_size, center_x + half_size, center_y + half_size), fill=color, width=width)
|
||||
draw_context.line((center_x - half_size, center_y + half_size, center_x + half_size, center_y - half_size), fill=color, width=width)
|
||||
|
||||
crop_left, crop_top, crop_right, crop_bottom = 0, 0, img_width, img_height
|
||||
|
||||
if len(coordinates) == 2:
|
||||
x, y = coordinates[0], coordinates[1]
|
||||
if draw:
|
||||
_draw_x(draw_ctx, x, y)
|
||||
|
||||
crop_left = x - expansion_pixels
|
||||
crop_top = y - expansion_pixels
|
||||
crop_right = x + expansion_pixels
|
||||
crop_bottom = y + expansion_pixels
|
||||
|
||||
elif len(coordinates) >= 4:
|
||||
x1, y1 = coordinates[0], coordinates[1]
|
||||
x2, y2 = coordinates[2], coordinates[3]
|
||||
|
||||
if draw:
|
||||
_draw_x(draw_ctx, x1, y1, color="red")
|
||||
_draw_x(draw_ctx, x2, y2, color="blue")
|
||||
draw_ctx.line((x1, y1, x2, y2), fill="green", width=5)
|
||||
|
||||
box_left = min(x1, x2)
|
||||
box_top = min(y1, y2)
|
||||
box_right = max(x1, x2)
|
||||
box_bottom = max(y1, y2)
|
||||
|
||||
crop_left = box_left - expansion_pixels
|
||||
crop_top = box_top - expansion_pixels
|
||||
crop_right = box_right + expansion_pixels
|
||||
crop_bottom = box_bottom + expansion_pixels
|
||||
|
||||
# check boundary
|
||||
crop_left = max(0, int(crop_left))
|
||||
crop_top = max(0, int(crop_top))
|
||||
crop_right = min(img_width, int(crop_right))
|
||||
crop_bottom = min(img_height, int(crop_bottom))
|
||||
|
||||
crop_box = (crop_left, crop_top, crop_right, crop_bottom)
|
||||
cropped_image = image.crop(crop_box)
|
||||
|
||||
new_width, new_height = cropped_image.size
|
||||
|
||||
buffered = io.BytesIO()
|
||||
cropped_image.save(buffered, format="PNG")
|
||||
|
||||
return buffered.getvalue(), crop_left, crop_top, new_width, new_height
|
||||
|
||||
|
|
@ -0,0 +1,106 @@
|
|||
"""This file contains various formatting checks used to reprompt an agent for correctly formatted responses."""
|
||||
from typing import List
|
||||
import json
|
||||
import yaml
|
||||
import re
|
||||
from mm_agents.os_symphony.utils.common_utils import (
|
||||
extract_agent_functions,
|
||||
parse_code_from_string,
|
||||
split_thinking_response,
|
||||
)
|
||||
|
||||
|
||||
single_action_check = (
|
||||
lambda response: len(extract_agent_functions(parse_code_from_string(response))) == 1
|
||||
)
|
||||
single_action_error_msg = (
|
||||
"Incorrect code: There must be a single agent action in the code response."
|
||||
)
|
||||
SINGLE_ACTION_FORMATTER = lambda response: (
|
||||
single_action_check(response),
|
||||
single_action_error_msg,
|
||||
)
|
||||
|
||||
|
||||
def code_valid_check(tool_config, response):
|
||||
code = parse_code_from_string(response)
|
||||
print(f'[code_valid_check] parsed code is: {code}')
|
||||
|
||||
# check if the action is pre-defined
|
||||
with open(tool_config, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
valid_methods = set(config['tools'].keys())
|
||||
|
||||
pattern = r"^agent\.(\w+)\(.*\)$"
|
||||
|
||||
match = re.match(pattern, code.strip(), re.DOTALL)
|
||||
|
||||
if match:
|
||||
method_name = match.group(1)
|
||||
print(f'[code_valid_check]: method is {method_name}')
|
||||
if method_name in valid_methods:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
code_valid_error_msg = "Incorrect code: The agent action must be a SINGLE and VALID function and use valid parameters from the docstring list."
|
||||
CODE_VALID_FORMATTER = lambda tool_config, response: (
|
||||
code_valid_check(tool_config, response),
|
||||
code_valid_error_msg,
|
||||
)
|
||||
|
||||
thoughts_answer_tag_check = lambda response: split_thinking_response(response)[1] != ""
|
||||
thoughts_answer_tag_error_msg = "Incorrect response: The response must contain both <thoughts>...</thoughts> and <answer>...</answer> tags."
|
||||
THOUGHTS_ANSWER_TAG_FORMATTER = lambda response: (
|
||||
thoughts_answer_tag_check(response),
|
||||
thoughts_answer_tag_error_msg,
|
||||
)
|
||||
|
||||
integer_answer_check = (
|
||||
lambda response: split_thinking_response(response)[0].strip().isdigit()
|
||||
)
|
||||
integer_answer_error_msg = (
|
||||
"Incorrect response: The <answer>...</answer> tag must contain a single integer."
|
||||
)
|
||||
INTEGER_ANSWER_FORMATTER = lambda response: (
|
||||
integer_answer_check(response),
|
||||
integer_answer_error_msg,
|
||||
)
|
||||
|
||||
|
||||
def json_answer_check(response: str, required_fields: List[str]) -> bool:
|
||||
"""
|
||||
一个只返回 True/False 的检查函数。
|
||||
"""
|
||||
try:
|
||||
answer_str = parse_code_from_string(response)
|
||||
|
||||
if len(answer_str) == 0:
|
||||
return False
|
||||
|
||||
data = json.loads(answer_str)
|
||||
|
||||
if not isinstance(data, dict):
|
||||
return False
|
||||
|
||||
if set(required_fields) - set(data.keys()):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
json_answer_error_msg = (
|
||||
"Incorrect response: The (Answer) part must contain a valid JSON object that includes ALL required keys and need to be wrapped by ```json and ```"
|
||||
)
|
||||
|
||||
|
||||
JSON_ANSWER_FORMATTER = lambda response, required_fields: (
|
||||
json_answer_check(required_fields, response),
|
||||
json_answer_error_msg,
|
||||
)
|
||||
|
|
@ -0,0 +1,216 @@
|
|||
import io
|
||||
import math
|
||||
import numpy as np
|
||||
from typing import List, Tuple, Dict, Any, Optional
|
||||
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from rapidfuzz import fuzz
|
||||
import logging
|
||||
from mm_agents.os_symphony.agents.memoryer_agent import StepBehavior
|
||||
|
||||
logger = logging.getLogger("desktopenv.loop_detection")
|
||||
|
||||
def _are_actions_similar(
|
||||
action1: Dict[str, Any],
|
||||
action2: Dict[str, Any],
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
relative_coord_threshold: float,
|
||||
fuzzy_text_threshold: float,
|
||||
) -> bool:
|
||||
"""
|
||||
[Internal Auxiliary] Determine if two actions are similar based on detailed rules.
|
||||
|
||||
Args:
|
||||
action1: The first action.
|
||||
action2: The second action.
|
||||
image_width: The width of the screenshot.
|
||||
image_height: The height of the screenshot.
|
||||
relative_coord_threshold: A relative distance threshold for coordinate comparison.
|
||||
fuzzy_text_threshold: A similarity threshold (0-100) for fuzzy text matching.
|
||||
|
||||
Returns:
|
||||
Return True if the actions are similar, otherwise return False.
|
||||
"""
|
||||
# ensure same action
|
||||
if action1.get("function") != action2.get("function"):
|
||||
return False
|
||||
|
||||
func = action1.get("function")
|
||||
args1 = action1.get("args", {})
|
||||
args2 = action2.get("args", {})
|
||||
|
||||
diagonal = math.sqrt(image_width**2 + image_height**2)
|
||||
abs_coord_thresh = relative_coord_threshold * diagonal
|
||||
|
||||
def are_coords_close(x1, y1, x2, y2):
|
||||
if None in [x1, y1, x2, y2]: return False
|
||||
distance = math.sqrt((x1 - x2)**2 + (y1 - y2)**2)
|
||||
return distance < abs_coord_thresh
|
||||
|
||||
if func == "click":
|
||||
return (
|
||||
are_coords_close(args1.get("x"), args1.get("y"), args2.get("x"), args2.get("y")) and
|
||||
args1.get("button") == args2.get("button") and
|
||||
args1.get("clicks") == args2.get("clicks")
|
||||
)
|
||||
|
||||
elif func == "open":
|
||||
return args1.get("name") == args2.get("name")
|
||||
|
||||
elif func == "type":
|
||||
if args1.get("x") and args1.get("y") and args2.get("x") and args2.get("y"):
|
||||
return (
|
||||
are_coords_close(args1.get("x"), args1.get("y"), args2.get("x"), args2.get("y")) and
|
||||
args1.get("text") == args2.get("text")
|
||||
)
|
||||
else:
|
||||
return args1.get("text") == args2.get("text")
|
||||
|
||||
elif func == "drag":
|
||||
return (
|
||||
are_coords_close(args1.get("x1"), args1.get("y1"), args2.get("x1"), args2.get("y1")) and
|
||||
are_coords_close(args1.get("x2"), args1.get("y2"), args2.get("x2"), args2.get("y2"))
|
||||
)
|
||||
|
||||
elif func == "set_cell_values":
|
||||
return args1.get("text") == args2.get("text")
|
||||
|
||||
elif func == "scroll":
|
||||
clicks1 = args1.get("clicks", 0)
|
||||
clicks2 = args2.get("clicks", 0)
|
||||
if (clicks1 == 0 and clicks2 != 0) or (clicks1 != 0 and clicks2 == 0):
|
||||
same_direction = False
|
||||
else:
|
||||
same_direction = math.copysign(1, clicks1) == math.copysign(1, clicks2)
|
||||
|
||||
return (
|
||||
are_coords_close(args1.get("x"), args1.get("y"), args2.get("x"), args2.get("y")) and
|
||||
same_direction and
|
||||
args1.get("shift") == args2.get("shift")
|
||||
)
|
||||
|
||||
elif func == "key":
|
||||
return args1.get("keys") == args2.get("keys")
|
||||
|
||||
elif func == "wait":
|
||||
return True
|
||||
|
||||
elif func in ["call_code_agent", "call_search_agent"]:
|
||||
query1 = args1.get("query", "")
|
||||
query2 = args2.get("query", "")
|
||||
# use Levenshtein distance to calculate fuzzy similarity
|
||||
query_similarity = fuzz.token_set_ratio(query1, query2)
|
||||
# print(f'query_sim: {query_similarity}')
|
||||
return (
|
||||
query_similarity >= fuzzy_text_threshold and
|
||||
args1.get("result") == args2.get("result")
|
||||
)
|
||||
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def _are_steps_similar_optimized(
|
||||
step1: StepBehavior,
|
||||
step2: StepBehavior,
|
||||
idx1: int,
|
||||
idx2: int,
|
||||
full_trajectory: List[StepBehavior],
|
||||
phash_threshold: int,
|
||||
ssim_threshold: float,
|
||||
# 动作比较所需的参数
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
relative_coord_threshold: float,
|
||||
fuzzy_text_threshold: float,
|
||||
) -> bool:
|
||||
"""
|
||||
[Internal Auxiliary] use pre-calculated data to quickly determine if the two actions are similar/
|
||||
"""
|
||||
|
||||
if step1.phash is None or step2.phash is None:
|
||||
return False
|
||||
|
||||
if (step1.phash - step2.phash) > phash_threshold:
|
||||
return False
|
||||
|
||||
|
||||
later_step_idx = max(idx1, idx2)
|
||||
earlier_step_idx = min(idx1, idx2)
|
||||
|
||||
ssim_score = full_trajectory[later_step_idx].ssim_list[earlier_step_idx]
|
||||
|
||||
if ssim_score < ssim_threshold:
|
||||
return False
|
||||
|
||||
if not _are_actions_similar(
|
||||
step1.action_dict, step2.action_dict,
|
||||
image_width, image_height, relative_coord_threshold, fuzzy_text_threshold
|
||||
):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def detect_loop(
|
||||
full_trajectory: List[StepBehavior],
|
||||
image_width: int = 1920,
|
||||
image_height: int = 1080,
|
||||
N: int = 3,
|
||||
phash_threshold: int = 1,
|
||||
ssim_threshold: float = 0.99,
|
||||
relative_coord_threshold: float = 0.02,
|
||||
fuzzy_text_threshold: float = 85.0,
|
||||
) -> Tuple[bool, Optional[Dict[str, List[int]]]]:
|
||||
"""
|
||||
Efficiently detect the presence of looping patterns based on precomputed data.
|
||||
|
||||
Args:
|
||||
full_trajectory (List[StepBehavior]): Full history including the current step.
|
||||
image_width (int): Width of the screenshot.
|
||||
image_height (int): Height of the screenshot.
|
||||
N (int): Number of steps in the candidate loop (sequence length).
|
||||
phash_threshold (int): Hamming distance threshold for pHash similarity. Recommended: 0–2.
|
||||
ssim_threshold (float): SSIM similarity threshold for image comparison. Recommended: 0.95–0.99.
|
||||
relative_coord_threshold (float): Relative threshold for coordinate similarity. Recommended: 0.01–0.05.
|
||||
fuzzy_text_threshold (float): Fuzzy text matching similarity threshold (0–100) for agent queries.
|
||||
|
||||
Returns:
|
||||
A tuple (is_loop_detected, loop_info):
|
||||
- is_loop_detected (bool): Whether a loop is detected.
|
||||
- loop_info (Dict | None): If a loop is detected, contains the indices of the two matching sequences.
|
||||
"""
|
||||
L = len(full_trajectory)
|
||||
|
||||
if not isinstance(N, int) or N <= 0 or L < 2 * N:
|
||||
return False, None
|
||||
|
||||
max_start_index = L - 2 * N
|
||||
for i in range(max_start_index, -1, -1):
|
||||
is_potential_match = True
|
||||
|
||||
for j in range(N):
|
||||
idx_prev = i + j
|
||||
idx_curr = (L - N) + j
|
||||
|
||||
step_prev = full_trajectory[idx_prev]
|
||||
step_curr = full_trajectory[idx_curr]
|
||||
|
||||
if not _are_steps_similar_optimized(
|
||||
step_prev, step_curr, idx_prev, idx_curr, full_trajectory,
|
||||
phash_threshold, ssim_threshold,
|
||||
image_width, image_height, relative_coord_threshold, fuzzy_text_threshold
|
||||
):
|
||||
is_potential_match = False
|
||||
break
|
||||
|
||||
if is_potential_match:
|
||||
previous_sequence_indices = list(range(i, i + N))
|
||||
loop_info = {
|
||||
"match_sequence_indices": previous_sequence_indices
|
||||
}
|
||||
return True, loop_info
|
||||
|
||||
return False, None
|
||||
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
# process_context.py
|
||||
# This module provides an independent context storage for each process.
|
||||
|
||||
from multiprocessing import current_process
|
||||
|
||||
# We will store process-specific contexts here.
|
||||
# Since each process has its own separate memory space, when accessing this variable,
|
||||
# each process accesses its own copy, without conflicting with others.
|
||||
_context_storage = {}
|
||||
|
||||
def set_context(key, value):
|
||||
"""Set a value in the context of the current process."""
|
||||
_context_storage[key] = value
|
||||
# print(f"[{current_process().name}] Set context: {key} = {value}") # For debugging
|
||||
|
||||
def get_context(key, default=None):
|
||||
"""Retrieve a value from the context of the current process."""
|
||||
value = _context_storage.get(key, default)
|
||||
# print(f"[{current_process().name}] Get context: {key} -> {value}") # For debugging
|
||||
if value is None and default is None:
|
||||
raise NameError(f"'{key}' not found in the current process context. Ensure it is set at the process entry point.")
|
||||
return value
|
||||
|
||||
# For convenience, we can create a specialized getter for result_dir
|
||||
def get_current_result_dir():
|
||||
"""Get the result_dir specific to the current process."""
|
||||
return get_context('current_result_dir')
|
||||
|
||||
def set_current_result_dir(example_result_dir):
|
||||
set_context("current_result_dir", example_result_dir)
|
||||
File diff suppressed because one or more lines are too long
|
|
@ -69,4 +69,6 @@ alibabacloud_ecs20140526
|
|||
alibabacloud_tea_openapi
|
||||
alibabacloud_tea_util
|
||||
json_minify
|
||||
json_repair
|
||||
json_repair
|
||||
volcengine-python-sdk[ark]
|
||||
ui-tars>=0.4.2.2
|
||||
|
|
@ -0,0 +1,603 @@
|
|||
"""
|
||||
Script to run EvoCUA native agent model on OSWorld tasks.
|
||||
|
||||
export AWS_ACCESS_KEY_ID="xx"
|
||||
export AWS_SECRET_ACCESS_KEY="xx"
|
||||
export AWS_REGION="xx"
|
||||
export AWS_SECURITY_GROUP_ID="xx"
|
||||
export AWS_SUBNET_ID="xx"
|
||||
export OPENAI_API_KEY="xxxx"
|
||||
export OPENAI_BASE_URL="xxxx"
|
||||
|
||||
Example Usage (S2):
|
||||
python3 run_multienv_evocua.py \
|
||||
--headless \
|
||||
--provider_name aws \
|
||||
--observation_type screenshot \
|
||||
--model EvoCUA-S2 \
|
||||
--result_dir ./evocua_s2 \
|
||||
--test_all_meta_path evaluation_examples/test_nogdrive.json \
|
||||
--max_steps 50 \
|
||||
--num_envs 30 \
|
||||
--temperature 0.01 \
|
||||
--max_history_turns 4 \
|
||||
--coordinate_type relative \
|
||||
--resize_factor 32 \
|
||||
--prompt_style S2
|
||||
|
||||
|
||||
Example Usage (S1):
|
||||
python3 run_multienv_evocua.py \
|
||||
--headless \
|
||||
--provider_name aws \
|
||||
--observation_type screenshot \
|
||||
--model EvoCUA-S1 \
|
||||
--result_dir ./evocua_s1 \
|
||||
--test_all_meta_path evaluation_examples/test_nogdrive.json \
|
||||
--max_steps 50 \
|
||||
--num_envs 30 \
|
||||
--max_history_turns 3 \
|
||||
--coordinate_type qwen25 \
|
||||
--max_tokens 10240 \
|
||||
--resize_factor 28 \
|
||||
--prompt_style S1
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import signal
|
||||
import time
|
||||
from typing import List
|
||||
from multiprocessing import Process, Manager, Queue
|
||||
from multiprocessing import current_process
|
||||
import lib_run_single
|
||||
from desktop_env.desktop_env import DesktopEnv
|
||||
from mm_agents.evocua.evocua_agent import EvoCUAAgent
|
||||
|
||||
# Global variables for signal handling
|
||||
active_environments = []
|
||||
processes = []
|
||||
is_terminating = False
|
||||
|
||||
|
||||
# Thread-local storage for task context (works per-process in multiprocessing)
|
||||
import threading
|
||||
_task_context = threading.local()
|
||||
|
||||
def get_task_context():
|
||||
"""Get current task context from thread-local storage."""
|
||||
return getattr(_task_context, 'context', {'domain': None, 'example_id': None})
|
||||
|
||||
def set_task_context(domain: str, example_id: str):
|
||||
"""Set current task context in thread-local storage."""
|
||||
_task_context.context = {'domain': domain, 'example_id': example_id}
|
||||
|
||||
def clear_task_context():
|
||||
"""Clear current task context."""
|
||||
if hasattr(_task_context, 'context'):
|
||||
delattr(_task_context, 'context')
|
||||
|
||||
class TaskContextFilter(logging.Filter):
|
||||
"""Filter to add domain and example_id to log records."""
|
||||
def filter(self, record):
|
||||
ctx = get_task_context()
|
||||
domain = ctx.get('domain')
|
||||
example_id = ctx.get('example_id')
|
||||
if domain and example_id:
|
||||
record.domain = domain
|
||||
record.example_id = example_id
|
||||
# Add prefix to message
|
||||
if hasattr(record, 'msg') and isinstance(record.msg, str):
|
||||
if not record.msg.startswith(f"[{domain}/{example_id}]"):
|
||||
record.msg = f"[{domain}/{example_id}] {record.msg}"
|
||||
else:
|
||||
record.domain = domain or "N/A"
|
||||
record.example_id = example_id or "N/A"
|
||||
return True
|
||||
|
||||
# load the environment variables from .env file
|
||||
if os.path.exists(".env"):
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
def config() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run end-to-end evaluation with EvoCUAAgent"
|
||||
)
|
||||
|
||||
# environment config
|
||||
parser.add_argument("--path_to_vm", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--headless", action="store_true", help="Run in headless machine"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--action_space", type=str, default="pyautogui", help="Action type"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--observation_type",
|
||||
choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
default="screenshot",
|
||||
help="Observation type",
|
||||
)
|
||||
parser.add_argument("--sleep_after_execution", type=float, default=5.0)
|
||||
parser.add_argument("--max_steps", type=int, default=50)
|
||||
|
||||
# evaluation config
|
||||
parser.add_argument(
|
||||
"--test_config_base_dir", type=str, default="evaluation_examples"
|
||||
)
|
||||
|
||||
# lm config
|
||||
parser.add_argument("--model", type=str, default="evocua", help="Model name.")
|
||||
parser.add_argument("--temperature", type=float, default=0.0)
|
||||
parser.add_argument("--top_p", type=float, default=0.9)
|
||||
parser.add_argument("--max_tokens", type=int, default=32768)
|
||||
parser.add_argument("--stop_token", type=str, default=None)
|
||||
|
||||
parser.add_argument("--prompt_style", type=str, default="S2", choices=["S1", "S2"], help="Prompt style: 'S1' (structured reasoning) or 'S2' (tool calling)")
|
||||
parser.add_argument("--history_type", type=str, default="action_history", help="[S1] History type")
|
||||
parser.add_argument("--coordinate_type", type=str, default="relative", help="Coordinate type: relative, absolute, qwen25")
|
||||
parser.add_argument("--password", type=str, default="osworld-public-evaluation", help="VM Password")
|
||||
|
||||
# Unified History Parameter
|
||||
parser.add_argument("--max_history_turns", type=int, default=3, help="Number of history turns to include")
|
||||
parser.add_argument("--resize_factor", type=int, default=32, help="Image resize factor (S1: 28, S2: 32)")
|
||||
|
||||
# example config
|
||||
parser.add_argument("--domain", type=str, default="all")
|
||||
parser.add_argument(
|
||||
"--test_all_meta_path", type=str, default="evaluation_examples/test_nogdrive.json"
|
||||
)
|
||||
|
||||
# logging related
|
||||
parser.add_argument("--result_dir", type=str, default="./results")
|
||||
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel")
|
||||
parser.add_argument("--log_level", type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
|
||||
default='INFO', help="Set the logging level")
|
||||
# aws config
|
||||
parser.add_argument(
|
||||
"--region", type=str, default="us-east-1", help="AWS region for the VM"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--provider_name", type=str, default="aws", choices=["aws", "virtualbox", "vmware", "docker", "azure"], help="Provider name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--client_password", type=str, default="", help="Client password"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--screen_width", type=int, default=1920, help="Screen width"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--screen_height", type=int, default=1080, help="Screen height"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
args = config()
|
||||
|
||||
logger = logging.getLogger()
|
||||
log_level = getattr(logging, args.log_level.upper())
|
||||
logger.setLevel(log_level)
|
||||
|
||||
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
|
||||
file_handler = logging.FileHandler(
|
||||
os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||
)
|
||||
debug_handler = logging.FileHandler(
|
||||
os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||
)
|
||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||
|
||||
file_handler.setLevel(logging.INFO)
|
||||
debug_handler.setLevel(logging.DEBUG)
|
||||
stdout_handler.setLevel(log_level)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
|
||||
)
|
||||
file_handler.setFormatter(formatter)
|
||||
debug_handler.setFormatter(formatter)
|
||||
stdout_handler.setFormatter(formatter)
|
||||
|
||||
# Add task context filter to all handlers
|
||||
task_filter = TaskContextFilter()
|
||||
file_handler.addFilter(task_filter)
|
||||
debug_handler.addFilter(task_filter)
|
||||
stdout_handler.addFilter(task_filter)
|
||||
|
||||
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
||||
|
||||
logger.addHandler(file_handler)
|
||||
logger.addHandler(debug_handler)
|
||||
logger.addHandler(stdout_handler)
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
|
||||
def distribute_tasks(test_all_meta: dict) -> List[tuple]:
|
||||
all_tasks = []
|
||||
for domain, examples in test_all_meta.items():
|
||||
for example_id in examples:
|
||||
all_tasks.append((domain, example_id))
|
||||
return all_tasks
|
||||
|
||||
|
||||
def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: list):
|
||||
active_environments = []
|
||||
env = None
|
||||
try:
|
||||
REGION = args.region
|
||||
screen_size = (args.screen_width, args.screen_height)
|
||||
|
||||
# Determine snapshot based on provider
|
||||
snapshot_name = "init_state"
|
||||
if args.provider_name == "aws":
|
||||
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
|
||||
ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION].get((1920, 1080)))
|
||||
snapshot_name = ami_id
|
||||
|
||||
env = DesktopEnv(
|
||||
path_to_vm=args.path_to_vm,
|
||||
action_space=args.action_space,
|
||||
provider_name=args.provider_name,
|
||||
region=REGION,
|
||||
snapshot_name=snapshot_name,
|
||||
screen_size=screen_size,
|
||||
headless=args.headless,
|
||||
os_type="Ubuntu",
|
||||
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
enable_proxy=True,
|
||||
client_password=args.client_password
|
||||
)
|
||||
|
||||
active_environments.append(env)
|
||||
|
||||
logger.info(f"Process {current_process().name} started.")
|
||||
while True:
|
||||
try:
|
||||
item = task_queue.get(timeout=5)
|
||||
except Exception:
|
||||
break
|
||||
domain, example_id = item
|
||||
set_task_context(domain, example_id)
|
||||
try:
|
||||
config_file = os.path.join(
|
||||
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
|
||||
)
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
example = json.load(f)
|
||||
logger.info(f"[{current_process().name}][Domain]: {domain}")
|
||||
logger.info(f"[{current_process().name}][Example ID]: {example_id}")
|
||||
logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}")
|
||||
example_result_dir = os.path.join(
|
||||
args.result_dir,
|
||||
args.action_space,
|
||||
args.observation_type,
|
||||
args.model,
|
||||
domain,
|
||||
example_id,
|
||||
)
|
||||
os.makedirs(example_result_dir, exist_ok=True)
|
||||
|
||||
# Initialize EvoCUAAgent
|
||||
agent = EvoCUAAgent(
|
||||
model=args.model,
|
||||
max_tokens=args.max_tokens,
|
||||
top_p=args.top_p,
|
||||
temperature=args.temperature,
|
||||
action_space=args.action_space,
|
||||
observation_type=args.observation_type,
|
||||
max_steps=args.max_steps,
|
||||
prompt_style=args.prompt_style,
|
||||
max_history_turns=args.max_history_turns,
|
||||
screen_size=(args.screen_width, args.screen_height),
|
||||
coordinate_type=args.coordinate_type,
|
||||
password=args.password,
|
||||
resize_factor=args.resize_factor,
|
||||
)
|
||||
|
||||
try:
|
||||
lib_run_single.run_single_example_evocua(
|
||||
agent,
|
||||
env,
|
||||
example,
|
||||
args.max_steps,
|
||||
example["instruction"],
|
||||
args,
|
||||
example_result_dir,
|
||||
shared_scores,
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
try:
|
||||
env.controller.end_recording(
|
||||
os.path.join(example_result_dir, "recording.mp4")
|
||||
)
|
||||
except Exception as rec_e:
|
||||
logger.error(f"Failed to end recording: {rec_e}")
|
||||
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||
f.write(json.dumps({"Error": f"{domain}/{example_id} - {e}"}))
|
||||
f.write("\n")
|
||||
except Exception as e:
|
||||
logger.error(f"Task-level error in {current_process().name}: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
finally:
|
||||
clear_task_context()
|
||||
except Exception as e:
|
||||
logger.error(f"Process-level error in {current_process().name}: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
finally:
|
||||
logger.info(f"{current_process().name} cleaning up environment...")
|
||||
try:
|
||||
if env:
|
||||
env.close()
|
||||
logger.info(f"{current_process().name} environment closed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"{current_process().name} error during environment cleanup: {e}")
|
||||
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
"""Handle termination signals (SIGINT, SIGTERM) to gracefully shutdown environments."""
|
||||
global is_terminating, active_environments, processes
|
||||
|
||||
if is_terminating:
|
||||
return
|
||||
|
||||
is_terminating = True
|
||||
logger.info(f"Received signal {signum}. Gracefully shutting down...")
|
||||
|
||||
for env in active_environments:
|
||||
try:
|
||||
logger.info(f"Closing environment...")
|
||||
env.close()
|
||||
logger.info(f"Environment closed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing environment: {e}")
|
||||
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
try:
|
||||
logger.info(f"Sending termination signal to process {p.name}...")
|
||||
p.terminate()
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending termination signal to process: {e}")
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
try:
|
||||
logger.info(f"Forcefully terminating process {p.name}...")
|
||||
import signal as sig
|
||||
os.kill(p.pid, sig.SIGKILL)
|
||||
except Exception as e:
|
||||
logger.error(f"Error forcefully terminating process: {e}")
|
||||
|
||||
logger.info("Shutdown complete. Exiting.")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||
global processes
|
||||
logger.info("Args: %s", args)
|
||||
all_tasks = distribute_tasks(test_all_meta)
|
||||
logger.info(f"Total tasks: {len(all_tasks)}")
|
||||
with Manager() as manager:
|
||||
shared_scores = manager.list()
|
||||
task_queue = manager.Queue()
|
||||
for item in all_tasks:
|
||||
task_queue.put(item)
|
||||
num_envs = args.num_envs
|
||||
processes = []
|
||||
for i in range(num_envs):
|
||||
p = Process(
|
||||
target=run_env_tasks,
|
||||
args=(task_queue, args, shared_scores),
|
||||
name=f"EnvProcess-{i+1}"
|
||||
)
|
||||
p.daemon = True
|
||||
p.start()
|
||||
processes.append(p)
|
||||
logger.info(f"Started process {p.name} with PID {p.pid}")
|
||||
try:
|
||||
while True:
|
||||
alive_count = 0
|
||||
for idx, p in enumerate(processes):
|
||||
if not p.is_alive():
|
||||
logger.warning(f"Process {p.name} died, restarting...")
|
||||
new_p = Process(
|
||||
target=run_env_tasks,
|
||||
args=(task_queue, args, shared_scores),
|
||||
name=f"EnvProcess-Restart-{idx+1}"
|
||||
)
|
||||
new_p.daemon = True
|
||||
new_p.start()
|
||||
processes[idx] = new_p
|
||||
logger.info(f"Restarted process {new_p.name} with PID {new_p.pid}")
|
||||
else:
|
||||
alive_count += 1
|
||||
if task_queue.empty():
|
||||
logger.info("All tasks finished.")
|
||||
break
|
||||
if alive_count == 0:
|
||||
logger.error("All processes died, exiting.")
|
||||
break
|
||||
time.sleep(5)
|
||||
for p in processes:
|
||||
p.join()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True)
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
try:
|
||||
logger.info(f"Terminating process {p.name} due to error...")
|
||||
p.terminate()
|
||||
except Exception as term_e:
|
||||
logger.error(f"Error terminating process {p.name}: {term_e}")
|
||||
raise
|
||||
scores = list(shared_scores)
|
||||
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
|
||||
|
||||
|
||||
def get_unfinished(
|
||||
action_space, use_model, observation_type, result_dir, total_file_json
|
||||
):
|
||||
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||
|
||||
if not os.path.exists(target_dir):
|
||||
return total_file_json
|
||||
|
||||
finished = {}
|
||||
for domain in os.listdir(target_dir):
|
||||
finished[domain] = []
|
||||
domain_path = os.path.join(target_dir, domain)
|
||||
if os.path.isdir(domain_path):
|
||||
for example_id in os.listdir(domain_path):
|
||||
if example_id == "onboard":
|
||||
continue
|
||||
example_path = os.path.join(domain_path, example_id)
|
||||
if os.path.isdir(example_path):
|
||||
if "result.txt" not in os.listdir(example_path):
|
||||
for file in os.listdir(example_path):
|
||||
os.remove(os.path.join(example_path, file))
|
||||
else:
|
||||
finished[domain].append(example_id)
|
||||
|
||||
if not finished:
|
||||
return total_file_json
|
||||
|
||||
for domain, examples in finished.items():
|
||||
if domain in total_file_json:
|
||||
total_file_json[domain] = [
|
||||
x for x in total_file_json[domain] if x not in examples
|
||||
]
|
||||
|
||||
return total_file_json
|
||||
|
||||
|
||||
def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
|
||||
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||
if not os.path.exists(target_dir):
|
||||
print("New experiment, no result yet.")
|
||||
return None
|
||||
|
||||
all_result = []
|
||||
|
||||
for domain in os.listdir(target_dir):
|
||||
domain_path = os.path.join(target_dir, domain)
|
||||
if os.path.isdir(domain_path):
|
||||
for example_id in os.listdir(domain_path):
|
||||
example_path = os.path.join(domain_path, example_id)
|
||||
if os.path.isdir(example_path):
|
||||
if "result.txt" in os.listdir(example_path):
|
||||
try:
|
||||
all_result.append(
|
||||
float(
|
||||
open(
|
||||
os.path.join(example_path, "result.txt"), "r"
|
||||
).read()
|
||||
)
|
||||
)
|
||||
except:
|
||||
all_result.append(0.0)
|
||||
|
||||
if not all_result:
|
||||
print("New experiment, no result yet.")
|
||||
return None
|
||||
else:
|
||||
print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%")
|
||||
return all_result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
# Register signal handlers
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
try:
|
||||
args = config()
|
||||
|
||||
path_to_args = os.path.join(
|
||||
args.result_dir,
|
||||
args.action_space,
|
||||
args.observation_type,
|
||||
args.model,
|
||||
"args.json",
|
||||
)
|
||||
os.makedirs(os.path.dirname(path_to_args), exist_ok=True)
|
||||
with open(path_to_args, "w", encoding="utf-8") as f:
|
||||
json.dump(vars(args), f, indent=4)
|
||||
|
||||
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
|
||||
test_all_meta = json.load(f)
|
||||
|
||||
if args.domain != "all":
|
||||
test_all_meta = {args.domain: test_all_meta[args.domain]}
|
||||
|
||||
test_file_list = get_unfinished(
|
||||
args.action_space,
|
||||
args.model,
|
||||
args.observation_type,
|
||||
args.result_dir,
|
||||
test_all_meta,
|
||||
)
|
||||
left_info = ""
|
||||
for domain in test_file_list:
|
||||
left_info += f"{domain}: {len(test_file_list[domain])}\n"
|
||||
logger.info(f"Left tasks:\n{left_info}")
|
||||
|
||||
get_result(
|
||||
args.action_space,
|
||||
args.model,
|
||||
args.observation_type,
|
||||
args.result_dir,
|
||||
test_all_meta,
|
||||
)
|
||||
test(args, test_file_list)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Main process received KeyboardInterrupt.")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in main process: {e}", exc_info=True)
|
||||
signal_handler(signal.SIGTERM, None)
|
||||
finally:
|
||||
logger.info("Main process final cleanup...")
|
||||
for env in active_environments:
|
||||
if env is not None:
|
||||
try:
|
||||
logger.info("Closing environment in final cleanup...")
|
||||
env.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error during final environment cleanup: {e}")
|
||||
|
||||
for p in processes:
|
||||
if p is not None and p.is_alive():
|
||||
try:
|
||||
p.terminate()
|
||||
except Exception as e:
|
||||
logger.error(f"Error terminating process: {e}")
|
||||
|
||||
time.sleep(1)
|
||||
for p in processes:
|
||||
if p is not None and p.is_alive():
|
||||
try:
|
||||
os.kill(p.pid, signal.SIGKILL)
|
||||
except Exception as e:
|
||||
logger.error(f"Error force killing process: {e}")
|
||||
|
|
@ -0,0 +1,907 @@
|
|||
"""
|
||||
OS-Symphony Official Evaluation Script
|
||||
|
||||
This script serves as the official evaluation entry point for OS-Symphony.
|
||||
It handles the setup of the desktop environment, agent initialization, and
|
||||
execution of evaluation tasks.
|
||||
|
||||
For detailed evaluation metrics, configuration options, and usage instructions,
|
||||
please refer to the official repository:
|
||||
https://github.com/OS-Copilot/OS-Symphony
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import signal
|
||||
import time
|
||||
from multiprocessing import Process, Manager, current_process, Queue
|
||||
|
||||
from mm_agents.os_symphony.agents.os_symphony import OSSymphony
|
||||
from mm_agents.os_symphony.agents.os_aci import OSACI
|
||||
import lib_run_single
|
||||
# Modify desktop_env, add a new function 'start'
|
||||
from desktop_env.desktop_env_os_symphony import DesktopEnv as OSWorldDesktopEnv
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
# only for WAA
|
||||
def prepare_worker_vm_paths(base_golden_path: str, worker_idx: int):
|
||||
# remove the '/' at the end
|
||||
base_golden_path = base_golden_path.rstrip(os.sep)
|
||||
|
||||
# get parent directory (like /nvme/yangbowen/vm_stroage/waa)
|
||||
parent_dir = os.path.dirname(base_golden_path)
|
||||
|
||||
# define the path of this worker
|
||||
worker_storage_path = os.path.join(parent_dir, f"storage_{worker_idx}")
|
||||
worker_backup_path = os.path.join(parent_dir, f"storage_{worker_idx}_backup")
|
||||
|
||||
return worker_storage_path, worker_backup_path
|
||||
|
||||
|
||||
# only for WAA
|
||||
def initialize_worker_files(golden_path: str, worker_backup_path: str, worker_storage_path: str):
|
||||
"""
|
||||
Initialize worker. If backup doesn't exist, then replicate from golden path.
|
||||
"""
|
||||
if not os.path.exists(golden_path):
|
||||
raise FileNotFoundError(f"Golden VM path not found: {golden_path}")
|
||||
|
||||
if not os.path.exists(worker_backup_path):
|
||||
logger.info(f"Initializing backup for worker from {golden_path} to {worker_backup_path} ...")
|
||||
try:
|
||||
os.makedirs(os.path.dirname(worker_backup_path), exist_ok=True)
|
||||
|
||||
if os.path.isdir(golden_path):
|
||||
subprocess.check_call(['cp', '-r', '--sparse=always', golden_path, worker_backup_path])
|
||||
else:
|
||||
subprocess.check_call(['cp', '--sparse=always', golden_path, worker_backup_path])
|
||||
|
||||
logger.info(f"Backup initialization complete for {worker_backup_path}")
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"Failed to copy golden image to backup using cp: {e}")
|
||||
raise e
|
||||
else:
|
||||
logger.info(f"Worker backup already exists at {worker_backup_path}, skipping copy.")
|
||||
|
||||
if not os.path.exists(worker_storage_path):
|
||||
os.makedirs(worker_storage_path, exist_ok=True)
|
||||
|
||||
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
formatter = logging.Formatter(
|
||||
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
|
||||
)
|
||||
|
||||
# Set up Stdout handler
|
||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||
stdout_handler.setLevel(logging.INFO)
|
||||
stdout_handler.setFormatter(formatter)
|
||||
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
||||
logger.addHandler(stdout_handler)
|
||||
|
||||
# Set up File Handler
|
||||
# file_handler = logging.FileHandler(filename="log.txt")
|
||||
# file_handler.setLevel(logging.ERROR)
|
||||
# file_handler.setFormatter(formatter)
|
||||
# file_handler.addFilter(logging.Filter("desktopenv"))
|
||||
# logger.addHandler(file_handler)
|
||||
|
||||
# Logger Configs
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
|
||||
# Global variables for signal handling
|
||||
active_environments = []
|
||||
processes = []
|
||||
is_terminating = False
|
||||
|
||||
|
||||
def distribute_tasks(test_all_meta: dict) -> list:
|
||||
all_tasks = []
|
||||
for domain, examples in test_all_meta.items():
|
||||
for example_id in examples:
|
||||
all_tasks.append((domain, example_id))
|
||||
return all_tasks
|
||||
|
||||
|
||||
def process_signal_handler(signum, frame, env_idx):
|
||||
logger.info(f"Process {env_idx + 1} received signal {signum}. Shutting down...")
|
||||
local_vars = frame.f_locals
|
||||
active_environments = local_vars.get("active_environments", [])
|
||||
for env in active_environments:
|
||||
if env is not None:
|
||||
try:
|
||||
logger.info(f"Process {env_idx + 1} closing environment...")
|
||||
env.close()
|
||||
logger.info(f"Process {env_idx + 1} environment closed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Process {env_idx + 1} error closing environment: {e}")
|
||||
logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def run_env_tasks(
|
||||
task_queue: Queue,
|
||||
args: argparse.Namespace,
|
||||
shared_scores: list,
|
||||
engine_params_for_orchestrator,
|
||||
engine_params_for_grounder,
|
||||
engine_params_for_coder,
|
||||
engine_params_for_memoryer,
|
||||
engine_params_for_searcher,
|
||||
worker_id: int,
|
||||
):
|
||||
active_environments = []
|
||||
env = None
|
||||
search_env = None
|
||||
try:
|
||||
# Use IMAGE_ID_MAP for AWS provider to get snapshot_name
|
||||
snapshot_name = None
|
||||
region = getattr(args, "region", "us-east-1")
|
||||
platform = 'linux'
|
||||
screen_size = (args.screen_width, args.screen_height)
|
||||
|
||||
if "osworld" in args.benchmark:
|
||||
if args.provider_name == "aws":
|
||||
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
|
||||
ami_id = IMAGE_ID_MAP[region].get(screen_size, IMAGE_ID_MAP[region][(1920, 1080)])
|
||||
env = OSWorldDesktopEnv(
|
||||
path_to_vm=args.path_to_vm,
|
||||
action_space=args.action_space,
|
||||
provider_name=args.provider_name,
|
||||
region=region,
|
||||
snapshot_name=ami_id,
|
||||
screen_size=screen_size,
|
||||
headless=args.headless,
|
||||
os_type="Ubuntu",
|
||||
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"]
|
||||
)
|
||||
elif args.provider_name == "docker":
|
||||
env = OSWorldDesktopEnv(
|
||||
path_to_vm=args.path_to_vm,
|
||||
action_space=args.action_space,
|
||||
provider_name=args.provider_name,
|
||||
region=region,
|
||||
snapshot_name=snapshot_name,
|
||||
screen_size=screen_size,
|
||||
headless=args.headless,
|
||||
os_type="Ubuntu",
|
||||
require_a11y_tree=args.observation_type
|
||||
in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
enable_proxy=True,
|
||||
client_password=getattr(args, "client_password", "")
|
||||
)
|
||||
else:
|
||||
raise Exception("Don't support other providers!")
|
||||
|
||||
env.start()
|
||||
|
||||
if args.provider_name == "aws":
|
||||
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
|
||||
ami_id = IMAGE_ID_MAP[region].get(screen_size, IMAGE_ID_MAP[region][(1920, 1080)])
|
||||
search_env = OSWorldDesktopEnv(
|
||||
path_to_vm=args.path_to_vm,
|
||||
action_space=args.action_space,
|
||||
provider_name=args.provider_name,
|
||||
region=region,
|
||||
snapshot_name=ami_id,
|
||||
screen_size=screen_size,
|
||||
headless=args.headless,
|
||||
os_type="Ubuntu",
|
||||
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"]
|
||||
)
|
||||
elif args.provider_name == "docker":
|
||||
search_env = OSWorldDesktopEnv(
|
||||
path_to_vm=args.path_to_vm,
|
||||
action_space=args.action_space,
|
||||
provider_name=args.provider_name,
|
||||
region=region,
|
||||
snapshot_name=snapshot_name,
|
||||
screen_size=screen_size,
|
||||
headless=args.headless,
|
||||
os_type="Ubuntu",
|
||||
require_a11y_tree=args.observation_type
|
||||
in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
enable_proxy=True,
|
||||
client_password=getattr(args, "client_password", "")
|
||||
)
|
||||
else:
|
||||
raise Exception("Don't support other providers!")
|
||||
|
||||
engine_params_for_ocr = copy.deepcopy(engine_params_for_orchestrator)
|
||||
engine_params_for_ocr["agent_name"] = "ocr"
|
||||
os_aci = OSACI(
|
||||
env=env,
|
||||
search_env=search_env,
|
||||
platform=platform,
|
||||
client_password=args.client_password,
|
||||
engine_params_for_ocr=engine_params_for_ocr,
|
||||
engine_params_for_grounder=engine_params_for_grounder,
|
||||
engine_params_for_coder=engine_params_for_coder,
|
||||
engine_params_for_searcher=engine_params_for_searcher,
|
||||
screen_width=args.screen_width,
|
||||
screen_height=args.screen_height,
|
||||
)
|
||||
agent = OSSymphony(
|
||||
engine_params_for_orchestrator,
|
||||
engine_params_for_memoryer,
|
||||
os_aci,
|
||||
platform=platform,
|
||||
client_password=args.client_password,
|
||||
max_trajectory_length=args.max_trajectory_length,
|
||||
enable_reflection=args.enable_reflection,
|
||||
)
|
||||
|
||||
active_environments.append(env)
|
||||
active_environments.append(search_env)
|
||||
logger.info(f"Process {current_process().name} started.")
|
||||
while True:
|
||||
try:
|
||||
item = task_queue.get(timeout=5)
|
||||
except Exception:
|
||||
break
|
||||
domain, example_id = item
|
||||
try:
|
||||
config_file = os.path.join(
|
||||
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
|
||||
)
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
example = json.load(f)
|
||||
|
||||
if args.enable_rewrite_instruction and "rewritten_instruction" in example:
|
||||
instruction = example["rewritten_instruction"]
|
||||
else:
|
||||
instruction = example["instruction"]
|
||||
|
||||
example_result_dir = os.path.join(
|
||||
args.result_dir,
|
||||
domain,
|
||||
example_id
|
||||
)
|
||||
os.makedirs(example_result_dir, exist_ok=True)
|
||||
logger.info(f"[{current_process().name}][Domain]: {domain}")
|
||||
logger.info(f"[{current_process().name}][Example ID]: {example_id}")
|
||||
logger.info(f"[{current_process().name}][Instruction]: {instruction}")
|
||||
try:
|
||||
lib_run_single.run_single_example_os_symphony(
|
||||
agent,
|
||||
env,
|
||||
example,
|
||||
args.max_steps,
|
||||
instruction,
|
||||
args,
|
||||
example_result_dir,
|
||||
shared_scores,
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
logger.error(
|
||||
f"Exception in {current_process().name} {domain}/{example_id}: {e}"
|
||||
)
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
with open(os.path.join(os.path.dirname(example_result_dir), "error.jsonl"), "a") as f:
|
||||
f.write(json.dumps({"Error": f"{domain}/{example_id} - {e}"}))
|
||||
f.write("\n")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Task-level error in {current_process().name}: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
except Exception as e:
|
||||
logger.error(f"Process-level error in {current_process().name}: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
finally:
|
||||
logger.info(f"{current_process().name} cleaning up environment...")
|
||||
try:
|
||||
if env:
|
||||
env.close()
|
||||
logger.info(f"{current_process().name} environment closed successfully")
|
||||
if search_env:
|
||||
search_env.close()
|
||||
logger.info(f"{current_process().name} searcher environment closed successfully")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"{current_process().name} error during environment cleanup: {e}"
|
||||
)
|
||||
|
||||
# exit function
|
||||
def signal_handler(signum, frame):
|
||||
global is_terminating, active_environments, processes
|
||||
if is_terminating:
|
||||
return
|
||||
is_terminating = True
|
||||
logger.info(f"Received signal {signum}. Gracefully shutting down...")
|
||||
for env in active_environments:
|
||||
try:
|
||||
logger.info(f"Closing environment...")
|
||||
env.close()
|
||||
logger.info(f"Environment closed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing environment: {e}")
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
try:
|
||||
logger.info(f"Sending termination signal to process {p.name}...")
|
||||
p.terminate()
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending termination signal to process: {e}")
|
||||
time.sleep(1)
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
try:
|
||||
logger.info(f"Forcefully terminating process {p.name}...")
|
||||
import signal as sig
|
||||
|
||||
os.kill(p.pid, sig.SIGKILL)
|
||||
except Exception as e:
|
||||
logger.error(f"Error forcefully terminating process: {e}")
|
||||
logger.info("Shutdown complete. Exiting.")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def config() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run end-to-end evaluation on the benchmark"
|
||||
)
|
||||
|
||||
# environment config
|
||||
parser.add_argument("--path_to_vm", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--provider_name",
|
||||
type=str,
|
||||
default="vmware",
|
||||
help="Virtualization provider (vmware, docker, aws, azure, gcp, virtualbox)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--headless", action="store_true", help="Run in headless machine"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--action_space", type=str, default="pyautogui", help="Action type"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--observation_type",
|
||||
choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
default="screenshot",
|
||||
help="Observation type",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_envs",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of environments to run in parallel",
|
||||
)
|
||||
parser.add_argument("--screen_width", type=int, default=1920, help="Main environment's width")
|
||||
parser.add_argument("--screen_height", type=int, default=1080, help="Main environment's height")
|
||||
parser.add_argument("--sleep_after_execution", type=float, default=1.0)
|
||||
parser.add_argument("--max_steps", type=int, default=15)
|
||||
|
||||
# Benchmark
|
||||
parser.add_argument("--benchmark", type=str, default="osworld", help="osworld / waa / macos")
|
||||
|
||||
parser.add_argument("--domain", type=str, default="all")
|
||||
parser.add_argument(
|
||||
"--test_all_meta_path", type=str, default="evaluation_examples/osworld/test_all.json"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test_config_base_dir", type=str, default="evaluation_examples"
|
||||
)
|
||||
parser.add_argument("--result_dir", type=str, default="./results")
|
||||
|
||||
parser.add_argument(
|
||||
"--region", type=str, default="us-east-1", help="AWS region for the VM for OSWorld."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--client_password", type=str, default="password", help="Client password for OSWorld. Aws is 'osworld-public-evaluation', other is 'password'"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--proxy", type=str, default="http://10.1.8.5:23128", help="Important! Proxy setting, format should be http://<ip>:<port>, if no-use, set it empty"
|
||||
)
|
||||
|
||||
# agent config
|
||||
parser.add_argument("--max_trajectory_length", type=int, default=8)
|
||||
parser.add_argument("--enable_reflection", action="store_true", default=False)
|
||||
parser.add_argument("--enable_rewrite_instruction", action="store_true", default=False)
|
||||
parser.add_argument(
|
||||
"--tool_config",
|
||||
type=str,
|
||||
help="The path of tool config yaml"
|
||||
)
|
||||
|
||||
# generator-agent config
|
||||
parser.add_argument("--orchestrator_provider", type=str, default="openai")
|
||||
parser.add_argument("--orchestrator_model", type=str, default="gpt-5")
|
||||
parser.add_argument(
|
||||
"--orchestrator_url",
|
||||
type=str,
|
||||
default="",
|
||||
help="The URL of the main orchestrator model API.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--orchestrator_api_key",
|
||||
type=str,
|
||||
default="",
|
||||
help="The API key of the main orchestrator model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--orchestrator_temperature",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Temperature to fix the orchestrator model at (e.g. o3 can only be run with 1.0)",
|
||||
)
|
||||
parser.add_argument("--orchestrator_keep_first_image", action="store_true", default=False, help="Whether keep the first image(first state) in the orchestrator agent")
|
||||
|
||||
# code-agent config
|
||||
parser.add_argument("--coder_provider", type=str, default="openai")
|
||||
parser.add_argument("--coder_model", type=str, default="gpt-4o")
|
||||
parser.add_argument(
|
||||
"--coder_url",
|
||||
type=str,
|
||||
default="",
|
||||
help="The URL of the coder model API.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--coder_api_key",
|
||||
type=str,
|
||||
default="",
|
||||
help="The API key of the coder model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--coder_temperature",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Temperature to fix the coder model at (e.g. o3 can only be run with 1.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--coder_budget",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Max inner loop steps of coder agent",
|
||||
)
|
||||
|
||||
# reflection-memory agent config
|
||||
parser.add_argument("--memoryer_provider", type=str, default="openai")
|
||||
parser.add_argument("--memoryer_model", type=str, default="gpt-4o")
|
||||
parser.add_argument(
|
||||
"--memoryer_url",
|
||||
type=str,
|
||||
default="",
|
||||
help="The URL of the memoryer model API.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--memoryer_api_key",
|
||||
type=str,
|
||||
default="",
|
||||
help="The API key of the memoryer model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--memoryer_temperature",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Temperature to fix the memoryer model at (e.g. o3 can only be run with 1.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--memoryer_max_images",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Max images of memoryer model"
|
||||
)
|
||||
|
||||
# search model config
|
||||
parser.add_argument("--searcher_provider", type=str, default="openai")
|
||||
parser.add_argument("--searcher_model", type=str, default="gpt-4o")
|
||||
parser.add_argument(
|
||||
"--searcher_url",
|
||||
type=str,
|
||||
default="",
|
||||
help="The URL of the searcher model API.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--searcher_api_key",
|
||||
type=str,
|
||||
default="",
|
||||
help="The API key of the searcher model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--searcher_temperature",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Temperature to fix searcher model at (e.g. o3 can only be run with 1.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--searcher_type",
|
||||
type=str,
|
||||
default="vlm",
|
||||
help="Type of search agent, vlm/llm(all in search action), default is vlm",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--searcher_engine",
|
||||
type=str,
|
||||
default="google",
|
||||
help="Type of search engine, google / duckduckgo",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--searcher_budget",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Max inner loop steps of search agent",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--searcher_screen_width",
|
||||
type=int,
|
||||
default=1920,
|
||||
help="Search enviroment's width",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--searcher_screen_height",
|
||||
type=int,
|
||||
default=1080,
|
||||
help="Search enviroment's height",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--searcher_path_to_vm",
|
||||
type=str,
|
||||
default="/nvme/yangbowen/vm_stroage/osworld/Ubuntu.qcow2",
|
||||
help="Searcher Env VM's path (OSWorld'VM Path)",
|
||||
)
|
||||
|
||||
# grounding model config, temperature is 0 with hardcode
|
||||
parser.add_argument(
|
||||
"--grounder_provider",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The provider for the grounder model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grounder_url", type=str, required=True, help="The URL of the grounder model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grounder_api_key",
|
||||
type=str,
|
||||
default="",
|
||||
help="The API key of the grounder model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grounder_model",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The model name for the grounder model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grounding_width",
|
||||
type=int,
|
||||
required=True,
|
||||
help="Width of screenshot image after processor rescaling",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grounding_height",
|
||||
type=int,
|
||||
required=True,
|
||||
help="Height of screenshot image after processor rescaling",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grounding_smart_resize",
|
||||
action="store_true", default=False,
|
||||
help="UI-TARS-1.5 and ScaleCUA needs smart resize, if this set, grounding_width and grounding_height is no use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grounder_zoom_in_time",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Zoom-in times for grounder agent, aiming to enhance grounding ability.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp_name",
|
||||
type=str,
|
||||
default="",
|
||||
help="Experiment Name",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||
global processes
|
||||
logger.info("Args: %s", args)
|
||||
all_tasks = distribute_tasks(test_all_meta)
|
||||
logger.info(f"Total tasks: {len(all_tasks)}")
|
||||
|
||||
engine_params_for_orchestrator = {
|
||||
"engine_type": args.orchestrator_provider,
|
||||
"model": args.orchestrator_model,
|
||||
"base_url": getattr(args, "orchestrator_url", ""),
|
||||
"api_key": getattr(args, "orchestrator_api_key", ""),
|
||||
"temperature": getattr(args, "orchestrator_temperature", None),
|
||||
"tool_config": args.tool_config,
|
||||
"keep_first_image": args.orchestrator_keep_first_image,
|
||||
"agent_name": "orchestrator"
|
||||
}
|
||||
|
||||
|
||||
engine_params_for_grounder = {
|
||||
"engine_type": args.grounder_provider,
|
||||
"model": args.grounder_model,
|
||||
"base_url": getattr(args, "grounder_url", ""),
|
||||
"api_key": getattr(args, "grounder_api_key", ""),
|
||||
"grounding_width": args.grounding_width,
|
||||
"grounding_height": args.grounding_height,
|
||||
"grounding_smart_resize": args.grounding_smart_resize,
|
||||
"grounder_zoom_in_time": args.grounder_zoom_in_time,
|
||||
"agent_name": "grounder"
|
||||
}
|
||||
|
||||
engine_params_for_coder = {
|
||||
"engine_type": args.coder_provider,
|
||||
"model": args.coder_model,
|
||||
"base_url": getattr(args, "coder_url", ""),
|
||||
"api_key": getattr(args, "coder_api_key", ""),
|
||||
"temperature": getattr(args, "coder_temperature", None),
|
||||
"budget": args.coder_budget,
|
||||
"agent_name": "coder"
|
||||
}
|
||||
|
||||
engine_params_for_memoryer = {
|
||||
"engine_type": args.memoryer_provider,
|
||||
"model": args.memoryer_model,
|
||||
"base_url": getattr(args, "memoryer_url", ""),
|
||||
"api_key": getattr(args, "memoryer_api_key", ""),
|
||||
"temperature": getattr(args, "memoryer_temperature", None),
|
||||
"max_images": args.memoryer_max_images,
|
||||
"agent_name": "memoryer"
|
||||
}
|
||||
|
||||
engine_params_for_searcher = {
|
||||
"engine_type": args.searcher_provider,
|
||||
"model": args.searcher_model,
|
||||
"base_url": getattr(args, "searcher_url", ""),
|
||||
"api_key": getattr(args, "searcher_api_key", ""),
|
||||
"temperature": getattr(args, "searcher_temperature", None),
|
||||
"budget": args.searcher_budget,
|
||||
"type": args.searcher_type,
|
||||
"engine": args.searcher_engine,
|
||||
"agent_name": "searcher"
|
||||
}
|
||||
|
||||
# --- Initialize Worker Path ---
|
||||
num_envs = args.num_envs
|
||||
# only for waa
|
||||
if args.benchmark == "waa":
|
||||
logger.info(f"[WindowsAgentArena] Initializing storage for {num_envs} workers from golden image: {args.path_to_vm}")
|
||||
for i in range(num_envs):
|
||||
s_path, b_path = prepare_worker_vm_paths(args.path_to_vm, i)
|
||||
initialize_worker_files(args.path_to_vm, b_path, s_path)
|
||||
|
||||
with Manager() as manager:
|
||||
shared_scores = manager.list()
|
||||
task_queue = manager.Queue()
|
||||
for item in all_tasks:
|
||||
task_queue.put(item)
|
||||
processes = []
|
||||
for worker_id in range(num_envs):
|
||||
p = Process(
|
||||
target=run_env_tasks,
|
||||
args=(
|
||||
task_queue,
|
||||
args,
|
||||
shared_scores,
|
||||
engine_params_for_orchestrator,
|
||||
engine_params_for_grounder,
|
||||
engine_params_for_coder,
|
||||
engine_params_for_memoryer,
|
||||
engine_params_for_searcher,
|
||||
worker_id
|
||||
),
|
||||
name=f"EnvProcess-{worker_id+1}",
|
||||
)
|
||||
p.daemon = True
|
||||
p.start()
|
||||
processes.append(p)
|
||||
logger.info(f"Started process {p.name} with PID {p.pid}")
|
||||
try:
|
||||
while True:
|
||||
alive_count = 0
|
||||
for idx, p in enumerate(processes):
|
||||
if not p.is_alive():
|
||||
logger.warning(f"Process {p.name} died, restarting...")
|
||||
new_p = Process(
|
||||
target=run_env_tasks,
|
||||
args=(
|
||||
task_queue,
|
||||
args,
|
||||
shared_scores,
|
||||
engine_params_for_orchestrator,
|
||||
engine_params_for_grounder,
|
||||
engine_params_for_coder,
|
||||
engine_params_for_memoryer,
|
||||
engine_params_for_searcher,
|
||||
idx
|
||||
),
|
||||
name=f"EnvProcess-Restart-{idx+1}",
|
||||
)
|
||||
new_p.daemon = True
|
||||
new_p.start()
|
||||
processes[idx] = new_p
|
||||
logger.info(
|
||||
f"Restarted process {new_p.name} with PID {new_p.pid}"
|
||||
)
|
||||
else:
|
||||
alive_count += 1
|
||||
if task_queue.empty():
|
||||
logger.info("All tasks finished.")
|
||||
break
|
||||
if alive_count == 0:
|
||||
logger.error("All processes died, exiting.")
|
||||
break
|
||||
time.sleep(5)
|
||||
for p in processes:
|
||||
p.join()
|
||||
except KeyboardInterrupt:
|
||||
logger.info(
|
||||
"Main process received KeyboardInterrupt. Initiating graceful shutdown..."
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Unexpected error while waiting for processes: {e}", exc_info=True
|
||||
)
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
try:
|
||||
logger.info(f"Terminating process {p.name} due to error...")
|
||||
p.terminate()
|
||||
except Exception as term_e:
|
||||
logger.error(f"Error terminating process {p.name}: {term_e}")
|
||||
raise
|
||||
scores = list(shared_scores)
|
||||
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
|
||||
|
||||
def get_unfinished(
|
||||
target_dir, total_file_json
|
||||
):
|
||||
|
||||
if not os.path.exists(target_dir):
|
||||
return total_file_json
|
||||
|
||||
finished = {}
|
||||
for domain in os.listdir(target_dir):
|
||||
finished[domain] = []
|
||||
domain_path = os.path.join(target_dir, domain)
|
||||
if os.path.isdir(domain_path):
|
||||
for example_id in os.listdir(domain_path):
|
||||
if example_id == "onboard":
|
||||
continue
|
||||
example_path = os.path.join(domain_path, example_id)
|
||||
if os.path.isdir(example_path):
|
||||
if "result.txt" not in os.listdir(example_path):
|
||||
# empty all files under example_id
|
||||
for file in os.listdir(example_path):
|
||||
os.remove(os.path.join(example_path, file))
|
||||
else:
|
||||
finished[domain].append(example_id)
|
||||
|
||||
if not finished:
|
||||
return total_file_json
|
||||
|
||||
for domain, examples in finished.items():
|
||||
if domain in total_file_json:
|
||||
total_file_json[domain] = [
|
||||
x for x in total_file_json[domain] if x not in examples
|
||||
]
|
||||
|
||||
return total_file_json
|
||||
|
||||
|
||||
def get_result(target_dir, total_file_json: dict):
|
||||
if not os.path.exists(target_dir):
|
||||
print("New experiment, no result yet.")
|
||||
return None
|
||||
|
||||
# list for all tasks
|
||||
all_result = []
|
||||
|
||||
for domain, example_id_list in total_file_json.items():
|
||||
for example_id in example_id_list:
|
||||
example_path = os.path.join(target_dir, domain, example_id)
|
||||
if os.path.isdir(example_path):
|
||||
if "result.txt" in os.listdir(example_path):
|
||||
# empty all files under example_id
|
||||
try:
|
||||
all_result.append(
|
||||
float(
|
||||
open(
|
||||
os.path.join(example_path, "result.txt"), "r"
|
||||
).read()
|
||||
)
|
||||
)
|
||||
except:
|
||||
all_result.append(0.0)
|
||||
else:
|
||||
all_result.append(0.0)
|
||||
else:
|
||||
all_result.append(0.0)
|
||||
|
||||
if not all_result:
|
||||
print("New experiment, no result yet.")
|
||||
return None
|
||||
else:
|
||||
print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%")
|
||||
return all_result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
####### The complete version of the list of examples #######
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
args = config()
|
||||
|
||||
if args.exp_name != "":
|
||||
args.result_dir = os.path.join(
|
||||
args.result_dir,
|
||||
args.exp_name
|
||||
)
|
||||
else:
|
||||
args.result_dir = os.path.join(
|
||||
args.result_dir,
|
||||
args.action_space,
|
||||
args.observation_type,
|
||||
args.model
|
||||
)
|
||||
|
||||
path_to_args = os.path.join(
|
||||
args.result_dir,
|
||||
"args.json"
|
||||
)
|
||||
os.makedirs(os.path.dirname(path_to_args), exist_ok=True)
|
||||
with open(path_to_args, "w", encoding="utf-8") as f:
|
||||
json.dump(vars(args), f, indent=4)
|
||||
|
||||
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
|
||||
test_all_meta = json.load(f)
|
||||
|
||||
if args.domain != "all":
|
||||
test_all_meta = {args.domain: test_all_meta[args.domain]}
|
||||
|
||||
|
||||
logger.info(f"====================\nExperiment on {args.benchmark} is started\n====================")
|
||||
test_file_list = get_unfinished(
|
||||
target_dir=args.result_dir,
|
||||
total_file_json=test_all_meta,
|
||||
)
|
||||
left_info = ""
|
||||
for domain in test_file_list:
|
||||
left_info += f"{domain}: {len(test_file_list[domain])}\n"
|
||||
logger.info(f"Left tasks:\n{left_info}")
|
||||
|
||||
get_result(
|
||||
target_dir=args.result_dir,
|
||||
total_file_json=test_all_meta
|
||||
)
|
||||
test(
|
||||
args,
|
||||
test_file_list
|
||||
)
|
||||
logger.info(f"====================\nExperiment on {args.benchmark} is ended\n====================")
|
||||
|
||||
logger.info(f"====================\nExperiment {args.exp_name} is totally ended!\n====================")
|
||||
|
|
@ -0,0 +1,540 @@
|
|||
from __future__ import annotations
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import signal
|
||||
import time
|
||||
from typing import List, Dict
|
||||
from multiprocessing import Process, Manager
|
||||
from multiprocessing import current_process
|
||||
import lib_run_single
|
||||
from desktop_env.desktop_env import DesktopEnv
|
||||
from mm_agents.seed_agent import SeedAgent
|
||||
import os
|
||||
|
||||
|
||||
# Global variables for signal handling
|
||||
active_environments = []
|
||||
processes = []
|
||||
is_terminating = False
|
||||
|
||||
# load the environment variables from .env file
|
||||
if os.path.exists(".env"):
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
# Logger Configs {{{ #
|
||||
def config() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run end-to-end evaluation on the benchmark"
|
||||
)
|
||||
|
||||
# environment config
|
||||
parser.add_argument("--path_to_vm", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--headless", action="store_true", help="Run in headless machine"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--action_space", type=str, default="pyautogui", help="Action type"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--observation_type",
|
||||
choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
default="screenshot",
|
||||
help="Observation type",
|
||||
)
|
||||
parser.add_argument("--sleep_after_execution", type=float, default=5.0)
|
||||
parser.add_argument("--max_steps", type=int, default=100)
|
||||
|
||||
# evaluation config
|
||||
parser.add_argument(
|
||||
"--test_config_base_dir", type=str, default="evaluation_examples"
|
||||
)
|
||||
|
||||
# lm config
|
||||
parser.add_argument("--model", type=str, default="doubao-1-5-thinking-vision-pro-250428")
|
||||
parser.add_argument("--model_type", type=str, default="doubao", choices=["doubao", "qwen25"])
|
||||
parser.add_argument("--temperature", type=float, default=1.0)
|
||||
parser.add_argument("--top_p", type=float, default=0.7)
|
||||
parser.add_argument("--max_tokens", type=int, default=4096)
|
||||
parser.add_argument("--use_thinking", action="store_true", default=False)
|
||||
|
||||
parser.add_argument("--max_trajectory_length", type=int, default=None, help="The max number of trajectory steps.")
|
||||
parser.add_argument("--history_n", type=int, default=5, help="The max number of images in the history.")
|
||||
parser.add_argument("--language", type=str, default="Chinese", help="Language for the agent.")
|
||||
|
||||
# example config
|
||||
parser.add_argument("--domain", type=str, default="all")
|
||||
parser.add_argument(
|
||||
"--test_all_meta_path", type=str, default="evaluation_examples/test_all.json"
|
||||
)
|
||||
|
||||
# logging related
|
||||
parser.add_argument("--result_dir", type=str, default="./results")
|
||||
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel")
|
||||
parser.add_argument("--log_level", type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
|
||||
default='INFO', help="Set the logging level")
|
||||
# aws config
|
||||
parser.add_argument(
|
||||
"--region", type=str, default="us-east-1", help="AWS region for the VM"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--provider_name", type=str, default="aws", choices=["aws", "virtualbox", "vmware", "docker", "azure"], help="Provider name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--client_password", type=str, default="", help="Client password"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--screen_width", type=int, default=1920, help="Screen width"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--screen_height", type=int, default=1080, help="Screen height"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
args = config() # Get command line arguments first
|
||||
|
||||
logger = logging.getLogger()
|
||||
log_level = getattr(logging, args.log_level.upper())
|
||||
logger.setLevel(log_level)
|
||||
|
||||
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
|
||||
file_handler = logging.FileHandler(
|
||||
os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||
)
|
||||
debug_handler = logging.FileHandler(
|
||||
os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||
)
|
||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||
|
||||
file_handler.setLevel(logging.INFO)
|
||||
debug_handler.setLevel(logging.DEBUG)
|
||||
stdout_handler.setLevel(log_level)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
|
||||
)
|
||||
file_handler.setFormatter(formatter)
|
||||
debug_handler.setFormatter(formatter)
|
||||
stdout_handler.setFormatter(formatter)
|
||||
|
||||
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
||||
|
||||
logger.addHandler(file_handler)
|
||||
logger.addHandler(debug_handler)
|
||||
logger.addHandler(stdout_handler)
|
||||
# }}} Logger Configs #
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
|
||||
def distribute_tasks(test_all_meta: dict) -> List[tuple]:
|
||||
all_tasks = []
|
||||
for domain, examples in test_all_meta.items():
|
||||
for example_id in examples:
|
||||
all_tasks.append((domain, example_id))
|
||||
return all_tasks
|
||||
|
||||
|
||||
def process_signal_handler(signum, frame, env_idx):
|
||||
"""Signal handler for child processes to gracefully shut down their environments."""
|
||||
logger.info(f"Process {env_idx + 1} received signal {signum}. Shutting down...")
|
||||
|
||||
# Get the active_environments from the caller's frame
|
||||
local_vars = frame.f_locals
|
||||
active_environments = local_vars.get('active_environments', [])
|
||||
|
||||
# Close environment in the current process context
|
||||
for env in active_environments:
|
||||
if env is not None:
|
||||
try:
|
||||
logger.info(f"Process {env_idx + 1} closing environment...")
|
||||
env.close()
|
||||
logger.info(f"Process {env_idx + 1} environment closed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Process {env_idx + 1} error closing environment: {e}")
|
||||
|
||||
logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.")
|
||||
sys.exit(0)
|
||||
|
||||
def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: list):
|
||||
active_environments = []
|
||||
env = None
|
||||
try:
|
||||
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
|
||||
REGION = args.region
|
||||
screen_size = (args.screen_width, args.screen_height)
|
||||
ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)])
|
||||
env = DesktopEnv(
|
||||
path_to_vm=args.path_to_vm,
|
||||
action_space=args.action_space,
|
||||
provider_name=args.provider_name,
|
||||
region=REGION,
|
||||
snapshot_name=ami_id,
|
||||
screen_size=screen_size,
|
||||
headless=args.headless,
|
||||
os_type="Ubuntu",
|
||||
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
enable_proxy=True,
|
||||
client_password=args.client_password
|
||||
)
|
||||
active_environments.append(env)
|
||||
agent = SeedAgent(
|
||||
model=args.model,
|
||||
model_type=args.model_type,
|
||||
max_tokens=args.max_tokens,
|
||||
top_p=args.top_p,
|
||||
temperature=args.temperature,
|
||||
max_trajectory_length=args.max_trajectory_length,
|
||||
history_n=args.history_n,
|
||||
use_thinking=args.use_thinking,
|
||||
)
|
||||
logger.info(f"Process {current_process().name} started.")
|
||||
while True:
|
||||
try:
|
||||
item = task_queue.get(timeout=5)
|
||||
except Exception:
|
||||
break
|
||||
domain, example_id = item
|
||||
try:
|
||||
config_file = os.path.join(
|
||||
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
|
||||
)
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
example = json.load(f)
|
||||
logger.info(f"[{current_process().name}][Domain]: {domain}")
|
||||
logger.info(f"[{current_process().name}][Example ID]: {example_id}")
|
||||
logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}")
|
||||
example_result_dir = os.path.join(
|
||||
args.result_dir,
|
||||
args.action_space,
|
||||
args.observation_type,
|
||||
args.model,
|
||||
domain,
|
||||
example_id,
|
||||
)
|
||||
os.makedirs(example_result_dir, exist_ok=True)
|
||||
try:
|
||||
lib_run_single.run_single_example(
|
||||
agent,
|
||||
env,
|
||||
example,
|
||||
args.max_steps,
|
||||
example["instruction"],
|
||||
args,
|
||||
example_result_dir,
|
||||
shared_scores,
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
try:
|
||||
env.controller.end_recording(
|
||||
os.path.join(example_result_dir, "recording.mp4")
|
||||
)
|
||||
except Exception as rec_e:
|
||||
logger.error(f"Failed to end recording: {rec_e}")
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||
f.write(
|
||||
json.dumps(
|
||||
{"Error": f"{domain}/{example_id} - {e}"}
|
||||
)
|
||||
)
|
||||
f.write("\n")
|
||||
except Exception as e:
|
||||
logger.error(f"Task-level error in {current_process().name}: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
except Exception as e:
|
||||
logger.error(f"Process-level error in {current_process().name}: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
finally:
|
||||
logger.info(f"{current_process().name} cleaning up environment...")
|
||||
try:
|
||||
if env:
|
||||
env.close()
|
||||
logger.info(f"{current_process().name} environment closed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"{current_process().name} error during environment cleanup: {e}")
|
||||
|
||||
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
"""Handle termination signals (SIGINT, SIGTERM) to gracefully shutdown environments."""
|
||||
global is_terminating, active_environments, processes
|
||||
|
||||
# Avoid duplicate handling
|
||||
if is_terminating:
|
||||
return
|
||||
|
||||
is_terminating = True
|
||||
logger.info(f"Received signal {signum}. Gracefully shutting down...")
|
||||
|
||||
# Close all registered environments in the main process
|
||||
for env in active_environments:
|
||||
try:
|
||||
logger.info(f"Closing environment...")
|
||||
env.close()
|
||||
logger.info(f"Environment closed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing environment: {e}")
|
||||
|
||||
# Send termination signal to all child processes first
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
try:
|
||||
logger.info(f"Sending termination signal to process {p.name}...")
|
||||
p.terminate()
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending termination signal to process: {e}")
|
||||
|
||||
# Allow a short time for processes to handle their own cleanup
|
||||
time.sleep(1)
|
||||
|
||||
# Forcefully terminate any processes that didn't exit
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
try:
|
||||
logger.info(f"Forcefully terminating process {p.name}...")
|
||||
import signal as sig
|
||||
os.kill(p.pid, sig.SIGKILL)
|
||||
except Exception as e:
|
||||
logger.error(f"Error forcefully terminating process: {e}")
|
||||
|
||||
logger.info("Shutdown complete. Exiting.")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||
global processes
|
||||
logger.info("Args: %s", args)
|
||||
all_tasks = distribute_tasks(test_all_meta)
|
||||
logger.info(f"Total tasks: {len(all_tasks)}")
|
||||
with Manager() as manager:
|
||||
shared_scores = manager.list()
|
||||
task_queue = manager.Queue()
|
||||
for item in all_tasks:
|
||||
task_queue.put(item)
|
||||
num_envs = args.num_envs
|
||||
processes = []
|
||||
for i in range(num_envs):
|
||||
p = Process(
|
||||
target=run_env_tasks,
|
||||
args=(task_queue, args, shared_scores),
|
||||
name=f"EnvProcess-{i+1}"
|
||||
)
|
||||
p.daemon = True
|
||||
p.start()
|
||||
processes.append(p)
|
||||
logger.info(f"Started process {p.name} with PID {p.pid}")
|
||||
try:
|
||||
while True:
|
||||
alive_count = 0
|
||||
for idx, p in enumerate(processes):
|
||||
if not p.is_alive():
|
||||
logger.warning(f"Process {p.name} died, restarting...")
|
||||
new_p = Process(
|
||||
target=run_env_tasks,
|
||||
args=(task_queue, args, shared_scores),
|
||||
name=f"EnvProcess-Restart-{idx+1}"
|
||||
)
|
||||
new_p.daemon = True
|
||||
new_p.start()
|
||||
processes[idx] = new_p
|
||||
logger.info(f"Restarted process {new_p.name} with PID {new_p.pid}")
|
||||
else:
|
||||
alive_count += 1
|
||||
if task_queue.empty():
|
||||
logger.info("All tasks finished.")
|
||||
break
|
||||
if alive_count == 0:
|
||||
logger.error("All processes died, exiting.")
|
||||
break
|
||||
time.sleep(5)
|
||||
for p in processes:
|
||||
p.join()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True)
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
try:
|
||||
logger.info(f"Terminating process {p.name} due to error...")
|
||||
p.terminate()
|
||||
except Exception as term_e:
|
||||
logger.error(f"Error terminating process {p.name}: {term_e}")
|
||||
raise
|
||||
scores = list(shared_scores)
|
||||
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
|
||||
|
||||
|
||||
def get_unfinished(
|
||||
action_space, use_model, observation_type, result_dir, total_file_json
|
||||
):
|
||||
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||
|
||||
if not os.path.exists(target_dir):
|
||||
return total_file_json
|
||||
|
||||
finished = {}
|
||||
for domain in os.listdir(target_dir):
|
||||
finished[domain] = []
|
||||
domain_path = os.path.join(target_dir, domain)
|
||||
if os.path.isdir(domain_path):
|
||||
for example_id in os.listdir(domain_path):
|
||||
if example_id == "onboard":
|
||||
continue
|
||||
example_path = os.path.join(domain_path, example_id)
|
||||
if os.path.isdir(example_path):
|
||||
if "result.txt" not in os.listdir(example_path):
|
||||
# empty all files under example_id
|
||||
for file in os.listdir(example_path):
|
||||
os.remove(os.path.join(example_path, file))
|
||||
else:
|
||||
finished[domain].append(example_id)
|
||||
|
||||
if not finished:
|
||||
return total_file_json
|
||||
|
||||
for domain, examples in finished.items():
|
||||
if domain in total_file_json:
|
||||
total_file_json[domain] = [
|
||||
x for x in total_file_json[domain] if x not in examples
|
||||
]
|
||||
|
||||
return total_file_json
|
||||
|
||||
|
||||
def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
|
||||
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||
if not os.path.exists(target_dir):
|
||||
print("New experiment, no result yet.")
|
||||
return None
|
||||
|
||||
all_result = []
|
||||
|
||||
for domain in os.listdir(target_dir):
|
||||
domain_path = os.path.join(target_dir, domain)
|
||||
if os.path.isdir(domain_path):
|
||||
for example_id in os.listdir(domain_path):
|
||||
example_path = os.path.join(domain_path, example_id)
|
||||
if os.path.isdir(example_path):
|
||||
if "result.txt" in os.listdir(example_path):
|
||||
# empty all files under example_id
|
||||
try:
|
||||
all_result.append(
|
||||
float(
|
||||
open(
|
||||
os.path.join(example_path, "result.txt"), "r"
|
||||
).read()
|
||||
)
|
||||
)
|
||||
except:
|
||||
all_result.append(0.0)
|
||||
|
||||
if not all_result:
|
||||
print("New experiment, no result yet.")
|
||||
return None
|
||||
else:
|
||||
print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%")
|
||||
return all_result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
####### The complete version of the list of examples #######
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
# Register signal handlers for graceful termination
|
||||
signal.signal(signal.SIGINT, signal_handler) # Handle Ctrl+C
|
||||
signal.signal(signal.SIGTERM, signal_handler) # Handle termination signal
|
||||
|
||||
try:
|
||||
args = config()
|
||||
|
||||
# save args to json in result_dir/action_space/observation_type/model/args.json
|
||||
path_to_args = os.path.join(
|
||||
args.result_dir,
|
||||
args.action_space,
|
||||
args.observation_type,
|
||||
args.model,
|
||||
"args.json",
|
||||
)
|
||||
os.makedirs(os.path.dirname(path_to_args), exist_ok=True)
|
||||
with open(path_to_args, "w", encoding="utf-8") as f:
|
||||
json.dump(vars(args), f, indent=4)
|
||||
|
||||
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
|
||||
test_all_meta = json.load(f)
|
||||
|
||||
if args.domain != "all":
|
||||
test_all_meta = {args.domain: test_all_meta[args.domain]}
|
||||
|
||||
test_file_list = get_unfinished(
|
||||
args.action_space,
|
||||
args.model,
|
||||
args.observation_type,
|
||||
args.result_dir,
|
||||
test_all_meta,
|
||||
)
|
||||
left_info = ""
|
||||
for domain in test_file_list:
|
||||
left_info += f"{domain}: {len(test_file_list[domain])}\n"
|
||||
logger.info(f"Left tasks:\n{left_info}")
|
||||
|
||||
get_result(
|
||||
args.action_space,
|
||||
args.model,
|
||||
args.observation_type,
|
||||
args.result_dir,
|
||||
test_all_meta,
|
||||
)
|
||||
test(args, test_file_list)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Main process received KeyboardInterrupt.")
|
||||
# Signal handler will take care of cleanup
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in main process: {e}", exc_info=True)
|
||||
# Also trigger cleanup for unhandled exceptions
|
||||
signal_handler(signal.SIGTERM, None)
|
||||
finally:
|
||||
# Final cleanup in case any environments or processes remain
|
||||
logger.info("Main process final cleanup...")
|
||||
for env in active_environments:
|
||||
if env is not None:
|
||||
try:
|
||||
logger.info(f"Closing environment in final cleanup...")
|
||||
env.close()
|
||||
logger.info(f"Environment closed successfully in final cleanup")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during final environment cleanup: {e}")
|
||||
|
||||
# First try gentle termination
|
||||
for p in processes:
|
||||
if p is not None and p.is_alive():
|
||||
try:
|
||||
logger.info(f"Terminating process {p.name}...")
|
||||
p.terminate()
|
||||
except Exception as e:
|
||||
logger.error(f"Error terminating process: {e}")
|
||||
|
||||
# Wait a moment for processes to terminate
|
||||
time.sleep(1)
|
||||
|
||||
# Then force kill if needed
|
||||
for p in processes:
|
||||
if p is not None and p.is_alive():
|
||||
try:
|
||||
logger.info(f"Force killing process {p.name}...")
|
||||
os.kill(p.pid, signal.SIGKILL)
|
||||
logger.info(f"Process {p.name} force killed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error force killing process: {e}")
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
EXP_NAME="xxx"
|
||||
export AWS_SECRET_ACCESS_KEY="xxx"
|
||||
export AWS_ACCESS_KEY_ID="xxx"
|
||||
export AWS_REGION="us-east-1"
|
||||
export AWS_SUBNET_ID="xxx"
|
||||
export AWS_SECURITY_GROUP_ID="xxx"
|
||||
# >> logs/${EXP_NAME}.log 2>&1
|
||||
python run_multienv_os_symphony.py \
|
||||
--provider_name "aws" \
|
||||
--region "us-east-1" \
|
||||
--client_password "osworld-public-evaluation" \
|
||||
--headless \
|
||||
--num_envs 7 \
|
||||
--max_steps 50 \
|
||||
--benchmark osworld \
|
||||
--domain "all" \
|
||||
--test_all_meta_path evaluation_examples/test_nogdrive.json \
|
||||
--result_dir "results" \
|
||||
--tool_config mm_agents/os_symphony/tool/all_tool_config.yaml \
|
||||
--orchestrator_provider "openai" \
|
||||
--orchestrator_model "gpt-5" \
|
||||
--orchestrator_url "xxx" \
|
||||
--orchestrator_api_key "xxx" \
|
||||
--orchestrator_temperature 0.1 \
|
||||
--orchestrator_keep_first_image \
|
||||
--max_trajectory_length 8 \
|
||||
--grounder_provider "vllm" \
|
||||
--grounder_model "UI-TARS-1.5-7B" \
|
||||
--grounder_api_key "none" \
|
||||
--grounder_url "xxx" \
|
||||
--grounding_smart_resize \
|
||||
--grounding_width 1920 \
|
||||
--grounding_height 1080 \
|
||||
--coder_provider "openai" \
|
||||
--coder_model "gpt-5" \
|
||||
--coder_url "xxx" \
|
||||
--coder_api_key "xxx" \
|
||||
--coder_temperature 0.1 \
|
||||
--coder_budget 20 \
|
||||
--memoryer_provider "openai" \
|
||||
--memoryer_model "gpt-5" \
|
||||
--memoryer_url "xxx" \
|
||||
--memoryer_api_key "xxx" \
|
||||
--memoryer_temperature 0.1 \
|
||||
--memoryer_max_images 8 \
|
||||
--searcher_provider "openai" \
|
||||
--searcher_model "gpt-5" \
|
||||
--searcher_url "xxx" \
|
||||
--searcher_api_key "xxx" \
|
||||
--searcher_temperature 0.1 \
|
||||
--searcher_type "vlm" \
|
||||
--searcher_engine "google" \
|
||||
--searcher_budget 20 \
|
||||
--searcher_screen_width 1920 \
|
||||
--searcher_screen_height 1080 \
|
||||
--sleep_after_execution 3 \
|
||||
--exp_name ${EXP_NAME} \
|
||||
--enable_reflection >> logs/${EXP_NAME}.log 2>&1
|
||||
Loading…
Reference in New Issue