Compare commits

...

6 Commits

Author SHA1 Message Date
XinyuanWangCS bfddbcff27 Merge branch 'main' into uitars/dev 2025-07-30 14:03:43 +00:00
XinyuanWangCS b5a57c2d58 merge main 2025-07-27 05:48:15 +00:00
XinyuanWangCS c3f3329fc2 add all the uitars agents:
1. run_multienv_uitars.py: Qwen2VL-based UITARS models
2. run_multienv_uitars15_v1.py: UITARS1.5-7B
3. run_multienv_uitars15_v2.py: SeedVL1.5 thining/non-thinking
2025-07-27 05:42:57 +00:00
Jiaqi 826c0ef945 Merge branch 'main' into jq/dev 2025-07-24 08:13:07 +00:00
Jiaqi 80b80617c4 os task fix: set the default dim screen time to be 300s 2025-07-24 08:12:45 +00:00
Jiaqi 0f1ef6d9b7 use aws pub ip 2025-07-24 06:04:01 +00:00
9 changed files with 2155 additions and 643 deletions

View File

@ -77,7 +77,8 @@ class AWSProvider(Provider):
else:
logger.warning("No public IP address available for VNC access")
return private_ip_address
#return private_ip_address
return public_ip_address
return '' # Return an empty string if no IP address is found
except ClientError as e:
logger.error(f"Failed to retrieve IP address for the instance {path_to_vm}: {str(e)}")

View File

@ -44,6 +44,7 @@ def run_single_example(agent, env, example, max_steps, instruction, args, exampl
"step_num": step_idx + 1,
"action_timestamp": action_timestamp,
"action": action,
"response": response,
"reward": reward,
"done": done,
"info": info,

956
mm_agents/uitars15_v1.py Normal file
View File

@ -0,0 +1,956 @@
import ast
import base64
from openai import OpenAI
import math
import re
import xml.etree.ElementTree as ET
from io import BytesIO
from typing import Dict, List
import numpy as np
import base64
from loguru import logger
import os
import re
from io import BytesIO
from typing import Dict, List
from PIL import Image
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import (
filter_nodes,
)
UITARS_ACTION_SPACE = """
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
hotkey(key='')
type(content='') #If you want to submit your input, use "\\n" at the end of `content`.
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
wait() #Sleep for 5s and take a screenshot to check for any changes.
finished()
"""
UITARS_CALL_USR_ACTION_SPACE = """
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
hotkey(key='')
type(content='') #If you want to submit your input, use "\\n" at the end of `content`.
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
wait() #Sleep for 5s and take a screenshot to check for any changes.
finished()
call_user() # Submit the task and call the user when the task is unsolvable, or when you need the user's help.
"""
UITARS_NORMAL_ACTION_SPACE = """
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
hotkey(key='')
type(content='') #If you want to submit your input, use "\\n" at the end of `content`.
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
wait() #Sleep for 5s and take a screenshot to check for any changes.
finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format.
"""
UITARS_USR_PROMPT_NOTHOUGHT = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
## Output Format
```
Action: ...
```
## Action Space
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
hotkey(key='')
type(content='') #If you want to submit your input, use "\\n" at the end of `content`.
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
wait() #Sleep for 5s and take a screenshot to check for any changes.
finished()
call_user() # Submit the task and call the user when the task is unsolvable, or when you need the user's help.
## User Instruction
{instruction}
"""
UITARS_USR_PROMPT_THOUGHT = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
## Output Format
```
Thought: ...
Action: ...
```
## Action Space
{action_space}
## Note
- Use {language} in `Thought` part.
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
## User Instruction
{instruction}
"""
FINISH_WORD = "finished"
WAIT_WORD = "wait"
ENV_FAIL_WORD = "error_env"
CALL_USER = "call_user"
IMAGE_FACTOR = 28
MIN_PIXELS = 100 * 28 * 28
MAX_PIXELS = 16384 * 28 * 28
MAX_RATIO = 200
# 定义一个函数来解析每个 action
def parse_action(action_str):
try:
# 解析字符串为 AST 节点
node = ast.parse(action_str, mode='eval')
# 确保节点是一个表达式
if not isinstance(node, ast.Expression):
raise ValueError("Not an expression")
# 获取表达式的主体
call = node.body
# 确保主体是一个函数调用
if not isinstance(call, ast.Call):
raise ValueError("Not a function call")
# 获取函数名
if isinstance(call.func, ast.Name):
func_name = call.func.id
elif isinstance(call.func, ast.Attribute):
func_name = call.func.attr
else:
func_name = None
# 获取关键字参数
kwargs = {}
for kw in call.keywords:
key = kw.arg
# 处理不同类型的值,这里假设都是常量
if isinstance(kw.value, ast.Constant):
value = kw.value.value
elif isinstance(kw.value, ast.Str): # 兼容旧版本 Python
value = kw.value.s
else:
value = None
kwargs[key] = value
return {
'function': func_name,
'args': kwargs
}
except Exception as e:
print(f"Failed to parse action '{action_str}': {e}")
return None
def escape_single_quotes(text):
# 匹配未转义的单引号(不匹配 \\'
pattern = r"(?<!\\)'"
return re.sub(pattern, r"\\'", text)
def round_by_factor(number: int, factor: int) -> int:
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
return round(number / factor) * factor
def ceil_by_factor(number: int, factor: int) -> int:
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
return math.ceil(number / factor) * factor
def floor_by_factor(number: int, factor: int) -> int:
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return math.floor(number / factor) * factor
def linear_resize(
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
) -> tuple[int, int]:
if width * height > max_pixels:
"""
如果图片超过/低于像素限制则计算一个缩放因子resize_factor使图片的像素数缩小到等于或小于max_pixels这个缩放因子是通过开平方根计算的确保纵横比保持不变,这样原始的相对坐标可以不经转换直接复用
"""
resize_factor = math.sqrt(max_pixels / (width * height))
width, height = int(width * resize_factor), int(height * resize_factor)
if width * height < min_pixels:
resize_factor = math.sqrt(min_pixels / (width * height))
width, height = math.ceil(width * resize_factor), math.ceil(height * resize_factor)
return height, width
def smart_resize(
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
) -> tuple[int, int]:
"""
Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
if max(height, width) / min(height, width) > MAX_RATIO:
raise ValueError(
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
)
h_bar = max(factor, round_by_factor(height, factor))
w_bar = max(factor, round_by_factor(width, factor))
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = floor_by_factor(height / beta, factor)
w_bar = floor_by_factor(width / beta, factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = ceil_by_factor(height * beta, factor)
w_bar = ceil_by_factor(width * beta, factor)
return h_bar, w_bar
def parse_action_to_structure_output(text, factor, origin_resized_height, origin_resized_width, model_type, max_pixels=16384*28*28, min_pixels=100*28*28):
text = text.strip()
if model_type == "qwen25vl":
smart_resize_height, smart_resize_width = smart_resize(origin_resized_height, origin_resized_width, factor=IMAGE_FACTOR, min_pixels=min_pixels, max_pixels=max_pixels)
# 正则表达式匹配 Action 字符串
if text.startswith("Thought:"):
thought_pattern = r"Thought: (.+?)(?=\s*Action:|$)"
thought_hint = "Thought: "
elif text.startswith("Reflection:"):
thought_pattern = r"Reflection: (.+?)Action_Summary: (.+?)(?=\s*Action:|$)"
thought_hint = "Reflection: "
elif text.startswith("Action_Summary:"):
thought_pattern = r"Action_Summary: (.+?)(?=\s*Action:|$)"
thought_hint = "Action_Summary: "
else:
thought_pattern = r"Thought: (.+?)(?=\s*Action:|$)"
thought_hint = "Thought: "
reflection, thought = None, None
thought_match = re.search(thought_pattern, text, re.DOTALL)
if thought_match:
if len(thought_match.groups()) == 1:
thought = thought_match.group(1).strip()
elif len(thought_match.groups()) == 2:
thought = thought_match.group(2).strip()
reflection = thought_match.group(1).strip()
assert "Action:" in text
action_str = text.split("Action:")[-1]
tmp_all_action = action_str.split("\n\n")
all_action = []
for action_str in tmp_all_action:
if "type(content" in action_str:
# 正则表达式匹配 content 中的字符串并转义单引号
def escape_quotes(match):
content = match.group(1) # 获取 content 的值
return content
# 使用正则表达式进行替换
pattern = r"type\(content='(.*?)'\)" # 匹配 type(content='...')
content = re.sub(pattern, escape_quotes, action_str)
# 处理字符串
action_str = escape_single_quotes(content)
action_str = "type(content='" + action_str + "')"
all_action.append(action_str)
parsed_actions = [parse_action(action.replace("\n","\\n").lstrip()) for action in all_action]
actions = []
for action_instance, raw_str in zip(parsed_actions, all_action):
if action_instance == None:
print(f"Action can't parse: {raw_str}")
raise ValueError(f"Action can't parse: {raw_str}")
action_type = action_instance["function"]
params = action_instance["args"]
# import pdb; pdb.set_trace()
action_inputs = {}
for param_name, param in params.items():
if param == "": continue
param = param.lstrip() # 去掉引号和多余的空格
# 处理start_box或者end_box参数格式 '<bbox>x1 y1 x2 y2</bbox>'
action_inputs[param_name.strip()] = param
if "start_box" in param_name or "end_box" in param_name:
ori_box = param
# Remove parentheses and split the string by commas
numbers = ori_box.replace("(", "").replace(")", "").split(",")
# Convert to float and scale by 1000
# Qwen2.5vl output absolute coordinates, qwen2vl output relative coordinates
if model_type == "qwen25vl":
float_numbers = []
for num_idx, num in enumerate(numbers):
num = float(num)
if (num_idx + 1) % 2 == 0:
float_numbers.append(float(num/smart_resize_height))
else:
float_numbers.append(float(num/smart_resize_width))
else:
float_numbers = [float(num) / factor for num in numbers]
if len(float_numbers) == 2:
float_numbers = [float_numbers[0], float_numbers[1], float_numbers[0], float_numbers[1]]
action_inputs[param_name.strip()] = str(float_numbers)
# import pdb; pdb.set_trace()
actions.append({
"reflection": reflection,
"thought": thought,
"action_type": action_type,
"action_inputs": action_inputs,
"text": text
})
return actions
def parsing_response_to_pyautogui_code(responses, image_height: int, image_width:int, input_swap:bool=True) -> str:
'''
将M模型的输出解析为OSWorld中的action生成pyautogui代码字符串
参数:
response: 包含模型输出的字典结构类似于
{
"action_type": "hotkey",
"action_inputs": {
"hotkey": "v ctrl",
"start_box": None,
"end_box": None
}
}
返回:
生成的pyautogui代码字符串
'''
pyautogui_code = f"import pyautogui\nimport time\n"
if isinstance(responses, dict):
responses = [responses]
for response_id, response in enumerate(responses):
if "observation" in response:
observation = response["observation"]
else:
observation = ""
if "thought" in response:
thought = response["thought"]
else:
thought = ""
if response_id == 0:
pyautogui_code += f"'''\nObservation:\n{observation}\n\nThought:\n{thought}\n'''\n"
else:
pyautogui_code += f"\ntime.sleep(1)\n"
action_dict = response
action_type = action_dict.get("action_type")
action_inputs = action_dict.get("action_inputs", {})
if action_type == "hotkey":
# Parsing hotkey action
if "key" in action_inputs:
hotkey = action_inputs.get("key", "")
else:
hotkey = action_inputs.get("hotkey", "")
if hotkey == "arrowleft":
hotkey = "left"
elif hotkey == "arrowright":
hotkey = "right"
elif hotkey == "arrowup":
hotkey = "up"
elif hotkey == "arrowdown":
hotkey = "down"
if hotkey:
# Handle other hotkeys
keys = hotkey.split() # Split the keys by space
convert_keys = []
for key in keys:
if key == "space":
key = ' '
convert_keys.append(key)
pyautogui_code += f"\npyautogui.hotkey({', '.join([repr(k) for k in convert_keys])})"
elif action_type == "press":
# Parsing press action
if "key" in action_inputs:
key_to_press = action_inputs.get("key", "")
else:
key_to_press = action_inputs.get("press", "")
if hotkey == "arrowleft":
hotkey = "left"
elif hotkey == "arrowright":
hotkey = "right"
elif hotkey == "arrowup":
hotkey = "up"
elif hotkey == "arrowdown":
hotkey = "down"
elif hotkey == "space":
hotkey = " "
if key_to_press:
# Simulate pressing a single key
pyautogui_code += f"\npyautogui.press({repr(key_to_press)})"
elif action_type == "keyup":
key_to_up = action_inputs.get("key", "")
pyautogui_code += f"\npyautogui.keyUp({repr(key_to_up)})"
elif action_type == "keydown":
key_to_down = action_inputs.get("key", "")
pyautogui_code += f"\npyautogui.keyDown({repr(key_to_down)})"
elif action_type == "type":
# Parsing typing action using clipboard
content = action_inputs.get("content", "")
content = escape_single_quotes(content)
stripped_content = content
if content.endswith("\n") or content.endswith("\\n"):
stripped_content = stripped_content.rstrip("\\n").rstrip("\n")
if content:
if input_swap:
pyautogui_code += f"\nimport pyperclip"
pyautogui_code += f"\npyperclip.copy('{stripped_content}')"
pyautogui_code += f"\npyautogui.hotkey('ctrl', 'v')"
pyautogui_code += f"\ntime.sleep(0.5)\n"
if content.endswith("\n") or content.endswith("\\n"):
pyautogui_code += f"\npyautogui.press('enter')"
else:
pyautogui_code += f"\npyautogui.write('{stripped_content}', interval=0.1)"
pyautogui_code += f"\ntime.sleep(0.5)\n"
if content.endswith("\n") or content.endswith("\\n"):
pyautogui_code += f"\npyautogui.press('enter')"
elif action_type in ["drag", "select"]:
# Parsing drag or select action based on start and end_boxes
start_box = action_inputs.get("start_box")
end_box = action_inputs.get("end_box")
if start_box and end_box:
x1, y1, x2, y2 = eval(start_box) # Assuming box is in [x1, y1, x2, y2]
sx = round(float((x1 + x2) / 2) * image_width, 3)
sy = round(float((y1 + y2) / 2) * image_height, 3)
x1, y1, x2, y2 = eval(end_box) # Assuming box is in [x1, y1, x2, y2]
ex = round(float((x1 + x2) / 2) * image_width, 3)
ey = round(float((y1 + y2) / 2) * image_height, 3)
pyautogui_code += (
f"\npyautogui.moveTo({sx}, {sy})\n"
f"\npyautogui.dragTo({ex}, {ey}, duration=1.0)\n"
)
elif action_type == "scroll":
# Parsing scroll action
start_box = action_inputs.get("start_box")
if start_box:
x1, y1, x2, y2 = eval(start_box) # Assuming box is in [x1, y1, x2, y2]
x = round(float((x1 + x2) / 2) * image_width, 3)
y = round(float((y1 + y2) / 2) * image_height, 3)
# # 先点对应区域,再滚动
# pyautogui_code += f"\npyautogui.click({x}, {y}, button='left')"
else:
x = None
y = None
direction = action_inputs.get("direction", "")
if x == None:
if "up" in direction.lower():
pyautogui_code += f"\npyautogui.scroll(5)"
elif "down" in direction.lower():
pyautogui_code += f"\npyautogui.scroll(-5)"
else:
if "up" in direction.lower():
pyautogui_code += f"\npyautogui.scroll(5, x={x}, y={y})"
elif "down" in direction.lower():
pyautogui_code += f"\npyautogui.scroll(-5, x={x}, y={y})"
elif action_type in ["click", "left_single", "left_double", "right_single", "hover"]:
# Parsing mouse click actions
start_box = action_inputs.get("start_box")
start_box = str(start_box)
if start_box:
start_box = eval(start_box)
if len(start_box) == 4:
x1, y1, x2, y2 = start_box # Assuming box is in [x1, y1, x2, y2]
elif len(start_box) == 2:
x1, y1 = start_box
x2 = x1
y2 = y1
x = round(float((x1 + x2) / 2) * image_width, 3)
y = round(float((y1 + y2) / 2) * image_height, 3)
if action_type == "left_single" or action_type == "click":
pyautogui_code += f"\npyautogui.click({x}, {y}, button='left')"
elif action_type == "left_double":
pyautogui_code += f"\npyautogui.doubleClick({x}, {y}, button='left')"
elif action_type == "right_single":
pyautogui_code += f"\npyautogui.click({x}, {y}, button='right')"
elif action_type == "hover":
pyautogui_code += f"\npyautogui.moveTo({x}, {y})"
elif action_type in ["finished"]:
pyautogui_code = f"DONE"
else:
pyautogui_code += f"\n# Unrecognized action type: {action_type}"
return pyautogui_code
def add_box_token(input_string):
# Step 1: Split the string into individual actions
if "Action: " in input_string and "start_box=" in input_string:
suffix = input_string.split("Action: ")[0] + "Action: "
actions = input_string.split("Action: ")[1:]
processed_actions = []
for action in actions:
action = action.strip()
# Step 2: Extract coordinates (start_box or end_box) using regex
coordinates = re.findall(r"(start_box|end_box)='\((\d+),\s*(\d+)\)'", action)
updated_action = action # Start with the original action
for coord_type, x, y in coordinates:
# Convert x and y to integers
updated_action = updated_action.replace(f"{coord_type}='({x},{y})'", f"{coord_type}='<|box_start|>({x},{y})<|box_end|>'")
processed_actions.append(updated_action)
# Step 5: Reconstruct the final string
final_string = suffix + "\n\n".join(processed_actions)
else:
final_string = input_string
return final_string
def pil_to_base64(image):
buffer = BytesIO()
image.save(buffer, format="PNG") # 你可以改成 "JPEG" 等格式
return base64.b64encode(buffer.getvalue()).decode("utf-8")
def linearize_accessibility_tree(accessibility_tree, platform="ubuntu"):
if platform == "ubuntu":
_attributes_ns = attributes_ns_ubuntu
_state_ns = state_ns_ubuntu
_component_ns = component_ns_ubuntu
_value_ns = value_ns_ubuntu
elif platform == "windows":
_attributes_ns = attributes_ns_windows
_state_ns = state_ns_windows
_component_ns = component_ns_windows
_value_ns = value_ns_windows
else:
raise ValueError("Invalid platform, must be 'ubuntu' or 'windows'")
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree), platform)
linearized_accessibility_tree = [
"tag\tname\ttext\tclass\tdescription\tposition (top-left x&y)\tsize (w&h)"
]
# Linearize the accessibility tree nodes into a table format
for node in filtered_nodes:
if node.text:
text = (
node.text
if '"' not in node.text
else '"{:}"'.format(node.text.replace('"', '""'))
)
elif node.get("{{{:}}}class".format(class_ns_windows), "").endswith(
"EditWrapper"
) and node.get("{{{:}}}value".format(_value_ns)):
node_text = node.get("{{{:}}}value".format(_value_ns), "")
text = (
node_text
if '"' not in node_text
else '"{:}"'.format(node_text.replace('"', '""'))
)
else:
text = '""'
linearized_accessibility_tree.append(
"{:}\t{:}\t{:}\t{:}\t{:}\t{:}\t{:}".format(
node.tag,
node.get("name", ""),
text,
(
node.get("{{{:}}}class".format(_attributes_ns), "")
if platform == "ubuntu"
else node.get("{{{:}}}class".format(class_ns_windows), "")
),
node.get("{{{:}}}description".format(_attributes_ns), ""),
node.get("{{{:}}}screencoord".format(_component_ns), ""),
node.get("{{{:}}}size".format(_component_ns), ""),
)
)
return "\n".join(linearized_accessibility_tree)
def trim_accessibility_tree(linearized_accessibility_tree, max_tokens):
# enc = tiktoken.encoding_for_model("gpt-4")
# tokens = enc.encode(linearized_accessibility_tree)
# if len(tokens) > max_tokens:
# linearized_accessibility_tree = enc.decode(tokens[:max_tokens])
# linearized_accessibility_tree += "[...]\n"
return linearized_accessibility_tree
class UITARSAgent:
def __init__(
self,
model: str,
runtime_conf: Dict,
platform="ubuntu",
action_space="pyautogui",
observation_type="screenshot",
# observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"]
max_trajectory_length=50,
a11y_tree_max_tokens=10000,
model_type="qwen25vl",
**kwargs
):
self.model = model
self.platform = platform
self.action_space = action_space
self.observation_type = observation_type
self.max_trajectory_length = max_trajectory_length
self.a11y_tree_max_tokens = a11y_tree_max_tokens
self.model_type = model_type
self.runtime_conf = runtime_conf
self.temperature = self.runtime_conf["temperature"]
self.top_k = self.runtime_conf["top_k"]
self.top_p = self.runtime_conf["top_p"]
self.max_tokens = self.runtime_conf["max_tokens"]
self.infer_mode = self.runtime_conf["infer_mode"]
self.prompt_style = self.runtime_conf["prompt_style"]
self.input_swap = self.runtime_conf["input_swap"]
self.language = self.runtime_conf["language"]
self.max_pixels = self.runtime_conf["max_pixels"]
self.min_pixels = self.runtime_conf["min_pixels"]
self.callusr_tolerance = self.runtime_conf["callusr_tolerance"]
self.vlm = OpenAI(
base_url=os.environ['DOUBAO_API_URL'],
api_key=os.environ['DOUBAO_API_KEY'],
)
self.thoughts = []
self.actions = []
self.observations = []
self.history_images = []
self.history_responses = []
self.prompt_action_space = UITARS_ACTION_SPACE
self.action_parse_res_factor = 1000
if self.infer_mode == "qwen2vl_user":
self.prompt_action_space = UITARS_CALL_USR_ACTION_SPACE
elif self.infer_mode == "qwen25vl_normal":
self.prompt_action_space = UITARS_NORMAL_ACTION_SPACE
self.prompt_template = UITARS_USR_PROMPT_THOUGHT
if self.prompt_style == "qwen2vl_user" or self.prompt_style == "qwen25vl_normal":
self.prompt_template = UITARS_USR_PROMPT_THOUGHT
elif self.prompt_style == "qwen2vl_no_thought":
self.prompt_template = UITARS_USR_PROMPT_NOTHOUGHT
if "history_n" in self.runtime_conf:
self.history_n = self.runtime_conf["history_n"]
else:
self.history_n = 5
self.cur_callusr_count = 0
def reset(self, runtime_logger=None):
self.thoughts = []
self.actions = []
self.observations = []
self.history_images = []
self.history_responses = []
def predict(
self, instruction: str, obs: Dict, last_action_after_obs: Dict = None
) -> List:
"""
Predict the next action(s) based on the current observation.
"""
# Append trajectory
# print(len(self.observations), len(self.actions), len(self.actions))
assert len(self.observations) == len(self.actions) and len(self.actions) == len(
self.thoughts
), "The number of observations and actions should be the same."
if len(self.observations) > self.max_trajectory_length:
if self.max_trajectory_length == 0:
_observations = []
_actions = []
_thoughts = []
else:
_observations = self.observations[-self.max_trajectory_length :]
_actions = self.actions[-self.max_trajectory_length :]
_thoughts = self.thoughts[-self.max_trajectory_length :]
else:
_observations = self.observations
_actions = self.actions
_thoughts = self.thoughts
self.history_images.append(obs["screenshot"])
if self.observation_type in ["screenshot", "screenshot_a11y_tree"]:
base64_image = obs["screenshot"]
try:
linearized_accessibility_tree = (
linearize_accessibility_tree(
accessibility_tree=obs["accessibility_tree"],
platform=self.platform,
)
if self.observation_type == "screenshot_a11y_tree"
else None
)
except:
linearized_accessibility_tree = None
# logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
if linearized_accessibility_tree:
linearized_accessibility_tree = trim_accessibility_tree(
linearized_accessibility_tree, self.a11y_tree_max_tokens
)
if self.observation_type == "screenshot_a11y_tree":
self.observations.append(
{
"screenshot": base64_image,
"accessibility_tree": linearized_accessibility_tree,
}
)
else:
self.observations.append(
{"screenshot": base64_image, "accessibility_tree": None}
)
else:
raise ValueError(
"Invalid observation_type type: " + self.observation_type
) # 1}}}
if self.infer_mode == "qwen2vl_user" or self.infer_mode == "qwen25vl_normal":
user_prompt = self.prompt_template.format(
instruction=instruction,
action_space=self.prompt_action_space,
language=self.language
)
elif self.infer_mode == "qwen2vl_no_thought":
user_prompt = self.prompt_template.format(
instruction=instruction
)
if len(self.history_images) > self.history_n:
self.history_images = self.history_images[-self.history_n:]
messages, images = [], []
if isinstance(self.history_images, bytes):
self.history_images = [self.history_images]
elif isinstance(self.history_images, np.ndarray):
self.history_images = list(self.history_images)
elif isinstance(self.history_images, list):
pass
else:
raise TypeError(f"Unidentified images type: {type(self.history_images)}")
for turn, image in enumerate(self.history_images):
if len(images) >= self.history_n:
break
try:
image = Image.open(BytesIO(image))
except Exception as e:
raise RuntimeError(f"Error opening image: {e}")
if image.width * image.height > self.max_pixels:
"""
如果图片超过/低于像素限制则计算一个缩放因子resize_factor使图片的像素数缩小到等于或小于max_pixels这个缩放因子是通过开平方根计算的确保纵横比保持不变,这样原始的相对坐标可以不经转换直接复用
"""
resize_factor = math.sqrt(self.max_pixels / (image.width * image.height))
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
image = image.resize((width, height))
if image.width * image.height < self.min_pixels:
resize_factor = math.sqrt(self.min_pixels / (image.width * image.height))
width, height = math.ceil(image.width * resize_factor), math.ceil(image.height * resize_factor)
image = image.resize((width, height))
if image.mode != "RGB":
image = image.convert("RGB")
images.append(image)
messages = [
{
"role": "system",
"content": [{"type": "text", "text": "You are a helpful assistant."}]
},
{
"role": "user",
"content": [{"type": "text", "text": user_prompt}]
}
]
image_num = 0
if len(self.history_responses) > 0:
for history_idx, history_response in enumerate(self.history_responses):
# send at most history_n images to the model
if history_idx + self.history_n > len(self.history_responses):
cur_image = images[image_num]
encoded_string = pil_to_base64(cur_image)
messages.append({
"role": "user",
"content": [{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{encoded_string}"}}]
})
image_num += 1
messages.append({
"role": "assistant",
"content": [{"type": "text", "text": add_box_token(history_response)}]
})
cur_image = images[image_num]
encoded_string = pil_to_base64(cur_image)
messages.append({
"role": "user",
"content": [{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{encoded_string}"}}]
})
image_num += 1
else:
cur_image = images[image_num]
encoded_string = pil_to_base64(cur_image)
messages.append({
"role": "user",
"content": [{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{encoded_string}"}}]
})
image_num += 1
try_times = 3
origin_resized_height = images[-1].height
origin_resized_width = images[-1].width
temperature = self.temperature
top_k = self.top_k
while True:
if try_times <= 0:
print(f"Reach max retry times to fetch response from client, as error flag.")
return "client error", ["DONE"]
try:
response = self.vlm.chat.completions.create(
model=self.model,
messages=messages,
frequency_penalty=1,
max_tokens=self.max_tokens,
temperature=temperature,
top_p=self.top_p
)
print("*" * 20)
print("Response:")
print(response.choices[0].message.content)
print("*" * 20)
prediction = response.choices[0].message.content.strip()
except Exception as e:
logger.exception(f"Error when fetching response from client: {e}")
prediction = None
try_times -= 1
try:
parsed_responses = parse_action_to_structure_output(
prediction,
self.action_parse_res_factor,
origin_resized_height,
origin_resized_width,
self.model_type,
self.max_pixels,
self.min_pixels
)
break
except Exception as e:
print(f"Error when parsing response from client: {e}")
# If fail to parse the model response, we use sampling parameters to avoid it
prediction = None
try_times -= 1
temperature = 1
top_k = -1
if prediction is None:
return "client error", ["DONE"]
self.history_responses.append(prediction)
self.thoughts.append(prediction)
try:
parsed_responses = parse_action_to_structure_output(
prediction,
self.action_parse_res_factor,
origin_resized_height,
origin_resized_width,
self.model_type,
self.max_pixels,
self.min_pixels
)
except Exception as e:
print(f"Parsing action error: {prediction}, with error:\n{e}")
return f"Parsing action error: {prediction}, with error:\n{e}", ["DONE"]
actions = []
last_image = Image.open(BytesIO(self.history_images[-1]))
obs_image_height = last_image.height
obs_image_width = last_image.width
for parsed_response in parsed_responses:
if "action_type" in parsed_response:
if parsed_response["action_type"] == FINISH_WORD:
self.actions.append(actions)
return prediction, ["DONE"]
elif parsed_response["action_type"] == WAIT_WORD:
self.actions.append(actions)
return prediction, ["WAIT"]
elif parsed_response["action_type"] == ENV_FAIL_WORD:
self.actions.append(actions)
return prediction, ["FAIL"]
elif parsed_response["action_type"] == CALL_USER:
if self.callusr_tolerance > self.cur_callusr_count:
self.actions.append(actions)
self.cur_callusr_count += 1
return prediction, ["WAIT"]
else:
self.actions.append(actions)
return prediction, ["FAIL"]
pyautogui_code = parsing_response_to_pyautogui_code(
parsed_response,
obs_image_height,
obs_image_width,
self.input_swap
)
actions.append(pyautogui_code)
self.actions.append(actions)
if len(self.history_responses) >= self.max_trajectory_length:
# Default to FAIL if exceed max steps
actions = ["FAIL"]
return prediction, actions

View File

@ -613,7 +613,7 @@ class UITarsAgent:
self,
# Model settings
model: str,
model_type: str,
# Generation settings
max_tokens: int,
top_p: Optional[float],
@ -672,7 +672,7 @@ class UITarsAgent:
self.system_prompt = COMPUTER_USE_NO_THINKING
self.action_parse_res_factor = 1000
self.model_type = "doubao"
self.model_type = model_type
self.history_n = 5
self.top_p = top_p
self.temperature = temperature

View File

@ -6,7 +6,7 @@ import re
import xml.etree.ElementTree as ET
from io import BytesIO
from typing import Dict, List
import os
import backoff
import numpy as np
from PIL import Image
@ -28,22 +28,16 @@ from mm_agents.prompts import (
UITARS_CALL_USR_ACTION_SPACE,
UITARS_USR_PROMPT_NOTHOUGHT,
UITARS_USR_PROMPT_THOUGHT,
UITARS_NORMAL_ACTION_SPACE
)
logger = logging.getLogger("desktopenv.agent")
from loguru import logger
FINISH_WORD = "finished"
WAIT_WORD = "wait"
ENV_FAIL_WORD = "error_env"
CALL_USER = "call_user"
IMAGE_FACTOR = 28
MIN_PIXELS = 100 * 28 * 28
MAX_PIXELS = 16384 * 28 * 28
MAX_RATIO = 200
pure_text_settings = ["a11y_tree"]
attributes_ns_ubuntu = "https://accessibility.windows.example.org/ns/attributes"
@ -109,68 +103,8 @@ def escape_single_quotes(text):
pattern = r"(?<!\\)'"
return re.sub(pattern, r"\\'", text)
def round_by_factor(number: int, factor: int) -> int:
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
return round(number / factor) * factor
def ceil_by_factor(number: int, factor: int) -> int:
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
return math.ceil(number / factor) * factor
def floor_by_factor(number: int, factor: int) -> int:
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return math.floor(number / factor) * factor
def linear_resize(
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
) -> tuple[int, int]:
if width * height > max_pixels:
"""
如果图片超过/低于像素限制则计算一个缩放因子resize_factor使图片的像素数缩小到等于或小于max_pixels这个缩放因子是通过开平方根计算的确保纵横比保持不变,这样原始的相对坐标可以不经转换直接复用
"""
resize_factor = math.sqrt(max_pixels / (width * height))
width, height = int(width * resize_factor), int(height * resize_factor)
if width * height < min_pixels:
resize_factor = math.sqrt(min_pixels / (width * height))
width, height = math.ceil(width * resize_factor), math.ceil(height * resize_factor)
return height, width
def smart_resize(
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
) -> tuple[int, int]:
"""
Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
if max(height, width) / min(height, width) > MAX_RATIO:
raise ValueError(
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
)
h_bar = max(factor, round_by_factor(height, factor))
w_bar = max(factor, round_by_factor(width, factor))
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = floor_by_factor(height / beta, factor)
w_bar = floor_by_factor(width / beta, factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = ceil_by_factor(height * beta, factor)
w_bar = ceil_by_factor(width * beta, factor)
return h_bar, w_bar
def parse_action_to_structure_output(text, factor, origin_resized_height, origin_resized_width, model_type, max_pixels=16384*28*28, min_pixels=100*28*28):
def parse_action_qwen2vl(text, factor, image_height, image_width):
text = text.strip()
if model_type == "qwen25vl":
smart_resize_height, smart_resize_width = smart_resize(origin_resized_height, origin_resized_width, factor=IMAGE_FACTOR, min_pixels=min_pixels, max_pixels=max_pixels)
# 正则表达式匹配 Action 字符串
if text.startswith("Thought:"):
thought_pattern = r"Thought: (.+?)(?=\s*Action:|$)"
@ -182,8 +116,10 @@ def parse_action_to_structure_output(text, factor, origin_resized_height, origin
thought_pattern = r"Action_Summary: (.+?)(?=\s*Action:|$)"
thought_hint = "Action_Summary: "
else:
thought_pattern = r"Thought: (.+?)(?=\s*Action:|$)"
thought_hint = "Thought: "
# 修复:当没有明确的"Thought:"标识时提取Action:之前的所有内容作为思考
thought_pattern = r"(.+?)(?=\s*Action:|$)"
thought_hint = ""
reflection, thought = None, None
thought_match = re.search(thought_pattern, text, re.DOTALL)
if thought_match:
@ -218,7 +154,7 @@ def parse_action_to_structure_output(text, factor, origin_resized_height, origin
for action_instance, raw_str in zip(parsed_actions, all_action):
if action_instance == None:
print(f"Action can't parse: {raw_str}")
raise ValueError(f"Action can't parse: {raw_str}")
continue
action_type = action_instance["function"]
params = action_instance["args"]
@ -236,18 +172,7 @@ def parse_action_to_structure_output(text, factor, origin_resized_height, origin
numbers = ori_box.replace("(", "").replace(")", "").split(",")
# Convert to float and scale by 1000
# Qwen2.5vl output absolute coordinates, qwen2vl output relative coordinates
if model_type == "qwen25vl":
float_numbers = []
for num_idx, num in enumerate(numbers):
num = float(num)
if (num_idx + 1) % 2 == 0:
float_numbers.append(float(num/smart_resize_height))
else:
float_numbers.append(float(num/smart_resize_width))
else:
float_numbers = [float(num) / factor for num in numbers]
float_numbers = [float(num) / factor for num in numbers]
if len(float_numbers) == 2:
float_numbers = [float_numbers[0], float_numbers[1], float_numbers[0], float_numbers[1]]
action_inputs[param_name.strip()] = str(float_numbers)
@ -296,7 +221,7 @@ def parsing_response_to_pyautogui_code(responses, image_height: int, image_width
if response_id == 0:
pyautogui_code += f"'''\nObservation:\n{observation}\n\nThought:\n{thought}\n'''\n"
else:
pyautogui_code += f"\ntime.sleep(1)\n"
pyautogui_code += f"\ntime.sleep(3)\n"
action_dict = response
action_type = action_dict.get("action_type")
@ -309,79 +234,25 @@ def parsing_response_to_pyautogui_code(responses, image_height: int, image_width
else:
hotkey = action_inputs.get("hotkey", "")
if hotkey == "arrowleft":
hotkey = "left"
elif hotkey == "arrowright":
hotkey = "right"
elif hotkey == "arrowup":
hotkey = "up"
elif hotkey == "arrowdown":
hotkey = "down"
if hotkey:
# Handle other hotkeys
keys = hotkey.split() # Split the keys by space
convert_keys = []
for key in keys:
if key == "space":
key = ' '
convert_keys.append(key)
pyautogui_code += f"\npyautogui.hotkey({', '.join([repr(k) for k in convert_keys])})"
pyautogui_code += f"\npyautogui.hotkey({', '.join([repr(k) for k in keys])})"
elif action_type == "press":
# Parsing press action
if "key" in action_inputs:
key_to_press = action_inputs.get("key", "")
else:
key_to_press = action_inputs.get("press", "")
if hotkey == "arrowleft":
hotkey = "left"
elif hotkey == "arrowright":
hotkey = "right"
elif hotkey == "arrowup":
hotkey = "up"
elif hotkey == "arrowdown":
hotkey = "down"
elif hotkey == "space":
hotkey = " "
if key_to_press:
# Simulate pressing a single key
pyautogui_code += f"\npyautogui.press({repr(key_to_press)})"
elif action_type == "keyup":
key_to_up = action_inputs.get("key", "")
pyautogui_code += f"\npyautogui.keyUp({repr(key_to_up)})"
elif action_type == "keydown":
key_to_down = action_inputs.get("key", "")
pyautogui_code += f"\npyautogui.keyDown({repr(key_to_down)})"
elif action_type == "type":
# Parsing typing action using clipboard
content = action_inputs.get("content", "")
content = escape_single_quotes(content)
stripped_content = content
if content.endswith("\n") or content.endswith("\\n"):
stripped_content = stripped_content.rstrip("\\n").rstrip("\n")
if content:
if input_swap:
pyautogui_code += f"\nimport pyperclip"
pyautogui_code += f"\npyperclip.copy('{stripped_content}')"
pyautogui_code += f"\npyperclip.copy('{content.strip()}')"
pyautogui_code += f"\npyautogui.hotkey('ctrl', 'v')"
pyautogui_code += f"\ntime.sleep(0.5)\n"
if content.endswith("\n") or content.endswith("\\n"):
pyautogui_code += f"\npyautogui.press('enter')"
else:
pyautogui_code += f"\npyautogui.write('{stripped_content}', interval=0.1)"
pyautogui_code += f"\npyautogui.write('{content.strip()}', interval=0.1)"
pyautogui_code += f"\ntime.sleep(0.5)\n"
if content.endswith("\n") or content.endswith("\\n"):
pyautogui_code += f"\npyautogui.press('enter')"
@ -460,29 +331,6 @@ def parsing_response_to_pyautogui_code(responses, image_height: int, image_width
return pyautogui_code
def add_box_token(input_string):
# Step 1: Split the string into individual actions
if "Action: " in input_string and "start_box=" in input_string:
suffix = input_string.split("Action: ")[0] + "Action: "
actions = input_string.split("Action: ")[1:]
processed_actions = []
for action in actions:
action = action.strip()
# Step 2: Extract coordinates (start_box or end_box) using regex
coordinates = re.findall(r"(start_box|end_box)='\((\d+),\s*(\d+)\)'", action)
updated_action = action # Start with the original action
for coord_type, x, y in coordinates:
# Convert x and y to integers
updated_action = updated_action.replace(f"{coord_type}='({x},{y})'", f"{coord_type}='<|box_start|>({x},{y})<|box_end|>'")
processed_actions.append(updated_action)
# Step 5: Reconstruct the final string
final_string = suffix + "\n\n".join(processed_actions)
else:
final_string = input_string
return final_string
def pil_to_base64(image):
buffer = BytesIO()
image.save(buffer, format="PNG") # 你可以改成 "JPEG" 等格式
@ -558,51 +406,48 @@ def trim_accessibility_tree(linearized_accessibility_tree, max_tokens):
class UITARSAgent:
def __init__(
self,
model: str,
platform="ubuntu",
max_tokens=1000,
top_p=0.9,
top_k=1.0,
temperature=0.0,
action_space="pyautogui",
observation_type="screenshot",
observation_type="screenshot_a11y_tree",
# observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"]
max_trajectory_length=50,
a11y_tree_max_tokens=10000,
model_type="qwen25vl",
runtime_conf: dict = {
"infer_mode": "qwen25vl_normal",
"prompt_style": "qwen25vl_normal",
"infer_mode": "qwen2vl_user",
"prompt_style": "qwen2vl_user",
"input_swap": True,
"language": "Chinese",
"max_steps": 50,
"history_n": 5,
"max_pixels": 16384*28*28,
"min_pixels": 100*28*28,
"callusr_tolerance": 3,
"temperature": 0.0,
"top_k": -1,
"top_p": 0.9,
"max_tokens": 500
"screen_height": 1080,
"screen_width": 1920
}
):
self.model = model
self.platform = platform
self.max_tokens = max_tokens
self.top_p = top_p
self.top_k = top_k
self.temperature = temperature
self.action_space = action_space
self.observation_type = observation_type
self.max_trajectory_length = max_trajectory_length
self.a11y_tree_max_tokens = a11y_tree_max_tokens
self.model_type = model_type
self.runtime_conf = runtime_conf
self.vlm = OpenAI(
base_url="http://127.0.0.1:8000/v1",
api_key="empty",
base_url=os.environ['DOUBAO_API_URL'],
api_key=os.environ['DOUBAO_API_KEY'],
) # should replace with your UI-TARS server api
self.temperature = self.runtime_conf["temperature"]
self.top_k = self.runtime_conf["top_k"]
self.top_p = self.runtime_conf["top_p"]
self.max_tokens = self.runtime_conf["max_tokens"]
self.infer_mode = self.runtime_conf["infer_mode"]
self.prompt_style = self.runtime_conf["prompt_style"]
self.input_swap = self.runtime_conf["input_swap"]
self.language = self.runtime_conf["language"]
self.max_pixels = self.runtime_conf["max_pixels"]
self.min_pixels = self.runtime_conf["min_pixels"]
self.callusr_tolerance = self.runtime_conf["callusr_tolerance"]
self.max_steps = max_trajectory_length
self.thoughts = []
self.actions = []
@ -611,15 +456,14 @@ class UITARSAgent:
self.history_responses = []
self.prompt_action_space = UITARS_ACTION_SPACE
self.customize_action_parser = parse_action_qwen2vl
self.action_parse_res_factor = 1000
if self.infer_mode == "qwen2vl_user":
self.prompt_action_space = UITARS_CALL_USR_ACTION_SPACE
elif self.infer_mode == "qwen25vl_normal":
self.prompt_action_space = UITARS_NORMAL_ACTION_SPACE
self.prompt_template = UITARS_USR_PROMPT_THOUGHT
if self.prompt_style == "qwen2vl_user" or self.prompt_style == "qwen25vl_normal":
if self.prompt_style == "qwen2vl_user":
self.prompt_template = UITARS_USR_PROMPT_THOUGHT
elif self.prompt_style == "qwen2vl_no_thought":
@ -630,8 +474,6 @@ class UITARSAgent:
self.history_n = self.runtime_conf["history_n"]
else:
self.history_n = 5
self.cur_callusr_count = 0
def predict(
self, instruction: str, obs: Dict, last_action_after_obs: Dict = None
@ -660,18 +502,9 @@ class UITARSAgent:
_actions = self.actions
_thoughts = self.thoughts
for previous_obs, previous_action, previous_thought in zip(
_observations, _actions, _thoughts
):
# {{{1
if self.observation_type == "screenshot_a11y_tree":
_screenshot = previous_obs["screenshot"]
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
else:
raise ValueError(
"Invalid observation_type type: " + self.observation_type
) # 1}}}
if last_action_after_obs is not None and self.infer_mode == "double_image":
self.history_images.append(last_action_after_obs["screenshot"])
self.history_images.append(obs["screenshot"])
@ -712,7 +545,7 @@ class UITARSAgent:
"Invalid observation_type type: " + self.observation_type
) # 1}}}
if self.infer_mode == "qwen2vl_user" or self.infer_mode == "qwen25vl_normal":
if self.infer_mode == "qwen2vl_user":
user_prompt = self.prompt_template.format(
instruction=instruction,
action_space=self.prompt_action_space,
@ -726,6 +559,8 @@ class UITARSAgent:
if len(self.history_images) > self.history_n:
self.history_images = self.history_images[-self.history_n:]
max_pixels = 2116800
min_pixels = 3136
messages, images = [], []
if isinstance(self.history_images, bytes):
self.history_images = [self.history_images]
@ -735,24 +570,28 @@ class UITARSAgent:
pass
else:
raise TypeError(f"Unidentified images type: {type(self.history_images)}")
max_image_nums_under_32k = int(32768*0.75/max_pixels*28*28)
if len(self.history_images) > max_image_nums_under_32k:
num_of_images = min(5, len(self.history_images))
max_pixels = int(32768*0.75) // num_of_images
for turn, image in enumerate(self.history_images):
if len(images) >= self.history_n:
if len(images) >= 5:
break
try:
image = Image.open(BytesIO(image))
except Exception as e:
raise RuntimeError(f"Error opening image: {e}")
if image.width * image.height > self.max_pixels:
if image.width * image.height > max_pixels:
"""
如果图片超过/低于像素限制则计算一个缩放因子resize_factor使图片的像素数缩小到等于或小于max_pixels这个缩放因子是通过开平方根计算的确保纵横比保持不变,这样原始的相对坐标可以不经转换直接复用
"""
resize_factor = math.sqrt(self.max_pixels / (image.width * image.height))
resize_factor = math.sqrt(max_pixels / (image.width * image.height))
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
image = image.resize((width, height))
if image.width * image.height < self.min_pixels:
resize_factor = math.sqrt(self.min_pixels / (image.width * image.height))
if image.width * image.height < min_pixels:
resize_factor = math.sqrt(min_pixels / (image.width * image.height))
width, height = math.ceil(image.width * resize_factor), math.ceil(image.height * resize_factor)
image = image.resize((width, height))
@ -788,7 +627,7 @@ class UITARSAgent:
messages.append({
"role": "assistant",
"content": [add_box_token(history_response)]
"content": history_response
})
cur_image = images[image_num]
@ -809,79 +648,59 @@ class UITARSAgent:
image_num += 1
try_times = 3
origin_resized_height = images[-1].height
origin_resized_width = images[-1].width
temperature = self.temperature
top_k = self.top_k
while True:
if try_times <= 0:
print(f"Reach max retry times to fetch response from client, as error flag.")
return "client error", ["DONE"], []
return "client error", ["DONE"]
try:
response = self.vlm.chat.completions.create(
model="ui-tars",
model=self.model,
messages=messages,
frequency_penalty=1,
max_tokens=self.max_tokens,
temperature=temperature,
temperature=self.temperature,
top_p=self.top_p
)
# print(response.choices[0].message.content)
prediction = response.choices[0].message.content.strip()
except Exception as e:
print(f"Error when fetching response from client, with response: {response}")
prediction = None
try_times -= 1
try:
parsed_responses = parse_action_to_structure_output(
print("Response:")
print(response.choices[0].message.content)
prediction = response.choices[0].message.content
parsed_responses = self.customize_action_parser(
prediction,
self.action_parse_res_factor,
origin_resized_height,
origin_resized_width,
self.model_type,
self.max_pixels,
self.min_pixels
self.runtime_conf["screen_height"],
self.runtime_conf["screen_width"]
)
break
except Exception as e:
print(f"Error when parsing response from client, with response: {response}")
# If fail to parse the model response, we use sampling parameters to avoid it
logger.exception(f"Error when fetching response from client, with response: {e}")
prediction = None
try_times -= 1
temperature = 1
top_k = -1
if prediction is None:
return "client error", ["DONE"]
self.history_responses.append(prediction)
self.thoughts.append(prediction)
try:
parsed_responses = parse_action_to_structure_output(
parsed_responses = self.customize_action_parser(
prediction,
self.action_parse_res_factor,
origin_resized_height,
origin_resized_width,
self.model_type,
self.max_pixels,
self.min_pixels
self.runtime_conf["screen_height"],
self.runtime_conf["screen_width"]
)
except Exception as e:
print(f"Parsing action error: {prediction}, with error:\n{e}")
return f"Parsing action error: {prediction}, with error:\n{e}", ["DONE"]
actions = []
last_image = Image.open(BytesIO(self.history_images[-1]))
obs_image_height = last_image.height
obs_image_width = last_image.width
for parsed_response in parsed_responses:
if "action_type" in parsed_response:
if parsed_response["action_type"] == FINISH_WORD:
self.actions.append(actions)
return prediction, ["DONE"]
elif parsed_response["action_type"] == WAIT_WORD:
@ -893,18 +712,13 @@ class UITARSAgent:
return prediction, ["FAIL"]
elif parsed_response["action_type"] == CALL_USER:
if self.callusr_tolerance > self.cur_callusr_count:
self.actions.append(actions)
self.cur_callusr_count += 1
return prediction, ["WAIT"]
else:
self.actions.append(actions)
return prediction, ["FAIL"]
self.actions.append(actions)
return prediction, ["FAIL"]
pyautogui_code = parsing_response_to_pyautogui_code(
parsed_response,
obs_image_height,
obs_image_width,
self.runtime_conf["screen_height"],
self.runtime_conf["screen_width"],
self.input_swap
)
actions.append(pyautogui_code)
@ -917,7 +731,6 @@ class UITARSAgent:
return prediction, actions
@backoff.on_exception(
backoff.constant,
# here you should add more model exceptions as you want,
@ -947,4 +760,4 @@ class UITARSAgent:
self.actions = []
self.observations = []
self.history_images = []
self.history_responses = []
self.history_responses = []

539
run_multienv_uitars.py Normal file
View File

@ -0,0 +1,539 @@
from __future__ import annotations
import argparse
import datetime
import json
import logging
import os
import sys
import signal
import time
from typing import List
from multiprocessing import Process, Manager
from multiprocessing import current_process
import lib_run_single
from desktop_env.desktop_env import DesktopEnv
from mm_agents.uitars_agent import UITARSAgent
import os
# Global variables for signal handling
active_environments = []
processes = []
is_terminating = False
# load the environment variables from .env file
if os.path.exists(".env"):
from dotenv import load_dotenv
load_dotenv()
# Logger Configs {{{ #
def config() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Run end-to-end evaluation on the benchmark"
)
# environment config
parser.add_argument("--path_to_vm", type=str, default=None)
parser.add_argument(
"--headless", action="store_true", help="Run in headless machine"
)
parser.add_argument(
"--action_space", type=str, default="pyautogui", help="Action type"
)
parser.add_argument(
"--observation_type",
choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"],
default="screenshot",
help="Observation type",
)
parser.add_argument("--sleep_after_execution", type=float, default=3.0)
parser.add_argument("--max_steps", type=int, default=15)
# evaluation config
parser.add_argument(
"--test_config_base_dir", type=str, default="evaluation_examples"
)
# lm config
# lm config
parser.add_argument("--model", type=str, default="uitars-72b-dpo", help="Model name")
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--top_p", type=float, default=0.9)
parser.add_argument("--max_tokens", type=int, default=1500)
parser.add_argument("--stop_token", type=str, default=None)
parser.add_argument("--max_trajectory_length", type=int, default=3, help="The max number of trajectory steps.")
# example config
parser.add_argument("--domain", type=str, default="all")
parser.add_argument(
"--test_all_meta_path", type=str, default="evaluation_examples/test_all.json"
)
# logging related
parser.add_argument("--result_dir", type=str, default="./results")
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel")
parser.add_argument("--log_level", type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
default='INFO', help="Set the logging level")
# aws config
parser.add_argument(
"--region", type=str, default="us-east-1", help="AWS region for the VM"
)
parser.add_argument(
"--provider_name", type=str, default="aws", choices=["aws", "virtualbox", "vmware", "docker", "azure"], help="Provider name"
)
parser.add_argument(
"--client_password", type=str, default="", help="Client password"
)
parser.add_argument(
"--screen_width", type=int, default=1920, help="Screen width"
)
parser.add_argument(
"--screen_height", type=int, default=1080, help="Screen height"
)
args = parser.parse_args()
return args
args = config() # Get command line arguments first
logger = logging.getLogger()
log_level = getattr(logging, args.log_level.upper())
logger.setLevel(log_level)
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
file_handler = logging.FileHandler(
os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8"
)
debug_handler = logging.FileHandler(
os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8"
)
stdout_handler = logging.StreamHandler(sys.stdout)
file_handler.setLevel(logging.INFO)
debug_handler.setLevel(logging.DEBUG)
stdout_handler.setLevel(log_level)
formatter = logging.Formatter(
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
)
file_handler.setFormatter(formatter)
debug_handler.setFormatter(formatter)
stdout_handler.setFormatter(formatter)
stdout_handler.addFilter(logging.Filter("desktopenv"))
logger.addHandler(file_handler)
logger.addHandler(debug_handler)
logger.addHandler(stdout_handler)
# }}} Logger Configs #
logger = logging.getLogger("desktopenv.experiment")
def distribute_tasks(test_all_meta: dict) -> List[tuple]:
all_tasks = []
for domain, examples in test_all_meta.items():
for example_id in examples:
all_tasks.append((domain, example_id))
return all_tasks
def process_signal_handler(signum, frame, env_idx):
"""Signal handler for child processes to gracefully shut down their environments."""
logger.info(f"Process {env_idx + 1} received signal {signum}. Shutting down...")
# Get the active_environments from the caller's frame
local_vars = frame.f_locals
active_environments = local_vars.get('active_environments', [])
# Close environment in the current process context
for env in active_environments:
if env is not None:
try:
logger.info(f"Process {env_idx + 1} closing environment...")
env.close()
logger.info(f"Process {env_idx + 1} environment closed successfully")
except Exception as e:
logger.error(f"Process {env_idx + 1} error closing environment: {e}")
logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.")
sys.exit(0)
def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: list):
active_environments = []
env = None
try:
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
REGION = args.region
screen_size = (args.screen_width, args.screen_height)
ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)])
env = DesktopEnv(
path_to_vm=args.path_to_vm,
action_space=args.action_space,
provider_name=args.provider_name,
region=REGION,
snapshot_name=ami_id,
screen_size=screen_size,
headless=args.headless,
os_type="Ubuntu",
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
enable_proxy=True,
client_password=args.client_password
)
active_environments.append(env)
args.max_trajectory_length = args.max_steps
agent = UITARSAgent(
model=args.model,
max_tokens=args.max_tokens,
top_p=args.top_p,
temperature=args.temperature,
action_space=args.action_space,
observation_type=args.observation_type,
max_trajectory_length=args.max_trajectory_length,
)
logger.info(f"Process {current_process().name} started.")
while True:
try:
item = task_queue.get(timeout=5)
except Exception:
break
domain, example_id = item
try:
config_file = os.path.join(
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
)
with open(config_file, "r", encoding="utf-8") as f:
example = json.load(f)
logger.info(f"[{current_process().name}][Domain]: {domain}")
logger.info(f"[{current_process().name}][Example ID]: {example_id}")
logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}")
example_result_dir = os.path.join(
args.result_dir,
args.action_space,
args.observation_type,
args.model,
domain,
example_id,
)
os.makedirs(example_result_dir, exist_ok=True)
try:
lib_run_single.run_single_example(
agent,
env,
example,
args.max_steps,
example["instruction"],
args,
example_result_dir,
shared_scores,
)
except Exception as e:
import traceback
logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}")
logger.error(traceback.format_exc())
try:
env.controller.end_recording(
os.path.join(example_result_dir, "recording.mp4")
)
except Exception as rec_e:
logger.error(f"Failed to end recording: {rec_e}")
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write(
json.dumps(
{"Error": f"{domain}/{example_id} - {e}"}
)
)
f.write("\n")
except Exception as e:
logger.error(f"Task-level error in {current_process().name}: {e}")
import traceback
logger.error(traceback.format_exc())
except Exception as e:
logger.error(f"Process-level error in {current_process().name}: {e}")
import traceback
logger.error(traceback.format_exc())
finally:
logger.info(f"{current_process().name} cleaning up environment...")
try:
if env:
env.close()
logger.info(f"{current_process().name} environment closed successfully")
except Exception as e:
logger.error(f"{current_process().name} error during environment cleanup: {e}")
def signal_handler(signum, frame):
"""Handle termination signals (SIGINT, SIGTERM) to gracefully shutdown environments."""
global is_terminating, active_environments, processes
# Avoid duplicate handling
if is_terminating:
return
is_terminating = True
logger.info(f"Received signal {signum}. Gracefully shutting down...")
# Close all registered environments in the main process
for env in active_environments:
try:
logger.info(f"Closing environment...")
env.close()
logger.info(f"Environment closed successfully")
except Exception as e:
logger.error(f"Error closing environment: {e}")
# Send termination signal to all child processes first
for p in processes:
if p.is_alive():
try:
logger.info(f"Sending termination signal to process {p.name}...")
p.terminate()
except Exception as e:
logger.error(f"Error sending termination signal to process: {e}")
# Allow a short time for processes to handle their own cleanup
time.sleep(1)
# Forcefully terminate any processes that didn't exit
for p in processes:
if p.is_alive():
try:
logger.info(f"Forcefully terminating process {p.name}...")
import signal as sig
os.kill(p.pid, sig.SIGKILL)
except Exception as e:
logger.error(f"Error forcefully terminating process: {e}")
logger.info("Shutdown complete. Exiting.")
sys.exit(0)
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
global processes
logger.info("Args: %s", args)
all_tasks = distribute_tasks(test_all_meta)
logger.info(f"Total tasks: {len(all_tasks)}")
with Manager() as manager:
shared_scores = manager.list()
task_queue = manager.Queue()
for item in all_tasks:
task_queue.put(item)
num_envs = args.num_envs
processes = []
for i in range(num_envs):
p = Process(
target=run_env_tasks,
args=(task_queue, args, shared_scores),
name=f"EnvProcess-{i+1}"
)
p.daemon = True
p.start()
processes.append(p)
logger.info(f"Started process {p.name} with PID {p.pid}")
try:
while True:
alive_count = 0
for idx, p in enumerate(processes):
if not p.is_alive():
logger.warning(f"Process {p.name} died, restarting...")
new_p = Process(
target=run_env_tasks,
args=(task_queue, args, shared_scores),
name=f"EnvProcess-Restart-{idx+1}"
)
new_p.daemon = True
new_p.start()
processes[idx] = new_p
logger.info(f"Restarted process {new_p.name} with PID {new_p.pid}")
else:
alive_count += 1
if task_queue.empty():
logger.info("All tasks finished.")
break
if alive_count == 0:
logger.error("All processes died, exiting.")
break
time.sleep(5)
for p in processes:
p.join()
except KeyboardInterrupt:
logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...")
raise
except Exception as e:
logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True)
for p in processes:
if p.is_alive():
try:
logger.info(f"Terminating process {p.name} due to error...")
p.terminate()
except Exception as term_e:
logger.error(f"Error terminating process {p.name}: {term_e}")
raise
scores = list(shared_scores)
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
def get_unfinished(
action_space, use_model, observation_type, result_dir, total_file_json
):
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
if not os.path.exists(target_dir):
return total_file_json
finished = {}
for domain in os.listdir(target_dir):
finished[domain] = []
domain_path = os.path.join(target_dir, domain)
if os.path.isdir(domain_path):
for example_id in os.listdir(domain_path):
if example_id == "onboard":
continue
example_path = os.path.join(domain_path, example_id)
if os.path.isdir(example_path):
if "result.txt" not in os.listdir(example_path):
# empty all files under example_id
for file in os.listdir(example_path):
os.remove(os.path.join(example_path, file))
else:
finished[domain].append(example_id)
if not finished:
return total_file_json
for domain, examples in finished.items():
if domain in total_file_json:
total_file_json[domain] = [
x for x in total_file_json[domain] if x not in examples
]
return total_file_json
def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
if not os.path.exists(target_dir):
print("New experiment, no result yet.")
return None
all_result = []
for domain in os.listdir(target_dir):
domain_path = os.path.join(target_dir, domain)
if os.path.isdir(domain_path):
for example_id in os.listdir(domain_path):
example_path = os.path.join(domain_path, example_id)
if os.path.isdir(example_path):
if "result.txt" in os.listdir(example_path):
# empty all files under example_id
try:
all_result.append(
float(
open(
os.path.join(example_path, "result.txt"), "r"
).read()
)
)
except:
all_result.append(0.0)
if not all_result:
print("New experiment, no result yet.")
return None
else:
print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%")
return all_result
if __name__ == "__main__":
####### The complete version of the list of examples #######
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Register signal handlers for graceful termination
signal.signal(signal.SIGINT, signal_handler) # Handle Ctrl+C
signal.signal(signal.SIGTERM, signal_handler) # Handle termination signal
try:
args = config()
# save args to json in result_dir/action_space/observation_type/model/args.json
path_to_args = os.path.join(
args.result_dir,
args.action_space,
args.observation_type,
args.model,
"args.json",
)
os.makedirs(os.path.dirname(path_to_args), exist_ok=True)
with open(path_to_args, "w", encoding="utf-8") as f:
json.dump(vars(args), f, indent=4)
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
test_all_meta = json.load(f)
if args.domain != "all":
test_all_meta = {args.domain: test_all_meta[args.domain]}
test_file_list = get_unfinished(
args.action_space,
args.model,
args.observation_type,
args.result_dir,
test_all_meta,
)
left_info = ""
for domain in test_file_list:
left_info += f"{domain}: {len(test_file_list[domain])}\n"
logger.info(f"Left tasks:\n{left_info}")
get_result(
args.action_space,
args.model,
args.observation_type,
args.result_dir,
test_all_meta,
)
test(args, test_file_list)
except KeyboardInterrupt:
logger.info("Main process received KeyboardInterrupt.")
# Signal handler will take care of cleanup
except Exception as e:
logger.error(f"Unexpected error in main process: {e}", exc_info=True)
# Also trigger cleanup for unhandled exceptions
signal_handler(signal.SIGTERM, None)
finally:
# Final cleanup in case any environments or processes remain
logger.info("Main process final cleanup...")
for env in active_environments:
if env is not None:
try:
logger.info(f"Closing environment in final cleanup...")
env.close()
logger.info(f"Environment closed successfully in final cleanup")
except Exception as e:
logger.error(f"Error during final environment cleanup: {e}")
# First try gentle termination
for p in processes:
if p is not None and p.is_alive():
try:
logger.info(f"Terminating process {p.name}...")
p.terminate()
except Exception as e:
logger.error(f"Error terminating process: {e}")
# Wait a moment for processes to terminate
time.sleep(1)
# Then force kill if needed
for p in processes:
if p is not None and p.is_alive():
try:
logger.info(f"Force killing process {p.name}...")
os.kill(p.pid, signal.SIGKILL)
logger.info(f"Process {p.name} force killed")
except Exception as e:
logger.error(f"Error force killing process: {e}")

581
run_multienv_uitars15_v1.py Normal file
View File

@ -0,0 +1,581 @@
from __future__ import annotations
import argparse
import datetime
import json
import logging
import os
import sys
import signal
import time
from typing import List
from multiprocessing import Process, Manager
from multiprocessing import current_process
import lib_run_single
from desktop_env.desktop_env import DesktopEnv
from mm_agents.uitars15_v1 import UITARSAgent
import os
# Global variables for signal handling
active_environments = []
processes = []
is_terminating = False
# load the environment variables from .env file
if os.path.exists(".env"):
from dotenv import load_dotenv
load_dotenv()
# Logger Configs {{{ #
def config() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Run end-to-end evaluation on the benchmark"
)
# environment config
parser.add_argument("--path_to_vm", type=str, default=None)
parser.add_argument(
"--headless", action="store_true", help="Run in headless machine"
)
parser.add_argument(
"--action_space", type=str, default="pyautogui", help="Action type"
)
parser.add_argument(
"--observation_type",
choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"],
default="screenshot",
help="Observation type",
)
parser.add_argument("--sleep_after_execution", type=float, default=3.0)
parser.add_argument("--max_steps", type=int, default=15)
# evaluation config
parser.add_argument(
"--test_config_base_dir", type=str, default="evaluation_examples"
)
# lm config
parser.add_argument("--model", type=str, default="uitars15-7b")
parser.add_argument("--model_type", type=str, default="qwen25vl", choices=["qwen25vl", "qwen2vl"])
parser.add_argument("--infer_mode", type=str, default="qwen25vl_normal", choices=["qwen25vl_normal", "qwen2vl_user"])
parser.add_argument("--prompt_style", type=str, default="qwen25vl_normal")
parser.add_argument("--input_swap", action="store_true", help="Use copy and paste to type content")
parser.add_argument("--language", type=str, default="Chinese")
parser.add_argument("--max_pixels", type=float, default=16384*28*28)
parser.add_argument("--min_pixels", type=float, default=100*28*28)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--top_p", type=float, default=0.9)
parser.add_argument("--top_k", type=int, default=-1)
parser.add_argument("--history_n", type=int, default=5)
parser.add_argument("--callusr_tolerance", type=int, default=3)
parser.add_argument("--max_tokens", type=int, default=500)
parser.add_argument("--stop_token", type=str, default=None)
parser.add_argument("--max_trajectory_length", type=int, default=None, help="The max number of trajectory steps.")
parser.add_argument("--max_image_history_length", type=int, default=5, help="The max number of images in the history.")
# example config
parser.add_argument("--domain", type=str, default="all")
parser.add_argument(
"--test_all_meta_path", type=str, default="evaluation_examples/test_all.json"
)
# logging related
parser.add_argument("--result_dir", type=str, default="./results")
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel")
parser.add_argument("--log_level", type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
default='INFO', help="Set the logging level")
# aws config
parser.add_argument(
"--region", type=str, default="us-east-1", help="AWS region for the VM"
)
parser.add_argument(
"--provider_name", type=str, default="aws", choices=["aws", "virtualbox", "vmware", "docker", "azure"], help="Provider name"
)
parser.add_argument(
"--client_password", type=str, default="", help="Client password"
)
parser.add_argument(
"--screen_width", type=int, default=1920, help="Screen width"
)
parser.add_argument(
"--screen_height", type=int, default=1080, help="Screen height"
)
args = parser.parse_args()
return args
args = config() # Get command line arguments first
logger = logging.getLogger()
log_level = getattr(logging, args.log_level.upper())
logger.setLevel(log_level)
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
file_handler = logging.FileHandler(
os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8"
)
debug_handler = logging.FileHandler(
os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8"
)
stdout_handler = logging.StreamHandler(sys.stdout)
file_handler.setLevel(logging.INFO)
debug_handler.setLevel(logging.DEBUG)
stdout_handler.setLevel(log_level)
formatter = logging.Formatter(
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
)
file_handler.setFormatter(formatter)
debug_handler.setFormatter(formatter)
stdout_handler.setFormatter(formatter)
stdout_handler.addFilter(logging.Filter("desktopenv"))
logger.addHandler(file_handler)
logger.addHandler(debug_handler)
logger.addHandler(stdout_handler)
# }}} Logger Configs #
logger = logging.getLogger("desktopenv.experiment")
def distribute_tasks(test_all_meta: dict) -> List[tuple]:
all_tasks = []
for domain, examples in test_all_meta.items():
for example_id in examples:
all_tasks.append((domain, example_id))
return all_tasks
def process_signal_handler(signum, frame, env_idx):
"""Signal handler for child processes to gracefully shut down their environments."""
logger.info(f"Process {env_idx + 1} received signal {signum}. Shutting down...")
# Get the active_environments from the caller's frame
local_vars = frame.f_locals
active_environments = local_vars.get('active_environments', [])
# Close environment in the current process context
for env in active_environments:
if env is not None:
try:
logger.info(f"Process {env_idx + 1} closing environment...")
env.close()
logger.info(f"Process {env_idx + 1} environment closed successfully")
except Exception as e:
logger.error(f"Process {env_idx + 1} error closing environment: {e}")
logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.")
sys.exit(0)
def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: list):
active_environments = []
env = None
try:
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
REGION = args.region
screen_size = (args.screen_width, args.screen_height)
ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)])
env = DesktopEnv(
path_to_vm=args.path_to_vm,
action_space=args.action_space,
provider_name=args.provider_name,
region=REGION,
snapshot_name=ami_id,
screen_size=screen_size,
headless=args.headless,
os_type="Ubuntu",
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
enable_proxy=True,
client_password=args.client_password
)
active_environments.append(env)
args.max_trajectory_length = args.max_steps
if args.infer_mode == "qwen25vl_normal":
runtime_conf: dict = {
"infer_mode": "qwen25vl_normal",
"prompt_style": "qwen25vl_normal",
"input_swap": True,
"language": "Chinese",
"history_n": 5,
"max_pixels": 16384*28*28,
"min_pixels": 100*28*28,
"callusr_tolerance": 3,
"temperature": 0.0,
"top_k": -1,
"top_p": 0.9,
"max_tokens": 1000
}
elif args.infer_mode == "qwen2vl_user":
runtime_conf: dict = {
"infer_mode": "qwen2vl_user",
"prompt_style": "qwen2vl_user",
"input_swap": True,
"language": "Chinese",
"history_n": 5,
"max_pixels": 2116800,
"min_pixels": 3136,
"callusr_tolerance": 3,
"temperature": 0.0,
"top_k": -1,
"top_p": 0.9,
"max_tokens": 1000
}
else:
raise ValueError(f"Unknown infer_mode: {args.infer_mode}")
agent = UITARSAgent(
model=args.model,
action_space=args.action_space,
observation_type=args.observation_type,
max_trajectory_length=args.max_trajectory_length,
model_type=args.model_type,
runtime_conf = runtime_conf
)
logger.info(f"Process {current_process().name} started.")
while True:
try:
item = task_queue.get(timeout=5)
except Exception:
break
domain, example_id = item
try:
config_file = os.path.join(
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
)
with open(config_file, "r", encoding="utf-8") as f:
example = json.load(f)
logger.info(f"[{current_process().name}][Domain]: {domain}")
logger.info(f"[{current_process().name}][Example ID]: {example_id}")
logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}")
example_result_dir = os.path.join(
args.result_dir,
args.action_space,
args.observation_type,
args.model,
domain,
example_id,
)
os.makedirs(example_result_dir, exist_ok=True)
try:
lib_run_single.run_single_example(
agent,
env,
example,
args.max_steps,
example["instruction"],
args,
example_result_dir,
shared_scores,
)
except Exception as e:
import traceback
logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}")
logger.error(traceback.format_exc())
try:
env.controller.end_recording(
os.path.join(example_result_dir, "recording.mp4")
)
except Exception as rec_e:
logger.error(f"Failed to end recording: {rec_e}")
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write(
json.dumps(
{"Error": f"{domain}/{example_id} - {e}"}
)
)
f.write("\n")
except Exception as e:
logger.error(f"Task-level error in {current_process().name}: {e}")
import traceback
logger.error(traceback.format_exc())
except Exception as e:
logger.error(f"Process-level error in {current_process().name}: {e}")
import traceback
logger.error(traceback.format_exc())
finally:
logger.info(f"{current_process().name} cleaning up environment...")
try:
if env:
env.close()
logger.info(f"{current_process().name} environment closed successfully")
except Exception as e:
logger.error(f"{current_process().name} error during environment cleanup: {e}")
def signal_handler(signum, frame):
"""Handle termination signals (SIGINT, SIGTERM) to gracefully shutdown environments."""
global is_terminating, active_environments, processes
# Avoid duplicate handling
if is_terminating:
return
is_terminating = True
logger.info(f"Received signal {signum}. Gracefully shutting down...")
# Close all registered environments in the main process
for env in active_environments:
try:
logger.info(f"Closing environment...")
env.close()
logger.info(f"Environment closed successfully")
except Exception as e:
logger.error(f"Error closing environment: {e}")
# Send termination signal to all child processes first
for p in processes:
if p.is_alive():
try:
logger.info(f"Sending termination signal to process {p.name}...")
p.terminate()
except Exception as e:
logger.error(f"Error sending termination signal to process: {e}")
# Allow a short time for processes to handle their own cleanup
time.sleep(1)
# Forcefully terminate any processes that didn't exit
for p in processes:
if p.is_alive():
try:
logger.info(f"Forcefully terminating process {p.name}...")
import signal as sig
os.kill(p.pid, sig.SIGKILL)
except Exception as e:
logger.error(f"Error forcefully terminating process: {e}")
logger.info("Shutdown complete. Exiting.")
sys.exit(0)
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
global processes
logger.info("Args: %s", args)
all_tasks = distribute_tasks(test_all_meta)
logger.info(f"Total tasks: {len(all_tasks)}")
with Manager() as manager:
shared_scores = manager.list()
task_queue = manager.Queue()
for item in all_tasks:
task_queue.put(item)
num_envs = args.num_envs
processes = []
for i in range(num_envs):
p = Process(
target=run_env_tasks,
args=(task_queue, args, shared_scores),
name=f"EnvProcess-{i+1}"
)
p.daemon = True
p.start()
processes.append(p)
logger.info(f"Started process {p.name} with PID {p.pid}")
try:
while True:
alive_count = 0
for idx, p in enumerate(processes):
if not p.is_alive():
logger.warning(f"Process {p.name} died, restarting...")
new_p = Process(
target=run_env_tasks,
args=(task_queue, args, shared_scores),
name=f"EnvProcess-Restart-{idx+1}"
)
new_p.daemon = True
new_p.start()
processes[idx] = new_p
logger.info(f"Restarted process {new_p.name} with PID {new_p.pid}")
else:
alive_count += 1
if task_queue.empty():
logger.info("All tasks finished.")
break
if alive_count == 0:
logger.error("All processes died, exiting.")
break
time.sleep(5)
for p in processes:
p.join()
except KeyboardInterrupt:
logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...")
raise
except Exception as e:
logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True)
for p in processes:
if p.is_alive():
try:
logger.info(f"Terminating process {p.name} due to error...")
p.terminate()
except Exception as term_e:
logger.error(f"Error terminating process {p.name}: {term_e}")
raise
scores = list(shared_scores)
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
def get_unfinished(
action_space, use_model, observation_type, result_dir, total_file_json
):
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
if not os.path.exists(target_dir):
return total_file_json
finished = {}
for domain in os.listdir(target_dir):
finished[domain] = []
domain_path = os.path.join(target_dir, domain)
if os.path.isdir(domain_path):
for example_id in os.listdir(domain_path):
if example_id == "onboard":
continue
example_path = os.path.join(domain_path, example_id)
if os.path.isdir(example_path):
if "result.txt" not in os.listdir(example_path):
# empty all files under example_id
for file in os.listdir(example_path):
os.remove(os.path.join(example_path, file))
else:
finished[domain].append(example_id)
if not finished:
return total_file_json
for domain, examples in finished.items():
if domain in total_file_json:
total_file_json[domain] = [
x for x in total_file_json[domain] if x not in examples
]
return total_file_json
def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
if not os.path.exists(target_dir):
print("New experiment, no result yet.")
return None
all_result = []
for domain in os.listdir(target_dir):
domain_path = os.path.join(target_dir, domain)
if os.path.isdir(domain_path):
for example_id in os.listdir(domain_path):
example_path = os.path.join(domain_path, example_id)
if os.path.isdir(example_path):
if "result.txt" in os.listdir(example_path):
# empty all files under example_id
try:
all_result.append(
float(
open(
os.path.join(example_path, "result.txt"), "r"
).read()
)
)
except:
all_result.append(0.0)
if not all_result:
print("New experiment, no result yet.")
return None
else:
print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%")
return all_result
if __name__ == "__main__":
####### The complete version of the list of examples #######
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Register signal handlers for graceful termination
signal.signal(signal.SIGINT, signal_handler) # Handle Ctrl+C
signal.signal(signal.SIGTERM, signal_handler) # Handle termination signal
try:
args = config()
# save args to json in result_dir/action_space/observation_type/model/args.json
path_to_args = os.path.join(
args.result_dir,
args.action_space,
args.observation_type,
args.model,
"args.json",
)
os.makedirs(os.path.dirname(path_to_args), exist_ok=True)
with open(path_to_args, "w", encoding="utf-8") as f:
json.dump(vars(args), f, indent=4)
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
test_all_meta = json.load(f)
if args.domain != "all":
test_all_meta = {args.domain: test_all_meta[args.domain]}
test_file_list = get_unfinished(
args.action_space,
args.model,
args.observation_type,
args.result_dir,
test_all_meta,
)
left_info = ""
for domain in test_file_list:
left_info += f"{domain}: {len(test_file_list[domain])}\n"
logger.info(f"Left tasks:\n{left_info}")
get_result(
args.action_space,
args.model,
args.observation_type,
args.result_dir,
test_all_meta,
)
test(args, test_file_list)
except KeyboardInterrupt:
logger.info("Main process received KeyboardInterrupt.")
# Signal handler will take care of cleanup
except Exception as e:
logger.error(f"Unexpected error in main process: {e}", exc_info=True)
# Also trigger cleanup for unhandled exceptions
signal_handler(signal.SIGTERM, None)
finally:
# Final cleanup in case any environments or processes remain
logger.info("Main process final cleanup...")
for env in active_environments:
if env is not None:
try:
logger.info(f"Closing environment in final cleanup...")
env.close()
logger.info(f"Environment closed successfully in final cleanup")
except Exception as e:
logger.error(f"Error during final environment cleanup: {e}")
# First try gentle termination
for p in processes:
if p is not None and p.is_alive():
try:
logger.info(f"Terminating process {p.name}...")
p.terminate()
except Exception as e:
logger.error(f"Error terminating process: {e}")
# Wait a moment for processes to terminate
time.sleep(1)
# Then force kill if needed
for p in processes:
if p is not None and p.is_alive():
try:
logger.info(f"Force killing process {p.name}...")
os.kill(p.pid, signal.SIGKILL)
logger.info(f"Process {p.name} force killed")
except Exception as e:
logger.error(f"Error force killing process: {e}")

View File

@ -8,31 +8,13 @@ import sys
import signal
import time
from typing import List, Dict
import math
from tqdm import tqdm
from multiprocessing import Process, Manager
from multiprocessing import current_process
import lib_run_single
from desktop_env.desktop_env import DesktopEnv
from mm_agents.uitars15_agent import UITarsAgent
import shutil
from mm_agents.uitars15_v2 import UITarsAgent
import os
# def clear_cache():
# cache_path = "cache"
# try:
# if os.path.exists(cache_path):
# logger.info(f"Deleting cache directory: {cache_path}")
# shutil.rmtree(cache_path)
# logger.info(f"Cache directory deleted successfully")
# else:
# logger.info(f"Cache directory {cache_path} does not exist")
# except Exception as e:
# logger.error(f"Error deleting cache directory: {e}")
# clear_cache()
# Global variables for signal handling
active_environments = []
@ -74,12 +56,12 @@ def config() -> argparse.Namespace:
# lm config
parser.add_argument("--model", type=str, default="doubao-1-5-thinking-vision-pro-250428")
parser.add_argument("--model_type", type=str, default="doubao", choices=["doubao", "qwen25"])
parser.add_argument("--temperature", type=float, default=0)
parser.add_argument("--top_p", type=float, default=None)
parser.add_argument("--max_tokens", type=int, default=3000)
parser.add_argument("--use_thinking", action="store_true", default=False)
# OpenCUAagent config
parser.add_argument("--max_trajectory_length", type=int, default=None, help="The max number of trajectory steps.")
parser.add_argument("--max_image_history_length", type=int, default=5, help="The max number of images in the history.")
parser.add_argument("--language", type=str, default="Chinese", help="Language for the agent.")
@ -204,6 +186,7 @@ def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: li
active_environments.append(env)
agent = UITarsAgent(
model=args.model,
model_type=args.model_type,
max_tokens=args.max_tokens,
top_p=args.top_p,
temperature=args.temperature,

View File

@ -1,362 +0,0 @@
"""Script to run end-to-end evaluation on the benchmark.
Utils and basic architecture credit to https://github.com/web-arena-x/webarena/blob/main/run.py.
"""
import argparse
import datetime
import json
import logging
import os
import sys
from tqdm import tqdm
import lib_run_single
from desktop_env.desktop_env import DesktopEnv
from mm_agents.uitars_agent import UITARSAgent
# import wandb
# Logger Configs {{{ #
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
file_handler = logging.FileHandler(
os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8"
)
debug_handler = logging.FileHandler(
os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8"
)
stdout_handler = logging.StreamHandler(sys.stdout)
sdebug_handler = logging.FileHandler(
os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8"
)
file_handler.setLevel(logging.INFO)
debug_handler.setLevel(logging.DEBUG)
stdout_handler.setLevel(logging.INFO)
sdebug_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter(
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
)
file_handler.setFormatter(formatter)
debug_handler.setFormatter(formatter)
stdout_handler.setFormatter(formatter)
sdebug_handler.setFormatter(formatter)
stdout_handler.addFilter(logging.Filter("desktopenv"))
sdebug_handler.addFilter(logging.Filter("desktopenv"))
logger.addHandler(file_handler)
logger.addHandler(debug_handler)
logger.addHandler(stdout_handler)
logger.addHandler(sdebug_handler)
# }}} Logger Configs #
logger = logging.getLogger("desktopenv.experiment")
def config() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Run end-to-end evaluation on the benchmark"
)
# environment config
parser.add_argument("--path_to_vm", type=str, default=None)
parser.add_argument(
"--headless", action="store_true", help="Run in headless machine"
)
parser.add_argument(
"--action_space", type=str, default="pyautogui", help="Action type"
)
parser.add_argument(
"--observation_type",
choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"],
default="a11y_tree",
help="Observation type",
)
parser.add_argument("--screen_width", type=int, default=1920)
parser.add_argument("--screen_height", type=int, default=1080)
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
parser.add_argument("--max_steps", type=int, default=15)
# agent config
parser.add_argument("--max_trajectory_length", type=int, default=3)
parser.add_argument(
"--test_config_base_dir", type=str, default="evaluation_examples"
)
# lm config
parser.add_argument("--model", type=str, default="uitars")
parser.add_argument("--model_type", type=str, default="qwen25vl")
parser.add_argument("--infer_mode", type=str, default="qwen25vl_normal")
parser.add_argument("--prompt_style", type=str, default="qwen25vl_normal")
parser.add_argument("--input_swap", action="store_true", help="Use copy and paste to type content")
parser.add_argument("--language", type=str, default="Chinese")
parser.add_argument("--max_pixels", type=float, default=16384*28*28)
parser.add_argument("--min_pixels", type=float, default=100*28*28)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--top_p", type=float, default=0.9)
parser.add_argument("--top_k", type=int, default=-1)
parser.add_argument("--history_n", type=int, default=5)
parser.add_argument("--callusr_tolerance", type=int, default=3)
parser.add_argument("--max_tokens", type=int, default=500)
parser.add_argument("--stop_token", type=str, default=None)
# example config
parser.add_argument("--domain", type=str, default="all")
parser.add_argument(
"--test_all_meta_path", type=str, default="evaluation_examples/test_all.json"
)
# logging related
parser.add_argument("--result_dir", type=str, default="./results")
args = parser.parse_args()
return args
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
scores = []
max_steps = args.max_steps
# log args
logger.info("Args: %s", args)
# set wandb project
cfg_args = {
"path_to_vm": args.path_to_vm,
"headless": args.headless,
"action_space": args.action_space,
"observation_type": args.observation_type,
"screen_width": args.screen_width,
"screen_height": args.screen_height,
"sleep_after_execution": args.sleep_after_execution,
"max_steps": args.max_steps,
"max_trajectory_length": args.max_trajectory_length,
"model": args.model,
"model_type": args.model_type,
"infer_mode": args.infer_mode,
"prompt_style": args.prompt_style,
"input_swap": args.input_swap,
"language": args.language,
"history_n": args.history_n,
"max_pixels": args.max_pixels,
"min_pixels": args.min_pixels,
"callusr_tolerance": args.callusr_tolerance,
"temperature": args.temperature,
"top_p": args.top_p,
"top_k": args.top_k,
"max_tokens": args.max_tokens,
"stop_token": args.stop_token,
"result_dir": args.result_dir,
}
agent = UITARSAgent(
action_space=args.action_space,
observation_type=args.observation_type,
max_trajectory_length=args.max_trajectory_length,
model_type=args.model_type,
runtime_conf = {
"infer_mode": args.infer_mode,
"prompt_style": args.prompt_style,
"input_swap": args.input_swap,
"language": args.language,
"history_n": args.history_n,
"max_pixels": args.max_pixels,
"min_pixels": args.min_pixels,
"callusr_tolerance": args.callusr_tolerance,
"temperature": args.temperature,
"top_p": args.top_p,
"top_k": args.top_k,
"max_tokens": args.max_tokens
}
)
env = DesktopEnv(
path_to_vm=args.path_to_vm,
action_space=agent.action_space,
screen_size=(args.screen_width, args.screen_height),
headless=args.headless,
os_type = "Ubuntu",
require_a11y_tree=args.observation_type
in ["a11y_tree", "screenshot_a11y_tree", "som"],
)
for domain in tqdm(test_all_meta, desc="Domain"):
for example_id in tqdm(test_all_meta[domain], desc="Example", leave=False):
config_file = os.path.join(
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
)
with open(config_file, "r", encoding="utf-8") as f:
example = json.load(f)
logger.info(f"[Domain]: {domain}")
logger.info(f"[Example ID]: {example_id}")
instruction = example["instruction"]
logger.info(f"[Instruction]: {instruction}")
# wandb each example config settings
cfg_args["instruction"] = instruction
cfg_args["start_time"] = datetime.datetime.now().strftime(
"%Y:%m:%d-%H:%M:%S"
)
# run.config.update(cfg_args)
example_result_dir = os.path.join(
args.result_dir,
args.action_space,
args.observation_type,
args.model,
domain,
example_id,
)
os.makedirs(example_result_dir, exist_ok=True)
# example start running
try:
lib_run_single.run_single_example(
agent,
env,
example,
max_steps,
instruction,
args,
example_result_dir,
scores,
)
except Exception as e:
logger.error(f"Exception in {domain}/{example_id}: {e}")
env.controller.end_recording(
os.path.join(example_result_dir, "recording.mp4")
)
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write(
json.dumps(
{"Error": f"Time limit exceeded in {domain}/{example_id}"}
)
)
f.write("\n")
env.close()
logger.info(f"Average score: {sum(scores) / len(scores)}")
def get_unfinished(
action_space, use_model, observation_type, result_dir, total_file_json
):
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
if not os.path.exists(target_dir):
return total_file_json
finished = {}
for domain in os.listdir(target_dir):
finished[domain] = []
domain_path = os.path.join(target_dir, domain)
if os.path.isdir(domain_path):
for example_id in os.listdir(domain_path):
if example_id == "onboard":
continue
example_path = os.path.join(domain_path, example_id)
if os.path.isdir(example_path):
if "result.txt" not in os.listdir(example_path):
# empty all files under example_id
for file in os.listdir(example_path):
os.remove(os.path.join(example_path, file))
else:
finished[domain].append(example_id)
if not finished:
return total_file_json
for domain, examples in finished.items():
if domain in total_file_json:
total_file_json[domain] = [
x for x in total_file_json[domain] if x not in examples
]
return total_file_json
def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
if not os.path.exists(target_dir):
print("New experiment, no result yet.")
return None
all_result = []
for domain in os.listdir(target_dir):
domain_path = os.path.join(target_dir, domain)
if os.path.isdir(domain_path):
for example_id in os.listdir(domain_path):
example_path = os.path.join(domain_path, example_id)
if os.path.isdir(example_path):
if "result.txt" in os.listdir(example_path):
# empty all files under example_id
try:
all_result.append(
float(
open(
os.path.join(example_path, "result.txt"), "r"
).read()
)
)
except:
all_result.append(0.0)
if not all_result:
print("New experiment, no result yet.")
return None
else:
print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%")
return all_result
if __name__ == "__main__":
####### The complete version of the list of examples #######
os.environ["TOKENIZERS_PARALLELISM"] = "false"
args = config()
# save args to json in result_dir/action_space/observation_type/model/args.json
path_to_args = os.path.join(
args.result_dir,
args.action_space,
args.observation_type,
args.model,
"args.json",
)
os.makedirs(os.path.dirname(path_to_args), exist_ok=True)
with open(path_to_args, "w", encoding="utf-8") as f:
json.dump(vars(args), f, indent=4)
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
test_all_meta = json.load(f)
if args.domain != "all":
test_all_meta = {args.domain: test_all_meta[args.domain]}
test_file_list = get_unfinished(
args.action_space,
args.model,
args.observation_type,
args.result_dir,
test_all_meta,
)
left_info = ""
for domain in test_file_list:
left_info += f"{domain}: {len(test_file_list[domain])}\n"
logger.info(f"Left tasks:\n{left_info}")
get_result(
args.action_space,
args.model,
args.observation_type,
args.result_dir,
test_all_meta,
)
test(args, test_file_list)