Compare commits
6 Commits
main
...
uitars/dev
| Author | SHA1 | Date |
|---|---|---|
|
|
bfddbcff27 | |
|
|
b5a57c2d58 | |
|
|
c3f3329fc2 | |
|
|
826c0ef945 | |
|
|
80b80617c4 | |
|
|
0f1ef6d9b7 |
|
|
@ -77,7 +77,8 @@ class AWSProvider(Provider):
|
|||
else:
|
||||
logger.warning("No public IP address available for VNC access")
|
||||
|
||||
return private_ip_address
|
||||
#return private_ip_address
|
||||
return public_ip_address
|
||||
return '' # Return an empty string if no IP address is found
|
||||
except ClientError as e:
|
||||
logger.error(f"Failed to retrieve IP address for the instance {path_to_vm}: {str(e)}")
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ def run_single_example(agent, env, example, max_steps, instruction, args, exampl
|
|||
"step_num": step_idx + 1,
|
||||
"action_timestamp": action_timestamp,
|
||||
"action": action,
|
||||
"response": response,
|
||||
"reward": reward,
|
||||
"done": done,
|
||||
"info": info,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,956 @@
|
|||
import ast
|
||||
import base64
|
||||
from openai import OpenAI
|
||||
import math
|
||||
import re
|
||||
import xml.etree.ElementTree as ET
|
||||
from io import BytesIO
|
||||
from typing import Dict, List
|
||||
import numpy as np
|
||||
import base64
|
||||
from loguru import logger
|
||||
import os
|
||||
import re
|
||||
from io import BytesIO
|
||||
from typing import Dict, List
|
||||
from PIL import Image
|
||||
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import (
|
||||
filter_nodes,
|
||||
)
|
||||
|
||||
UITARS_ACTION_SPACE = """
|
||||
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
|
||||
hotkey(key='')
|
||||
type(content='') #If you want to submit your input, use "\\n" at the end of `content`.
|
||||
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
|
||||
wait() #Sleep for 5s and take a screenshot to check for any changes.
|
||||
finished()
|
||||
"""
|
||||
|
||||
UITARS_CALL_USR_ACTION_SPACE = """
|
||||
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
|
||||
hotkey(key='')
|
||||
type(content='') #If you want to submit your input, use "\\n" at the end of `content`.
|
||||
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
|
||||
wait() #Sleep for 5s and take a screenshot to check for any changes.
|
||||
finished()
|
||||
call_user() # Submit the task and call the user when the task is unsolvable, or when you need the user's help.
|
||||
"""
|
||||
|
||||
UITARS_NORMAL_ACTION_SPACE = """
|
||||
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
|
||||
hotkey(key='')
|
||||
type(content='') #If you want to submit your input, use "\\n" at the end of `content`.
|
||||
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
|
||||
wait() #Sleep for 5s and take a screenshot to check for any changes.
|
||||
finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format.
|
||||
"""
|
||||
|
||||
UITARS_USR_PROMPT_NOTHOUGHT = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
|
||||
## Output Format
|
||||
```
|
||||
Action: ...
|
||||
```
|
||||
## Action Space
|
||||
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
|
||||
hotkey(key='')
|
||||
type(content='') #If you want to submit your input, use "\\n" at the end of `content`.
|
||||
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
|
||||
wait() #Sleep for 5s and take a screenshot to check for any changes.
|
||||
finished()
|
||||
call_user() # Submit the task and call the user when the task is unsolvable, or when you need the user's help.
|
||||
## User Instruction
|
||||
{instruction}
|
||||
"""
|
||||
|
||||
UITARS_USR_PROMPT_THOUGHT = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
|
||||
|
||||
## Output Format
|
||||
```
|
||||
Thought: ...
|
||||
Action: ...
|
||||
```
|
||||
|
||||
## Action Space
|
||||
{action_space}
|
||||
|
||||
## Note
|
||||
- Use {language} in `Thought` part.
|
||||
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
|
||||
|
||||
## User Instruction
|
||||
{instruction}
|
||||
"""
|
||||
|
||||
FINISH_WORD = "finished"
|
||||
WAIT_WORD = "wait"
|
||||
ENV_FAIL_WORD = "error_env"
|
||||
CALL_USER = "call_user"
|
||||
|
||||
IMAGE_FACTOR = 28
|
||||
MIN_PIXELS = 100 * 28 * 28
|
||||
MAX_PIXELS = 16384 * 28 * 28
|
||||
MAX_RATIO = 200
|
||||
|
||||
# 定义一个函数来解析每个 action
|
||||
def parse_action(action_str):
|
||||
try:
|
||||
# 解析字符串为 AST 节点
|
||||
node = ast.parse(action_str, mode='eval')
|
||||
|
||||
# 确保节点是一个表达式
|
||||
if not isinstance(node, ast.Expression):
|
||||
raise ValueError("Not an expression")
|
||||
|
||||
# 获取表达式的主体
|
||||
call = node.body
|
||||
|
||||
# 确保主体是一个函数调用
|
||||
if not isinstance(call, ast.Call):
|
||||
raise ValueError("Not a function call")
|
||||
|
||||
# 获取函数名
|
||||
if isinstance(call.func, ast.Name):
|
||||
func_name = call.func.id
|
||||
elif isinstance(call.func, ast.Attribute):
|
||||
func_name = call.func.attr
|
||||
else:
|
||||
func_name = None
|
||||
|
||||
# 获取关键字参数
|
||||
kwargs = {}
|
||||
for kw in call.keywords:
|
||||
key = kw.arg
|
||||
# 处理不同类型的值,这里假设都是常量
|
||||
if isinstance(kw.value, ast.Constant):
|
||||
value = kw.value.value
|
||||
elif isinstance(kw.value, ast.Str): # 兼容旧版本 Python
|
||||
value = kw.value.s
|
||||
else:
|
||||
value = None
|
||||
kwargs[key] = value
|
||||
|
||||
return {
|
||||
'function': func_name,
|
||||
'args': kwargs
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to parse action '{action_str}': {e}")
|
||||
return None
|
||||
|
||||
def escape_single_quotes(text):
|
||||
# 匹配未转义的单引号(不匹配 \\')
|
||||
pattern = r"(?<!\\)'"
|
||||
return re.sub(pattern, r"\\'", text)
|
||||
|
||||
def round_by_factor(number: int, factor: int) -> int:
|
||||
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
||||
return round(number / factor) * factor
|
||||
|
||||
|
||||
def ceil_by_factor(number: int, factor: int) -> int:
|
||||
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.ceil(number / factor) * factor
|
||||
|
||||
|
||||
def floor_by_factor(number: int, factor: int) -> int:
|
||||
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.floor(number / factor) * factor
|
||||
|
||||
|
||||
def linear_resize(
|
||||
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
|
||||
) -> tuple[int, int]:
|
||||
if width * height > max_pixels:
|
||||
"""
|
||||
如果图片超过/低于像素限制,则计算一个缩放因子resize_factor,使图片的像素数缩小到等于或小于max_pixels。这个缩放因子是通过开平方根计算的,确保纵横比保持不变,这样原始的相对坐标可以不经转换直接复用
|
||||
"""
|
||||
resize_factor = math.sqrt(max_pixels / (width * height))
|
||||
width, height = int(width * resize_factor), int(height * resize_factor)
|
||||
if width * height < min_pixels:
|
||||
resize_factor = math.sqrt(min_pixels / (width * height))
|
||||
width, height = math.ceil(width * resize_factor), math.ceil(height * resize_factor)
|
||||
|
||||
return height, width
|
||||
|
||||
def smart_resize(
|
||||
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Rescales the image so that the following conditions are met:
|
||||
|
||||
1. Both dimensions (height and width) are divisible by 'factor'.
|
||||
|
||||
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
||||
|
||||
3. The aspect ratio of the image is maintained as closely as possible.
|
||||
"""
|
||||
if max(height, width) / min(height, width) > MAX_RATIO:
|
||||
raise ValueError(
|
||||
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
|
||||
)
|
||||
h_bar = max(factor, round_by_factor(height, factor))
|
||||
w_bar = max(factor, round_by_factor(width, factor))
|
||||
if h_bar * w_bar > max_pixels:
|
||||
beta = math.sqrt((height * width) / max_pixels)
|
||||
h_bar = floor_by_factor(height / beta, factor)
|
||||
w_bar = floor_by_factor(width / beta, factor)
|
||||
elif h_bar * w_bar < min_pixels:
|
||||
beta = math.sqrt(min_pixels / (height * width))
|
||||
h_bar = ceil_by_factor(height * beta, factor)
|
||||
w_bar = ceil_by_factor(width * beta, factor)
|
||||
return h_bar, w_bar
|
||||
|
||||
def parse_action_to_structure_output(text, factor, origin_resized_height, origin_resized_width, model_type, max_pixels=16384*28*28, min_pixels=100*28*28):
|
||||
text = text.strip()
|
||||
if model_type == "qwen25vl":
|
||||
smart_resize_height, smart_resize_width = smart_resize(origin_resized_height, origin_resized_width, factor=IMAGE_FACTOR, min_pixels=min_pixels, max_pixels=max_pixels)
|
||||
|
||||
# 正则表达式匹配 Action 字符串
|
||||
if text.startswith("Thought:"):
|
||||
thought_pattern = r"Thought: (.+?)(?=\s*Action:|$)"
|
||||
thought_hint = "Thought: "
|
||||
elif text.startswith("Reflection:"):
|
||||
thought_pattern = r"Reflection: (.+?)Action_Summary: (.+?)(?=\s*Action:|$)"
|
||||
thought_hint = "Reflection: "
|
||||
elif text.startswith("Action_Summary:"):
|
||||
thought_pattern = r"Action_Summary: (.+?)(?=\s*Action:|$)"
|
||||
thought_hint = "Action_Summary: "
|
||||
else:
|
||||
thought_pattern = r"Thought: (.+?)(?=\s*Action:|$)"
|
||||
thought_hint = "Thought: "
|
||||
reflection, thought = None, None
|
||||
thought_match = re.search(thought_pattern, text, re.DOTALL)
|
||||
if thought_match:
|
||||
if len(thought_match.groups()) == 1:
|
||||
thought = thought_match.group(1).strip()
|
||||
elif len(thought_match.groups()) == 2:
|
||||
thought = thought_match.group(2).strip()
|
||||
reflection = thought_match.group(1).strip()
|
||||
assert "Action:" in text
|
||||
action_str = text.split("Action:")[-1]
|
||||
|
||||
tmp_all_action = action_str.split("\n\n")
|
||||
all_action = []
|
||||
for action_str in tmp_all_action:
|
||||
if "type(content" in action_str:
|
||||
# 正则表达式匹配 content 中的字符串并转义单引号
|
||||
def escape_quotes(match):
|
||||
content = match.group(1) # 获取 content 的值
|
||||
return content
|
||||
|
||||
# 使用正则表达式进行替换
|
||||
pattern = r"type\(content='(.*?)'\)" # 匹配 type(content='...')
|
||||
content = re.sub(pattern, escape_quotes, action_str)
|
||||
|
||||
# 处理字符串
|
||||
action_str = escape_single_quotes(content)
|
||||
action_str = "type(content='" + action_str + "')"
|
||||
all_action.append(action_str)
|
||||
|
||||
parsed_actions = [parse_action(action.replace("\n","\\n").lstrip()) for action in all_action]
|
||||
actions = []
|
||||
for action_instance, raw_str in zip(parsed_actions, all_action):
|
||||
if action_instance == None:
|
||||
print(f"Action can't parse: {raw_str}")
|
||||
raise ValueError(f"Action can't parse: {raw_str}")
|
||||
action_type = action_instance["function"]
|
||||
params = action_instance["args"]
|
||||
|
||||
# import pdb; pdb.set_trace()
|
||||
action_inputs = {}
|
||||
for param_name, param in params.items():
|
||||
if param == "": continue
|
||||
param = param.lstrip() # 去掉引号和多余的空格
|
||||
# 处理start_box或者end_box参数格式 '<bbox>x1 y1 x2 y2</bbox>'
|
||||
action_inputs[param_name.strip()] = param
|
||||
|
||||
if "start_box" in param_name or "end_box" in param_name:
|
||||
ori_box = param
|
||||
# Remove parentheses and split the string by commas
|
||||
numbers = ori_box.replace("(", "").replace(")", "").split(",")
|
||||
|
||||
# Convert to float and scale by 1000
|
||||
# Qwen2.5vl output absolute coordinates, qwen2vl output relative coordinates
|
||||
if model_type == "qwen25vl":
|
||||
float_numbers = []
|
||||
for num_idx, num in enumerate(numbers):
|
||||
num = float(num)
|
||||
if (num_idx + 1) % 2 == 0:
|
||||
float_numbers.append(float(num/smart_resize_height))
|
||||
else:
|
||||
float_numbers.append(float(num/smart_resize_width))
|
||||
else:
|
||||
float_numbers = [float(num) / factor for num in numbers]
|
||||
|
||||
if len(float_numbers) == 2:
|
||||
float_numbers = [float_numbers[0], float_numbers[1], float_numbers[0], float_numbers[1]]
|
||||
action_inputs[param_name.strip()] = str(float_numbers)
|
||||
|
||||
# import pdb; pdb.set_trace()
|
||||
actions.append({
|
||||
"reflection": reflection,
|
||||
"thought": thought,
|
||||
"action_type": action_type,
|
||||
"action_inputs": action_inputs,
|
||||
"text": text
|
||||
})
|
||||
return actions
|
||||
|
||||
def parsing_response_to_pyautogui_code(responses, image_height: int, image_width:int, input_swap:bool=True) -> str:
|
||||
'''
|
||||
将M模型的输出解析为OSWorld中的action,生成pyautogui代码字符串
|
||||
参数:
|
||||
response: 包含模型输出的字典,结构类似于:
|
||||
{
|
||||
"action_type": "hotkey",
|
||||
"action_inputs": {
|
||||
"hotkey": "v ctrl",
|
||||
"start_box": None,
|
||||
"end_box": None
|
||||
}
|
||||
}
|
||||
返回:
|
||||
生成的pyautogui代码字符串
|
||||
'''
|
||||
|
||||
pyautogui_code = f"import pyautogui\nimport time\n"
|
||||
if isinstance(responses, dict):
|
||||
responses = [responses]
|
||||
for response_id, response in enumerate(responses):
|
||||
if "observation" in response:
|
||||
observation = response["observation"]
|
||||
else:
|
||||
observation = ""
|
||||
|
||||
if "thought" in response:
|
||||
thought = response["thought"]
|
||||
else:
|
||||
thought = ""
|
||||
|
||||
if response_id == 0:
|
||||
pyautogui_code += f"'''\nObservation:\n{observation}\n\nThought:\n{thought}\n'''\n"
|
||||
else:
|
||||
pyautogui_code += f"\ntime.sleep(1)\n"
|
||||
|
||||
action_dict = response
|
||||
action_type = action_dict.get("action_type")
|
||||
action_inputs = action_dict.get("action_inputs", {})
|
||||
|
||||
if action_type == "hotkey":
|
||||
# Parsing hotkey action
|
||||
if "key" in action_inputs:
|
||||
hotkey = action_inputs.get("key", "")
|
||||
else:
|
||||
hotkey = action_inputs.get("hotkey", "")
|
||||
|
||||
if hotkey == "arrowleft":
|
||||
hotkey = "left"
|
||||
|
||||
elif hotkey == "arrowright":
|
||||
hotkey = "right"
|
||||
|
||||
elif hotkey == "arrowup":
|
||||
hotkey = "up"
|
||||
|
||||
elif hotkey == "arrowdown":
|
||||
hotkey = "down"
|
||||
|
||||
if hotkey:
|
||||
# Handle other hotkeys
|
||||
keys = hotkey.split() # Split the keys by space
|
||||
convert_keys = []
|
||||
for key in keys:
|
||||
if key == "space":
|
||||
key = ' '
|
||||
convert_keys.append(key)
|
||||
pyautogui_code += f"\npyautogui.hotkey({', '.join([repr(k) for k in convert_keys])})"
|
||||
|
||||
elif action_type == "press":
|
||||
# Parsing press action
|
||||
if "key" in action_inputs:
|
||||
key_to_press = action_inputs.get("key", "")
|
||||
else:
|
||||
key_to_press = action_inputs.get("press", "")
|
||||
|
||||
if hotkey == "arrowleft":
|
||||
hotkey = "left"
|
||||
|
||||
elif hotkey == "arrowright":
|
||||
hotkey = "right"
|
||||
|
||||
elif hotkey == "arrowup":
|
||||
hotkey = "up"
|
||||
|
||||
elif hotkey == "arrowdown":
|
||||
hotkey = "down"
|
||||
|
||||
elif hotkey == "space":
|
||||
hotkey = " "
|
||||
|
||||
if key_to_press:
|
||||
# Simulate pressing a single key
|
||||
pyautogui_code += f"\npyautogui.press({repr(key_to_press)})"
|
||||
|
||||
elif action_type == "keyup":
|
||||
key_to_up = action_inputs.get("key", "")
|
||||
pyautogui_code += f"\npyautogui.keyUp({repr(key_to_up)})"
|
||||
|
||||
elif action_type == "keydown":
|
||||
key_to_down = action_inputs.get("key", "")
|
||||
pyautogui_code += f"\npyautogui.keyDown({repr(key_to_down)})"
|
||||
|
||||
elif action_type == "type":
|
||||
# Parsing typing action using clipboard
|
||||
content = action_inputs.get("content", "")
|
||||
content = escape_single_quotes(content)
|
||||
stripped_content = content
|
||||
if content.endswith("\n") or content.endswith("\\n"):
|
||||
stripped_content = stripped_content.rstrip("\\n").rstrip("\n")
|
||||
if content:
|
||||
if input_swap:
|
||||
pyautogui_code += f"\nimport pyperclip"
|
||||
pyautogui_code += f"\npyperclip.copy('{stripped_content}')"
|
||||
pyautogui_code += f"\npyautogui.hotkey('ctrl', 'v')"
|
||||
pyautogui_code += f"\ntime.sleep(0.5)\n"
|
||||
if content.endswith("\n") or content.endswith("\\n"):
|
||||
pyautogui_code += f"\npyautogui.press('enter')"
|
||||
else:
|
||||
pyautogui_code += f"\npyautogui.write('{stripped_content}', interval=0.1)"
|
||||
pyautogui_code += f"\ntime.sleep(0.5)\n"
|
||||
if content.endswith("\n") or content.endswith("\\n"):
|
||||
pyautogui_code += f"\npyautogui.press('enter')"
|
||||
|
||||
|
||||
elif action_type in ["drag", "select"]:
|
||||
# Parsing drag or select action based on start and end_boxes
|
||||
start_box = action_inputs.get("start_box")
|
||||
end_box = action_inputs.get("end_box")
|
||||
if start_box and end_box:
|
||||
x1, y1, x2, y2 = eval(start_box) # Assuming box is in [x1, y1, x2, y2]
|
||||
sx = round(float((x1 + x2) / 2) * image_width, 3)
|
||||
sy = round(float((y1 + y2) / 2) * image_height, 3)
|
||||
x1, y1, x2, y2 = eval(end_box) # Assuming box is in [x1, y1, x2, y2]
|
||||
ex = round(float((x1 + x2) / 2) * image_width, 3)
|
||||
ey = round(float((y1 + y2) / 2) * image_height, 3)
|
||||
pyautogui_code += (
|
||||
f"\npyautogui.moveTo({sx}, {sy})\n"
|
||||
f"\npyautogui.dragTo({ex}, {ey}, duration=1.0)\n"
|
||||
)
|
||||
|
||||
elif action_type == "scroll":
|
||||
# Parsing scroll action
|
||||
start_box = action_inputs.get("start_box")
|
||||
if start_box:
|
||||
x1, y1, x2, y2 = eval(start_box) # Assuming box is in [x1, y1, x2, y2]
|
||||
x = round(float((x1 + x2) / 2) * image_width, 3)
|
||||
y = round(float((y1 + y2) / 2) * image_height, 3)
|
||||
|
||||
# # 先点对应区域,再滚动
|
||||
# pyautogui_code += f"\npyautogui.click({x}, {y}, button='left')"
|
||||
else:
|
||||
x = None
|
||||
y = None
|
||||
direction = action_inputs.get("direction", "")
|
||||
|
||||
if x == None:
|
||||
if "up" in direction.lower():
|
||||
pyautogui_code += f"\npyautogui.scroll(5)"
|
||||
elif "down" in direction.lower():
|
||||
pyautogui_code += f"\npyautogui.scroll(-5)"
|
||||
else:
|
||||
if "up" in direction.lower():
|
||||
pyautogui_code += f"\npyautogui.scroll(5, x={x}, y={y})"
|
||||
elif "down" in direction.lower():
|
||||
pyautogui_code += f"\npyautogui.scroll(-5, x={x}, y={y})"
|
||||
|
||||
elif action_type in ["click", "left_single", "left_double", "right_single", "hover"]:
|
||||
# Parsing mouse click actions
|
||||
start_box = action_inputs.get("start_box")
|
||||
start_box = str(start_box)
|
||||
if start_box:
|
||||
start_box = eval(start_box)
|
||||
if len(start_box) == 4:
|
||||
x1, y1, x2, y2 = start_box # Assuming box is in [x1, y1, x2, y2]
|
||||
elif len(start_box) == 2:
|
||||
x1, y1 = start_box
|
||||
x2 = x1
|
||||
y2 = y1
|
||||
x = round(float((x1 + x2) / 2) * image_width, 3)
|
||||
y = round(float((y1 + y2) / 2) * image_height, 3)
|
||||
if action_type == "left_single" or action_type == "click":
|
||||
pyautogui_code += f"\npyautogui.click({x}, {y}, button='left')"
|
||||
elif action_type == "left_double":
|
||||
pyautogui_code += f"\npyautogui.doubleClick({x}, {y}, button='left')"
|
||||
elif action_type == "right_single":
|
||||
pyautogui_code += f"\npyautogui.click({x}, {y}, button='right')"
|
||||
elif action_type == "hover":
|
||||
pyautogui_code += f"\npyautogui.moveTo({x}, {y})"
|
||||
|
||||
elif action_type in ["finished"]:
|
||||
pyautogui_code = f"DONE"
|
||||
|
||||
else:
|
||||
pyautogui_code += f"\n# Unrecognized action type: {action_type}"
|
||||
|
||||
return pyautogui_code
|
||||
|
||||
def add_box_token(input_string):
|
||||
# Step 1: Split the string into individual actions
|
||||
if "Action: " in input_string and "start_box=" in input_string:
|
||||
suffix = input_string.split("Action: ")[0] + "Action: "
|
||||
actions = input_string.split("Action: ")[1:]
|
||||
processed_actions = []
|
||||
for action in actions:
|
||||
action = action.strip()
|
||||
# Step 2: Extract coordinates (start_box or end_box) using regex
|
||||
coordinates = re.findall(r"(start_box|end_box)='\((\d+),\s*(\d+)\)'", action)
|
||||
|
||||
updated_action = action # Start with the original action
|
||||
for coord_type, x, y in coordinates:
|
||||
# Convert x and y to integers
|
||||
updated_action = updated_action.replace(f"{coord_type}='({x},{y})'", f"{coord_type}='<|box_start|>({x},{y})<|box_end|>'")
|
||||
processed_actions.append(updated_action)
|
||||
|
||||
# Step 5: Reconstruct the final string
|
||||
final_string = suffix + "\n\n".join(processed_actions)
|
||||
else:
|
||||
final_string = input_string
|
||||
return final_string
|
||||
|
||||
def pil_to_base64(image):
|
||||
buffer = BytesIO()
|
||||
image.save(buffer, format="PNG") # 你可以改成 "JPEG" 等格式
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
def linearize_accessibility_tree(accessibility_tree, platform="ubuntu"):
|
||||
|
||||
if platform == "ubuntu":
|
||||
_attributes_ns = attributes_ns_ubuntu
|
||||
_state_ns = state_ns_ubuntu
|
||||
_component_ns = component_ns_ubuntu
|
||||
_value_ns = value_ns_ubuntu
|
||||
elif platform == "windows":
|
||||
_attributes_ns = attributes_ns_windows
|
||||
_state_ns = state_ns_windows
|
||||
_component_ns = component_ns_windows
|
||||
_value_ns = value_ns_windows
|
||||
else:
|
||||
raise ValueError("Invalid platform, must be 'ubuntu' or 'windows'")
|
||||
|
||||
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree), platform)
|
||||
linearized_accessibility_tree = [
|
||||
"tag\tname\ttext\tclass\tdescription\tposition (top-left x&y)\tsize (w&h)"
|
||||
]
|
||||
|
||||
# Linearize the accessibility tree nodes into a table format
|
||||
for node in filtered_nodes:
|
||||
if node.text:
|
||||
text = (
|
||||
node.text
|
||||
if '"' not in node.text
|
||||
else '"{:}"'.format(node.text.replace('"', '""'))
|
||||
)
|
||||
|
||||
elif node.get("{{{:}}}class".format(class_ns_windows), "").endswith(
|
||||
"EditWrapper"
|
||||
) and node.get("{{{:}}}value".format(_value_ns)):
|
||||
node_text = node.get("{{{:}}}value".format(_value_ns), "")
|
||||
text = (
|
||||
node_text
|
||||
if '"' not in node_text
|
||||
else '"{:}"'.format(node_text.replace('"', '""'))
|
||||
)
|
||||
else:
|
||||
text = '""'
|
||||
|
||||
linearized_accessibility_tree.append(
|
||||
"{:}\t{:}\t{:}\t{:}\t{:}\t{:}\t{:}".format(
|
||||
node.tag,
|
||||
node.get("name", ""),
|
||||
text,
|
||||
(
|
||||
node.get("{{{:}}}class".format(_attributes_ns), "")
|
||||
if platform == "ubuntu"
|
||||
else node.get("{{{:}}}class".format(class_ns_windows), "")
|
||||
),
|
||||
node.get("{{{:}}}description".format(_attributes_ns), ""),
|
||||
node.get("{{{:}}}screencoord".format(_component_ns), ""),
|
||||
node.get("{{{:}}}size".format(_component_ns), ""),
|
||||
)
|
||||
)
|
||||
|
||||
return "\n".join(linearized_accessibility_tree)
|
||||
|
||||
def trim_accessibility_tree(linearized_accessibility_tree, max_tokens):
|
||||
# enc = tiktoken.encoding_for_model("gpt-4")
|
||||
# tokens = enc.encode(linearized_accessibility_tree)
|
||||
# if len(tokens) > max_tokens:
|
||||
# linearized_accessibility_tree = enc.decode(tokens[:max_tokens])
|
||||
# linearized_accessibility_tree += "[...]\n"
|
||||
return linearized_accessibility_tree
|
||||
|
||||
|
||||
class UITARSAgent:
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
runtime_conf: Dict,
|
||||
platform="ubuntu",
|
||||
action_space="pyautogui",
|
||||
observation_type="screenshot",
|
||||
# observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"]
|
||||
max_trajectory_length=50,
|
||||
a11y_tree_max_tokens=10000,
|
||||
model_type="qwen25vl",
|
||||
**kwargs
|
||||
):
|
||||
self.model = model
|
||||
self.platform = platform
|
||||
self.action_space = action_space
|
||||
self.observation_type = observation_type
|
||||
self.max_trajectory_length = max_trajectory_length
|
||||
self.a11y_tree_max_tokens = a11y_tree_max_tokens
|
||||
self.model_type = model_type
|
||||
self.runtime_conf = runtime_conf
|
||||
self.temperature = self.runtime_conf["temperature"]
|
||||
self.top_k = self.runtime_conf["top_k"]
|
||||
self.top_p = self.runtime_conf["top_p"]
|
||||
self.max_tokens = self.runtime_conf["max_tokens"]
|
||||
self.infer_mode = self.runtime_conf["infer_mode"]
|
||||
self.prompt_style = self.runtime_conf["prompt_style"]
|
||||
self.input_swap = self.runtime_conf["input_swap"]
|
||||
self.language = self.runtime_conf["language"]
|
||||
self.max_pixels = self.runtime_conf["max_pixels"]
|
||||
self.min_pixels = self.runtime_conf["min_pixels"]
|
||||
self.callusr_tolerance = self.runtime_conf["callusr_tolerance"]
|
||||
self.vlm = OpenAI(
|
||||
base_url=os.environ['DOUBAO_API_URL'],
|
||||
api_key=os.environ['DOUBAO_API_KEY'],
|
||||
)
|
||||
|
||||
self.thoughts = []
|
||||
self.actions = []
|
||||
self.observations = []
|
||||
self.history_images = []
|
||||
self.history_responses = []
|
||||
|
||||
self.prompt_action_space = UITARS_ACTION_SPACE
|
||||
self.action_parse_res_factor = 1000
|
||||
if self.infer_mode == "qwen2vl_user":
|
||||
self.prompt_action_space = UITARS_CALL_USR_ACTION_SPACE
|
||||
elif self.infer_mode == "qwen25vl_normal":
|
||||
self.prompt_action_space = UITARS_NORMAL_ACTION_SPACE
|
||||
|
||||
self.prompt_template = UITARS_USR_PROMPT_THOUGHT
|
||||
|
||||
if self.prompt_style == "qwen2vl_user" or self.prompt_style == "qwen25vl_normal":
|
||||
self.prompt_template = UITARS_USR_PROMPT_THOUGHT
|
||||
|
||||
elif self.prompt_style == "qwen2vl_no_thought":
|
||||
self.prompt_template = UITARS_USR_PROMPT_NOTHOUGHT
|
||||
|
||||
|
||||
if "history_n" in self.runtime_conf:
|
||||
self.history_n = self.runtime_conf["history_n"]
|
||||
else:
|
||||
self.history_n = 5
|
||||
|
||||
self.cur_callusr_count = 0
|
||||
|
||||
def reset(self, runtime_logger=None):
|
||||
self.thoughts = []
|
||||
self.actions = []
|
||||
self.observations = []
|
||||
self.history_images = []
|
||||
self.history_responses = []
|
||||
|
||||
|
||||
def predict(
|
||||
self, instruction: str, obs: Dict, last_action_after_obs: Dict = None
|
||||
) -> List:
|
||||
"""
|
||||
Predict the next action(s) based on the current observation.
|
||||
"""
|
||||
|
||||
# Append trajectory
|
||||
# print(len(self.observations), len(self.actions), len(self.actions))
|
||||
assert len(self.observations) == len(self.actions) and len(self.actions) == len(
|
||||
self.thoughts
|
||||
), "The number of observations and actions should be the same."
|
||||
|
||||
if len(self.observations) > self.max_trajectory_length:
|
||||
if self.max_trajectory_length == 0:
|
||||
_observations = []
|
||||
_actions = []
|
||||
_thoughts = []
|
||||
else:
|
||||
_observations = self.observations[-self.max_trajectory_length :]
|
||||
_actions = self.actions[-self.max_trajectory_length :]
|
||||
_thoughts = self.thoughts[-self.max_trajectory_length :]
|
||||
else:
|
||||
_observations = self.observations
|
||||
_actions = self.actions
|
||||
_thoughts = self.thoughts
|
||||
|
||||
|
||||
self.history_images.append(obs["screenshot"])
|
||||
|
||||
if self.observation_type in ["screenshot", "screenshot_a11y_tree"]:
|
||||
base64_image = obs["screenshot"]
|
||||
try:
|
||||
linearized_accessibility_tree = (
|
||||
linearize_accessibility_tree(
|
||||
accessibility_tree=obs["accessibility_tree"],
|
||||
platform=self.platform,
|
||||
)
|
||||
if self.observation_type == "screenshot_a11y_tree"
|
||||
else None
|
||||
)
|
||||
except:
|
||||
linearized_accessibility_tree = None
|
||||
# logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
|
||||
|
||||
if linearized_accessibility_tree:
|
||||
linearized_accessibility_tree = trim_accessibility_tree(
|
||||
linearized_accessibility_tree, self.a11y_tree_max_tokens
|
||||
)
|
||||
|
||||
if self.observation_type == "screenshot_a11y_tree":
|
||||
self.observations.append(
|
||||
{
|
||||
"screenshot": base64_image,
|
||||
"accessibility_tree": linearized_accessibility_tree,
|
||||
}
|
||||
)
|
||||
else:
|
||||
self.observations.append(
|
||||
{"screenshot": base64_image, "accessibility_tree": None}
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid observation_type type: " + self.observation_type
|
||||
) # 1}}}
|
||||
|
||||
if self.infer_mode == "qwen2vl_user" or self.infer_mode == "qwen25vl_normal":
|
||||
user_prompt = self.prompt_template.format(
|
||||
instruction=instruction,
|
||||
action_space=self.prompt_action_space,
|
||||
language=self.language
|
||||
)
|
||||
elif self.infer_mode == "qwen2vl_no_thought":
|
||||
user_prompt = self.prompt_template.format(
|
||||
instruction=instruction
|
||||
)
|
||||
|
||||
if len(self.history_images) > self.history_n:
|
||||
self.history_images = self.history_images[-self.history_n:]
|
||||
|
||||
messages, images = [], []
|
||||
if isinstance(self.history_images, bytes):
|
||||
self.history_images = [self.history_images]
|
||||
elif isinstance(self.history_images, np.ndarray):
|
||||
self.history_images = list(self.history_images)
|
||||
elif isinstance(self.history_images, list):
|
||||
pass
|
||||
else:
|
||||
raise TypeError(f"Unidentified images type: {type(self.history_images)}")
|
||||
|
||||
for turn, image in enumerate(self.history_images):
|
||||
if len(images) >= self.history_n:
|
||||
break
|
||||
try:
|
||||
image = Image.open(BytesIO(image))
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error opening image: {e}")
|
||||
|
||||
if image.width * image.height > self.max_pixels:
|
||||
"""
|
||||
如果图片超过/低于像素限制,则计算一个缩放因子resize_factor,使图片的像素数缩小到等于或小于max_pixels。这个缩放因子是通过开平方根计算的,确保纵横比保持不变,这样原始的相对坐标可以不经转换直接复用
|
||||
"""
|
||||
resize_factor = math.sqrt(self.max_pixels / (image.width * image.height))
|
||||
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
|
||||
image = image.resize((width, height))
|
||||
if image.width * image.height < self.min_pixels:
|
||||
resize_factor = math.sqrt(self.min_pixels / (image.width * image.height))
|
||||
width, height = math.ceil(image.width * resize_factor), math.ceil(image.height * resize_factor)
|
||||
image = image.resize((width, height))
|
||||
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
images.append(image)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": "You are a helpful assistant."}]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": user_prompt}]
|
||||
}
|
||||
]
|
||||
|
||||
image_num = 0
|
||||
if len(self.history_responses) > 0:
|
||||
for history_idx, history_response in enumerate(self.history_responses):
|
||||
# send at most history_n images to the model
|
||||
if history_idx + self.history_n > len(self.history_responses):
|
||||
|
||||
cur_image = images[image_num]
|
||||
encoded_string = pil_to_base64(cur_image)
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{encoded_string}"}}]
|
||||
})
|
||||
image_num += 1
|
||||
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": add_box_token(history_response)}]
|
||||
})
|
||||
|
||||
cur_image = images[image_num]
|
||||
encoded_string = pil_to_base64(cur_image)
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{encoded_string}"}}]
|
||||
})
|
||||
image_num += 1
|
||||
|
||||
else:
|
||||
cur_image = images[image_num]
|
||||
encoded_string = pil_to_base64(cur_image)
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{encoded_string}"}}]
|
||||
})
|
||||
image_num += 1
|
||||
|
||||
try_times = 3
|
||||
origin_resized_height = images[-1].height
|
||||
origin_resized_width = images[-1].width
|
||||
temperature = self.temperature
|
||||
top_k = self.top_k
|
||||
while True:
|
||||
if try_times <= 0:
|
||||
print(f"Reach max retry times to fetch response from client, as error flag.")
|
||||
return "client error", ["DONE"]
|
||||
try:
|
||||
response = self.vlm.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
frequency_penalty=1,
|
||||
max_tokens=self.max_tokens,
|
||||
temperature=temperature,
|
||||
top_p=self.top_p
|
||||
)
|
||||
print("*" * 20)
|
||||
print("Response:")
|
||||
print(response.choices[0].message.content)
|
||||
print("*" * 20)
|
||||
prediction = response.choices[0].message.content.strip()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error when fetching response from client: {e}")
|
||||
prediction = None
|
||||
try_times -= 1
|
||||
|
||||
try:
|
||||
parsed_responses = parse_action_to_structure_output(
|
||||
prediction,
|
||||
self.action_parse_res_factor,
|
||||
origin_resized_height,
|
||||
origin_resized_width,
|
||||
self.model_type,
|
||||
self.max_pixels,
|
||||
self.min_pixels
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Error when parsing response from client: {e}")
|
||||
# If fail to parse the model response, we use sampling parameters to avoid it
|
||||
prediction = None
|
||||
try_times -= 1
|
||||
temperature = 1
|
||||
top_k = -1
|
||||
|
||||
if prediction is None:
|
||||
return "client error", ["DONE"]
|
||||
|
||||
self.history_responses.append(prediction)
|
||||
self.thoughts.append(prediction)
|
||||
|
||||
try:
|
||||
parsed_responses = parse_action_to_structure_output(
|
||||
prediction,
|
||||
self.action_parse_res_factor,
|
||||
origin_resized_height,
|
||||
origin_resized_width,
|
||||
self.model_type,
|
||||
self.max_pixels,
|
||||
self.min_pixels
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Parsing action error: {prediction}, with error:\n{e}")
|
||||
return f"Parsing action error: {prediction}, with error:\n{e}", ["DONE"]
|
||||
|
||||
actions = []
|
||||
last_image = Image.open(BytesIO(self.history_images[-1]))
|
||||
obs_image_height = last_image.height
|
||||
obs_image_width = last_image.width
|
||||
for parsed_response in parsed_responses:
|
||||
if "action_type" in parsed_response:
|
||||
|
||||
if parsed_response["action_type"] == FINISH_WORD:
|
||||
self.actions.append(actions)
|
||||
|
||||
return prediction, ["DONE"]
|
||||
|
||||
elif parsed_response["action_type"] == WAIT_WORD:
|
||||
self.actions.append(actions)
|
||||
return prediction, ["WAIT"]
|
||||
|
||||
elif parsed_response["action_type"] == ENV_FAIL_WORD:
|
||||
self.actions.append(actions)
|
||||
return prediction, ["FAIL"]
|
||||
|
||||
elif parsed_response["action_type"] == CALL_USER:
|
||||
if self.callusr_tolerance > self.cur_callusr_count:
|
||||
self.actions.append(actions)
|
||||
self.cur_callusr_count += 1
|
||||
return prediction, ["WAIT"]
|
||||
else:
|
||||
self.actions.append(actions)
|
||||
return prediction, ["FAIL"]
|
||||
|
||||
pyautogui_code = parsing_response_to_pyautogui_code(
|
||||
parsed_response,
|
||||
obs_image_height,
|
||||
obs_image_width,
|
||||
self.input_swap
|
||||
)
|
||||
actions.append(pyautogui_code)
|
||||
|
||||
self.actions.append(actions)
|
||||
|
||||
if len(self.history_responses) >= self.max_trajectory_length:
|
||||
# Default to FAIL if exceed max steps
|
||||
actions = ["FAIL"]
|
||||
|
||||
return prediction, actions
|
||||
|
|
@ -613,7 +613,7 @@ class UITarsAgent:
|
|||
self,
|
||||
# Model settings
|
||||
model: str,
|
||||
|
||||
model_type: str,
|
||||
# Generation settings
|
||||
max_tokens: int,
|
||||
top_p: Optional[float],
|
||||
|
|
@ -672,7 +672,7 @@ class UITarsAgent:
|
|||
self.system_prompt = COMPUTER_USE_NO_THINKING
|
||||
|
||||
self.action_parse_res_factor = 1000
|
||||
self.model_type = "doubao"
|
||||
self.model_type = model_type
|
||||
self.history_n = 5
|
||||
self.top_p = top_p
|
||||
self.temperature = temperature
|
||||
|
|
@ -6,7 +6,7 @@ import re
|
|||
import xml.etree.ElementTree as ET
|
||||
from io import BytesIO
|
||||
from typing import Dict, List
|
||||
|
||||
import os
|
||||
import backoff
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
|
@ -28,22 +28,16 @@ from mm_agents.prompts import (
|
|||
UITARS_CALL_USR_ACTION_SPACE,
|
||||
UITARS_USR_PROMPT_NOTHOUGHT,
|
||||
UITARS_USR_PROMPT_THOUGHT,
|
||||
UITARS_NORMAL_ACTION_SPACE
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger("desktopenv.agent")
|
||||
from loguru import logger
|
||||
|
||||
FINISH_WORD = "finished"
|
||||
WAIT_WORD = "wait"
|
||||
ENV_FAIL_WORD = "error_env"
|
||||
CALL_USER = "call_user"
|
||||
|
||||
IMAGE_FACTOR = 28
|
||||
MIN_PIXELS = 100 * 28 * 28
|
||||
MAX_PIXELS = 16384 * 28 * 28
|
||||
MAX_RATIO = 200
|
||||
|
||||
pure_text_settings = ["a11y_tree"]
|
||||
|
||||
attributes_ns_ubuntu = "https://accessibility.windows.example.org/ns/attributes"
|
||||
|
|
@ -109,68 +103,8 @@ def escape_single_quotes(text):
|
|||
pattern = r"(?<!\\)'"
|
||||
return re.sub(pattern, r"\\'", text)
|
||||
|
||||
def round_by_factor(number: int, factor: int) -> int:
|
||||
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
||||
return round(number / factor) * factor
|
||||
|
||||
|
||||
def ceil_by_factor(number: int, factor: int) -> int:
|
||||
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.ceil(number / factor) * factor
|
||||
|
||||
|
||||
def floor_by_factor(number: int, factor: int) -> int:
|
||||
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.floor(number / factor) * factor
|
||||
|
||||
def linear_resize(
|
||||
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
|
||||
) -> tuple[int, int]:
|
||||
if width * height > max_pixels:
|
||||
"""
|
||||
如果图片超过/低于像素限制,则计算一个缩放因子resize_factor,使图片的像素数缩小到等于或小于max_pixels。这个缩放因子是通过开平方根计算的,确保纵横比保持不变,这样原始的相对坐标可以不经转换直接复用
|
||||
"""
|
||||
resize_factor = math.sqrt(max_pixels / (width * height))
|
||||
width, height = int(width * resize_factor), int(height * resize_factor)
|
||||
if width * height < min_pixels:
|
||||
resize_factor = math.sqrt(min_pixels / (width * height))
|
||||
width, height = math.ceil(width * resize_factor), math.ceil(height * resize_factor)
|
||||
|
||||
return height, width
|
||||
|
||||
def smart_resize(
|
||||
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Rescales the image so that the following conditions are met:
|
||||
|
||||
1. Both dimensions (height and width) are divisible by 'factor'.
|
||||
|
||||
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
||||
|
||||
3. The aspect ratio of the image is maintained as closely as possible.
|
||||
"""
|
||||
if max(height, width) / min(height, width) > MAX_RATIO:
|
||||
raise ValueError(
|
||||
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
|
||||
)
|
||||
h_bar = max(factor, round_by_factor(height, factor))
|
||||
w_bar = max(factor, round_by_factor(width, factor))
|
||||
if h_bar * w_bar > max_pixels:
|
||||
beta = math.sqrt((height * width) / max_pixels)
|
||||
h_bar = floor_by_factor(height / beta, factor)
|
||||
w_bar = floor_by_factor(width / beta, factor)
|
||||
elif h_bar * w_bar < min_pixels:
|
||||
beta = math.sqrt(min_pixels / (height * width))
|
||||
h_bar = ceil_by_factor(height * beta, factor)
|
||||
w_bar = ceil_by_factor(width * beta, factor)
|
||||
return h_bar, w_bar
|
||||
|
||||
def parse_action_to_structure_output(text, factor, origin_resized_height, origin_resized_width, model_type, max_pixels=16384*28*28, min_pixels=100*28*28):
|
||||
def parse_action_qwen2vl(text, factor, image_height, image_width):
|
||||
text = text.strip()
|
||||
if model_type == "qwen25vl":
|
||||
smart_resize_height, smart_resize_width = smart_resize(origin_resized_height, origin_resized_width, factor=IMAGE_FACTOR, min_pixels=min_pixels, max_pixels=max_pixels)
|
||||
|
||||
# 正则表达式匹配 Action 字符串
|
||||
if text.startswith("Thought:"):
|
||||
thought_pattern = r"Thought: (.+?)(?=\s*Action:|$)"
|
||||
|
|
@ -182,8 +116,10 @@ def parse_action_to_structure_output(text, factor, origin_resized_height, origin
|
|||
thought_pattern = r"Action_Summary: (.+?)(?=\s*Action:|$)"
|
||||
thought_hint = "Action_Summary: "
|
||||
else:
|
||||
thought_pattern = r"Thought: (.+?)(?=\s*Action:|$)"
|
||||
thought_hint = "Thought: "
|
||||
# 修复:当没有明确的"Thought:"标识时,提取Action:之前的所有内容作为思考
|
||||
thought_pattern = r"(.+?)(?=\s*Action:|$)"
|
||||
thought_hint = ""
|
||||
|
||||
reflection, thought = None, None
|
||||
thought_match = re.search(thought_pattern, text, re.DOTALL)
|
||||
if thought_match:
|
||||
|
|
@ -218,7 +154,7 @@ def parse_action_to_structure_output(text, factor, origin_resized_height, origin
|
|||
for action_instance, raw_str in zip(parsed_actions, all_action):
|
||||
if action_instance == None:
|
||||
print(f"Action can't parse: {raw_str}")
|
||||
raise ValueError(f"Action can't parse: {raw_str}")
|
||||
continue
|
||||
action_type = action_instance["function"]
|
||||
params = action_instance["args"]
|
||||
|
||||
|
|
@ -236,18 +172,7 @@ def parse_action_to_structure_output(text, factor, origin_resized_height, origin
|
|||
numbers = ori_box.replace("(", "").replace(")", "").split(",")
|
||||
|
||||
# Convert to float and scale by 1000
|
||||
# Qwen2.5vl output absolute coordinates, qwen2vl output relative coordinates
|
||||
if model_type == "qwen25vl":
|
||||
float_numbers = []
|
||||
for num_idx, num in enumerate(numbers):
|
||||
num = float(num)
|
||||
if (num_idx + 1) % 2 == 0:
|
||||
float_numbers.append(float(num/smart_resize_height))
|
||||
else:
|
||||
float_numbers.append(float(num/smart_resize_width))
|
||||
else:
|
||||
float_numbers = [float(num) / factor for num in numbers]
|
||||
|
||||
float_numbers = [float(num) / factor for num in numbers]
|
||||
if len(float_numbers) == 2:
|
||||
float_numbers = [float_numbers[0], float_numbers[1], float_numbers[0], float_numbers[1]]
|
||||
action_inputs[param_name.strip()] = str(float_numbers)
|
||||
|
|
@ -296,7 +221,7 @@ def parsing_response_to_pyautogui_code(responses, image_height: int, image_width
|
|||
if response_id == 0:
|
||||
pyautogui_code += f"'''\nObservation:\n{observation}\n\nThought:\n{thought}\n'''\n"
|
||||
else:
|
||||
pyautogui_code += f"\ntime.sleep(1)\n"
|
||||
pyautogui_code += f"\ntime.sleep(3)\n"
|
||||
|
||||
action_dict = response
|
||||
action_type = action_dict.get("action_type")
|
||||
|
|
@ -309,79 +234,25 @@ def parsing_response_to_pyautogui_code(responses, image_height: int, image_width
|
|||
else:
|
||||
hotkey = action_inputs.get("hotkey", "")
|
||||
|
||||
if hotkey == "arrowleft":
|
||||
hotkey = "left"
|
||||
|
||||
elif hotkey == "arrowright":
|
||||
hotkey = "right"
|
||||
|
||||
elif hotkey == "arrowup":
|
||||
hotkey = "up"
|
||||
|
||||
elif hotkey == "arrowdown":
|
||||
hotkey = "down"
|
||||
|
||||
if hotkey:
|
||||
# Handle other hotkeys
|
||||
keys = hotkey.split() # Split the keys by space
|
||||
convert_keys = []
|
||||
for key in keys:
|
||||
if key == "space":
|
||||
key = ' '
|
||||
convert_keys.append(key)
|
||||
pyautogui_code += f"\npyautogui.hotkey({', '.join([repr(k) for k in convert_keys])})"
|
||||
pyautogui_code += f"\npyautogui.hotkey({', '.join([repr(k) for k in keys])})"
|
||||
|
||||
elif action_type == "press":
|
||||
# Parsing press action
|
||||
if "key" in action_inputs:
|
||||
key_to_press = action_inputs.get("key", "")
|
||||
else:
|
||||
key_to_press = action_inputs.get("press", "")
|
||||
|
||||
if hotkey == "arrowleft":
|
||||
hotkey = "left"
|
||||
|
||||
elif hotkey == "arrowright":
|
||||
hotkey = "right"
|
||||
|
||||
elif hotkey == "arrowup":
|
||||
hotkey = "up"
|
||||
|
||||
elif hotkey == "arrowdown":
|
||||
hotkey = "down"
|
||||
|
||||
elif hotkey == "space":
|
||||
hotkey = " "
|
||||
|
||||
if key_to_press:
|
||||
# Simulate pressing a single key
|
||||
pyautogui_code += f"\npyautogui.press({repr(key_to_press)})"
|
||||
|
||||
elif action_type == "keyup":
|
||||
key_to_up = action_inputs.get("key", "")
|
||||
pyautogui_code += f"\npyautogui.keyUp({repr(key_to_up)})"
|
||||
|
||||
elif action_type == "keydown":
|
||||
key_to_down = action_inputs.get("key", "")
|
||||
pyautogui_code += f"\npyautogui.keyDown({repr(key_to_down)})"
|
||||
|
||||
elif action_type == "type":
|
||||
# Parsing typing action using clipboard
|
||||
content = action_inputs.get("content", "")
|
||||
content = escape_single_quotes(content)
|
||||
stripped_content = content
|
||||
if content.endswith("\n") or content.endswith("\\n"):
|
||||
stripped_content = stripped_content.rstrip("\\n").rstrip("\n")
|
||||
if content:
|
||||
if input_swap:
|
||||
pyautogui_code += f"\nimport pyperclip"
|
||||
pyautogui_code += f"\npyperclip.copy('{stripped_content}')"
|
||||
pyautogui_code += f"\npyperclip.copy('{content.strip()}')"
|
||||
pyautogui_code += f"\npyautogui.hotkey('ctrl', 'v')"
|
||||
pyautogui_code += f"\ntime.sleep(0.5)\n"
|
||||
if content.endswith("\n") or content.endswith("\\n"):
|
||||
pyautogui_code += f"\npyautogui.press('enter')"
|
||||
else:
|
||||
pyautogui_code += f"\npyautogui.write('{stripped_content}', interval=0.1)"
|
||||
pyautogui_code += f"\npyautogui.write('{content.strip()}', interval=0.1)"
|
||||
pyautogui_code += f"\ntime.sleep(0.5)\n"
|
||||
if content.endswith("\n") or content.endswith("\\n"):
|
||||
pyautogui_code += f"\npyautogui.press('enter')"
|
||||
|
|
@ -460,29 +331,6 @@ def parsing_response_to_pyautogui_code(responses, image_height: int, image_width
|
|||
|
||||
return pyautogui_code
|
||||
|
||||
def add_box_token(input_string):
|
||||
# Step 1: Split the string into individual actions
|
||||
if "Action: " in input_string and "start_box=" in input_string:
|
||||
suffix = input_string.split("Action: ")[0] + "Action: "
|
||||
actions = input_string.split("Action: ")[1:]
|
||||
processed_actions = []
|
||||
for action in actions:
|
||||
action = action.strip()
|
||||
# Step 2: Extract coordinates (start_box or end_box) using regex
|
||||
coordinates = re.findall(r"(start_box|end_box)='\((\d+),\s*(\d+)\)'", action)
|
||||
|
||||
updated_action = action # Start with the original action
|
||||
for coord_type, x, y in coordinates:
|
||||
# Convert x and y to integers
|
||||
updated_action = updated_action.replace(f"{coord_type}='({x},{y})'", f"{coord_type}='<|box_start|>({x},{y})<|box_end|>'")
|
||||
processed_actions.append(updated_action)
|
||||
|
||||
# Step 5: Reconstruct the final string
|
||||
final_string = suffix + "\n\n".join(processed_actions)
|
||||
else:
|
||||
final_string = input_string
|
||||
return final_string
|
||||
|
||||
def pil_to_base64(image):
|
||||
buffer = BytesIO()
|
||||
image.save(buffer, format="PNG") # 你可以改成 "JPEG" 等格式
|
||||
|
|
@ -558,51 +406,48 @@ def trim_accessibility_tree(linearized_accessibility_tree, max_tokens):
|
|||
class UITARSAgent:
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
platform="ubuntu",
|
||||
max_tokens=1000,
|
||||
top_p=0.9,
|
||||
top_k=1.0,
|
||||
temperature=0.0,
|
||||
action_space="pyautogui",
|
||||
observation_type="screenshot",
|
||||
observation_type="screenshot_a11y_tree",
|
||||
# observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"]
|
||||
max_trajectory_length=50,
|
||||
a11y_tree_max_tokens=10000,
|
||||
model_type="qwen25vl",
|
||||
runtime_conf: dict = {
|
||||
"infer_mode": "qwen25vl_normal",
|
||||
"prompt_style": "qwen25vl_normal",
|
||||
"infer_mode": "qwen2vl_user",
|
||||
"prompt_style": "qwen2vl_user",
|
||||
"input_swap": True,
|
||||
"language": "Chinese",
|
||||
"max_steps": 50,
|
||||
"history_n": 5,
|
||||
"max_pixels": 16384*28*28,
|
||||
"min_pixels": 100*28*28,
|
||||
"callusr_tolerance": 3,
|
||||
"temperature": 0.0,
|
||||
"top_k": -1,
|
||||
"top_p": 0.9,
|
||||
"max_tokens": 500
|
||||
|
||||
"screen_height": 1080,
|
||||
"screen_width": 1920
|
||||
}
|
||||
):
|
||||
self.model = model
|
||||
self.platform = platform
|
||||
self.max_tokens = max_tokens
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
self.temperature = temperature
|
||||
self.action_space = action_space
|
||||
self.observation_type = observation_type
|
||||
self.max_trajectory_length = max_trajectory_length
|
||||
self.a11y_tree_max_tokens = a11y_tree_max_tokens
|
||||
self.model_type = model_type
|
||||
self.runtime_conf = runtime_conf
|
||||
self.vlm = OpenAI(
|
||||
base_url="http://127.0.0.1:8000/v1",
|
||||
api_key="empty",
|
||||
base_url=os.environ['DOUBAO_API_URL'],
|
||||
api_key=os.environ['DOUBAO_API_KEY'],
|
||||
) # should replace with your UI-TARS server api
|
||||
self.temperature = self.runtime_conf["temperature"]
|
||||
self.top_k = self.runtime_conf["top_k"]
|
||||
self.top_p = self.runtime_conf["top_p"]
|
||||
self.max_tokens = self.runtime_conf["max_tokens"]
|
||||
self.infer_mode = self.runtime_conf["infer_mode"]
|
||||
self.prompt_style = self.runtime_conf["prompt_style"]
|
||||
self.input_swap = self.runtime_conf["input_swap"]
|
||||
self.language = self.runtime_conf["language"]
|
||||
self.max_pixels = self.runtime_conf["max_pixels"]
|
||||
self.min_pixels = self.runtime_conf["min_pixels"]
|
||||
self.callusr_tolerance = self.runtime_conf["callusr_tolerance"]
|
||||
self.max_steps = max_trajectory_length
|
||||
|
||||
self.thoughts = []
|
||||
self.actions = []
|
||||
|
|
@ -611,15 +456,14 @@ class UITARSAgent:
|
|||
self.history_responses = []
|
||||
|
||||
self.prompt_action_space = UITARS_ACTION_SPACE
|
||||
self.customize_action_parser = parse_action_qwen2vl
|
||||
self.action_parse_res_factor = 1000
|
||||
if self.infer_mode == "qwen2vl_user":
|
||||
self.prompt_action_space = UITARS_CALL_USR_ACTION_SPACE
|
||||
elif self.infer_mode == "qwen25vl_normal":
|
||||
self.prompt_action_space = UITARS_NORMAL_ACTION_SPACE
|
||||
|
||||
self.prompt_template = UITARS_USR_PROMPT_THOUGHT
|
||||
|
||||
if self.prompt_style == "qwen2vl_user" or self.prompt_style == "qwen25vl_normal":
|
||||
if self.prompt_style == "qwen2vl_user":
|
||||
self.prompt_template = UITARS_USR_PROMPT_THOUGHT
|
||||
|
||||
elif self.prompt_style == "qwen2vl_no_thought":
|
||||
|
|
@ -630,8 +474,6 @@ class UITARSAgent:
|
|||
self.history_n = self.runtime_conf["history_n"]
|
||||
else:
|
||||
self.history_n = 5
|
||||
|
||||
self.cur_callusr_count = 0
|
||||
|
||||
def predict(
|
||||
self, instruction: str, obs: Dict, last_action_after_obs: Dict = None
|
||||
|
|
@ -660,18 +502,9 @@ class UITARSAgent:
|
|||
_actions = self.actions
|
||||
_thoughts = self.thoughts
|
||||
|
||||
for previous_obs, previous_action, previous_thought in zip(
|
||||
_observations, _actions, _thoughts
|
||||
):
|
||||
# {{{1
|
||||
if self.observation_type == "screenshot_a11y_tree":
|
||||
_screenshot = previous_obs["screenshot"]
|
||||
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid observation_type type: " + self.observation_type
|
||||
) # 1}}}
|
||||
|
||||
if last_action_after_obs is not None and self.infer_mode == "double_image":
|
||||
self.history_images.append(last_action_after_obs["screenshot"])
|
||||
|
||||
self.history_images.append(obs["screenshot"])
|
||||
|
||||
|
|
@ -712,7 +545,7 @@ class UITARSAgent:
|
|||
"Invalid observation_type type: " + self.observation_type
|
||||
) # 1}}}
|
||||
|
||||
if self.infer_mode == "qwen2vl_user" or self.infer_mode == "qwen25vl_normal":
|
||||
if self.infer_mode == "qwen2vl_user":
|
||||
user_prompt = self.prompt_template.format(
|
||||
instruction=instruction,
|
||||
action_space=self.prompt_action_space,
|
||||
|
|
@ -726,6 +559,8 @@ class UITARSAgent:
|
|||
if len(self.history_images) > self.history_n:
|
||||
self.history_images = self.history_images[-self.history_n:]
|
||||
|
||||
max_pixels = 2116800
|
||||
min_pixels = 3136
|
||||
messages, images = [], []
|
||||
if isinstance(self.history_images, bytes):
|
||||
self.history_images = [self.history_images]
|
||||
|
|
@ -735,24 +570,28 @@ class UITARSAgent:
|
|||
pass
|
||||
else:
|
||||
raise TypeError(f"Unidentified images type: {type(self.history_images)}")
|
||||
max_image_nums_under_32k = int(32768*0.75/max_pixels*28*28)
|
||||
if len(self.history_images) > max_image_nums_under_32k:
|
||||
num_of_images = min(5, len(self.history_images))
|
||||
max_pixels = int(32768*0.75) // num_of_images
|
||||
|
||||
for turn, image in enumerate(self.history_images):
|
||||
if len(images) >= self.history_n:
|
||||
if len(images) >= 5:
|
||||
break
|
||||
try:
|
||||
image = Image.open(BytesIO(image))
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error opening image: {e}")
|
||||
|
||||
if image.width * image.height > self.max_pixels:
|
||||
if image.width * image.height > max_pixels:
|
||||
"""
|
||||
如果图片超过/低于像素限制,则计算一个缩放因子resize_factor,使图片的像素数缩小到等于或小于max_pixels。这个缩放因子是通过开平方根计算的,确保纵横比保持不变,这样原始的相对坐标可以不经转换直接复用
|
||||
"""
|
||||
resize_factor = math.sqrt(self.max_pixels / (image.width * image.height))
|
||||
resize_factor = math.sqrt(max_pixels / (image.width * image.height))
|
||||
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
|
||||
image = image.resize((width, height))
|
||||
if image.width * image.height < self.min_pixels:
|
||||
resize_factor = math.sqrt(self.min_pixels / (image.width * image.height))
|
||||
if image.width * image.height < min_pixels:
|
||||
resize_factor = math.sqrt(min_pixels / (image.width * image.height))
|
||||
width, height = math.ceil(image.width * resize_factor), math.ceil(image.height * resize_factor)
|
||||
image = image.resize((width, height))
|
||||
|
||||
|
|
@ -788,7 +627,7 @@ class UITARSAgent:
|
|||
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": [add_box_token(history_response)]
|
||||
"content": history_response
|
||||
})
|
||||
|
||||
cur_image = images[image_num]
|
||||
|
|
@ -809,79 +648,59 @@ class UITARSAgent:
|
|||
image_num += 1
|
||||
|
||||
try_times = 3
|
||||
origin_resized_height = images[-1].height
|
||||
origin_resized_width = images[-1].width
|
||||
temperature = self.temperature
|
||||
top_k = self.top_k
|
||||
while True:
|
||||
if try_times <= 0:
|
||||
print(f"Reach max retry times to fetch response from client, as error flag.")
|
||||
return "client error", ["DONE"], []
|
||||
return "client error", ["DONE"]
|
||||
try:
|
||||
|
||||
response = self.vlm.chat.completions.create(
|
||||
model="ui-tars",
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
frequency_penalty=1,
|
||||
max_tokens=self.max_tokens,
|
||||
temperature=temperature,
|
||||
temperature=self.temperature,
|
||||
top_p=self.top_p
|
||||
)
|
||||
# print(response.choices[0].message.content)
|
||||
prediction = response.choices[0].message.content.strip()
|
||||
except Exception as e:
|
||||
print(f"Error when fetching response from client, with response: {response}")
|
||||
prediction = None
|
||||
try_times -= 1
|
||||
|
||||
try:
|
||||
parsed_responses = parse_action_to_structure_output(
|
||||
print("Response:")
|
||||
print(response.choices[0].message.content)
|
||||
|
||||
prediction = response.choices[0].message.content
|
||||
parsed_responses = self.customize_action_parser(
|
||||
prediction,
|
||||
self.action_parse_res_factor,
|
||||
origin_resized_height,
|
||||
origin_resized_width,
|
||||
self.model_type,
|
||||
self.max_pixels,
|
||||
self.min_pixels
|
||||
self.runtime_conf["screen_height"],
|
||||
self.runtime_conf["screen_width"]
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Error when parsing response from client, with response: {response}")
|
||||
# If fail to parse the model response, we use sampling parameters to avoid it
|
||||
logger.exception(f"Error when fetching response from client, with response: {e}")
|
||||
prediction = None
|
||||
try_times -= 1
|
||||
temperature = 1
|
||||
top_k = -1
|
||||
|
||||
if prediction is None:
|
||||
return "client error", ["DONE"]
|
||||
|
||||
|
||||
self.history_responses.append(prediction)
|
||||
self.thoughts.append(prediction)
|
||||
|
||||
try:
|
||||
parsed_responses = parse_action_to_structure_output(
|
||||
parsed_responses = self.customize_action_parser(
|
||||
prediction,
|
||||
self.action_parse_res_factor,
|
||||
origin_resized_height,
|
||||
origin_resized_width,
|
||||
self.model_type,
|
||||
self.max_pixels,
|
||||
self.min_pixels
|
||||
self.runtime_conf["screen_height"],
|
||||
self.runtime_conf["screen_width"]
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Parsing action error: {prediction}, with error:\n{e}")
|
||||
return f"Parsing action error: {prediction}, with error:\n{e}", ["DONE"]
|
||||
|
||||
actions = []
|
||||
last_image = Image.open(BytesIO(self.history_images[-1]))
|
||||
obs_image_height = last_image.height
|
||||
obs_image_width = last_image.width
|
||||
for parsed_response in parsed_responses:
|
||||
if "action_type" in parsed_response:
|
||||
|
||||
if parsed_response["action_type"] == FINISH_WORD:
|
||||
self.actions.append(actions)
|
||||
|
||||
return prediction, ["DONE"]
|
||||
|
||||
elif parsed_response["action_type"] == WAIT_WORD:
|
||||
|
|
@ -893,18 +712,13 @@ class UITARSAgent:
|
|||
return prediction, ["FAIL"]
|
||||
|
||||
elif parsed_response["action_type"] == CALL_USER:
|
||||
if self.callusr_tolerance > self.cur_callusr_count:
|
||||
self.actions.append(actions)
|
||||
self.cur_callusr_count += 1
|
||||
return prediction, ["WAIT"]
|
||||
else:
|
||||
self.actions.append(actions)
|
||||
return prediction, ["FAIL"]
|
||||
self.actions.append(actions)
|
||||
return prediction, ["FAIL"]
|
||||
|
||||
pyautogui_code = parsing_response_to_pyautogui_code(
|
||||
parsed_response,
|
||||
obs_image_height,
|
||||
obs_image_width,
|
||||
self.runtime_conf["screen_height"],
|
||||
self.runtime_conf["screen_width"],
|
||||
self.input_swap
|
||||
)
|
||||
actions.append(pyautogui_code)
|
||||
|
|
@ -917,7 +731,6 @@ class UITARSAgent:
|
|||
|
||||
return prediction, actions
|
||||
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.constant,
|
||||
# here you should add more model exceptions as you want,
|
||||
|
|
@ -947,4 +760,4 @@ class UITARSAgent:
|
|||
self.actions = []
|
||||
self.observations = []
|
||||
self.history_images = []
|
||||
self.history_responses = []
|
||||
self.history_responses = []
|
||||
|
|
@ -0,0 +1,539 @@
|
|||
from __future__ import annotations
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import signal
|
||||
import time
|
||||
from typing import List
|
||||
from multiprocessing import Process, Manager
|
||||
from multiprocessing import current_process
|
||||
import lib_run_single
|
||||
from desktop_env.desktop_env import DesktopEnv
|
||||
from mm_agents.uitars_agent import UITARSAgent
|
||||
import os
|
||||
|
||||
|
||||
# Global variables for signal handling
|
||||
active_environments = []
|
||||
processes = []
|
||||
is_terminating = False
|
||||
|
||||
# load the environment variables from .env file
|
||||
if os.path.exists(".env"):
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
# Logger Configs {{{ #
|
||||
def config() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run end-to-end evaluation on the benchmark"
|
||||
)
|
||||
|
||||
# environment config
|
||||
parser.add_argument("--path_to_vm", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--headless", action="store_true", help="Run in headless machine"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--action_space", type=str, default="pyautogui", help="Action type"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--observation_type",
|
||||
choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
default="screenshot",
|
||||
help="Observation type",
|
||||
)
|
||||
parser.add_argument("--sleep_after_execution", type=float, default=3.0)
|
||||
parser.add_argument("--max_steps", type=int, default=15)
|
||||
|
||||
# evaluation config
|
||||
parser.add_argument(
|
||||
"--test_config_base_dir", type=str, default="evaluation_examples"
|
||||
)
|
||||
|
||||
# lm config
|
||||
# lm config
|
||||
parser.add_argument("--model", type=str, default="uitars-72b-dpo", help="Model name")
|
||||
parser.add_argument("--temperature", type=float, default=1.0)
|
||||
parser.add_argument("--top_p", type=float, default=0.9)
|
||||
parser.add_argument("--max_tokens", type=int, default=1500)
|
||||
parser.add_argument("--stop_token", type=str, default=None)
|
||||
|
||||
parser.add_argument("--max_trajectory_length", type=int, default=3, help="The max number of trajectory steps.")
|
||||
|
||||
# example config
|
||||
parser.add_argument("--domain", type=str, default="all")
|
||||
parser.add_argument(
|
||||
"--test_all_meta_path", type=str, default="evaluation_examples/test_all.json"
|
||||
)
|
||||
|
||||
# logging related
|
||||
parser.add_argument("--result_dir", type=str, default="./results")
|
||||
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel")
|
||||
parser.add_argument("--log_level", type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
|
||||
default='INFO', help="Set the logging level")
|
||||
# aws config
|
||||
parser.add_argument(
|
||||
"--region", type=str, default="us-east-1", help="AWS region for the VM"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--provider_name", type=str, default="aws", choices=["aws", "virtualbox", "vmware", "docker", "azure"], help="Provider name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--client_password", type=str, default="", help="Client password"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--screen_width", type=int, default=1920, help="Screen width"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--screen_height", type=int, default=1080, help="Screen height"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
args = config() # Get command line arguments first
|
||||
|
||||
logger = logging.getLogger()
|
||||
log_level = getattr(logging, args.log_level.upper())
|
||||
logger.setLevel(log_level)
|
||||
|
||||
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
|
||||
file_handler = logging.FileHandler(
|
||||
os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||
)
|
||||
debug_handler = logging.FileHandler(
|
||||
os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||
)
|
||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||
|
||||
file_handler.setLevel(logging.INFO)
|
||||
debug_handler.setLevel(logging.DEBUG)
|
||||
stdout_handler.setLevel(log_level)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
|
||||
)
|
||||
file_handler.setFormatter(formatter)
|
||||
debug_handler.setFormatter(formatter)
|
||||
stdout_handler.setFormatter(formatter)
|
||||
|
||||
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
||||
|
||||
logger.addHandler(file_handler)
|
||||
logger.addHandler(debug_handler)
|
||||
logger.addHandler(stdout_handler)
|
||||
# }}} Logger Configs #
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
|
||||
def distribute_tasks(test_all_meta: dict) -> List[tuple]:
|
||||
all_tasks = []
|
||||
for domain, examples in test_all_meta.items():
|
||||
for example_id in examples:
|
||||
all_tasks.append((domain, example_id))
|
||||
return all_tasks
|
||||
|
||||
|
||||
def process_signal_handler(signum, frame, env_idx):
|
||||
"""Signal handler for child processes to gracefully shut down their environments."""
|
||||
logger.info(f"Process {env_idx + 1} received signal {signum}. Shutting down...")
|
||||
|
||||
# Get the active_environments from the caller's frame
|
||||
local_vars = frame.f_locals
|
||||
active_environments = local_vars.get('active_environments', [])
|
||||
|
||||
# Close environment in the current process context
|
||||
for env in active_environments:
|
||||
if env is not None:
|
||||
try:
|
||||
logger.info(f"Process {env_idx + 1} closing environment...")
|
||||
env.close()
|
||||
logger.info(f"Process {env_idx + 1} environment closed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Process {env_idx + 1} error closing environment: {e}")
|
||||
|
||||
logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.")
|
||||
sys.exit(0)
|
||||
|
||||
def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: list):
|
||||
active_environments = []
|
||||
env = None
|
||||
try:
|
||||
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
|
||||
REGION = args.region
|
||||
screen_size = (args.screen_width, args.screen_height)
|
||||
ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)])
|
||||
env = DesktopEnv(
|
||||
path_to_vm=args.path_to_vm,
|
||||
action_space=args.action_space,
|
||||
provider_name=args.provider_name,
|
||||
region=REGION,
|
||||
snapshot_name=ami_id,
|
||||
screen_size=screen_size,
|
||||
headless=args.headless,
|
||||
os_type="Ubuntu",
|
||||
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
enable_proxy=True,
|
||||
client_password=args.client_password
|
||||
)
|
||||
active_environments.append(env)
|
||||
args.max_trajectory_length = args.max_steps
|
||||
agent = UITARSAgent(
|
||||
model=args.model,
|
||||
max_tokens=args.max_tokens,
|
||||
top_p=args.top_p,
|
||||
temperature=args.temperature,
|
||||
action_space=args.action_space,
|
||||
observation_type=args.observation_type,
|
||||
max_trajectory_length=args.max_trajectory_length,
|
||||
)
|
||||
|
||||
logger.info(f"Process {current_process().name} started.")
|
||||
while True:
|
||||
try:
|
||||
item = task_queue.get(timeout=5)
|
||||
except Exception:
|
||||
break
|
||||
domain, example_id = item
|
||||
try:
|
||||
config_file = os.path.join(
|
||||
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
|
||||
)
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
example = json.load(f)
|
||||
logger.info(f"[{current_process().name}][Domain]: {domain}")
|
||||
logger.info(f"[{current_process().name}][Example ID]: {example_id}")
|
||||
logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}")
|
||||
example_result_dir = os.path.join(
|
||||
args.result_dir,
|
||||
args.action_space,
|
||||
args.observation_type,
|
||||
args.model,
|
||||
domain,
|
||||
example_id,
|
||||
)
|
||||
os.makedirs(example_result_dir, exist_ok=True)
|
||||
try:
|
||||
lib_run_single.run_single_example(
|
||||
agent,
|
||||
env,
|
||||
example,
|
||||
args.max_steps,
|
||||
example["instruction"],
|
||||
args,
|
||||
example_result_dir,
|
||||
shared_scores,
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
try:
|
||||
env.controller.end_recording(
|
||||
os.path.join(example_result_dir, "recording.mp4")
|
||||
)
|
||||
except Exception as rec_e:
|
||||
logger.error(f"Failed to end recording: {rec_e}")
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||
f.write(
|
||||
json.dumps(
|
||||
{"Error": f"{domain}/{example_id} - {e}"}
|
||||
)
|
||||
)
|
||||
f.write("\n")
|
||||
except Exception as e:
|
||||
logger.error(f"Task-level error in {current_process().name}: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
except Exception as e:
|
||||
logger.error(f"Process-level error in {current_process().name}: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
finally:
|
||||
logger.info(f"{current_process().name} cleaning up environment...")
|
||||
try:
|
||||
if env:
|
||||
env.close()
|
||||
logger.info(f"{current_process().name} environment closed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"{current_process().name} error during environment cleanup: {e}")
|
||||
|
||||
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
"""Handle termination signals (SIGINT, SIGTERM) to gracefully shutdown environments."""
|
||||
global is_terminating, active_environments, processes
|
||||
|
||||
# Avoid duplicate handling
|
||||
if is_terminating:
|
||||
return
|
||||
|
||||
is_terminating = True
|
||||
logger.info(f"Received signal {signum}. Gracefully shutting down...")
|
||||
|
||||
# Close all registered environments in the main process
|
||||
for env in active_environments:
|
||||
try:
|
||||
logger.info(f"Closing environment...")
|
||||
env.close()
|
||||
logger.info(f"Environment closed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing environment: {e}")
|
||||
|
||||
# Send termination signal to all child processes first
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
try:
|
||||
logger.info(f"Sending termination signal to process {p.name}...")
|
||||
p.terminate()
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending termination signal to process: {e}")
|
||||
|
||||
# Allow a short time for processes to handle their own cleanup
|
||||
time.sleep(1)
|
||||
|
||||
# Forcefully terminate any processes that didn't exit
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
try:
|
||||
logger.info(f"Forcefully terminating process {p.name}...")
|
||||
import signal as sig
|
||||
os.kill(p.pid, sig.SIGKILL)
|
||||
except Exception as e:
|
||||
logger.error(f"Error forcefully terminating process: {e}")
|
||||
|
||||
logger.info("Shutdown complete. Exiting.")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||
global processes
|
||||
logger.info("Args: %s", args)
|
||||
all_tasks = distribute_tasks(test_all_meta)
|
||||
logger.info(f"Total tasks: {len(all_tasks)}")
|
||||
with Manager() as manager:
|
||||
shared_scores = manager.list()
|
||||
task_queue = manager.Queue()
|
||||
for item in all_tasks:
|
||||
task_queue.put(item)
|
||||
num_envs = args.num_envs
|
||||
processes = []
|
||||
for i in range(num_envs):
|
||||
p = Process(
|
||||
target=run_env_tasks,
|
||||
args=(task_queue, args, shared_scores),
|
||||
name=f"EnvProcess-{i+1}"
|
||||
)
|
||||
p.daemon = True
|
||||
p.start()
|
||||
processes.append(p)
|
||||
logger.info(f"Started process {p.name} with PID {p.pid}")
|
||||
try:
|
||||
while True:
|
||||
alive_count = 0
|
||||
for idx, p in enumerate(processes):
|
||||
if not p.is_alive():
|
||||
logger.warning(f"Process {p.name} died, restarting...")
|
||||
new_p = Process(
|
||||
target=run_env_tasks,
|
||||
args=(task_queue, args, shared_scores),
|
||||
name=f"EnvProcess-Restart-{idx+1}"
|
||||
)
|
||||
new_p.daemon = True
|
||||
new_p.start()
|
||||
processes[idx] = new_p
|
||||
logger.info(f"Restarted process {new_p.name} with PID {new_p.pid}")
|
||||
else:
|
||||
alive_count += 1
|
||||
if task_queue.empty():
|
||||
logger.info("All tasks finished.")
|
||||
break
|
||||
if alive_count == 0:
|
||||
logger.error("All processes died, exiting.")
|
||||
break
|
||||
time.sleep(5)
|
||||
for p in processes:
|
||||
p.join()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True)
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
try:
|
||||
logger.info(f"Terminating process {p.name} due to error...")
|
||||
p.terminate()
|
||||
except Exception as term_e:
|
||||
logger.error(f"Error terminating process {p.name}: {term_e}")
|
||||
raise
|
||||
scores = list(shared_scores)
|
||||
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
|
||||
|
||||
|
||||
def get_unfinished(
|
||||
action_space, use_model, observation_type, result_dir, total_file_json
|
||||
):
|
||||
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||
|
||||
if not os.path.exists(target_dir):
|
||||
return total_file_json
|
||||
|
||||
finished = {}
|
||||
for domain in os.listdir(target_dir):
|
||||
finished[domain] = []
|
||||
domain_path = os.path.join(target_dir, domain)
|
||||
if os.path.isdir(domain_path):
|
||||
for example_id in os.listdir(domain_path):
|
||||
if example_id == "onboard":
|
||||
continue
|
||||
example_path = os.path.join(domain_path, example_id)
|
||||
if os.path.isdir(example_path):
|
||||
if "result.txt" not in os.listdir(example_path):
|
||||
# empty all files under example_id
|
||||
for file in os.listdir(example_path):
|
||||
os.remove(os.path.join(example_path, file))
|
||||
else:
|
||||
finished[domain].append(example_id)
|
||||
|
||||
if not finished:
|
||||
return total_file_json
|
||||
|
||||
for domain, examples in finished.items():
|
||||
if domain in total_file_json:
|
||||
total_file_json[domain] = [
|
||||
x for x in total_file_json[domain] if x not in examples
|
||||
]
|
||||
|
||||
return total_file_json
|
||||
|
||||
|
||||
def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
|
||||
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||
if not os.path.exists(target_dir):
|
||||
print("New experiment, no result yet.")
|
||||
return None
|
||||
|
||||
all_result = []
|
||||
|
||||
for domain in os.listdir(target_dir):
|
||||
domain_path = os.path.join(target_dir, domain)
|
||||
if os.path.isdir(domain_path):
|
||||
for example_id in os.listdir(domain_path):
|
||||
example_path = os.path.join(domain_path, example_id)
|
||||
if os.path.isdir(example_path):
|
||||
if "result.txt" in os.listdir(example_path):
|
||||
# empty all files under example_id
|
||||
try:
|
||||
all_result.append(
|
||||
float(
|
||||
open(
|
||||
os.path.join(example_path, "result.txt"), "r"
|
||||
).read()
|
||||
)
|
||||
)
|
||||
except:
|
||||
all_result.append(0.0)
|
||||
|
||||
if not all_result:
|
||||
print("New experiment, no result yet.")
|
||||
return None
|
||||
else:
|
||||
print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%")
|
||||
return all_result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
####### The complete version of the list of examples #######
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
# Register signal handlers for graceful termination
|
||||
signal.signal(signal.SIGINT, signal_handler) # Handle Ctrl+C
|
||||
signal.signal(signal.SIGTERM, signal_handler) # Handle termination signal
|
||||
|
||||
try:
|
||||
args = config()
|
||||
|
||||
# save args to json in result_dir/action_space/observation_type/model/args.json
|
||||
path_to_args = os.path.join(
|
||||
args.result_dir,
|
||||
args.action_space,
|
||||
args.observation_type,
|
||||
args.model,
|
||||
"args.json",
|
||||
)
|
||||
os.makedirs(os.path.dirname(path_to_args), exist_ok=True)
|
||||
with open(path_to_args, "w", encoding="utf-8") as f:
|
||||
json.dump(vars(args), f, indent=4)
|
||||
|
||||
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
|
||||
test_all_meta = json.load(f)
|
||||
|
||||
if args.domain != "all":
|
||||
test_all_meta = {args.domain: test_all_meta[args.domain]}
|
||||
|
||||
test_file_list = get_unfinished(
|
||||
args.action_space,
|
||||
args.model,
|
||||
args.observation_type,
|
||||
args.result_dir,
|
||||
test_all_meta,
|
||||
)
|
||||
left_info = ""
|
||||
for domain in test_file_list:
|
||||
left_info += f"{domain}: {len(test_file_list[domain])}\n"
|
||||
logger.info(f"Left tasks:\n{left_info}")
|
||||
|
||||
get_result(
|
||||
args.action_space,
|
||||
args.model,
|
||||
args.observation_type,
|
||||
args.result_dir,
|
||||
test_all_meta,
|
||||
)
|
||||
test(args, test_file_list)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Main process received KeyboardInterrupt.")
|
||||
# Signal handler will take care of cleanup
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in main process: {e}", exc_info=True)
|
||||
# Also trigger cleanup for unhandled exceptions
|
||||
signal_handler(signal.SIGTERM, None)
|
||||
finally:
|
||||
# Final cleanup in case any environments or processes remain
|
||||
logger.info("Main process final cleanup...")
|
||||
for env in active_environments:
|
||||
if env is not None:
|
||||
try:
|
||||
logger.info(f"Closing environment in final cleanup...")
|
||||
env.close()
|
||||
logger.info(f"Environment closed successfully in final cleanup")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during final environment cleanup: {e}")
|
||||
|
||||
# First try gentle termination
|
||||
for p in processes:
|
||||
if p is not None and p.is_alive():
|
||||
try:
|
||||
logger.info(f"Terminating process {p.name}...")
|
||||
p.terminate()
|
||||
except Exception as e:
|
||||
logger.error(f"Error terminating process: {e}")
|
||||
|
||||
# Wait a moment for processes to terminate
|
||||
time.sleep(1)
|
||||
|
||||
# Then force kill if needed
|
||||
for p in processes:
|
||||
if p is not None and p.is_alive():
|
||||
try:
|
||||
logger.info(f"Force killing process {p.name}...")
|
||||
os.kill(p.pid, signal.SIGKILL)
|
||||
logger.info(f"Process {p.name} force killed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error force killing process: {e}")
|
||||
|
|
@ -0,0 +1,581 @@
|
|||
from __future__ import annotations
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import signal
|
||||
import time
|
||||
from typing import List
|
||||
from multiprocessing import Process, Manager
|
||||
from multiprocessing import current_process
|
||||
import lib_run_single
|
||||
from desktop_env.desktop_env import DesktopEnv
|
||||
from mm_agents.uitars15_v1 import UITARSAgent
|
||||
import os
|
||||
|
||||
# Global variables for signal handling
|
||||
active_environments = []
|
||||
processes = []
|
||||
is_terminating = False
|
||||
|
||||
# load the environment variables from .env file
|
||||
if os.path.exists(".env"):
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
# Logger Configs {{{ #
|
||||
def config() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run end-to-end evaluation on the benchmark"
|
||||
)
|
||||
|
||||
# environment config
|
||||
parser.add_argument("--path_to_vm", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--headless", action="store_true", help="Run in headless machine"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--action_space", type=str, default="pyautogui", help="Action type"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--observation_type",
|
||||
choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
default="screenshot",
|
||||
help="Observation type",
|
||||
)
|
||||
parser.add_argument("--sleep_after_execution", type=float, default=3.0)
|
||||
parser.add_argument("--max_steps", type=int, default=15)
|
||||
|
||||
# evaluation config
|
||||
parser.add_argument(
|
||||
"--test_config_base_dir", type=str, default="evaluation_examples"
|
||||
)
|
||||
|
||||
# lm config
|
||||
parser.add_argument("--model", type=str, default="uitars15-7b")
|
||||
parser.add_argument("--model_type", type=str, default="qwen25vl", choices=["qwen25vl", "qwen2vl"])
|
||||
parser.add_argument("--infer_mode", type=str, default="qwen25vl_normal", choices=["qwen25vl_normal", "qwen2vl_user"])
|
||||
parser.add_argument("--prompt_style", type=str, default="qwen25vl_normal")
|
||||
parser.add_argument("--input_swap", action="store_true", help="Use copy and paste to type content")
|
||||
parser.add_argument("--language", type=str, default="Chinese")
|
||||
parser.add_argument("--max_pixels", type=float, default=16384*28*28)
|
||||
parser.add_argument("--min_pixels", type=float, default=100*28*28)
|
||||
parser.add_argument("--temperature", type=float, default=1.0)
|
||||
parser.add_argument("--top_p", type=float, default=0.9)
|
||||
parser.add_argument("--top_k", type=int, default=-1)
|
||||
parser.add_argument("--history_n", type=int, default=5)
|
||||
parser.add_argument("--callusr_tolerance", type=int, default=3)
|
||||
parser.add_argument("--max_tokens", type=int, default=500)
|
||||
parser.add_argument("--stop_token", type=str, default=None)
|
||||
|
||||
parser.add_argument("--max_trajectory_length", type=int, default=None, help="The max number of trajectory steps.")
|
||||
parser.add_argument("--max_image_history_length", type=int, default=5, help="The max number of images in the history.")
|
||||
|
||||
# example config
|
||||
parser.add_argument("--domain", type=str, default="all")
|
||||
parser.add_argument(
|
||||
"--test_all_meta_path", type=str, default="evaluation_examples/test_all.json"
|
||||
)
|
||||
|
||||
# logging related
|
||||
parser.add_argument("--result_dir", type=str, default="./results")
|
||||
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel")
|
||||
parser.add_argument("--log_level", type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
|
||||
default='INFO', help="Set the logging level")
|
||||
# aws config
|
||||
parser.add_argument(
|
||||
"--region", type=str, default="us-east-1", help="AWS region for the VM"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--provider_name", type=str, default="aws", choices=["aws", "virtualbox", "vmware", "docker", "azure"], help="Provider name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--client_password", type=str, default="", help="Client password"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--screen_width", type=int, default=1920, help="Screen width"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--screen_height", type=int, default=1080, help="Screen height"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
args = config() # Get command line arguments first
|
||||
|
||||
logger = logging.getLogger()
|
||||
log_level = getattr(logging, args.log_level.upper())
|
||||
logger.setLevel(log_level)
|
||||
|
||||
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
|
||||
file_handler = logging.FileHandler(
|
||||
os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||
)
|
||||
debug_handler = logging.FileHandler(
|
||||
os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||
)
|
||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||
|
||||
file_handler.setLevel(logging.INFO)
|
||||
debug_handler.setLevel(logging.DEBUG)
|
||||
stdout_handler.setLevel(log_level)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
|
||||
)
|
||||
file_handler.setFormatter(formatter)
|
||||
debug_handler.setFormatter(formatter)
|
||||
stdout_handler.setFormatter(formatter)
|
||||
|
||||
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
||||
|
||||
logger.addHandler(file_handler)
|
||||
logger.addHandler(debug_handler)
|
||||
logger.addHandler(stdout_handler)
|
||||
# }}} Logger Configs #
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
|
||||
def distribute_tasks(test_all_meta: dict) -> List[tuple]:
|
||||
all_tasks = []
|
||||
for domain, examples in test_all_meta.items():
|
||||
for example_id in examples:
|
||||
all_tasks.append((domain, example_id))
|
||||
return all_tasks
|
||||
|
||||
|
||||
def process_signal_handler(signum, frame, env_idx):
|
||||
"""Signal handler for child processes to gracefully shut down their environments."""
|
||||
logger.info(f"Process {env_idx + 1} received signal {signum}. Shutting down...")
|
||||
|
||||
# Get the active_environments from the caller's frame
|
||||
local_vars = frame.f_locals
|
||||
active_environments = local_vars.get('active_environments', [])
|
||||
|
||||
# Close environment in the current process context
|
||||
for env in active_environments:
|
||||
if env is not None:
|
||||
try:
|
||||
logger.info(f"Process {env_idx + 1} closing environment...")
|
||||
env.close()
|
||||
logger.info(f"Process {env_idx + 1} environment closed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Process {env_idx + 1} error closing environment: {e}")
|
||||
|
||||
logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.")
|
||||
sys.exit(0)
|
||||
|
||||
def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: list):
|
||||
active_environments = []
|
||||
env = None
|
||||
try:
|
||||
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
|
||||
REGION = args.region
|
||||
screen_size = (args.screen_width, args.screen_height)
|
||||
ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)])
|
||||
env = DesktopEnv(
|
||||
path_to_vm=args.path_to_vm,
|
||||
action_space=args.action_space,
|
||||
provider_name=args.provider_name,
|
||||
region=REGION,
|
||||
snapshot_name=ami_id,
|
||||
screen_size=screen_size,
|
||||
headless=args.headless,
|
||||
os_type="Ubuntu",
|
||||
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
enable_proxy=True,
|
||||
client_password=args.client_password
|
||||
)
|
||||
active_environments.append(env)
|
||||
args.max_trajectory_length = args.max_steps
|
||||
if args.infer_mode == "qwen25vl_normal":
|
||||
runtime_conf: dict = {
|
||||
"infer_mode": "qwen25vl_normal",
|
||||
"prompt_style": "qwen25vl_normal",
|
||||
"input_swap": True,
|
||||
"language": "Chinese",
|
||||
"history_n": 5,
|
||||
"max_pixels": 16384*28*28,
|
||||
"min_pixels": 100*28*28,
|
||||
"callusr_tolerance": 3,
|
||||
"temperature": 0.0,
|
||||
"top_k": -1,
|
||||
"top_p": 0.9,
|
||||
"max_tokens": 1000
|
||||
|
||||
}
|
||||
elif args.infer_mode == "qwen2vl_user":
|
||||
runtime_conf: dict = {
|
||||
"infer_mode": "qwen2vl_user",
|
||||
"prompt_style": "qwen2vl_user",
|
||||
"input_swap": True,
|
||||
"language": "Chinese",
|
||||
"history_n": 5,
|
||||
"max_pixels": 2116800,
|
||||
"min_pixels": 3136,
|
||||
"callusr_tolerance": 3,
|
||||
"temperature": 0.0,
|
||||
"top_k": -1,
|
||||
"top_p": 0.9,
|
||||
"max_tokens": 1000
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unknown infer_mode: {args.infer_mode}")
|
||||
|
||||
agent = UITARSAgent(
|
||||
model=args.model,
|
||||
action_space=args.action_space,
|
||||
observation_type=args.observation_type,
|
||||
max_trajectory_length=args.max_trajectory_length,
|
||||
model_type=args.model_type,
|
||||
runtime_conf = runtime_conf
|
||||
)
|
||||
|
||||
logger.info(f"Process {current_process().name} started.")
|
||||
while True:
|
||||
try:
|
||||
item = task_queue.get(timeout=5)
|
||||
except Exception:
|
||||
break
|
||||
domain, example_id = item
|
||||
try:
|
||||
config_file = os.path.join(
|
||||
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
|
||||
)
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
example = json.load(f)
|
||||
logger.info(f"[{current_process().name}][Domain]: {domain}")
|
||||
logger.info(f"[{current_process().name}][Example ID]: {example_id}")
|
||||
logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}")
|
||||
example_result_dir = os.path.join(
|
||||
args.result_dir,
|
||||
args.action_space,
|
||||
args.observation_type,
|
||||
args.model,
|
||||
domain,
|
||||
example_id,
|
||||
)
|
||||
os.makedirs(example_result_dir, exist_ok=True)
|
||||
try:
|
||||
lib_run_single.run_single_example(
|
||||
agent,
|
||||
env,
|
||||
example,
|
||||
args.max_steps,
|
||||
example["instruction"],
|
||||
args,
|
||||
example_result_dir,
|
||||
shared_scores,
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
try:
|
||||
env.controller.end_recording(
|
||||
os.path.join(example_result_dir, "recording.mp4")
|
||||
)
|
||||
except Exception as rec_e:
|
||||
logger.error(f"Failed to end recording: {rec_e}")
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||
f.write(
|
||||
json.dumps(
|
||||
{"Error": f"{domain}/{example_id} - {e}"}
|
||||
)
|
||||
)
|
||||
f.write("\n")
|
||||
except Exception as e:
|
||||
logger.error(f"Task-level error in {current_process().name}: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
except Exception as e:
|
||||
logger.error(f"Process-level error in {current_process().name}: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
finally:
|
||||
logger.info(f"{current_process().name} cleaning up environment...")
|
||||
try:
|
||||
if env:
|
||||
env.close()
|
||||
logger.info(f"{current_process().name} environment closed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"{current_process().name} error during environment cleanup: {e}")
|
||||
|
||||
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
"""Handle termination signals (SIGINT, SIGTERM) to gracefully shutdown environments."""
|
||||
global is_terminating, active_environments, processes
|
||||
|
||||
# Avoid duplicate handling
|
||||
if is_terminating:
|
||||
return
|
||||
|
||||
is_terminating = True
|
||||
logger.info(f"Received signal {signum}. Gracefully shutting down...")
|
||||
|
||||
# Close all registered environments in the main process
|
||||
for env in active_environments:
|
||||
try:
|
||||
logger.info(f"Closing environment...")
|
||||
env.close()
|
||||
logger.info(f"Environment closed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing environment: {e}")
|
||||
|
||||
# Send termination signal to all child processes first
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
try:
|
||||
logger.info(f"Sending termination signal to process {p.name}...")
|
||||
p.terminate()
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending termination signal to process: {e}")
|
||||
|
||||
# Allow a short time for processes to handle their own cleanup
|
||||
time.sleep(1)
|
||||
|
||||
# Forcefully terminate any processes that didn't exit
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
try:
|
||||
logger.info(f"Forcefully terminating process {p.name}...")
|
||||
import signal as sig
|
||||
os.kill(p.pid, sig.SIGKILL)
|
||||
except Exception as e:
|
||||
logger.error(f"Error forcefully terminating process: {e}")
|
||||
|
||||
logger.info("Shutdown complete. Exiting.")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||
global processes
|
||||
logger.info("Args: %s", args)
|
||||
all_tasks = distribute_tasks(test_all_meta)
|
||||
logger.info(f"Total tasks: {len(all_tasks)}")
|
||||
with Manager() as manager:
|
||||
shared_scores = manager.list()
|
||||
task_queue = manager.Queue()
|
||||
for item in all_tasks:
|
||||
task_queue.put(item)
|
||||
num_envs = args.num_envs
|
||||
processes = []
|
||||
for i in range(num_envs):
|
||||
p = Process(
|
||||
target=run_env_tasks,
|
||||
args=(task_queue, args, shared_scores),
|
||||
name=f"EnvProcess-{i+1}"
|
||||
)
|
||||
p.daemon = True
|
||||
p.start()
|
||||
processes.append(p)
|
||||
logger.info(f"Started process {p.name} with PID {p.pid}")
|
||||
try:
|
||||
while True:
|
||||
alive_count = 0
|
||||
for idx, p in enumerate(processes):
|
||||
if not p.is_alive():
|
||||
logger.warning(f"Process {p.name} died, restarting...")
|
||||
new_p = Process(
|
||||
target=run_env_tasks,
|
||||
args=(task_queue, args, shared_scores),
|
||||
name=f"EnvProcess-Restart-{idx+1}"
|
||||
)
|
||||
new_p.daemon = True
|
||||
new_p.start()
|
||||
processes[idx] = new_p
|
||||
logger.info(f"Restarted process {new_p.name} with PID {new_p.pid}")
|
||||
else:
|
||||
alive_count += 1
|
||||
if task_queue.empty():
|
||||
logger.info("All tasks finished.")
|
||||
break
|
||||
if alive_count == 0:
|
||||
logger.error("All processes died, exiting.")
|
||||
break
|
||||
time.sleep(5)
|
||||
for p in processes:
|
||||
p.join()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True)
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
try:
|
||||
logger.info(f"Terminating process {p.name} due to error...")
|
||||
p.terminate()
|
||||
except Exception as term_e:
|
||||
logger.error(f"Error terminating process {p.name}: {term_e}")
|
||||
raise
|
||||
scores = list(shared_scores)
|
||||
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
|
||||
|
||||
|
||||
def get_unfinished(
|
||||
action_space, use_model, observation_type, result_dir, total_file_json
|
||||
):
|
||||
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||
|
||||
if not os.path.exists(target_dir):
|
||||
return total_file_json
|
||||
|
||||
finished = {}
|
||||
for domain in os.listdir(target_dir):
|
||||
finished[domain] = []
|
||||
domain_path = os.path.join(target_dir, domain)
|
||||
if os.path.isdir(domain_path):
|
||||
for example_id in os.listdir(domain_path):
|
||||
if example_id == "onboard":
|
||||
continue
|
||||
example_path = os.path.join(domain_path, example_id)
|
||||
if os.path.isdir(example_path):
|
||||
if "result.txt" not in os.listdir(example_path):
|
||||
# empty all files under example_id
|
||||
for file in os.listdir(example_path):
|
||||
os.remove(os.path.join(example_path, file))
|
||||
else:
|
||||
finished[domain].append(example_id)
|
||||
|
||||
if not finished:
|
||||
return total_file_json
|
||||
|
||||
for domain, examples in finished.items():
|
||||
if domain in total_file_json:
|
||||
total_file_json[domain] = [
|
||||
x for x in total_file_json[domain] if x not in examples
|
||||
]
|
||||
|
||||
return total_file_json
|
||||
|
||||
|
||||
def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
|
||||
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||
if not os.path.exists(target_dir):
|
||||
print("New experiment, no result yet.")
|
||||
return None
|
||||
|
||||
all_result = []
|
||||
|
||||
for domain in os.listdir(target_dir):
|
||||
domain_path = os.path.join(target_dir, domain)
|
||||
if os.path.isdir(domain_path):
|
||||
for example_id in os.listdir(domain_path):
|
||||
example_path = os.path.join(domain_path, example_id)
|
||||
if os.path.isdir(example_path):
|
||||
if "result.txt" in os.listdir(example_path):
|
||||
# empty all files under example_id
|
||||
try:
|
||||
all_result.append(
|
||||
float(
|
||||
open(
|
||||
os.path.join(example_path, "result.txt"), "r"
|
||||
).read()
|
||||
)
|
||||
)
|
||||
except:
|
||||
all_result.append(0.0)
|
||||
|
||||
if not all_result:
|
||||
print("New experiment, no result yet.")
|
||||
return None
|
||||
else:
|
||||
print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%")
|
||||
return all_result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
####### The complete version of the list of examples #######
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
# Register signal handlers for graceful termination
|
||||
signal.signal(signal.SIGINT, signal_handler) # Handle Ctrl+C
|
||||
signal.signal(signal.SIGTERM, signal_handler) # Handle termination signal
|
||||
|
||||
try:
|
||||
args = config()
|
||||
|
||||
# save args to json in result_dir/action_space/observation_type/model/args.json
|
||||
path_to_args = os.path.join(
|
||||
args.result_dir,
|
||||
args.action_space,
|
||||
args.observation_type,
|
||||
args.model,
|
||||
"args.json",
|
||||
)
|
||||
os.makedirs(os.path.dirname(path_to_args), exist_ok=True)
|
||||
with open(path_to_args, "w", encoding="utf-8") as f:
|
||||
json.dump(vars(args), f, indent=4)
|
||||
|
||||
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
|
||||
test_all_meta = json.load(f)
|
||||
|
||||
if args.domain != "all":
|
||||
test_all_meta = {args.domain: test_all_meta[args.domain]}
|
||||
|
||||
test_file_list = get_unfinished(
|
||||
args.action_space,
|
||||
args.model,
|
||||
args.observation_type,
|
||||
args.result_dir,
|
||||
test_all_meta,
|
||||
)
|
||||
left_info = ""
|
||||
for domain in test_file_list:
|
||||
left_info += f"{domain}: {len(test_file_list[domain])}\n"
|
||||
logger.info(f"Left tasks:\n{left_info}")
|
||||
|
||||
get_result(
|
||||
args.action_space,
|
||||
args.model,
|
||||
args.observation_type,
|
||||
args.result_dir,
|
||||
test_all_meta,
|
||||
)
|
||||
test(args, test_file_list)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Main process received KeyboardInterrupt.")
|
||||
# Signal handler will take care of cleanup
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in main process: {e}", exc_info=True)
|
||||
# Also trigger cleanup for unhandled exceptions
|
||||
signal_handler(signal.SIGTERM, None)
|
||||
finally:
|
||||
# Final cleanup in case any environments or processes remain
|
||||
logger.info("Main process final cleanup...")
|
||||
for env in active_environments:
|
||||
if env is not None:
|
||||
try:
|
||||
logger.info(f"Closing environment in final cleanup...")
|
||||
env.close()
|
||||
logger.info(f"Environment closed successfully in final cleanup")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during final environment cleanup: {e}")
|
||||
|
||||
# First try gentle termination
|
||||
for p in processes:
|
||||
if p is not None and p.is_alive():
|
||||
try:
|
||||
logger.info(f"Terminating process {p.name}...")
|
||||
p.terminate()
|
||||
except Exception as e:
|
||||
logger.error(f"Error terminating process: {e}")
|
||||
|
||||
# Wait a moment for processes to terminate
|
||||
time.sleep(1)
|
||||
|
||||
# Then force kill if needed
|
||||
for p in processes:
|
||||
if p is not None and p.is_alive():
|
||||
try:
|
||||
logger.info(f"Force killing process {p.name}...")
|
||||
os.kill(p.pid, signal.SIGKILL)
|
||||
logger.info(f"Process {p.name} force killed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error force killing process: {e}")
|
||||
|
|
@ -8,31 +8,13 @@ import sys
|
|||
import signal
|
||||
import time
|
||||
from typing import List, Dict
|
||||
import math
|
||||
from tqdm import tqdm
|
||||
from multiprocessing import Process, Manager
|
||||
from multiprocessing import current_process
|
||||
import lib_run_single
|
||||
from desktop_env.desktop_env import DesktopEnv
|
||||
from mm_agents.uitars15_agent import UITarsAgent
|
||||
|
||||
import shutil
|
||||
from mm_agents.uitars15_v2 import UITarsAgent
|
||||
import os
|
||||
|
||||
# def clear_cache():
|
||||
# cache_path = "cache"
|
||||
|
||||
# try:
|
||||
# if os.path.exists(cache_path):
|
||||
# logger.info(f"Deleting cache directory: {cache_path}")
|
||||
# shutil.rmtree(cache_path)
|
||||
# logger.info(f"Cache directory deleted successfully")
|
||||
# else:
|
||||
# logger.info(f"Cache directory {cache_path} does not exist")
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error deleting cache directory: {e}")
|
||||
|
||||
# clear_cache()
|
||||
|
||||
# Global variables for signal handling
|
||||
active_environments = []
|
||||
|
|
@ -74,12 +56,12 @@ def config() -> argparse.Namespace:
|
|||
|
||||
# lm config
|
||||
parser.add_argument("--model", type=str, default="doubao-1-5-thinking-vision-pro-250428")
|
||||
parser.add_argument("--model_type", type=str, default="doubao", choices=["doubao", "qwen25"])
|
||||
parser.add_argument("--temperature", type=float, default=0)
|
||||
parser.add_argument("--top_p", type=float, default=None)
|
||||
parser.add_argument("--max_tokens", type=int, default=3000)
|
||||
parser.add_argument("--use_thinking", action="store_true", default=False)
|
||||
|
||||
# OpenCUAagent config
|
||||
parser.add_argument("--max_trajectory_length", type=int, default=None, help="The max number of trajectory steps.")
|
||||
parser.add_argument("--max_image_history_length", type=int, default=5, help="The max number of images in the history.")
|
||||
parser.add_argument("--language", type=str, default="Chinese", help="Language for the agent.")
|
||||
|
|
@ -204,6 +186,7 @@ def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: li
|
|||
active_environments.append(env)
|
||||
agent = UITarsAgent(
|
||||
model=args.model,
|
||||
model_type=args.model_type,
|
||||
max_tokens=args.max_tokens,
|
||||
top_p=args.top_p,
|
||||
temperature=args.temperature,
|
||||
362
run_uitars.py
362
run_uitars.py
|
|
@ -1,362 +0,0 @@
|
|||
"""Script to run end-to-end evaluation on the benchmark.
|
||||
Utils and basic architecture credit to https://github.com/web-arena-x/webarena/blob/main/run.py.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import lib_run_single
|
||||
from desktop_env.desktop_env import DesktopEnv
|
||||
from mm_agents.uitars_agent import UITARSAgent
|
||||
|
||||
# import wandb
|
||||
|
||||
|
||||
# Logger Configs {{{ #
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
|
||||
file_handler = logging.FileHandler(
|
||||
os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||
)
|
||||
debug_handler = logging.FileHandler(
|
||||
os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||
)
|
||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||
sdebug_handler = logging.FileHandler(
|
||||
os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||
)
|
||||
|
||||
file_handler.setLevel(logging.INFO)
|
||||
debug_handler.setLevel(logging.DEBUG)
|
||||
stdout_handler.setLevel(logging.INFO)
|
||||
sdebug_handler.setLevel(logging.DEBUG)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
|
||||
)
|
||||
file_handler.setFormatter(formatter)
|
||||
debug_handler.setFormatter(formatter)
|
||||
stdout_handler.setFormatter(formatter)
|
||||
sdebug_handler.setFormatter(formatter)
|
||||
|
||||
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
||||
sdebug_handler.addFilter(logging.Filter("desktopenv"))
|
||||
|
||||
logger.addHandler(file_handler)
|
||||
logger.addHandler(debug_handler)
|
||||
logger.addHandler(stdout_handler)
|
||||
logger.addHandler(sdebug_handler)
|
||||
# }}} Logger Configs #
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
|
||||
def config() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run end-to-end evaluation on the benchmark"
|
||||
)
|
||||
|
||||
# environment config
|
||||
parser.add_argument("--path_to_vm", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--headless", action="store_true", help="Run in headless machine"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--action_space", type=str, default="pyautogui", help="Action type"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--observation_type",
|
||||
choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
default="a11y_tree",
|
||||
help="Observation type",
|
||||
)
|
||||
parser.add_argument("--screen_width", type=int, default=1920)
|
||||
parser.add_argument("--screen_height", type=int, default=1080)
|
||||
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
|
||||
parser.add_argument("--max_steps", type=int, default=15)
|
||||
|
||||
# agent config
|
||||
parser.add_argument("--max_trajectory_length", type=int, default=3)
|
||||
parser.add_argument(
|
||||
"--test_config_base_dir", type=str, default="evaluation_examples"
|
||||
)
|
||||
|
||||
# lm config
|
||||
parser.add_argument("--model", type=str, default="uitars")
|
||||
parser.add_argument("--model_type", type=str, default="qwen25vl")
|
||||
parser.add_argument("--infer_mode", type=str, default="qwen25vl_normal")
|
||||
parser.add_argument("--prompt_style", type=str, default="qwen25vl_normal")
|
||||
parser.add_argument("--input_swap", action="store_true", help="Use copy and paste to type content")
|
||||
parser.add_argument("--language", type=str, default="Chinese")
|
||||
parser.add_argument("--max_pixels", type=float, default=16384*28*28)
|
||||
parser.add_argument("--min_pixels", type=float, default=100*28*28)
|
||||
parser.add_argument("--temperature", type=float, default=1.0)
|
||||
parser.add_argument("--top_p", type=float, default=0.9)
|
||||
parser.add_argument("--top_k", type=int, default=-1)
|
||||
parser.add_argument("--history_n", type=int, default=5)
|
||||
parser.add_argument("--callusr_tolerance", type=int, default=3)
|
||||
parser.add_argument("--max_tokens", type=int, default=500)
|
||||
parser.add_argument("--stop_token", type=str, default=None)
|
||||
|
||||
# example config
|
||||
parser.add_argument("--domain", type=str, default="all")
|
||||
parser.add_argument(
|
||||
"--test_all_meta_path", type=str, default="evaluation_examples/test_all.json"
|
||||
)
|
||||
|
||||
# logging related
|
||||
parser.add_argument("--result_dir", type=str, default="./results")
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||
scores = []
|
||||
max_steps = args.max_steps
|
||||
|
||||
# log args
|
||||
logger.info("Args: %s", args)
|
||||
# set wandb project
|
||||
cfg_args = {
|
||||
"path_to_vm": args.path_to_vm,
|
||||
"headless": args.headless,
|
||||
"action_space": args.action_space,
|
||||
"observation_type": args.observation_type,
|
||||
"screen_width": args.screen_width,
|
||||
"screen_height": args.screen_height,
|
||||
"sleep_after_execution": args.sleep_after_execution,
|
||||
"max_steps": args.max_steps,
|
||||
"max_trajectory_length": args.max_trajectory_length,
|
||||
"model": args.model,
|
||||
"model_type": args.model_type,
|
||||
"infer_mode": args.infer_mode,
|
||||
"prompt_style": args.prompt_style,
|
||||
"input_swap": args.input_swap,
|
||||
"language": args.language,
|
||||
"history_n": args.history_n,
|
||||
"max_pixels": args.max_pixels,
|
||||
"min_pixels": args.min_pixels,
|
||||
"callusr_tolerance": args.callusr_tolerance,
|
||||
"temperature": args.temperature,
|
||||
"top_p": args.top_p,
|
||||
"top_k": args.top_k,
|
||||
"max_tokens": args.max_tokens,
|
||||
"stop_token": args.stop_token,
|
||||
"result_dir": args.result_dir,
|
||||
}
|
||||
|
||||
agent = UITARSAgent(
|
||||
action_space=args.action_space,
|
||||
observation_type=args.observation_type,
|
||||
max_trajectory_length=args.max_trajectory_length,
|
||||
model_type=args.model_type,
|
||||
runtime_conf = {
|
||||
"infer_mode": args.infer_mode,
|
||||
"prompt_style": args.prompt_style,
|
||||
"input_swap": args.input_swap,
|
||||
"language": args.language,
|
||||
"history_n": args.history_n,
|
||||
"max_pixels": args.max_pixels,
|
||||
"min_pixels": args.min_pixels,
|
||||
"callusr_tolerance": args.callusr_tolerance,
|
||||
"temperature": args.temperature,
|
||||
"top_p": args.top_p,
|
||||
"top_k": args.top_k,
|
||||
"max_tokens": args.max_tokens
|
||||
}
|
||||
)
|
||||
|
||||
env = DesktopEnv(
|
||||
path_to_vm=args.path_to_vm,
|
||||
action_space=agent.action_space,
|
||||
screen_size=(args.screen_width, args.screen_height),
|
||||
headless=args.headless,
|
||||
os_type = "Ubuntu",
|
||||
require_a11y_tree=args.observation_type
|
||||
in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
)
|
||||
|
||||
for domain in tqdm(test_all_meta, desc="Domain"):
|
||||
for example_id in tqdm(test_all_meta[domain], desc="Example", leave=False):
|
||||
config_file = os.path.join(
|
||||
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
|
||||
)
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
example = json.load(f)
|
||||
|
||||
logger.info(f"[Domain]: {domain}")
|
||||
logger.info(f"[Example ID]: {example_id}")
|
||||
|
||||
instruction = example["instruction"]
|
||||
|
||||
logger.info(f"[Instruction]: {instruction}")
|
||||
# wandb each example config settings
|
||||
cfg_args["instruction"] = instruction
|
||||
cfg_args["start_time"] = datetime.datetime.now().strftime(
|
||||
"%Y:%m:%d-%H:%M:%S"
|
||||
)
|
||||
# run.config.update(cfg_args)
|
||||
|
||||
example_result_dir = os.path.join(
|
||||
args.result_dir,
|
||||
args.action_space,
|
||||
args.observation_type,
|
||||
args.model,
|
||||
domain,
|
||||
example_id,
|
||||
)
|
||||
os.makedirs(example_result_dir, exist_ok=True)
|
||||
# example start running
|
||||
try:
|
||||
lib_run_single.run_single_example(
|
||||
agent,
|
||||
env,
|
||||
example,
|
||||
max_steps,
|
||||
instruction,
|
||||
args,
|
||||
example_result_dir,
|
||||
scores,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Exception in {domain}/{example_id}: {e}")
|
||||
env.controller.end_recording(
|
||||
os.path.join(example_result_dir, "recording.mp4")
|
||||
)
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||
f.write(
|
||||
json.dumps(
|
||||
{"Error": f"Time limit exceeded in {domain}/{example_id}"}
|
||||
)
|
||||
)
|
||||
f.write("\n")
|
||||
|
||||
env.close()
|
||||
logger.info(f"Average score: {sum(scores) / len(scores)}")
|
||||
|
||||
|
||||
def get_unfinished(
|
||||
action_space, use_model, observation_type, result_dir, total_file_json
|
||||
):
|
||||
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||
|
||||
if not os.path.exists(target_dir):
|
||||
return total_file_json
|
||||
|
||||
finished = {}
|
||||
for domain in os.listdir(target_dir):
|
||||
finished[domain] = []
|
||||
domain_path = os.path.join(target_dir, domain)
|
||||
if os.path.isdir(domain_path):
|
||||
for example_id in os.listdir(domain_path):
|
||||
if example_id == "onboard":
|
||||
continue
|
||||
example_path = os.path.join(domain_path, example_id)
|
||||
if os.path.isdir(example_path):
|
||||
if "result.txt" not in os.listdir(example_path):
|
||||
# empty all files under example_id
|
||||
for file in os.listdir(example_path):
|
||||
os.remove(os.path.join(example_path, file))
|
||||
else:
|
||||
finished[domain].append(example_id)
|
||||
|
||||
if not finished:
|
||||
return total_file_json
|
||||
|
||||
for domain, examples in finished.items():
|
||||
if domain in total_file_json:
|
||||
total_file_json[domain] = [
|
||||
x for x in total_file_json[domain] if x not in examples
|
||||
]
|
||||
|
||||
return total_file_json
|
||||
|
||||
|
||||
def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
|
||||
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||
if not os.path.exists(target_dir):
|
||||
print("New experiment, no result yet.")
|
||||
return None
|
||||
|
||||
all_result = []
|
||||
|
||||
for domain in os.listdir(target_dir):
|
||||
domain_path = os.path.join(target_dir, domain)
|
||||
if os.path.isdir(domain_path):
|
||||
for example_id in os.listdir(domain_path):
|
||||
example_path = os.path.join(domain_path, example_id)
|
||||
if os.path.isdir(example_path):
|
||||
if "result.txt" in os.listdir(example_path):
|
||||
# empty all files under example_id
|
||||
try:
|
||||
all_result.append(
|
||||
float(
|
||||
open(
|
||||
os.path.join(example_path, "result.txt"), "r"
|
||||
).read()
|
||||
)
|
||||
)
|
||||
except:
|
||||
all_result.append(0.0)
|
||||
|
||||
if not all_result:
|
||||
print("New experiment, no result yet.")
|
||||
return None
|
||||
else:
|
||||
print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%")
|
||||
return all_result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
####### The complete version of the list of examples #######
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
args = config()
|
||||
|
||||
# save args to json in result_dir/action_space/observation_type/model/args.json
|
||||
path_to_args = os.path.join(
|
||||
args.result_dir,
|
||||
args.action_space,
|
||||
args.observation_type,
|
||||
args.model,
|
||||
"args.json",
|
||||
)
|
||||
os.makedirs(os.path.dirname(path_to_args), exist_ok=True)
|
||||
with open(path_to_args, "w", encoding="utf-8") as f:
|
||||
json.dump(vars(args), f, indent=4)
|
||||
|
||||
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
|
||||
test_all_meta = json.load(f)
|
||||
|
||||
if args.domain != "all":
|
||||
test_all_meta = {args.domain: test_all_meta[args.domain]}
|
||||
|
||||
test_file_list = get_unfinished(
|
||||
args.action_space,
|
||||
args.model,
|
||||
args.observation_type,
|
||||
args.result_dir,
|
||||
test_all_meta,
|
||||
)
|
||||
left_info = ""
|
||||
for domain in test_file_list:
|
||||
left_info += f"{domain}: {len(test_file_list[domain])}\n"
|
||||
logger.info(f"Left tasks:\n{left_info}")
|
||||
|
||||
get_result(
|
||||
args.action_space,
|
||||
args.model,
|
||||
args.observation_type,
|
||||
args.result_dir,
|
||||
test_all_meta,
|
||||
)
|
||||
test(args, test_file_list)
|
||||
Loading…
Reference in New Issue