Compare commits
1 Commits
main
...
djlu/qwenv
| Author | SHA1 | Date |
|---|---|---|
|
|
0080d3cf08 |
|
|
@ -0,0 +1,582 @@
|
|||
import base64
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import os
|
||||
from io import BytesIO
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import backoff
|
||||
import openai
|
||||
from PIL import Image
|
||||
from requests.exceptions import SSLError
|
||||
from google.api_core.exceptions import (
|
||||
InvalidArgument,
|
||||
ResourceExhausted,
|
||||
InternalServerError,
|
||||
BadRequest,
|
||||
)
|
||||
from mm_agents.utils.qwen_vl_utils import smart_resize
|
||||
|
||||
|
||||
|
||||
logger = None
|
||||
|
||||
MAX_RETRY_TIMES = 5
|
||||
|
||||
def encode_image(image_content):
|
||||
return base64.b64encode(image_content).decode("utf-8")
|
||||
|
||||
|
||||
def process_image(image_bytes):
|
||||
"""
|
||||
Process an image for Qwen VL models.
|
||||
Resize the image to dimensions expected by the model.
|
||||
|
||||
Args:
|
||||
image_bytes: Raw image bytes
|
||||
|
||||
Returns:
|
||||
Base64 encoded image string of the processed image
|
||||
"""
|
||||
# Open image from bytes
|
||||
image = Image.open(BytesIO(image_bytes))
|
||||
width, height = image.size
|
||||
|
||||
# Calculate resized dimensions
|
||||
resized_height, resized_width = smart_resize(
|
||||
height=height,
|
||||
width=width
|
||||
)
|
||||
|
||||
# Resize the image
|
||||
image = image.resize((resized_width, resized_height))
|
||||
|
||||
# Convert to bytes
|
||||
buffer = BytesIO()
|
||||
image.save(buffer, format="PNG")
|
||||
processed_bytes = buffer.getvalue()
|
||||
|
||||
# Return base64 encoded string
|
||||
return base64.b64encode(processed_bytes).decode('utf-8')
|
||||
|
||||
|
||||
class Qwen25VLAgent:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
platform="ubuntu",
|
||||
planner_model="gpt-4o",
|
||||
executor_model="qwen2.5vl",
|
||||
max_tokens=1500,
|
||||
top_p=0.9,
|
||||
temperature=0.5,
|
||||
action_space="pyautogui",
|
||||
observation_type="screenshot",
|
||||
history_n=4, # Number of previous interactions to include in full detail
|
||||
):
|
||||
self.platform = platform
|
||||
self.planner_model = planner_model
|
||||
self.executor_model = executor_model
|
||||
assert self.executor_model is not None, "Executor model cannot be None"
|
||||
self.max_tokens = max_tokens
|
||||
self.top_p = top_p
|
||||
self.temperature = temperature
|
||||
self.action_space = action_space
|
||||
self.observation_type = observation_type
|
||||
self.history_n = history_n # Control how many previous interactions to include
|
||||
assert action_space in ["pyautogui"], "Invalid action space"
|
||||
assert observation_type in ["screenshot"], "Invalid observation type"
|
||||
self.thoughts = []
|
||||
self.actions = []
|
||||
self.observations = []
|
||||
self.responses = [] # Store model responses
|
||||
self.screenshots = [] # Store processed screenshots
|
||||
|
||||
def predict(self, instruction: str, obs: Dict) -> List:
|
||||
"""
|
||||
Predict the next action(s) based on the current observation.
|
||||
"""
|
||||
# Process the screenshot image
|
||||
screenshot_bytes = obs["screenshot"]
|
||||
|
||||
# Display original dimensions
|
||||
image = Image.open(BytesIO(screenshot_bytes))
|
||||
width, height = image.size
|
||||
print(f"Original screen resolution: {width}x{height}")
|
||||
|
||||
# Process the image
|
||||
processed_image = process_image(screenshot_bytes)
|
||||
processed_img = Image.open(BytesIO(base64.b64decode(processed_image)))
|
||||
processed_width, processed_height = processed_img.size
|
||||
print(f"Processed image resolution: {processed_width}x{processed_height}")
|
||||
|
||||
# Save the current screenshot to history
|
||||
self.screenshots.append(processed_image)
|
||||
|
||||
# Calculate history window start index
|
||||
current_step = len(self.actions)
|
||||
history_start_idx = max(0, current_step - self.history_n)
|
||||
|
||||
# Build previous actions string - only include actions outside the history window
|
||||
previous_actions = []
|
||||
for i in range(history_start_idx):
|
||||
if i < len(self.actions):
|
||||
previous_actions.append(f"Step {i+1}: {self.actions[i]}")
|
||||
previous_actions_str = "\n".join(previous_actions) if previous_actions else "None"
|
||||
|
||||
# System prompt with tool definition
|
||||
tools_def = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name_for_human": "computer_use",
|
||||
"name": "computer_use",
|
||||
"description": "Use a mouse and keyboard to interact with a computer, and take screenshots.",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"action": {
|
||||
"description": "The action to perform.",
|
||||
"enum": ["key", "type", "mouse_move", "left_click", "left_click_drag",
|
||||
"right_click", "middle_click", "double_click", "scroll", "wait", "terminate"],
|
||||
"type": "string"
|
||||
},
|
||||
"keys": {"description": "Required only by `action=key`.", "type": "array"},
|
||||
"text": {"description": "Required only by `action=type`.", "type": "string"},
|
||||
"coordinate": {"description": "The x,y coordinates for mouse actions.", "type": "array"},
|
||||
"pixels": {"description": "The amount of scrolling.", "type": "number"},
|
||||
"time": {"description": "The seconds to wait.", "type": "number"},
|
||||
"status": {
|
||||
"description": "The status of the task.",
|
||||
"type": "string",
|
||||
"enum": ["success", "failure"]
|
||||
}
|
||||
},
|
||||
"required": ["action"],
|
||||
"type": "object"
|
||||
},
|
||||
"args_format": "Format the arguments as a JSON object."
|
||||
}
|
||||
}
|
||||
|
||||
system_prompt = """You are a helpful assistant
|
||||
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
""" + json.dumps(tools_def) + """
|
||||
</tools>
|
||||
|
||||
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
||||
<tool_call>
|
||||
{"name": <function-name>, "arguments": <args-json-object>}
|
||||
</tool_call>"""
|
||||
|
||||
# Create instruction prompt
|
||||
instruction_prompt = f"""
|
||||
Please generate the next move according to the UI screenshot, instruction and previous actions.
|
||||
|
||||
Instruction: {instruction}
|
||||
|
||||
Previous actions:
|
||||
{previous_actions_str}"""
|
||||
|
||||
# Initialize messages with system prompt
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": system_prompt
|
||||
}]
|
||||
}
|
||||
]
|
||||
|
||||
# Add history responses and images within the history window
|
||||
history_len = min(self.history_n, len(self.responses))
|
||||
if history_len > 0:
|
||||
# Only include the most recent history_n steps
|
||||
history_responses = self.responses[-history_len:]
|
||||
history_screenshots = self.screenshots[-history_len-1:-1] # Include one more for the previous screenshot
|
||||
|
||||
# Add history in conversation format
|
||||
for idx in range(history_len):
|
||||
# Add the screenshot (user message)
|
||||
if idx < len(history_screenshots):
|
||||
screenshot_b64 = history_screenshots[idx]
|
||||
|
||||
# If this is the first history item, include instruction prompt
|
||||
if idx == 0:
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{screenshot_b64}"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": instruction_prompt
|
||||
}
|
||||
]
|
||||
})
|
||||
else:
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{screenshot_b64}"
|
||||
}
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
# Add the action and response (assistant message)
|
||||
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": history_responses[idx]}
|
||||
]
|
||||
})
|
||||
|
||||
# Add the current screenshot without instruction (since we already have history)
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{processed_image}"
|
||||
}
|
||||
}
|
||||
]
|
||||
})
|
||||
else:
|
||||
# If no history, just add current screenshot with instruction prompt
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{processed_image}"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": instruction_prompt
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
# append_text = f"""Step {current_step+1}: Thought:"""
|
||||
append_text = f"""Thought:"""
|
||||
messages.append({"role": "assistant", "content": [{"type": "text", "text": append_text}]})
|
||||
|
||||
# Call the LLM
|
||||
response = self.call_llm(
|
||||
{
|
||||
"model": self.executor_model,
|
||||
"messages": messages,
|
||||
"max_tokens": self.max_tokens,
|
||||
"top_p": self.top_p,
|
||||
"temperature": self.temperature,
|
||||
},
|
||||
self.executor_model,
|
||||
)
|
||||
|
||||
logger.info(f"Qwen25VL Output: {response}")
|
||||
|
||||
# Save response to history
|
||||
self.responses.append(response)
|
||||
|
||||
# Parse response and extract pyautogui code
|
||||
low_level_instruction, pyautogui_code = self.parse_response(
|
||||
response,
|
||||
width,
|
||||
height,
|
||||
processed_width,
|
||||
processed_height
|
||||
)
|
||||
|
||||
logger.info(f"Low level instruction: {low_level_instruction}")
|
||||
logger.info(f"Pyautogui code: {pyautogui_code}")
|
||||
|
||||
# Add the action to history
|
||||
self.actions.append(low_level_instruction)
|
||||
|
||||
return response, pyautogui_code
|
||||
|
||||
def parse_response(self, response: str, original_width: int = None, original_height: int = None,
|
||||
processed_width: int = None, processed_height: int = None) -> Tuple[str, List[str]]:
|
||||
"""
|
||||
Parse LLM response and convert it to low level action and pyautogui code.
|
||||
|
||||
Args:
|
||||
response: Raw response string from the model
|
||||
original_width: Width of the original screenshot
|
||||
original_height: Height of the original screenshot
|
||||
processed_width: Width of the processed image
|
||||
processed_height: Height of the processed image
|
||||
|
||||
Returns:
|
||||
Tuple of (low_level_instruction, list of pyautogui_commands)
|
||||
"""
|
||||
low_level_instruction = ""
|
||||
pyautogui_code = []
|
||||
|
||||
if response is None or not response.strip():
|
||||
return low_level_instruction, pyautogui_code
|
||||
|
||||
# Define function to adjust coordinates based on original and processed dimensions
|
||||
def adjust_coordinates(x: float, y: float) -> Tuple[int, int]:
|
||||
"""
|
||||
Adjust coordinates from processed image dimensions to original image dimensions.
|
||||
"""
|
||||
if all([original_width, original_height, processed_width, processed_height]):
|
||||
# Calculate the scale factors between original and processed images
|
||||
x_scale = original_width / processed_width
|
||||
y_scale = original_height / processed_height
|
||||
|
||||
# Apply scaling to get coordinates in original image space
|
||||
adjusted_x = int(x * x_scale)
|
||||
adjusted_y = int(y * y_scale)
|
||||
|
||||
return adjusted_x, adjusted_y
|
||||
else:
|
||||
# If any dimension is missing, return the original coordinates
|
||||
return int(x), int(y)
|
||||
|
||||
# Define inner function to process tool calls
|
||||
def process_tool_call(json_str: str) -> None:
|
||||
"""Process a single tool call JSON string."""
|
||||
try:
|
||||
# Parse the JSON
|
||||
tool_call = json.loads(json_str)
|
||||
if tool_call.get("name") == "computer_use":
|
||||
# Convert computer_use actions to pyautogui commands
|
||||
args = tool_call["arguments"]
|
||||
action = args["action"]
|
||||
|
||||
if action == "left_click":
|
||||
if "coordinate" in args:
|
||||
x, y = args["coordinate"]
|
||||
adj_x, adj_y = adjust_coordinates(x, y)
|
||||
pyautogui_code.append(f"pyautogui.click({adj_x}, {adj_y})")
|
||||
else:
|
||||
pyautogui_code.append("pyautogui.click()")
|
||||
|
||||
elif action == "right_click":
|
||||
if "coordinate" in args:
|
||||
x, y = args["coordinate"]
|
||||
adj_x, adj_y = adjust_coordinates(x, y)
|
||||
pyautogui_code.append(f"pyautogui.rightClick({adj_x}, {adj_y})")
|
||||
else:
|
||||
pyautogui_code.append("pyautogui.rightClick()")
|
||||
|
||||
elif action == "middle_click":
|
||||
if "coordinate" in args:
|
||||
x, y = args["coordinate"]
|
||||
adj_x, adj_y = adjust_coordinates(x, y)
|
||||
pyautogui_code.append(f"pyautogui.middleClick({adj_x}, {adj_y})")
|
||||
else:
|
||||
pyautogui_code.append("pyautogui.middleClick()")
|
||||
|
||||
elif action == "double_click":
|
||||
if "coordinate" in args:
|
||||
x, y = args["coordinate"]
|
||||
adj_x, adj_y = adjust_coordinates(x, y)
|
||||
pyautogui_code.append(f"pyautogui.doubleClick({adj_x}, {adj_y})")
|
||||
else:
|
||||
pyautogui_code.append("pyautogui.doubleClick()")
|
||||
|
||||
elif action == "type":
|
||||
text = args.get("text", "")
|
||||
pyautogui_code.append(f"pyautogui.typewrite('{text}')")
|
||||
|
||||
elif action == "key":
|
||||
keys = args.get("keys", [])
|
||||
# Fix possible formatting issues in the keys parameter
|
||||
if isinstance(keys, list):
|
||||
# Clean up any formatting issues in the keys
|
||||
cleaned_keys = []
|
||||
for key in keys:
|
||||
# Check if the key has the "keys=[" prefix or "]" suffix
|
||||
if isinstance(key, str):
|
||||
# Remove "keys=[" prefix if present
|
||||
if key.startswith("keys=["):
|
||||
key = key[6:]
|
||||
# Remove "]" suffix if present
|
||||
if key.endswith("]"):
|
||||
key = key[:-1]
|
||||
# Handle case where string contains representation of list items
|
||||
if key.startswith("['") or key.startswith("[\""):
|
||||
key = key[2:] if len(key) > 2 else key
|
||||
if key.endswith("']") or key.endswith("\"]"):
|
||||
key = key[:-2] if len(key) > 2 else key
|
||||
# Strip any extra whitespace
|
||||
key = key.strip()
|
||||
# Add to cleaned keys
|
||||
cleaned_keys.append(key)
|
||||
else:
|
||||
cleaned_keys.append(key)
|
||||
keys = cleaned_keys
|
||||
|
||||
# Format the keys for hotkey or press command
|
||||
keys_str = ", ".join([f"'{key}'" for key in keys])
|
||||
if len(keys) > 1:
|
||||
pyautogui_code.append(f"pyautogui.hotkey({keys_str})")
|
||||
else:
|
||||
pyautogui_code.append(f"pyautogui.press({keys_str})")
|
||||
|
||||
elif action == "scroll":
|
||||
pixels = args.get("pixels", 0)
|
||||
pyautogui_code.append(f"pyautogui.scroll({pixels})")
|
||||
|
||||
elif action == "wait":
|
||||
pyautogui_code.append("WAIT") # Special code for wait action
|
||||
|
||||
elif action == "terminate":
|
||||
pyautogui_code.append("DONE") # Special code for termination
|
||||
|
||||
elif action == "mouse_move":
|
||||
if "coordinate" in args:
|
||||
x, y = args["coordinate"]
|
||||
adj_x, adj_y = adjust_coordinates(x, y)
|
||||
pyautogui_code.append(f"pyautogui.moveTo({adj_x}, {adj_y})")
|
||||
else:
|
||||
pyautogui_code.append("pyautogui.moveTo(0, 0)")
|
||||
|
||||
elif action == "left_click_drag":
|
||||
if "coordinate" in args:
|
||||
x, y = args["coordinate"]
|
||||
adj_x, adj_y = adjust_coordinates(x, y)
|
||||
duration = args.get("duration", 0.5)
|
||||
pyautogui_code.append(f"pyautogui.dragTo({adj_x}, {adj_y}, duration={duration})")
|
||||
else:
|
||||
pyautogui_code.append("pyautogui.dragTo(0, 0)")
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
logger.error(f"Failed to parse tool call: {e}")
|
||||
|
||||
# Parse the response line by line
|
||||
lines = response.split('\n')
|
||||
inside_tool_call = False
|
||||
current_tool_call = []
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# Extract low-level instruction from lines starting with "Action:" or similar
|
||||
if line.lower().startswith(("action:", "step", "i will", "i'll", "now i")):
|
||||
if not low_level_instruction:
|
||||
# Only store the first action description as low level instruction
|
||||
low_level_instruction = line
|
||||
continue
|
||||
|
||||
# Handle lines inside tool call markers
|
||||
if line.startswith("<tool_call>"):
|
||||
inside_tool_call = True
|
||||
continue
|
||||
elif line.startswith("</tool_call>"):
|
||||
if current_tool_call:
|
||||
# Process the collected tool call
|
||||
process_tool_call("\n".join(current_tool_call))
|
||||
current_tool_call = []
|
||||
inside_tool_call = False
|
||||
continue
|
||||
|
||||
if inside_tool_call:
|
||||
current_tool_call.append(line)
|
||||
continue
|
||||
|
||||
# Try to parse individual lines as JSON
|
||||
if line.startswith("{") and line.endswith("}"):
|
||||
try:
|
||||
json_obj = json.loads(line)
|
||||
if "name" in json_obj and "arguments" in json_obj:
|
||||
process_tool_call(line)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Process any remaining tool call content
|
||||
if current_tool_call:
|
||||
process_tool_call("\n".join(current_tool_call))
|
||||
|
||||
# If we still don't have a low-level instruction, generate a default one
|
||||
if not low_level_instruction and len(pyautogui_code) > 0:
|
||||
action_type = pyautogui_code[0].split(".", 1)[1].split("(", 1)[0]
|
||||
low_level_instruction = f"Performing {action_type} action"
|
||||
|
||||
return low_level_instruction, pyautogui_code
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.constant,
|
||||
# here you should add more model exceptions as you want,
|
||||
# but you are forbidden to add "Exception", that is, a common type of exception
|
||||
# because we want to catch this kind of Exception in the outside to ensure
|
||||
# each example won't exceed the time limit
|
||||
(
|
||||
# General exceptions
|
||||
SSLError,
|
||||
# OpenAI exceptions
|
||||
openai.RateLimitError,
|
||||
openai.BadRequestError,
|
||||
openai.InternalServerError,
|
||||
# Google exceptions
|
||||
InvalidArgument,
|
||||
ResourceExhausted,
|
||||
InternalServerError,
|
||||
BadRequest,
|
||||
# Groq exceptions
|
||||
# todo: check
|
||||
),
|
||||
interval=30,
|
||||
max_tries=10,
|
||||
)
|
||||
def call_llm(self, payload, model):
|
||||
messages = payload["messages"]
|
||||
base_url = "your_base_url"
|
||||
api_key = "your_api_key"
|
||||
|
||||
client = openai.OpenAI(
|
||||
base_url=base_url,
|
||||
api_key=api_key
|
||||
)
|
||||
|
||||
for _ in range(MAX_RETRY_TIMES):
|
||||
logger.info("Generating content with Qwen model: %s", model)
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
max_tokens=self.max_tokens,
|
||||
temperature=self.temperature,
|
||||
top_p=self.top_p
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling Qwen model: {e}")
|
||||
time.sleep(5)
|
||||
continue
|
||||
return ""
|
||||
|
||||
def reset(self, _logger=None):
|
||||
global logger
|
||||
logger = (_logger if _logger is not None else
|
||||
logging.getLogger("desktopenv.qwen25vl_agent"))
|
||||
|
||||
self.thoughts = []
|
||||
self.action_descriptions = []
|
||||
self.actions = []
|
||||
self.observations = []
|
||||
self.responses = [] # Reset responses
|
||||
self.screenshots = [] # Reset screenshots
|
||||
|
|
@ -0,0 +1,271 @@
|
|||
import math
|
||||
|
||||
|
||||
def round_by_factor(number: int, factor: int) -> int:
|
||||
"""返回最接近 number 的且能被 factor 整除的整数"""
|
||||
return round(number / factor) * factor
|
||||
|
||||
|
||||
def ceil_by_factor(number: int, factor: int) -> int:
|
||||
"""返回大于等于 number 的且能被 factor 整除的整数"""
|
||||
return math.ceil(number / factor) * factor
|
||||
|
||||
|
||||
def floor_by_factor(number: int, factor: int) -> int:
|
||||
"""返回小于等于 number 的且能被 factor 整除的整数"""
|
||||
return math.floor(number / factor) * factor
|
||||
|
||||
|
||||
def smart_resize(height, width, factor=28, min_pixels=56 * 56, max_pixels=14 * 14 * 4 * 1280, max_long_side=8192):
|
||||
"""缩放后图片满足以下条件:
|
||||
1. 长宽能被 factor 整除
|
||||
2. pixels 总数被限制在 [min_pixels, max_pixels] 内
|
||||
3. 最长边限制在 max_long_side 内
|
||||
4. 保证其长宽比基本不变
|
||||
"""
|
||||
if height < 2 or width < 2:
|
||||
raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
|
||||
elif max(height, width) / min(height, width) > 200:
|
||||
raise ValueError(f"absolute aspect ratio must be smaller than 100, got {height} / {width}")
|
||||
|
||||
if max(height, width) > max_long_side:
|
||||
beta = max(height, width) / max_long_side
|
||||
height, width = int(height / beta), int(width / beta)
|
||||
|
||||
h_bar = round_by_factor(height, factor)
|
||||
w_bar = round_by_factor(width, factor)
|
||||
if h_bar * w_bar > max_pixels:
|
||||
beta = math.sqrt((height * width) / max_pixels)
|
||||
h_bar = floor_by_factor(height / beta, factor)
|
||||
w_bar = floor_by_factor(width / beta, factor)
|
||||
elif h_bar * w_bar < min_pixels:
|
||||
beta = math.sqrt(min_pixels / (height * width))
|
||||
h_bar = ceil_by_factor(height * beta, factor)
|
||||
w_bar = ceil_by_factor(width * beta, factor)
|
||||
return h_bar, w_bar
|
||||
|
||||
|
||||
def update_image_size_(image_ele: dict, min_tokens=1, max_tokens=12800, merge_base=2, patch_size=14):
|
||||
"""根据 min_tokens, max_tokens 更新 image_ele 的尺寸信息
|
||||
|
||||
Args:
|
||||
image_ele (dict):
|
||||
- image_ele["image"]: str 图片路径
|
||||
- image_ele["height"]: int 图片原始高度
|
||||
- image_ele["width"]: int 图片原始宽度
|
||||
|
||||
Returns:
|
||||
更新后的 image_ele, 新增如下 key-value pair
|
||||
dict:
|
||||
- image_ele["resized_height"]: int 输入到模型的真实高度
|
||||
- image_ele["resized_width"]: int 输入到模型的真实宽度
|
||||
- image_ele["seq_len"]: int 输入到模型所占的序列长度
|
||||
"""
|
||||
height, width = image_ele["height"], image_ele["width"]
|
||||
pixels_per_token = patch_size * patch_size * merge_base * merge_base
|
||||
resized_height, resized_width = smart_resize(
|
||||
height,
|
||||
width,
|
||||
factor=merge_base * patch_size,
|
||||
min_pixels=pixels_per_token * min_tokens,
|
||||
max_pixels=pixels_per_token * max_tokens,
|
||||
max_long_side=50000,
|
||||
)
|
||||
image_ele.update(
|
||||
{
|
||||
"resized_height": resized_height,
|
||||
"resized_width": resized_width,
|
||||
"seq_len": resized_height * resized_width // pixels_per_token + 2,
|
||||
}
|
||||
)
|
||||
return image_ele
|
||||
|
||||
|
||||
def _convert_bbox_format_from_abs_origin(bbox, image_ele: dict, *, tgt_format: str):
|
||||
x1, y1, x2, y2 = bbox
|
||||
if tgt_format == "abs_origin":
|
||||
new_bbox = [int(x1), int(y1), int(x2), int(y2)]
|
||||
elif tgt_format == "abs_resized":
|
||||
new_bbox = [
|
||||
int(x1 / image_ele["width"] * image_ele["resized_width"]),
|
||||
int(y1 / image_ele["height"] * image_ele["resized_height"]),
|
||||
int(x2 / image_ele["width"] * image_ele["resized_width"]),
|
||||
int(y2 / image_ele["height"] * image_ele["resized_height"]),
|
||||
]
|
||||
elif tgt_format == "qwen-vl":
|
||||
new_bbox = [
|
||||
int(x1 / image_ele["width"] * 999),
|
||||
int(y1 / image_ele["height"] * 999),
|
||||
int(x2 / image_ele["width"] * 999),
|
||||
int(y2 / image_ele["height"] * 999),
|
||||
]
|
||||
elif tgt_format == "rel":
|
||||
new_bbox = [
|
||||
float(x1 / image_ele["width"]),
|
||||
float(y1 / image_ele["height"]),
|
||||
float(x2 / image_ele["width"]),
|
||||
float(y2 / image_ele["height"]),
|
||||
]
|
||||
elif tgt_format == "molmo":
|
||||
new_bbox = [
|
||||
round(x1 / image_ele["width"] * 100, ndigits=1),
|
||||
round(y1 / image_ele["height"] * 100, ndigits=1),
|
||||
round(x2 / image_ele["width"] * 100, ndigits=1),
|
||||
round(y2 / image_ele["height"] * 100, ndigits=1),
|
||||
]
|
||||
else:
|
||||
assert False, f"Unknown tgt_format: {tgt_format}"
|
||||
return new_bbox
|
||||
|
||||
|
||||
def _convert_bbox_format_to_abs_origin(bbox, image_ele: dict, *, src_format: str):
|
||||
x1, y1, x2, y2 = bbox
|
||||
if src_format == "abs_origin":
|
||||
new_bbox = [int(x1), int(y1), int(x2), int(y2)]
|
||||
elif src_format == "abs_resized":
|
||||
new_bbox = [
|
||||
int(x1 / image_ele["resized_width"] * image_ele["width"]),
|
||||
int(y1 / image_ele["resized_height"] * image_ele["height"]),
|
||||
int(x2 / image_ele["resized_width"] * image_ele["width"]),
|
||||
int(y2 / image_ele["resized_height"] * image_ele["height"]),
|
||||
]
|
||||
elif src_format == "qwen-vl":
|
||||
new_bbox = [
|
||||
int(x1 / 999 * image_ele["width"]),
|
||||
int(y1 / 999 * image_ele["height"]),
|
||||
int(x2 / 999 * image_ele["width"]),
|
||||
int(y2 / 999 * image_ele["height"]),
|
||||
]
|
||||
elif src_format == "rel":
|
||||
new_bbox = [
|
||||
int(x1 * image_ele["width"]),
|
||||
int(y1 * image_ele["height"]),
|
||||
int(x2 * image_ele["width"]),
|
||||
int(y2 * image_ele["height"]),
|
||||
]
|
||||
elif src_format == "molmo":
|
||||
new_bbox = [
|
||||
int(x1 / 100 * image_ele["width"]),
|
||||
int(y1 / 100 * image_ele["height"]),
|
||||
int(x2 / 100 * image_ele["width"]),
|
||||
int(y2 / 100 * image_ele["height"]),
|
||||
]
|
||||
else:
|
||||
assert False, f"Unknown src_format: {src_format}"
|
||||
return new_bbox
|
||||
|
||||
|
||||
def convert_bbox_format(bbox, image_ele: dict, *, src_format: str, tgt_format: str):
|
||||
bbox_abs_origin = _convert_bbox_format_to_abs_origin(bbox, image_ele, src_format=src_format)
|
||||
bbox_tgt_format = _convert_bbox_format_from_abs_origin(bbox_abs_origin, image_ele, tgt_format=tgt_format)
|
||||
return bbox_tgt_format
|
||||
|
||||
|
||||
def _convert_point_format_from_abs_origin(point, image_ele: dict, *, tgt_format: str):
|
||||
x, y = point
|
||||
if tgt_format == "abs_origin":
|
||||
new_point = [int(x), int(y)]
|
||||
elif tgt_format == "abs_resized":
|
||||
new_point = [
|
||||
int(x / image_ele["width"] * image_ele["resized_width"]),
|
||||
int(y / image_ele["height"] * image_ele["resized_height"]),
|
||||
]
|
||||
elif tgt_format == "qwen-vl":
|
||||
new_point = [
|
||||
int(x / image_ele["width"] * 999),
|
||||
int(y / image_ele["height"] * 999),
|
||||
]
|
||||
elif tgt_format == "rel":
|
||||
new_point = [
|
||||
float(x / image_ele["width"]),
|
||||
float(y / image_ele["height"]),
|
||||
]
|
||||
elif tgt_format == "molmo":
|
||||
new_point = [
|
||||
round(x / image_ele["width"] * 100, ndigits=1),
|
||||
round(y / image_ele["height"] * 100, ndigits=1),
|
||||
]
|
||||
else:
|
||||
assert False, f"Unknown tgt_format: {tgt_format}"
|
||||
return new_point
|
||||
|
||||
|
||||
def _convert_point_format_to_abs_origin(point, image_ele: dict, *, src_format: str):
|
||||
x, y = point
|
||||
if src_format == "abs_origin":
|
||||
new_point = [int(x), int(y)]
|
||||
elif src_format == "abs_resized":
|
||||
new_point = [
|
||||
int(x / image_ele["resized_width"] * image_ele["width"]),
|
||||
int(y / image_ele["resized_height"] * image_ele["height"]),
|
||||
]
|
||||
elif src_format == "qwen-vl":
|
||||
new_point = [
|
||||
int(x / 999 * image_ele["width"]),
|
||||
int(y / 999 * image_ele["height"]),
|
||||
]
|
||||
elif src_format == "rel":
|
||||
new_point = [
|
||||
int(x * image_ele["width"]),
|
||||
int(y * image_ele["height"]),
|
||||
]
|
||||
elif src_format == "molmo":
|
||||
new_point = [
|
||||
int(x / 100 * image_ele["width"]),
|
||||
int(y / 100 * image_ele["height"]),
|
||||
]
|
||||
else:
|
||||
assert False, f"Unknown src_format: {src_format}"
|
||||
return new_point
|
||||
|
||||
|
||||
def convert_point_format(point, image_ele: dict, *, src_format: str, tgt_format: str):
|
||||
point_abs_origin = _convert_point_format_to_abs_origin(point, image_ele, src_format=src_format)
|
||||
point_tgt_format = _convert_point_format_from_abs_origin(point_abs_origin, image_ele, tgt_format=tgt_format)
|
||||
return point_tgt_format
|
||||
|
||||
|
||||
__all__ = [
|
||||
"update_image_size_",
|
||||
"convert_bbox_format",
|
||||
"convert_point_format",
|
||||
]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from PIL import Image
|
||||
|
||||
def draw_point(image: Image.Image, point: list):
|
||||
from copy import deepcopy
|
||||
|
||||
from PIL import ImageDraw
|
||||
|
||||
image = deepcopy(image)
|
||||
image_draw = ImageDraw.Draw(image)
|
||||
image_draw.ellipse([point[0] - 5, point[1] - 5, point[0] + 5, point[1] + 5], fill="red")
|
||||
return image
|
||||
|
||||
# image_ele = {
|
||||
# "image": "http://ofasys-multimodal-wlcb-3.oss-cn-wulanchabu.aliyuncs.com/data/datacomp1b/image/19774238/7218d7ceb39e82e0cafc389f326e218da623a8f2.jpg",
|
||||
# "height": 444,
|
||||
# "width": 592,
|
||||
# }
|
||||
image_ele = {
|
||||
"image": "46d5402b2c183f996f2a13cd2016af15.png",
|
||||
"height": 1080,
|
||||
"width": 1920,
|
||||
}
|
||||
point = [0.8379917184, 0.2087912088] # rel, keyboard 'k' in the image
|
||||
|
||||
# image: Image.Image = Image.open(requests.get(image_ele["image"], stream=True).raw)
|
||||
image: Image.Image = Image.open(image_ele["image"])
|
||||
assert image.width == image_ele["width"] and image.height == image_ele["height"], f"{image.size=}, {image_ele=}"
|
||||
resized_image = image.resize((image_ele["resized_width"], image_ele["resized_height"]))
|
||||
draw_point(image, [point[0] * image.width, point[1] * image.height]).save("image_1.png")
|
||||
|
||||
image_ele = update_image_size_(image_ele)
|
||||
point = convert_point_format(point, image_ele, src_format="rel", tgt_format="abs_resized")
|
||||
print(f"{image_ele=}\n{point=}")
|
||||
|
||||
|
||||
draw_point(resized_image, point).save("image_2.png")
|
||||
|
|
@ -0,0 +1,362 @@
|
|||
"""Script to run end-to-end evaluation on the benchmark.
|
||||
Utils and basic architecture credit to https://github.com/web-arena-x/webarena/blob/main/run.py.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Dict
|
||||
import math
|
||||
from tqdm import tqdm
|
||||
from multiprocessing import Process, Manager
|
||||
import lib_run_single
|
||||
from desktop_env.desktop_env import DesktopEnv
|
||||
from mm_agents.qwen25vl_agent import Qwen25VLAgent
|
||||
|
||||
# import wandb
|
||||
|
||||
|
||||
# Logger Configs {{{ #
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
|
||||
file_handler = logging.FileHandler(
|
||||
os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||
)
|
||||
debug_handler = logging.FileHandler(
|
||||
os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||
)
|
||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||
sdebug_handler = logging.FileHandler(
|
||||
os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||
)
|
||||
|
||||
file_handler.setLevel(logging.INFO)
|
||||
debug_handler.setLevel(logging.DEBUG)
|
||||
stdout_handler.setLevel(logging.INFO)
|
||||
sdebug_handler.setLevel(logging.DEBUG)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
|
||||
)
|
||||
file_handler.setFormatter(formatter)
|
||||
debug_handler.setFormatter(formatter)
|
||||
stdout_handler.setFormatter(formatter)
|
||||
sdebug_handler.setFormatter(formatter)
|
||||
|
||||
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
||||
sdebug_handler.addFilter(logging.Filter("desktopenv"))
|
||||
|
||||
logger.addHandler(file_handler)
|
||||
logger.addHandler(debug_handler)
|
||||
logger.addHandler(stdout_handler)
|
||||
logger.addHandler(sdebug_handler)
|
||||
# }}} Logger Configs #
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
|
||||
def config() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run end-to-end evaluation on the benchmark"
|
||||
)
|
||||
|
||||
# environment config
|
||||
parser.add_argument("--path_to_vm", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--headless", action="store_true", help="Run in headless machine"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--action_space", type=str, default="pyautogui", help="Action type"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--observation_type",
|
||||
choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
default="screenshot",
|
||||
help="Observation type",
|
||||
)
|
||||
parser.add_argument("--screen_width", type=int, default=1920)
|
||||
parser.add_argument("--screen_height", type=int, default=1080)
|
||||
parser.add_argument("--sleep_after_execution", type=float, default=2.0)
|
||||
parser.add_argument("--max_steps", type=int, default=20)
|
||||
|
||||
# agent config
|
||||
parser.add_argument(
|
||||
"--test_config_base_dir", type=str, default="evaluation_examples"
|
||||
)
|
||||
|
||||
# lm config
|
||||
parser.add_argument("--planner_model", type=str, default=None)
|
||||
parser.add_argument("--executor_model", type=str, default="aguvis-s1-s2-agentnet0105-mo5")
|
||||
parser.add_argument("--temperature", type=float, default=0)
|
||||
parser.add_argument("--top_p", type=float, default=0.9)
|
||||
parser.add_argument("--max_tokens", type=int, default=1500)
|
||||
parser.add_argument("--stop_token", type=str, default=None)
|
||||
|
||||
# example config
|
||||
parser.add_argument("--domain", type=str, default="all")
|
||||
parser.add_argument(
|
||||
"--test_all_meta_path", type=str, default="evaluation_examples/test_all.json"
|
||||
)
|
||||
|
||||
# logging related
|
||||
parser.add_argument("--result_dir", type=str, default="./results")
|
||||
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel")
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def distribute_tasks(test_all_meta: dict, num_envs: int) -> List[Dict]:
|
||||
"""Distribute tasks evenly across environments."""
|
||||
# Flatten the tasks into a single list
|
||||
all_tasks = []
|
||||
for domain, examples in test_all_meta.items():
|
||||
for example_id in examples:
|
||||
all_tasks.append((domain, example_id))
|
||||
|
||||
# Calculate tasks per environment
|
||||
tasks_per_env = math.ceil(len(all_tasks) / num_envs)
|
||||
|
||||
# Distribute tasks
|
||||
distributed_tasks = []
|
||||
for i in range(num_envs):
|
||||
env_tasks = {}
|
||||
start_idx = i * tasks_per_env
|
||||
end_idx = min((i + 1) * tasks_per_env, len(all_tasks))
|
||||
|
||||
for domain, example_id in all_tasks[start_idx:end_idx]:
|
||||
if domain not in env_tasks:
|
||||
env_tasks[domain] = []
|
||||
env_tasks[domain].append(example_id)
|
||||
|
||||
distributed_tasks.append(env_tasks)
|
||||
|
||||
return distributed_tasks
|
||||
|
||||
|
||||
|
||||
def run_env_tasks(env_idx: int, env: DesktopEnv, agent, env_tasks: dict, args: argparse.Namespace, shared_scores: list):
|
||||
"""Run tasks for a single environment."""
|
||||
logger.info(f"Executing tasks in environment {env_idx + 1}/{args.num_envs}")
|
||||
|
||||
for domain in tqdm(env_tasks, desc=f"Env{env_idx+1}-Domain"):
|
||||
for example_id in tqdm(env_tasks[domain], desc="Example", leave=False):
|
||||
config_file = os.path.join(
|
||||
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
|
||||
)
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
example = json.load(f)
|
||||
|
||||
logger.info(f"[Env {env_idx+1}][Domain]: {domain}")
|
||||
logger.info(f"[Env {env_idx+1}][Example ID]: {example_id}")
|
||||
logger.info(f"[Env {env_idx+1}][Instruction]: {example['instruction']}")
|
||||
|
||||
example_result_dir = os.path.join(
|
||||
args.result_dir,
|
||||
args.action_space,
|
||||
args.observation_type,
|
||||
"planner-" + str(args.planner_model) + "-executor-" + str(args.executor_model),
|
||||
domain,
|
||||
example_id,
|
||||
)
|
||||
os.makedirs(example_result_dir, exist_ok=True)
|
||||
|
||||
try:
|
||||
lib_run_single.run_single_example(
|
||||
agent,
|
||||
env,
|
||||
example,
|
||||
args.max_steps,
|
||||
example["instruction"],
|
||||
args,
|
||||
example_result_dir,
|
||||
shared_scores,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Exception in Env{env_idx+1} {domain}/{example_id}: {e}")
|
||||
env.controller.end_recording(
|
||||
os.path.join(example_result_dir, "recording.mp4")
|
||||
)
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||
f.write(
|
||||
json.dumps(
|
||||
{"Error": f"Time limit exceeded in {domain}/{example_id}"}
|
||||
)
|
||||
)
|
||||
f.write("\n")
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||
logger.info("Args: %s", args)
|
||||
|
||||
distributed_tasks = distribute_tasks(test_all_meta, args.num_envs)
|
||||
|
||||
# First, set up all environments
|
||||
logger.info("Setting up all environments...")
|
||||
envs = []
|
||||
agents = []
|
||||
|
||||
for env_idx in range(args.num_envs):
|
||||
logger.info(f"Setting up environment {env_idx + 1}/{args.num_envs}")
|
||||
|
||||
agent = Qwen25VLAgent(
|
||||
planner_model=args.planner_model,
|
||||
executor_model=args.executor_model,
|
||||
max_tokens=args.max_tokens,
|
||||
top_p=args.top_p,
|
||||
temperature=args.temperature,
|
||||
action_space=args.action_space,
|
||||
)
|
||||
agents.append(agent)
|
||||
|
||||
env = DesktopEnv(
|
||||
path_to_vm=args.path_to_vm,
|
||||
action_space=agent.action_space,
|
||||
screen_size=(args.screen_width, args.screen_height),
|
||||
headless=args.headless,
|
||||
os_type="Ubuntu",
|
||||
require_a11y_tree=args.observation_type
|
||||
in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
provider_name="docker"
|
||||
)
|
||||
envs.append(env)
|
||||
|
||||
logger.info("All environments are ready. Starting parallel task execution...")
|
||||
|
||||
# Create a shared list for scores across processes
|
||||
with Manager() as manager:
|
||||
shared_scores = manager.list()
|
||||
|
||||
# Create and start processes for each environment
|
||||
processes = []
|
||||
for env_idx, (env, agent, env_tasks) in enumerate(zip(envs, agents, distributed_tasks)):
|
||||
p = Process(
|
||||
target=run_env_tasks,
|
||||
args=(env_idx, env, agent, env_tasks, args, shared_scores)
|
||||
)
|
||||
processes.append(p)
|
||||
p.start()
|
||||
|
||||
# Wait for all processes to complete
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
# Convert shared list to regular list
|
||||
scores = list(shared_scores)
|
||||
|
||||
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
|
||||
|
||||
|
||||
def get_unfinished(
|
||||
action_space, use_model, observation_type, result_dir, total_file_json
|
||||
):
|
||||
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||
|
||||
if not os.path.exists(target_dir):
|
||||
return total_file_json
|
||||
|
||||
finished = {}
|
||||
for domain in os.listdir(target_dir):
|
||||
finished[domain] = []
|
||||
domain_path = os.path.join(target_dir, domain)
|
||||
if os.path.isdir(domain_path):
|
||||
for example_id in os.listdir(domain_path):
|
||||
if example_id == "onboard":
|
||||
continue
|
||||
example_path = os.path.join(domain_path, example_id)
|
||||
if os.path.isdir(example_path):
|
||||
if "result.txt" not in os.listdir(example_path):
|
||||
# empty all files under example_id
|
||||
for file in os.listdir(example_path):
|
||||
os.remove(os.path.join(example_path, file))
|
||||
else:
|
||||
finished[domain].append(example_id)
|
||||
|
||||
if not finished:
|
||||
return total_file_json
|
||||
|
||||
for domain, examples in finished.items():
|
||||
if domain in total_file_json:
|
||||
total_file_json[domain] = [
|
||||
x for x in total_file_json[domain] if x not in examples
|
||||
]
|
||||
|
||||
return total_file_json
|
||||
|
||||
|
||||
def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
|
||||
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||
if not os.path.exists(target_dir):
|
||||
print("New experiment, no result yet.")
|
||||
return None
|
||||
|
||||
all_result = []
|
||||
|
||||
for domain in os.listdir(target_dir):
|
||||
domain_path = os.path.join(target_dir, domain)
|
||||
if os.path.isdir(domain_path):
|
||||
for example_id in os.listdir(domain_path):
|
||||
example_path = os.path.join(domain_path, example_id)
|
||||
if os.path.isdir(example_path):
|
||||
if "result.txt" in os.listdir(example_path):
|
||||
# empty all files under example_id
|
||||
try:
|
||||
all_result.append(
|
||||
float(
|
||||
open(
|
||||
os.path.join(example_path, "result.txt"), "r"
|
||||
).read()
|
||||
)
|
||||
)
|
||||
except:
|
||||
all_result.append(0.0)
|
||||
|
||||
if not all_result:
|
||||
print("New experiment, no result yet.")
|
||||
return None
|
||||
else:
|
||||
print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%")
|
||||
return all_result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
####### The complete version of the list of examples #######
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
args = config()
|
||||
|
||||
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
|
||||
test_all_meta = json.load(f)
|
||||
|
||||
if args.domain != "all":
|
||||
test_all_meta = {args.domain: test_all_meta[args.domain]}
|
||||
|
||||
exp_name = "planner-" + str(args.planner_model) + "-executor-" + str(args.executor_model)
|
||||
|
||||
test_file_list = get_unfinished(
|
||||
args.action_space,
|
||||
exp_name,
|
||||
args.observation_type,
|
||||
args.result_dir,
|
||||
test_all_meta,
|
||||
)
|
||||
left_info = ""
|
||||
for domain in test_file_list:
|
||||
left_info += f"{domain}: {len(test_file_list[domain])}\n"
|
||||
logger.info(f"Left tasks:\n{left_info}")
|
||||
|
||||
get_result(
|
||||
args.action_space,
|
||||
exp_name,
|
||||
args.observation_type,
|
||||
args.result_dir,
|
||||
test_all_meta,
|
||||
)
|
||||
test(args, test_file_list)
|
||||
Loading…
Reference in New Issue