Compare commits
27 Commits
main
...
wxy/opencu
| Author | SHA1 | Date |
|---|---|---|
|
|
5b44be1c55 | |
|
|
e993663b5b | |
|
|
84d98d2c9d | |
|
|
de3411e56c | |
|
|
462e79c9d1 | |
|
|
c38264c971 | |
|
|
abf267eb11 | |
|
|
acf08a15d1 | |
|
|
953d5028ea | |
|
|
66def2c7a0 | |
|
|
e5a2398549 | |
|
|
4db72ec960 | |
|
|
a68c981777 | |
|
|
e40671e53b | |
|
|
aba043b9e8 | |
|
|
80aad6c2d5 | |
|
|
923b612a6d | |
|
|
6a5e119918 | |
|
|
710201d03a | |
|
|
185fbe1398 | |
|
|
e8508e8e3b | |
|
|
51af29354b | |
|
|
73dc19c1ce | |
|
|
462e6caea0 | |
|
|
8c024e4910 | |
|
|
30c8738be9 | |
|
|
f5c4563c8e |
|
|
@ -197,18 +197,4 @@ vmware_vm_data
|
|||
|
||||
.vscode
|
||||
|
||||
dataimpulse_proxy_config.json
|
||||
|
||||
## reference and draft and debug
|
||||
reference/
|
||||
draft/
|
||||
manual_examine.py
|
||||
run_human_examine.sh
|
||||
quick_start.py
|
||||
result_multi_apps_pengxiang_transformers12evaluation_examples/settings/proxy/dataimpulse.json
|
||||
evaluation_examples/settings/proxy/dataimpulse.json
|
||||
|
||||
# Local test configurations (not for public repo)
|
||||
evaluation_examples/spiderman.json
|
||||
evaluation_examples/test_50_random_proportional.json
|
||||
evaluation_examples/test_chrome.json
|
||||
dataimpulse_proxy_config.json
|
||||
54
README.md
54
README.md
|
|
@ -106,14 +106,45 @@ We are working on supporting more 👷. Please hold tight!
|
|||
## 🚀 Quick Start
|
||||
Run the following minimal example to interact with the environment:
|
||||
|
||||
```bash
|
||||
# Basic usage with default settings
|
||||
python quickstart.py
|
||||
```python
|
||||
from desktop_env.desktop_env import DesktopEnv
|
||||
|
||||
# Customize provider and VM path
|
||||
python quickstart.py --provider_name vmware --path_to_vm "path/to/your/vm.vmx"
|
||||
example = {
|
||||
"id": "94d95f96-9699-4208-98ba-3c3119edf9c2",
|
||||
"instruction": "I want to install Spotify on my current system. Could you please help me?",
|
||||
"config": [
|
||||
{
|
||||
"type": "execute",
|
||||
"parameters": {
|
||||
"command": [
|
||||
"python",
|
||||
"-c",
|
||||
"import pyautogui; import time; pyautogui.click(960, 540); time.sleep(0.5);"
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"evaluator": {
|
||||
"func": "check_include_exclude",
|
||||
"result": {
|
||||
"type": "vm_command_line",
|
||||
"command": "which spotify"
|
||||
},
|
||||
"expected": {
|
||||
"type": "rule",
|
||||
"rules": {
|
||||
"include": ["spotify"],
|
||||
"exclude": ["not found"]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
env = DesktopEnv(action_space="pyautogui")
|
||||
|
||||
obs = env.reset(task_config=example)
|
||||
obs, reward, done, info = env.step("pyautogui.rightClick()")
|
||||
```
|
||||
|
||||
You will see all the logs of the system running normally, including the successful creation of the environment, completion of setup, and successful execution of actions. In the end, you will observe a successful right-click on the screen, which means you are ready to go.
|
||||
|
||||
## 🧪 Experiments
|
||||
|
|
@ -221,14 +252,3 @@ If you find this environment useful, please consider citing our work:
|
|||
primaryClass={cs.AI}
|
||||
}
|
||||
```
|
||||
|
||||
## Acknowledgement for OSWorld-Verified
|
||||
Special thanks to the following institutions that provided feedback and participated in the fixes (as well as institutions that provided feedback during the process): [MoonShot AI, a.k.a. Kimi](https://www.moonshot.ai/),[Human Data](https://www.hud.so/), [OpenAI](https://openai.com/), [ByteDance Seed TARS](https://seed-tars.com/), [Anthropic](https://www.anthropic.com/), [Simular](https://www.simular.ai/), [HKU Data Intelligence Lab](https://sites.google.com/view/chaoh)
|
||||
|
||||
Special thanks to the following students who participated in the specific fixes: [Mengqi Yuan](https://yuanmengqi.github.io/), [Danyang Zhang](https://zdy023.github.io/), [Xinzhuang Xiong](https://thisisxxz.com/), [Zhennan Shen](https://scholar.google.com/citations?user=JPwg5MwAAAAJ&hl=en), [Zilong Zhou](https://github.com/adlsdztony), Yanxu Chen, [Jiaqi Deng](https://millank0817.github.io/), [Tianbao Xie](https://tianbaoxie.com/), Junda Chen, [Jixuan Chen](https://chenjix.github.io/), [Haoyuan Wu](https://www.linkedin.com/in/haoyuan-wu-240878291/).
|
||||
|
||||
Special thanks to the following students who participated in running the re-evaluation: [Mengqi Yuan](https://yuanmengqi.github.io/), [Zilong Zhou](https://github.com/adlsdztony), [Xinyuan Wang](https://xinyuanwangcs.github.io/), [Bowen Wang](https://bowenbryanwang.github.io/).
|
||||
|
||||
## You might also be interested
|
||||
|
||||
- **OSWorld-MCP**: Benchmarking MCP Tool Invocation in Computer-Use Agents. [Website](https://osworld-mcp.github.io/)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ import logging
|
|||
import random
|
||||
from typing import Any, Dict, Optional
|
||||
import time
|
||||
import traceback
|
||||
import requests
|
||||
|
||||
from desktop_env.actions import KEYBOARD_KEYS
|
||||
|
|
@ -21,41 +20,17 @@ class PythonController:
|
|||
self.retry_times = 3
|
||||
self.retry_interval = 5
|
||||
|
||||
@staticmethod
|
||||
def _is_valid_image_response(content_type: str, data: Optional[bytes]) -> bool:
|
||||
"""Quick validation for PNG/JPEG payload using magic bytes; Content-Type is advisory.
|
||||
Returns True only when bytes look like a real PNG or JPEG.
|
||||
"""
|
||||
if not isinstance(data, (bytes, bytearray)) or not data:
|
||||
return False
|
||||
# PNG magic
|
||||
if len(data) >= 8 and data[:8] == b"\x89PNG\r\n\x1a\n":
|
||||
return True
|
||||
# JPEG magic
|
||||
if len(data) >= 3 and data[:3] == b"\xff\xd8\xff":
|
||||
return True
|
||||
# If server explicitly marks as image, accept as a weak fallback (some environments strip magic)
|
||||
if content_type and ("image/png" in content_type or "image/jpeg" in content_type or "image/jpg" in content_type):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_screenshot(self) -> Optional[bytes]:
|
||||
"""
|
||||
Gets a screenshot from the server. With the cursor. None -> no screenshot or unexpected error.
|
||||
"""
|
||||
|
||||
for attempt_idx in range(self.retry_times):
|
||||
for _ in range(self.retry_times):
|
||||
try:
|
||||
response = requests.get(self.http_server + "/screenshot", timeout=10)
|
||||
response = requests.get(self.http_server + "/screenshot")
|
||||
if response.status_code == 200:
|
||||
content_type = response.headers.get("Content-Type", "")
|
||||
content = response.content
|
||||
if self._is_valid_image_response(content_type, content):
|
||||
logger.info("Got screenshot successfully")
|
||||
return content
|
||||
else:
|
||||
logger.error("Invalid screenshot payload (attempt %d/%d).", attempt_idx + 1, self.retry_times)
|
||||
logger.info("Retrying to get screenshot.")
|
||||
logger.info("Got screenshot successfully")
|
||||
return response.content
|
||||
else:
|
||||
logger.error("Failed to get screenshot. Status code: %d", response.status_code)
|
||||
logger.info("Retrying to get screenshot.")
|
||||
|
|
@ -161,94 +136,13 @@ class PythonController:
|
|||
|
||||
logger.error("Failed to execute command.")
|
||||
return None
|
||||
|
||||
def run_python_script(self, script: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Executes a python script on the server.
|
||||
"""
|
||||
payload = json.dumps({"code": script})
|
||||
|
||||
for _ in range(self.retry_times):
|
||||
try:
|
||||
response = requests.post(self.http_server + "/run_python", headers={'Content-Type': 'application/json'},
|
||||
data=payload, timeout=90)
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
return {"status": "error", "message": "Failed to execute command.", "output": None, "error": response.json()["error"]}
|
||||
except requests.exceptions.ReadTimeout:
|
||||
break
|
||||
except Exception:
|
||||
logger.error("An error occurred while trying to execute the command: %s", traceback.format_exc())
|
||||
logger.info("Retrying to execute command.")
|
||||
time.sleep(self.retry_interval)
|
||||
|
||||
logger.error("Failed to execute command.")
|
||||
return {"status": "error", "message": "Failed to execute command.", "output": "", "error": "Retry limit reached."}
|
||||
|
||||
def run_bash_script(self, script: str, timeout: int = 30, working_dir: Optional[str] = None) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Executes a bash script on the server.
|
||||
|
||||
:param script: The bash script content (can be multi-line)
|
||||
:param timeout: Execution timeout in seconds (default: 30)
|
||||
:param working_dir: Working directory for script execution (optional)
|
||||
:return: Dictionary with status, output, error, and returncode, or None if failed
|
||||
"""
|
||||
payload = json.dumps({
|
||||
"script": script,
|
||||
"timeout": timeout,
|
||||
"working_dir": working_dir
|
||||
})
|
||||
|
||||
for _ in range(self.retry_times):
|
||||
try:
|
||||
response = requests.post(
|
||||
self.http_server + "/run_bash_script",
|
||||
headers={'Content-Type': 'application/json'},
|
||||
data=payload,
|
||||
timeout=timeout + 100 # Add buffer to HTTP timeout
|
||||
)
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
logger.info("Bash script executed successfully with return code: %d", result.get("returncode", -1))
|
||||
return result
|
||||
else:
|
||||
logger.error("Failed to execute bash script. Status code: %d, response: %s",
|
||||
response.status_code, response.text)
|
||||
logger.info("Retrying to execute bash script.")
|
||||
except requests.exceptions.ReadTimeout:
|
||||
logger.error("Bash script execution timed out")
|
||||
return {
|
||||
"status": "error",
|
||||
"output": "",
|
||||
"error": f"Script execution timed out after {timeout} seconds",
|
||||
"returncode": -1
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error("An error occurred while trying to execute the bash script: %s", e)
|
||||
logger.info("Retrying to execute bash script.")
|
||||
time.sleep(self.retry_interval)
|
||||
|
||||
logger.error("Failed to execute bash script after %d retries.", self.retry_times)
|
||||
return {
|
||||
"status": "error",
|
||||
"output": "",
|
||||
"error": f"Failed to execute bash script after {self.retry_times} retries",
|
||||
"returncode": -1
|
||||
}
|
||||
|
||||
def execute_action(self, action):
|
||||
def execute_action(self, action: Dict[str, Any]):
|
||||
"""
|
||||
Executes an action on the server computer.
|
||||
"""
|
||||
# Handle string actions
|
||||
if action in ['WAIT', 'FAIL', 'DONE']:
|
||||
return
|
||||
|
||||
# Handle dictionary actions
|
||||
if type(action) == dict and action.get('action_type') in ['WAIT', 'FAIL', 'DONE']:
|
||||
return
|
||||
|
||||
action_type = action["action_type"]
|
||||
parameters = action["parameters"] if "parameters" in action else {param: action[param] for param in action if param != 'action_type'}
|
||||
|
|
|
|||
|
|
@ -199,62 +199,26 @@ class SetupController:
|
|||
path: str = f["path"]
|
||||
|
||||
if not os.path.exists(local_path):
|
||||
raise Exception(f"Setup Upload - Invalid local path ({local_path}).")
|
||||
logger.error(f"Setup Upload - Invalid local path ({local_path}).")
|
||||
return
|
||||
|
||||
file_size = None
|
||||
form = MultipartEncoder({
|
||||
"file_path": path,
|
||||
"file_data": (os.path.basename(path), open(local_path, "rb"))
|
||||
})
|
||||
headers = {"Content-Type": form.content_type}
|
||||
logger.debug(form.content_type)
|
||||
|
||||
# send request to server to upload file
|
||||
try:
|
||||
file_size = os.path.getsize(local_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
max_retries = 3
|
||||
last_error: Optional[Exception] = None
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
logger.info(
|
||||
f"Uploading {os.path.basename(local_path)}{f' ({file_size} bytes)' if file_size is not None else ''} "
|
||||
f"to VM at {path} (attempt {attempt + 1}/{max_retries})"
|
||||
)
|
||||
logger.debug("REQUEST ADDRESS: %s", self.http_server + "/setup" + "/upload")
|
||||
|
||||
# Open the file inside each attempt to ensure fresh stream position
|
||||
with open(local_path, "rb") as fp:
|
||||
form = MultipartEncoder({
|
||||
"file_path": path,
|
||||
"file_data": (os.path.basename(path), fp)
|
||||
})
|
||||
headers = {"Content-Type": form.content_type}
|
||||
logger.debug(form.content_type)
|
||||
|
||||
# Explicit connect/read timeout to avoid hanging forever
|
||||
response = requests.post(
|
||||
self.http_server + "/setup" + "/upload",
|
||||
headers=headers,
|
||||
data=form,
|
||||
timeout=(10, 600)
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
logger.info(f"File uploaded successfully: {path}")
|
||||
logger.debug("Upload response: %s", response.text)
|
||||
last_error = None
|
||||
break
|
||||
else:
|
||||
msg = f"Failed to upload file {path}. Status code: {response.status_code}, Response: {response.text}"
|
||||
logger.error(msg)
|
||||
last_error = requests.RequestException(msg)
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
last_error = e
|
||||
logger.error(f"Upload attempt {attempt + 1} failed for {path}: {e}")
|
||||
|
||||
# Exponential backoff between retries
|
||||
if attempt < max_retries - 1:
|
||||
time.sleep(2 ** attempt)
|
||||
|
||||
if last_error is not None:
|
||||
raise last_error
|
||||
logger.debug("REQUEST ADDRESS: %s", self.http_server + "/setup" + "/upload")
|
||||
response = requests.post(self.http_server + "/setup" + "/upload", headers=headers, data=form)
|
||||
if response.status_code == 200:
|
||||
logger.info("Command executed successfully: %s", response.text)
|
||||
else:
|
||||
logger.error("Failed to upload file. Status code: %s", response.text)
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error("An error occurred while trying to send the request: %s", e)
|
||||
|
||||
def _change_wallpaper_setup(self, path: str):
|
||||
if not path:
|
||||
|
|
@ -813,108 +777,106 @@ class SetupController:
|
|||
|
||||
def _update_browse_history_setup(self, **config):
|
||||
cache_path = os.path.join(self.cache_dir, "history_new.sqlite")
|
||||
db_url = "https://huggingface.co/datasets/xlangai/ubuntu_osworld_file_cache/resolve/main/chrome/44ee5668-ecd5-4366-a6ce-c1c9b8d4e938/history_empty.sqlite?download=true"
|
||||
db_url = "https://drive.usercontent.google.com/u/0/uc?id=1Lv74QkJYDWVX0RIgg0Co-DUcoYpVL0oX&export=download" # google drive
|
||||
if not os.path.exists(cache_path):
|
||||
max_retries = 3
|
||||
downloaded = False
|
||||
e = None
|
||||
for i in range(max_retries):
|
||||
try:
|
||||
response = requests.get(db_url, stream=True)
|
||||
response.raise_for_status()
|
||||
max_retries = 3
|
||||
downloaded = False
|
||||
e = None
|
||||
for i in range(max_retries):
|
||||
try:
|
||||
response = requests.get(db_url, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
with open(cache_path, 'wb') as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
logger.info("File downloaded successfully")
|
||||
downloaded = True
|
||||
break
|
||||
with open(cache_path, 'wb') as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
logger.info("File downloaded successfully")
|
||||
downloaded = True
|
||||
break
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error(
|
||||
f"Failed to download {db_url} caused by {e}. Retrying... ({max_retries - i - 1} attempts left)")
|
||||
if not downloaded:
|
||||
raise requests.RequestException(f"Failed to download {db_url}. No retries left. Error: {e}")
|
||||
except requests.RequestException as e:
|
||||
logger.error(
|
||||
f"Failed to download {db_url} caused by {e}. Retrying... ({max_retries - i - 1} attempts left)")
|
||||
if not downloaded:
|
||||
raise requests.RequestException(f"Failed to download {db_url}. No retries left. Error: {e}")
|
||||
else:
|
||||
logger.info("File already exists in cache directory")
|
||||
# copy a new history file in the tmp folder
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
db_path = os.path.join(tmp_dir, "history_empty.sqlite")
|
||||
shutil.copy(cache_path, db_path)
|
||||
db_path = cache_path
|
||||
|
||||
history = config['history']
|
||||
history = config['history']
|
||||
|
||||
for history_item in history:
|
||||
url = history_item['url']
|
||||
title = history_item['title']
|
||||
visit_time = datetime.now() - timedelta(seconds=history_item['visit_time_from_now_in_seconds'])
|
||||
for history_item in history:
|
||||
url = history_item['url']
|
||||
title = history_item['title']
|
||||
visit_time = datetime.now() - timedelta(seconds=history_item['visit_time_from_now_in_seconds'])
|
||||
|
||||
# Chrome use ms from 1601-01-01 as timestamp
|
||||
epoch_start = datetime(1601, 1, 1)
|
||||
chrome_timestamp = int((visit_time - epoch_start).total_seconds() * 1000000)
|
||||
# Chrome use ms from 1601-01-01 as timestamp
|
||||
epoch_start = datetime(1601, 1, 1)
|
||||
chrome_timestamp = int((visit_time - epoch_start).total_seconds() * 1000000)
|
||||
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
INSERT INTO urls (url, title, visit_count, typed_count, last_visit_time, hidden)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
''', (url, title, 1, 0, chrome_timestamp, 0))
|
||||
cursor.execute('''
|
||||
INSERT INTO urls (url, title, visit_count, typed_count, last_visit_time, hidden)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
''', (url, title, 1, 0, chrome_timestamp, 0))
|
||||
|
||||
url_id = cursor.lastrowid
|
||||
url_id = cursor.lastrowid
|
||||
|
||||
cursor.execute('''
|
||||
INSERT INTO visits (url, visit_time, from_visit, transition, segment_id, visit_duration)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
''', (url_id, chrome_timestamp, 0, 805306368, 0, 0))
|
||||
cursor.execute('''
|
||||
INSERT INTO visits (url, visit_time, from_visit, transition, segment_id, visit_duration)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
''', (url_id, chrome_timestamp, 0, 805306368, 0, 0))
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
logger.info('Fake browsing history added successfully.')
|
||||
logger.info('Fake browsing history added successfully.')
|
||||
|
||||
controller = PythonController(self.vm_ip, self.server_port)
|
||||
controller = PythonController(self.vm_ip, self.server_port)
|
||||
|
||||
# get the path of the history file according to the platform
|
||||
os_type = controller.get_vm_platform()
|
||||
# get the path of the history file according to the platform
|
||||
os_type = controller.get_vm_platform()
|
||||
|
||||
if os_type == 'Windows':
|
||||
if os_type == 'Windows':
|
||||
chrome_history_path = controller.execute_python_command(
|
||||
"""import os; print(os.path.join(os.getenv('USERPROFILE'), "AppData", "Local", "Google", "Chrome", "User Data", "Default", "History"))""")[
|
||||
'output'].strip()
|
||||
elif os_type == 'Darwin':
|
||||
chrome_history_path = controller.execute_python_command(
|
||||
"""import os; print(os.path.join(os.getenv('HOME'), "Library", "Application Support", "Google", "Chrome", "Default", "History"))""")[
|
||||
'output'].strip()
|
||||
elif os_type == 'Linux':
|
||||
if "arm" in platform.machine():
|
||||
chrome_history_path = controller.execute_python_command(
|
||||
"""import os; print(os.path.join(os.getenv('USERPROFILE'), "AppData", "Local", "Google", "Chrome", "User Data", "Default", "History"))""")[
|
||||
"import os; print(os.path.join(os.getenv('HOME'), 'snap', 'chromium', 'common', 'chromium', 'Default', 'History'))")[
|
||||
'output'].strip()
|
||||
elif os_type == 'Darwin':
|
||||
chrome_history_path = controller.execute_python_command(
|
||||
"""import os; print(os.path.join(os.getenv('HOME'), "Library", "Application Support", "Google", "Chrome", "Default", "History"))""")[
|
||||
'output'].strip()
|
||||
elif os_type == 'Linux':
|
||||
if "arm" in platform.machine():
|
||||
chrome_history_path = controller.execute_python_command(
|
||||
"import os; print(os.path.join(os.getenv('HOME'), 'snap', 'chromium', 'common', 'chromium', 'Default', 'History'))")[
|
||||
'output'].strip()
|
||||
else:
|
||||
chrome_history_path = controller.execute_python_command(
|
||||
"import os; print(os.path.join(os.getenv('HOME'), '.config', 'google-chrome', 'Default', 'History'))")[
|
||||
'output'].strip()
|
||||
else:
|
||||
raise Exception('Unsupported operating system')
|
||||
chrome_history_path = controller.execute_python_command(
|
||||
"import os; print(os.path.join(os.getenv('HOME'), '.config', 'google-chrome', 'Default', 'History'))")[
|
||||
'output'].strip()
|
||||
else:
|
||||
raise Exception('Unsupported operating system')
|
||||
|
||||
form = MultipartEncoder({
|
||||
"file_path": chrome_history_path,
|
||||
"file_data": (os.path.basename(chrome_history_path), open(db_path, "rb"))
|
||||
})
|
||||
headers = {"Content-Type": form.content_type}
|
||||
logger.debug(form.content_type)
|
||||
form = MultipartEncoder({
|
||||
"file_path": chrome_history_path,
|
||||
"file_data": (os.path.basename(chrome_history_path), open(db_path, "rb"))
|
||||
})
|
||||
headers = {"Content-Type": form.content_type}
|
||||
logger.debug(form.content_type)
|
||||
|
||||
# send request to server to upload file
|
||||
try:
|
||||
logger.debug("REQUEST ADDRESS: %s", self.http_server + "/setup" + "/upload")
|
||||
response = requests.post(self.http_server + "/setup" + "/upload", headers=headers, data=form)
|
||||
if response.status_code == 200:
|
||||
logger.info("Command executed successfully: %s", response.text)
|
||||
else:
|
||||
logger.error("Failed to upload file. Status code: %s", response.text)
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error("An error occurred while trying to send the request: %s", e)
|
||||
# send request to server to upload file
|
||||
try:
|
||||
logger.debug("REQUEST ADDRESS: %s", self.http_server + "/setup" + "/upload")
|
||||
response = requests.post(self.http_server + "/setup" + "/upload", headers=headers, data=form)
|
||||
if response.status_code == 200:
|
||||
logger.info("Command executed successfully: %s", response.text)
|
||||
else:
|
||||
logger.error("Failed to upload file. Status code: %s", response.text)
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error("An error occurred while trying to send the request: %s", e)
|
||||
|
||||
self._execute_setup(["sudo chown -R user:user /home/user/.config/google-chrome/Default/History"], shell=True)
|
||||
self._execute_setup(["sudo chown -R user:user /home/user/.config/google-chrome/Default/History"], shell=True)
|
||||
|
|
|
|||
|
|
@ -158,7 +158,7 @@ class DesktopEnv(gym.Env):
|
|||
# Track whether environment has been used (step/setup) to optimize snapshot revert
|
||||
# docker, aws, gcp, azure are always unused as the emulator starts from a clean state
|
||||
# vmware, virtualbox are always used as the emulator starts from a dirty state
|
||||
if self.provider_name in {"docker", "aws", "gcp", "azure", "aliyun", "volcengine"}:
|
||||
if self.provider_name in {"docker", "aws", "gcp", "azure"}:
|
||||
self.is_environment_used = False
|
||||
elif self.provider_name in {"vmware", "virtualbox"}:
|
||||
self.is_environment_used = True
|
||||
|
|
@ -172,52 +172,54 @@ class DesktopEnv(gym.Env):
|
|||
else:
|
||||
self.path_to_vm = self.manager.get_vm_path(os_type=self.os_type, region=region, screen_size=(self.screen_width, self.screen_height))
|
||||
|
||||
self.snapshot_name = snapshot_name
|
||||
self.cache_dir_base: str = cache_dir
|
||||
# todo: add the logic to get the screen size from the VM
|
||||
self.headless = headless
|
||||
self.require_a11y_tree = require_a11y_tree
|
||||
self.require_terminal = require_terminal
|
||||
try:
|
||||
self.snapshot_name = snapshot_name
|
||||
self.cache_dir_base: str = cache_dir
|
||||
# todo: add the logic to get the screen size from the VM
|
||||
self.headless = headless
|
||||
self.require_a11y_tree = require_a11y_tree
|
||||
self.require_terminal = require_terminal
|
||||
|
||||
# Initialize emulator and controller
|
||||
logger.info("Initializing...")
|
||||
self._start_emulator()
|
||||
# Initialize emulator and controller
|
||||
logger.info("Initializing...")
|
||||
self._start_emulator()
|
||||
|
||||
# mode: human or machine
|
||||
self.instruction = None
|
||||
assert action_space in ["computer_13", "pyautogui", "claude_computer_use", "autoglm_computer_use"]
|
||||
self.action_space = action_space # todo: refactor it to the ActType
|
||||
|
||||
# episodic stuffs, like counters, will be updated or reset
|
||||
# when calling self.reset()
|
||||
self._traj_no: int = -1
|
||||
self._step_no: int = 0
|
||||
self.action_history: List[Dict[str, any]] = []
|
||||
# mode: human or machine
|
||||
self.instruction = None
|
||||
assert action_space in ["computer_13", "pyautogui", "claude_computer_use"]
|
||||
self.action_space = action_space # todo: refactor it to the ActType
|
||||
|
||||
# episodic stuffs, like counters, will be updated or reset
|
||||
# when calling self.reset()
|
||||
self._traj_no: int = -1
|
||||
self._step_no: int = 0
|
||||
self.action_history: List[Dict[str, any]] = []
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize DesktopEnv: {e}")
|
||||
# If initialization fails, we should clean up the VM
|
||||
try:
|
||||
self.close()
|
||||
self.manager.delete_vm(self.path_to_vm, self.region)
|
||||
logger.info(f"Cleaned up VM {self.path_to_vm}.")
|
||||
except Exception as cleanup_error:
|
||||
logger.error(f"Failed to clean up VM {self.path_to_vm}: {cleanup_error}")
|
||||
raise
|
||||
|
||||
def _start_emulator(self):
|
||||
try:
|
||||
# Power on the virtual machine
|
||||
self.provider.start_emulator(self.path_to_vm, self.headless, self.os_type)
|
||||
# Power on the virtual machine
|
||||
self.provider.start_emulator(self.path_to_vm, self.headless, self.os_type)
|
||||
|
||||
# Get the ip from the virtual machine, and setup the controller
|
||||
vm_ip_ports = self.provider.get_ip_address(self.path_to_vm).split(':')
|
||||
self.vm_ip = vm_ip_ports[0]
|
||||
# Get the ports from the virtual machine (for Docker provider only)
|
||||
if len(vm_ip_ports) > 1:
|
||||
self.server_port = int(vm_ip_ports[1])
|
||||
self.chromium_port = int(vm_ip_ports[2])
|
||||
self.vnc_port = int(vm_ip_ports[3])
|
||||
self.vlc_port = int(vm_ip_ports[4])
|
||||
self.controller = PythonController(vm_ip=self.vm_ip, server_port=self.server_port)
|
||||
self.setup_controller = SetupController(vm_ip=self.vm_ip, server_port=self.server_port, chromium_port=self.chromium_port, vlc_port=self.vlc_port, cache_dir=self.cache_dir_base, client_password=self.client_password, screen_width=self.screen_width, screen_height=self.screen_height)
|
||||
|
||||
except Exception as e:
|
||||
try:
|
||||
self.provider.stop_emulator(self.path_to_vm)
|
||||
except Exception as stop_err:
|
||||
logger.warning(f"Cleanup after interrupt failed: {stop_err}")
|
||||
raise
|
||||
# Get the ip from the virtual machine, and setup the controller
|
||||
vm_ip_ports = self.provider.get_ip_address(self.path_to_vm).split(':')
|
||||
self.vm_ip = vm_ip_ports[0]
|
||||
# Get the ports from the virtual machine (for Docker provider only)
|
||||
if len(vm_ip_ports) > 1:
|
||||
self.server_port = int(vm_ip_ports[1])
|
||||
self.chromium_port = int(vm_ip_ports[2])
|
||||
self.vnc_port = int(vm_ip_ports[3])
|
||||
self.vlc_port = int(vm_ip_ports[4])
|
||||
self.controller = PythonController(vm_ip=self.vm_ip, server_port=self.server_port)
|
||||
self.setup_controller = SetupController(vm_ip=self.vm_ip, server_port=self.server_port, chromium_port=self.chromium_port, vlc_port=self.vlc_port, cache_dir=self.cache_dir_base, client_password=self.client_password, screen_width=self.screen_width, screen_height=self.screen_height)
|
||||
|
||||
def _revert_to_snapshot(self):
|
||||
# Revert to certain snapshot of the virtual machine, and refresh the path to vm and ip of vm
|
||||
|
|
@ -391,12 +393,12 @@ class DesktopEnv(gym.Env):
|
|||
logger.info(f"Step {self._step_no} in trajectory {self._traj_no} with action: {action}")
|
||||
# handle the special actions
|
||||
if action in ['WAIT', 'FAIL', 'DONE'] or (type(action) == dict and action['action_type'] in ['WAIT', 'FAIL', 'DONE']):
|
||||
if action == 'WAIT' or (type(action) == dict and action.get('action_type') == 'WAIT'):
|
||||
if action == 'WAIT':
|
||||
time.sleep(pause)
|
||||
elif action == 'FAIL' or (type(action) == dict and action.get('action_type') == 'FAIL'):
|
||||
elif action == 'FAIL':
|
||||
done = True
|
||||
info = {"fail": True}
|
||||
elif action == 'DONE' or (type(action) == dict and action.get('action_type') == 'DONE'):
|
||||
elif action == 'DONE':
|
||||
done = True
|
||||
info = {"done": True}
|
||||
|
||||
|
|
@ -404,7 +406,7 @@ class DesktopEnv(gym.Env):
|
|||
# the set of all possible actions defined in the action representation
|
||||
self.controller.execute_action(action)
|
||||
elif self.action_space == "pyautogui" or self.action_space == "claude_computer_use":
|
||||
if action in ['WAIT', 'FAIL', 'DONE'] or (type(action) == dict and action.get('action_type') in ['WAIT', 'FAIL', 'DONE']):
|
||||
if action in ['WAIT', 'FAIL', 'DONE']:
|
||||
self.controller.execute_action(action)
|
||||
else:
|
||||
# the set of all possible python commands insides `pyautogui`
|
||||
|
|
@ -428,22 +430,19 @@ class DesktopEnv(gym.Env):
|
|||
"""
|
||||
|
||||
postconfig = self.evaluator.get("postconfig", [])
|
||||
self.setup_controller.setup(postconfig, self.enable_proxy)
|
||||
self.setup_controller.setup(postconfig)
|
||||
# Mark environment as used if there were postconfig setup operations
|
||||
if postconfig:
|
||||
self.is_environment_used = True
|
||||
|
||||
if self.evaluator['func'] == "infeasible":
|
||||
if len(self.action_history) > 0:
|
||||
last_action = self.action_history[-1]
|
||||
if last_action == "FAIL" or (type(last_action) == dict and last_action.get('action_type') == 'FAIL'):
|
||||
return 1
|
||||
return 0
|
||||
if len(self.action_history) > 0 and self.action_history[-1] == "FAIL":
|
||||
return 1
|
||||
else:
|
||||
return 0
|
||||
else:
|
||||
if len(self.action_history) > 0:
|
||||
last_action = self.action_history[-1]
|
||||
if last_action == "FAIL" or (type(last_action) == dict and last_action.get('action_type') == 'FAIL'):
|
||||
return 0
|
||||
if len(self.action_history) > 0 and self.action_history[-1] == "FAIL":
|
||||
return 0
|
||||
|
||||
if type(self.metric) == list:
|
||||
# Multiple metrics to evaluate whether the task is successfully completed
|
||||
|
|
|
|||
|
|
@ -1,499 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import re
|
||||
from typing import Callable, Any, Optional, Tuple
|
||||
from typing import List, Dict, Union
|
||||
|
||||
import gymnasium as gym
|
||||
|
||||
from desktop_env.controllers.python import PythonController
|
||||
from desktop_env.controllers.setup import SetupController
|
||||
from desktop_env.evaluators import metrics, getters
|
||||
from desktop_env.providers import create_vm_manager_and_provider
|
||||
|
||||
logger = logging.getLogger("desktopenv.env")
|
||||
|
||||
Metric = Callable[[Any, Any], float]
|
||||
Getter = Callable[[gym.Env, Dict[str, Any]], Any]
|
||||
|
||||
MAX_RETRIES = 5 # Maximum retries for environment setup
|
||||
|
||||
|
||||
|
||||
def _fix_pyautogui_less_than_bug(command: str) -> str:
|
||||
"""
|
||||
Fix PyAutoGUI '<' character bug by converting it to hotkey("shift", ',') calls.
|
||||
|
||||
This fixes the known PyAutoGUI issue where typing '<' produces '>' instead.
|
||||
References:
|
||||
- https://github.com/asweigart/pyautogui/issues/198
|
||||
- https://github.com/xlang-ai/OSWorld/issues/257
|
||||
|
||||
Args:
|
||||
command (str): The original pyautogui command
|
||||
|
||||
Returns:
|
||||
str: The fixed command with '<' characters handled properly
|
||||
"""
|
||||
# Pattern to match press('<') or press('\u003c') calls
|
||||
press_pattern = r'pyautogui\.press\(["\'](?:<|\\u003c)["\']\)'
|
||||
|
||||
# Handle press('<') calls
|
||||
def replace_press_less_than(match):
|
||||
return 'pyautogui.hotkey("shift", ",")'
|
||||
|
||||
# First handle press('<') calls
|
||||
command = re.sub(press_pattern, replace_press_less_than, command)
|
||||
|
||||
# Pattern to match typewrite calls with quoted strings
|
||||
typewrite_pattern = r'pyautogui\.typewrite\((["\'])(.*?)\1\)'
|
||||
|
||||
# Then handle typewrite calls
|
||||
def process_typewrite_match(match):
|
||||
quote_char = match.group(1)
|
||||
content = match.group(2)
|
||||
|
||||
# Preprocess: Try to decode Unicode escapes like \u003c to actual '<'
|
||||
# This handles cases where '<' is represented as escaped Unicode
|
||||
try:
|
||||
# Attempt to decode unicode escapes
|
||||
decoded_content = content.encode('utf-8').decode('unicode_escape')
|
||||
content = decoded_content
|
||||
except UnicodeDecodeError:
|
||||
# If decoding fails, proceed with original content to avoid breaking existing logic
|
||||
pass # English comment: Graceful degradation - fall back to original content if decoding fails
|
||||
|
||||
# Check if content contains '<'
|
||||
if '<' not in content:
|
||||
return match.group(0)
|
||||
|
||||
# Split by '<' and rebuild
|
||||
parts = content.split('<')
|
||||
result_parts = []
|
||||
|
||||
for i, part in enumerate(parts):
|
||||
if i == 0:
|
||||
# First part
|
||||
if part:
|
||||
result_parts.append(f"pyautogui.typewrite({quote_char}{part}{quote_char})")
|
||||
else:
|
||||
# Add hotkey for '<' and then typewrite for the rest
|
||||
result_parts.append('pyautogui.hotkey("shift", ",")')
|
||||
if part:
|
||||
result_parts.append(f"pyautogui.typewrite({quote_char}{part}{quote_char})")
|
||||
|
||||
return '; '.join(result_parts)
|
||||
|
||||
command = re.sub(typewrite_pattern, process_typewrite_match, command)
|
||||
|
||||
return command
|
||||
|
||||
|
||||
class DesktopEnv(gym.Env):
|
||||
"""
|
||||
DesktopEnv with OpenAI Gym interface. It provides a desktop environment for setting and evaluating desktop automation tasks.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
provider_name: str = "vmware",
|
||||
region: str = None,
|
||||
path_to_vm: str = None,
|
||||
snapshot_name: str = "init_state",
|
||||
action_space: str = "pyautogui",
|
||||
cache_dir: str = "cache",
|
||||
screen_size: Tuple[int] = (int(os.environ.get("SCREEN_WIDTH", 1920)), int(os.environ.get("SCREEN_HEIGHT", 1080))),
|
||||
headless: bool = False,
|
||||
require_a11y_tree: bool = True,
|
||||
require_terminal: bool = False,
|
||||
os_type: str = "Ubuntu",
|
||||
enable_proxy: bool = False,
|
||||
client_password: str = "",
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
provider_name (str): virtualization provider name, default to "vmware"
|
||||
region (str): the region for allocate machines, work for cloud services, default to "us-east-1"
|
||||
path_to_vm (str): path to .vmx file
|
||||
snapshot_name (str): snapshot name to revert to, default to "init_state"
|
||||
action_space (str): "computer_13" | "pyautogui"
|
||||
cache_dir (str): cache directory to cache task-related stuffs like
|
||||
reference file for evaluation
|
||||
screen_size (Tuple[int]): screen size of the VM
|
||||
headless (bool): whether to run the VM in headless mode
|
||||
require_a11y_tree (bool): whether to require accessibility tree
|
||||
require_terminal (bool): whether to require terminal output
|
||||
os_type (str): operating system type, default to "Ubuntu"
|
||||
enable_proxy (bool): whether to enable proxy support, default to False
|
||||
"""
|
||||
# Initialize VM manager and vitualization provider
|
||||
self.region = region
|
||||
self.provider_name = provider_name
|
||||
self.enable_proxy = enable_proxy # Store proxy enablement setting
|
||||
if client_password == "":
|
||||
if self.provider_name == "aws":
|
||||
self.client_password = "osworld-public-evaluation"
|
||||
else:
|
||||
self.client_password = "password"
|
||||
else:
|
||||
self.client_password = client_password
|
||||
|
||||
self.screen_width = screen_size[0]
|
||||
self.screen_height = screen_size[1]
|
||||
|
||||
# Default
|
||||
self.server_port = 5000
|
||||
self.chromium_port = 9222
|
||||
self.vnc_port = 8006
|
||||
self.vlc_port = 8080
|
||||
|
||||
# Initialize with default (no proxy) provider
|
||||
self.current_use_proxy = False
|
||||
self.manager, self.provider = None, None
|
||||
self.os_type = os_type
|
||||
self.path_to_vm = path_to_vm
|
||||
# Track whether environment has been used (step/setup) to optimize snapshot revert
|
||||
# docker, aws, gcp, azure are always unused as the emulator starts from a clean state
|
||||
# vmware, virtualbox are always used as the emulator starts from a dirty state
|
||||
if self.provider_name in {"docker", "aws", "gcp", "azure", "aliyun", "volcengine"}:
|
||||
self.is_environment_used = False
|
||||
elif self.provider_name in {"vmware", "virtualbox"}:
|
||||
self.is_environment_used = True
|
||||
else:
|
||||
raise ValueError(f"Invalid provider name: {self.provider_name}")
|
||||
|
||||
self.snapshot_name = snapshot_name
|
||||
self.cache_dir_base: str = cache_dir
|
||||
self.headless = headless
|
||||
self.require_a11y_tree = require_a11y_tree
|
||||
self.require_terminal = require_terminal
|
||||
|
||||
# mode: human or machine
|
||||
self.instruction = None
|
||||
assert action_space in ["computer_13", "pyautogui", "claude_computer_use", "autoglm_computer_use"]
|
||||
self.action_space = action_space # todo: refactor it to the ActType
|
||||
|
||||
# episodic stuffs, like counters, will be updated or reset
|
||||
# when calling self.reset()
|
||||
self._traj_no: int = -1
|
||||
self._step_no: int = 0
|
||||
self.action_history: List[Dict[str, any]] = []
|
||||
|
||||
def start(self):
|
||||
# Initialize emulator and controller
|
||||
if not self.manager and not self.provider:
|
||||
logger.info("Initializing...")
|
||||
self.manager, self.provider = create_vm_manager_and_provider(self.provider_name, self.region, use_proxy=False)
|
||||
|
||||
if self.path_to_vm:
|
||||
self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(self.path_to_vm))) \
|
||||
if self.provider_name in {"vmware", "virtualbox"} else self.path_to_vm
|
||||
else:
|
||||
self.path_to_vm = self.manager.get_vm_path(os_type=self.os_type, region=self.region, screen_size=(self.screen_width, self.screen_height))
|
||||
|
||||
self._start_emulator()
|
||||
|
||||
def _start_emulator(self):
|
||||
try:
|
||||
# Power on the virtual machine
|
||||
self.provider.start_emulator(self.path_to_vm, self.headless, self.os_type)
|
||||
|
||||
# Get the ip from the virtual machine, and setup the controller
|
||||
vm_ip_ports = self.provider.get_ip_address(self.path_to_vm).split(':')
|
||||
self.vm_ip = vm_ip_ports[0]
|
||||
# Get the ports from the virtual machine (for Docker provider only)
|
||||
if len(vm_ip_ports) > 1:
|
||||
self.server_port = int(vm_ip_ports[1])
|
||||
self.chromium_port = int(vm_ip_ports[2])
|
||||
self.vnc_port = int(vm_ip_ports[3])
|
||||
self.vlc_port = int(vm_ip_ports[4])
|
||||
self.controller = PythonController(vm_ip=self.vm_ip, server_port=self.server_port)
|
||||
self.setup_controller = SetupController(vm_ip=self.vm_ip, server_port=self.server_port, chromium_port=self.chromium_port, vlc_port=self.vlc_port, cache_dir=self.cache_dir_base, client_password=self.client_password, screen_width=self.screen_width, screen_height=self.screen_height)
|
||||
|
||||
except Exception as e:
|
||||
try:
|
||||
self.provider.stop_emulator(self.path_to_vm)
|
||||
except Exception as stop_err:
|
||||
logger.warning(f"Cleanup after interrupt failed: {stop_err}")
|
||||
raise
|
||||
|
||||
def _revert_to_snapshot(self):
|
||||
# Revert to certain snapshot of the virtual machine, and refresh the path to vm and ip of vm
|
||||
# due to the fact it could be changed when implemented by cloud services
|
||||
path_to_vm = self.provider.revert_to_snapshot(self.path_to_vm, self.snapshot_name)
|
||||
if path_to_vm and not path_to_vm == self.path_to_vm:
|
||||
# path_to_vm has to be a new path
|
||||
|
||||
self.manager.delete_vm(self.path_to_vm, self.region)
|
||||
self.manager.add_vm(path_to_vm, self.region)
|
||||
self.manager.occupy_vm(path_to_vm, os.getpid(), self.region)
|
||||
self.path_to_vm = path_to_vm
|
||||
|
||||
def _save_state(self, snapshot_name=None):
|
||||
# Save the current virtual machine state to a certain snapshot name
|
||||
self.provider.save_state(self.path_to_vm, snapshot_name)
|
||||
|
||||
def close(self):
|
||||
# Close (release) the virtual machine
|
||||
self.provider.stop_emulator(self.path_to_vm)
|
||||
|
||||
def reset(self, task_config: Optional[Dict[str, Any]] = None, seed=None, options=None) -> Dict[str, Any]:
|
||||
|
||||
# Reset to certain task in OSWorld
|
||||
logger.info("Resetting environment...")
|
||||
logger.info("Switching task...")
|
||||
logger.info("Setting counters...")
|
||||
self._traj_no += 1
|
||||
self._step_no = 0
|
||||
self.action_history.clear()
|
||||
|
||||
for attempt in range(MAX_RETRIES):
|
||||
# Only revert to snapshot if environment has been used (step/setup)
|
||||
# This optimization is especially important for cloud providers like AWS
|
||||
# where unnecessary snapshot operations are costly and time-consuming
|
||||
|
||||
if task_config is not None:
|
||||
# Only consider task proxy requirement if proxy is enabled at system level
|
||||
task_use_proxy = task_config.get("proxy", False) and self.enable_proxy
|
||||
if not self.enable_proxy and task_config.get("proxy", False):
|
||||
logger.info("Task requires proxy but proxy is disabled at system level, ignoring proxy requirement.")
|
||||
|
||||
if task_use_proxy != self.current_use_proxy:
|
||||
# keep because get_info_from_website depend on this
|
||||
self.current_use_proxy = task_use_proxy
|
||||
|
||||
if self.is_environment_used:
|
||||
logger.info("Environment has been used, reverting to snapshot {}...".format(self.snapshot_name))
|
||||
self._revert_to_snapshot()
|
||||
logger.info("Starting emulator...")
|
||||
self._start_emulator()
|
||||
logger.info("Emulator started.")
|
||||
# Reset the usage flag after reverting
|
||||
self.is_environment_used = False
|
||||
else:
|
||||
logger.info("Environment is clean, skipping snapshot revert (provider: {}).".format(self.provider_name))
|
||||
|
||||
if task_config is not None:
|
||||
if task_config.get("proxy", False) and self.enable_proxy:
|
||||
# If using proxy and proxy is enabled, set up the proxy configuration
|
||||
self.setup_controller._proxy_setup(self.client_password)
|
||||
self._set_task_info(task_config)
|
||||
self.setup_controller.reset_cache_dir(self.cache_dir)
|
||||
logger.info("Setting up environment...")
|
||||
success = self.setup_controller.setup(self.config, task_config.get("proxy", False) and self.enable_proxy)
|
||||
if success:
|
||||
# Mark environment as used when setup is successfully executed
|
||||
if self.config: # Only mark as used if there were actual setup operations
|
||||
self.is_environment_used = True
|
||||
break
|
||||
else:
|
||||
logger.error(
|
||||
"Environment setup failed, retrying (%d/%d)...",
|
||||
attempt + 1,
|
||||
MAX_RETRIES,
|
||||
)
|
||||
time.sleep(5)
|
||||
else:
|
||||
break
|
||||
|
||||
logger.info("Environment setup complete.")
|
||||
|
||||
observation = self._get_obs()
|
||||
return observation
|
||||
|
||||
def _get_obs(self):
|
||||
# We provide screenshot, accessibility_tree (optional), terminal (optional), and instruction.
|
||||
# can be customized and scaled
|
||||
return {
|
||||
"screenshot": self.controller.get_screenshot(),
|
||||
"accessibility_tree": self.controller.get_accessibility_tree() if self.require_a11y_tree else None,
|
||||
"terminal": self.controller.get_terminal_output() if self.require_terminal else None,
|
||||
"instruction": self.instruction
|
||||
}
|
||||
|
||||
@property
|
||||
def vm_platform(self):
|
||||
return self.controller.get_vm_platform()
|
||||
|
||||
@property
|
||||
def vm_screen_size(self):
|
||||
return self.controller.get_vm_screen_size()
|
||||
|
||||
def _set_task_info(self, task_config: Dict[str, Any]):
|
||||
"""Set task info (proxy logic is handled in reset method)"""
|
||||
self.task_id: str = task_config["id"]
|
||||
self.cache_dir: str = os.path.join(self.cache_dir_base, self.task_id)
|
||||
os.makedirs(self.cache_dir, exist_ok=True)
|
||||
self.instruction = task_config["instruction"]
|
||||
self.config = task_config["config"] if "config" in task_config else []
|
||||
|
||||
self._set_evaluator_info(task_config)
|
||||
|
||||
def _set_evaluator_info(self, task_config: Dict[str, Any]):
|
||||
"""Set evaluator information from task config"""
|
||||
if "evaluator" not in task_config:
|
||||
return
|
||||
# evaluator dict
|
||||
# func -> metric function string, or list of metric function strings
|
||||
# conj -> conjunction of multiple metrics if func is a list with length > 1, "and"/"or"
|
||||
# result -> result getter config, or list of result getter configs
|
||||
# expected (optional) -> expected getter config, or list of expected getter configs
|
||||
# options (optional) -> metric options, or list of metric options
|
||||
# if func is a str list, then result, expected (if exists), options (if exists) should also be lists of the same length
|
||||
# even if one of the metrics does not need expected or options field, it should be included in the list with None
|
||||
self.evaluator = task_config["evaluator"]
|
||||
self.metric: Metric = [getattr(metrics, func) for func in self.evaluator["func"]] \
|
||||
if isinstance(self.evaluator["func"], list) \
|
||||
else getattr(metrics, self.evaluator["func"])
|
||||
self.metric_conj: str = self.evaluator.get("conj", "and") # take conjunction of multiple metrics
|
||||
if "result" in self.evaluator and len(self.evaluator["result"]) > 0:
|
||||
self.result_getter: Getter = [getattr(getters, "get_{:}".format(res["type"])) for res in
|
||||
self.evaluator["result"]] \
|
||||
if isinstance(self.evaluator["result"], list) \
|
||||
else getattr(getters, "get_{:}".format(self.evaluator["result"]["type"]))
|
||||
else:
|
||||
self.result_getter = [None] * len(self.metric) \
|
||||
if isinstance(self.metric, list) \
|
||||
else None
|
||||
|
||||
if "expected" in self.evaluator and len(self.evaluator["expected"]) > 0:
|
||||
self.expected_getter: Getter = [getattr(getters, "get_{:}".format(exp["type"])) if exp else None for exp in
|
||||
self.evaluator["expected"]] \
|
||||
if isinstance(self.evaluator["expected"], list) \
|
||||
else getattr(getters, "get_{:}".format(self.evaluator["expected"]["type"]))
|
||||
else:
|
||||
self.expected_getter = [None] * len(self.metric) \
|
||||
if isinstance(self.metric, list) \
|
||||
else None
|
||||
self.metric_options: Union[List[Dict[str, Any]], Dict[str, Any]] = [opt if opt else {} for opt in
|
||||
self.evaluator["options"]] \
|
||||
if isinstance(self.evaluator.get("options", {}), list) \
|
||||
else self.evaluator["options"] \
|
||||
if "options" in self.evaluator \
|
||||
else [{}] * len(self.metric) \
|
||||
if isinstance(self.metric, list) \
|
||||
else {}
|
||||
|
||||
assert (not isinstance(self.evaluator["func"], list)
|
||||
or (len(self.metric) == len(self.result_getter) == len(self.expected_getter) == len(
|
||||
self.metric_options)))
|
||||
|
||||
def step(self, action, pause=2):
|
||||
self._step_no += 1
|
||||
self.action_history.append(action)
|
||||
|
||||
# Mark environment as used when step is called
|
||||
self.is_environment_used = True
|
||||
|
||||
reward = 0 # todo: Define reward calculation for each example
|
||||
done = False # todo: Define episode termination condition for each example
|
||||
info = {}
|
||||
logger.info(f"Step {self._step_no} in trajectory {self._traj_no} with action: {action}")
|
||||
# handle the special actions
|
||||
if action in ['WAIT', 'FAIL', 'DONE'] or (type(action) == dict and action['action_type'] in ['WAIT', 'FAIL', 'DONE']):
|
||||
if action == 'WAIT' or (type(action) == dict and action.get('action_type') == 'WAIT'):
|
||||
time.sleep(pause)
|
||||
elif action == 'FAIL' or (type(action) == dict and action.get('action_type') == 'FAIL'):
|
||||
done = True
|
||||
info = {"fail": True}
|
||||
elif action == 'DONE' or (type(action) == dict and action.get('action_type') == 'DONE'):
|
||||
done = True
|
||||
info = {"done": True}
|
||||
|
||||
if self.action_space == "computer_13":
|
||||
# the set of all possible actions defined in the action representation
|
||||
self.controller.execute_action(action)
|
||||
elif self.action_space == "pyautogui" or self.action_space == "claude_computer_use":
|
||||
if action in ['WAIT', 'FAIL', 'DONE'] or (type(action) == dict and action.get('action_type') in ['WAIT', 'FAIL', 'DONE']):
|
||||
self.controller.execute_action(action)
|
||||
else:
|
||||
# the set of all possible python commands insides `pyautogui`
|
||||
if type(action) == str:
|
||||
# Fix PyAutoGUI '<' character bug before execution
|
||||
fixed_command = _fix_pyautogui_less_than_bug(action)
|
||||
self.controller.execute_python_command(fixed_command)
|
||||
elif type(action) == dict:
|
||||
# Fix PyAutoGUI '<' character bug before execution
|
||||
fixed_command = _fix_pyautogui_less_than_bug(action['command'])
|
||||
self.controller.execute_python_command(fixed_command)
|
||||
|
||||
time.sleep(pause)
|
||||
observation = self._get_obs()
|
||||
|
||||
return observation, reward, done, info
|
||||
|
||||
def evaluate(self):
|
||||
"""
|
||||
Evaluate whether the task is successfully completed.
|
||||
"""
|
||||
|
||||
postconfig = self.evaluator.get("postconfig", [])
|
||||
self.setup_controller.setup(postconfig, self.enable_proxy)
|
||||
# Mark environment as used if there were postconfig setup operations
|
||||
if postconfig:
|
||||
self.is_environment_used = True
|
||||
|
||||
if self.evaluator['func'] == "infeasible":
|
||||
if len(self.action_history) > 0:
|
||||
last_action = self.action_history[-1]
|
||||
if last_action == "FAIL" or (type(last_action) == dict and last_action.get('action_type') == 'FAIL'):
|
||||
return 1
|
||||
return 0
|
||||
else:
|
||||
if len(self.action_history) > 0:
|
||||
last_action = self.action_history[-1]
|
||||
if last_action == "FAIL" or (type(last_action) == dict and last_action.get('action_type') == 'FAIL'):
|
||||
return 0
|
||||
|
||||
if type(self.metric) == list:
|
||||
# Multiple metrics to evaluate whether the task is successfully completed
|
||||
results = []
|
||||
assert len(self.metric) == len(self.result_getter), "The number of metrics and result getters must be the same"
|
||||
if "expected" in self.evaluator:
|
||||
assert len(self.metric) == len(self.expected_getter), "The number of metrics and expected getters must be the same"
|
||||
for idx, metric in enumerate(self.metric):
|
||||
try:
|
||||
config = self.evaluator["result"][idx]
|
||||
result_state = self.result_getter[idx](self, config)
|
||||
except FileNotFoundError:
|
||||
logger.error("File not found!")
|
||||
if self.metric_conj == 'and':
|
||||
return 0
|
||||
|
||||
if "expected" in self.evaluator and self.expected_getter and self.evaluator["expected"]:
|
||||
expected_state = self.expected_getter[idx](self, self.evaluator["expected"][idx])
|
||||
metric: int = metric(result_state, expected_state, **self.metric_options[idx])
|
||||
else:
|
||||
metric: int = metric(result_state, **self.metric_options[idx])
|
||||
|
||||
if self.metric_conj == 'and' and float(metric) == 0.0:
|
||||
return 0
|
||||
elif self.metric_conj == 'or' and float(metric) == 1.0:
|
||||
return 1
|
||||
else:
|
||||
results.append(metric)
|
||||
|
||||
return sum(results) / len(results) if self.metric_conj == 'and' else max(results)
|
||||
else:
|
||||
# Single metric to evaluate whether the task is successfully completed
|
||||
try:
|
||||
result_state = self.result_getter(self, self.evaluator["result"])
|
||||
except FileNotFoundError:
|
||||
logger.error("File not found!")
|
||||
return 0
|
||||
|
||||
if "expected" in self.evaluator and self.expected_getter and self.evaluator["expected"]:
|
||||
expected_state = self.expected_getter(self, self.evaluator["expected"])
|
||||
metric: float = self.metric(result_state, expected_state, **self.metric_options)
|
||||
else:
|
||||
metric: float = self.metric(result_state, **self.metric_options)
|
||||
|
||||
return metric
|
||||
|
||||
def render(self, mode='rgb_array'):
|
||||
if mode == 'rgb_array':
|
||||
return self.controller.get_screenshot()
|
||||
else:
|
||||
raise ValueError('Unsupported render mode: {}'.format(mode))
|
||||
|
|
@ -16,7 +16,6 @@ from .chrome import (
|
|||
get_active_tab_info,
|
||||
get_enable_do_not_track,
|
||||
get_enable_enhanced_safety_browsing,
|
||||
get_enable_safe_browsing,
|
||||
get_new_startup_page,
|
||||
get_find_unpacked_extension_path,
|
||||
get_data_delete_automacally,
|
||||
|
|
|
|||
|
|
@ -827,8 +827,8 @@ def get_active_tab_info(env, config: Dict[str, str]):
|
|||
|
||||
try:
|
||||
logger.info(f"[ACTIVE_TAB_INFO] Navigating to URL: {active_tab_url}")
|
||||
page.goto(active_tab_url, wait_until='load', timeout=timeout_ms)
|
||||
page.wait_for_load_state('load', timeout=timeout_ms) # Wait for the 'load' event to complete
|
||||
page.goto(active_tab_url, wait_until='networkidle', timeout=timeout_ms)
|
||||
page.wait_for_load_state('networkidle', timeout=timeout_ms) # Wait for the 'load' event to complete
|
||||
|
||||
active_tab_info = {
|
||||
'title': page.title(),
|
||||
|
|
@ -1304,40 +1304,6 @@ def get_enable_enhanced_safety_browsing(env, config: Dict[str, str]):
|
|||
return "Google"
|
||||
|
||||
|
||||
def get_enable_safe_browsing(env, config: Dict[str, str]):
|
||||
os_type = env.vm_platform
|
||||
if os_type == 'Windows':
|
||||
preference_file_path = env.controller.execute_python_command("""import os; print(os.path.join(os.getenv('LOCALAPPDATA'),
|
||||
'Google\\Chrome\\User Data\\Default\\Preferences'))""")['output'].strip()
|
||||
elif os_type == 'Darwin':
|
||||
preference_file_path = env.controller.execute_python_command(
|
||||
"import os; print(os.path.join(os.getenv('HOME'), 'Library/Application Support/Google/Chrome/Default/Preferences'))")[
|
||||
'output'].strip()
|
||||
elif os_type == 'Linux':
|
||||
if "arm" in platform.machine():
|
||||
preference_file_path = env.controller.execute_python_command(
|
||||
"import os; print(os.path.join(os.getenv('HOME'), 'snap/chromium/common/chromium/Default/Preferences'))")[
|
||||
'output'].strip()
|
||||
else:
|
||||
preference_file_path = env.controller.execute_python_command(
|
||||
"import os; print(os.path.join(os.getenv('HOME'), '.config/google-chrome/Default/Preferences'))")[
|
||||
'output'].strip()
|
||||
|
||||
else:
|
||||
raise Exception('Unsupported operating system')
|
||||
|
||||
try:
|
||||
content = env.controller.get_file(preference_file_path)
|
||||
data = json.loads(content)
|
||||
|
||||
safebrowsing = data.get('safebrowsing', {})
|
||||
is_enhanced = bool(safebrowsing.get('enhanced', False))
|
||||
is_enabled = bool(safebrowsing.get('enabled', False))
|
||||
return "true" if (is_enhanced or is_enabled) else "false"
|
||||
except Exception as e:
|
||||
logger.error(f"Error: {e}")
|
||||
return "false"
|
||||
|
||||
def get_new_startup_page(env, config: Dict[str, str]):
|
||||
os_type = env.vm_platform
|
||||
if os_type == 'Windows':
|
||||
|
|
|
|||
|
|
@ -2,8 +2,6 @@ import functools
|
|||
import itertools
|
||||
import logging
|
||||
import os.path
|
||||
import re
|
||||
import unicodedata
|
||||
|
||||
# import operator
|
||||
from numbers import Number
|
||||
|
|
@ -746,18 +744,6 @@ def compare_table(result: str, expected: str = None, **options) -> float:
|
|||
# }}} function compare_table #
|
||||
|
||||
|
||||
def _normalize_city_string(value: Any) -> str:
|
||||
"""Lowercase, strip punctuation, and remove accents for tolerant matching."""
|
||||
if value is None:
|
||||
return ""
|
||||
if not isinstance(value, str):
|
||||
value = str(value)
|
||||
normalized = unicodedata.normalize("NFKD", value)
|
||||
normalized = "".join(ch for ch in normalized if not unicodedata.combining(ch))
|
||||
normalized = re.sub(r"[^a-z0-9]+", " ", normalized.lower())
|
||||
return normalized.strip()
|
||||
|
||||
|
||||
def compare_conference_city_in_order(actual_city_list_path, expected_city):
|
||||
expected_city_list = expected_city["expected"]
|
||||
wb = openpyxl.load_workbook(actual_city_list_path)
|
||||
|
|
@ -766,35 +752,38 @@ def compare_conference_city_in_order(actual_city_list_path, expected_city):
|
|||
for row in sheet["C2:C22"]:
|
||||
for cell in row:
|
||||
actual_city_list.append(cell.value)
|
||||
|
||||
# expected_city is the city that we want to compare with the actual city list
|
||||
# must in order index
|
||||
# debug
|
||||
try:
|
||||
for i, actual_city in enumerate(actual_city_list):
|
||||
actual_normalized = _normalize_city_string(actual_city)
|
||||
expected_entry = expected_city_list[i]
|
||||
for i in range(len(actual_city_list)):
|
||||
if isinstance(expected_city_list[i], str):
|
||||
if expected_city_list[i] not in actual_city_list[i]:
|
||||
logger.debug(
|
||||
f"Expected city {expected_city_list[i]}; Actual city {actual_city_list[i]}"
|
||||
)
|
||||
print(
|
||||
f"Expected city {expected_city_list[i]}; Actual city {actual_city_list[i]}"
|
||||
)
|
||||
return 0.0
|
||||
|
||||
elif isinstance(expected_city_list[i], List):
|
||||
if not any(
|
||||
possible_str in actual_city_list[i]
|
||||
for possible_str in expected_city_list[i]
|
||||
):
|
||||
logger.debug(
|
||||
f"Expected city {expected_city_list[i]}; Actual city {actual_city_list[i]}"
|
||||
)
|
||||
print(
|
||||
f"Expected city {expected_city_list[i]}; Actual city {actual_city_list[i]}"
|
||||
)
|
||||
return 0.0
|
||||
|
||||
if isinstance(expected_entry, str):
|
||||
expected_candidates = [expected_entry]
|
||||
elif isinstance(expected_entry, List):
|
||||
expected_candidates = expected_entry
|
||||
else:
|
||||
raise TypeError("Expected city should be a string or a list of strings")
|
||||
|
||||
matched = False
|
||||
for candidate in expected_candidates:
|
||||
normalized_candidate = _normalize_city_string(candidate)
|
||||
if normalized_candidate and normalized_candidate in actual_normalized:
|
||||
matched = True
|
||||
break
|
||||
|
||||
if not matched:
|
||||
logger.debug(
|
||||
f"Expected city {expected_entry}; Actual city {actual_city}"
|
||||
)
|
||||
print(f"Expected city {expected_entry}; Actual city {actual_city}")
|
||||
return 0.0
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"Error comparing conference cities: {exc}")
|
||||
except:
|
||||
return 0.0
|
||||
|
||||
return 1.0
|
||||
|
|
|
|||
|
|
@ -31,13 +31,5 @@ def create_vm_manager_and_provider(provider_name: str, region: str, use_proxy: b
|
|||
from desktop_env.providers.docker.manager import DockerVMManager
|
||||
from desktop_env.providers.docker.provider import DockerProvider
|
||||
return DockerVMManager(), DockerProvider(region)
|
||||
elif provider_name == "aliyun":
|
||||
from desktop_env.providers.aliyun.manager import AliyunVMManager
|
||||
from desktop_env.providers.aliyun.provider import AliyunProvider
|
||||
return AliyunVMManager(), AliyunProvider()
|
||||
elif provider_name == "volcengine":
|
||||
from desktop_env.providers.volcengine.manager import VolcengineVMManager
|
||||
from desktop_env.providers.volcengine.provider import VolcengineProvider
|
||||
return VolcengineVMManager(), VolcengineProvider()
|
||||
else:
|
||||
raise NotImplementedError(f"{provider_name} not implemented!")
|
||||
|
|
|
|||
|
|
@ -1,80 +0,0 @@
|
|||
# Aliyun ECS Provider Configuration Guide
|
||||
|
||||
This guide explains how to configure and use the Aliyun ECS provider for OSWorld desktop environments.
|
||||
|
||||
## Configuration Process
|
||||
|
||||
1. **Aliyun Account**: You need an active Aliyun Cloud account. This script uses pay-as-you-go billing by default, so ensure your account balance is above 100.
|
||||
2. **Access Keys**: Create AccessKey ID and AccessKey Secret in Aliyun RAM Access Control Console and grant ECS control permissions
|
||||
3. **VPC Setup**: Create a VPC, VSwitch, and Security Group in your target region
|
||||
4. **Custom Images**: Create OSWorld custom images
|
||||
5. It is recommended to manually complete the ECS creation process once to record all required environment variable information.
|
||||
|
||||
## Environment Variables
|
||||
|
||||
Set the following environment variables in your `.env` file:
|
||||
|
||||
```bash
|
||||
# Aliyun Access Credentials
|
||||
ALIYUN_ACCESS_KEY_ID=your_access_key_id
|
||||
ALIYUN_ACCESS_KEY_SECRET=your_access_key_secret
|
||||
|
||||
# ECS Configuration Information
|
||||
ALIYUN_REGION=eu-central-1
|
||||
ALIYUN_IMAGE_ID=your_image_id
|
||||
ALIYUN_INSTANCE_TYPE=ecs.e-c1m2.large
|
||||
ALIYUN_VSWITCH_ID=vsw-xxxxxxxxx
|
||||
ALIYUN_SECURITY_GROUP_ID=sg-xxxxxxxxx
|
||||
```
|
||||
|
||||
## Required Aliyun Resources
|
||||
|
||||
### 1. VPC and VSwitch
|
||||
- Create a VPC in your target region
|
||||
- Create a VSwitch within the VPC
|
||||
- Ensure the VSwitch has internet access for VNC connectivity
|
||||
|
||||
### 2. Security Group
|
||||
**⚠️ Important**: Please strictly follow the port settings below to prevent OSWorld tasks from failing due to connection issues:
|
||||
|
||||
#### Inbound Rules (8 rules required)
|
||||
|
||||
| Type | Protocol | Port Range | Source | Description |
|
||||
|------|----------|------------|--------|-------------|
|
||||
| SSH | TCP | 22 | 0.0.0.0/0 | SSH access |
|
||||
| HTTP | TCP | 80 | 172.31.0.0/16 | HTTP traffic |
|
||||
| Custom TCP | TCP | 5000 | 172.31.0.0/16 | OSWorld backend service |
|
||||
| Custom TCP | TCP | 5910 | 0.0.0.0/0 | NoVNC visualization port |
|
||||
| Custom TCP | TCP | 8006 | 172.31.0.0/16 | VNC service port |
|
||||
| Custom TCP | TCP | 8080 | 172.31.0.0/16 | VLC service port |
|
||||
| Custom TCP | TCP | 8081 | 172.31.0.0/16 | Additional service port |
|
||||
| Custom TCP | TCP | 9222 | 172.31.0.0/16 | Chrome control port |
|
||||
|
||||
#### Outbound Rules (1 rule required)
|
||||
|
||||
| Type | Protocol | Port Range | Destination | Description |
|
||||
|------|----------|------------|-------------|-------------|
|
||||
| All traffic | All | All | 0.0.0.0/0 | Allow all outbound traffic |
|
||||
|
||||
### 3. Custom Images
|
||||
You need to create a custom OSWorld image for Aliyun ECS. Please follow the instructions in the "Creating Custom ECS Images for OSWorld" section.
|
||||
|
||||
## Creating Custom ECS Images for OSWorld
|
||||
|
||||
This section provides guidance on how to create the custom ECS images required for OSWorld desktop environments. The process involves setting up a base instance with desktop environment and VNC server, then creating a custom image from it.
|
||||
|
||||
### Step-by-Step Image Creation Process
|
||||
#### Step 1: Upload existing qcow2 image to Aliyun
|
||||
- Download the provided qcow2 image from the link in `desktop_env/providers/docker/manager.py`: https://huggingface.co/datasets/xlangai/ubuntu_osworld/resolve/main/Ubuntu.qcow2.zip
|
||||
- Unzip the downloaded file and upload it to Aliyun Object Storage Service (OSS). Make sure the OSS is in the same region as your target region to launch ECS instance.
|
||||
- In your ECS dashboard, go to "Images" and You will see the "Import Image" button. Click it and follow the instructions to import the qcow2 image from OSS.
|
||||
- After the import is complete, you will see the imported image in the "Images" list.
|
||||
#### Step 2: Create a new image
|
||||
Note that the image you created in Step 1 will have a different resolution than the one you want to use for OSWorld (1920x1080). We need to customize the image to have the correct resolution and setup noVNC.
|
||||
- Go to `Instances` tab and create a new instance with the imported image.
|
||||
- Connect to the running instance via VNC.
|
||||
- After connecting to the instance, please open the terminal and download this configuration script: `https://gist.githubusercontent.com/qykong/bea58ff98f20057d3a69921276dd4553/raw/cd1a91a0840c4192d793f43cfb90553370343b08/config.sh`.
|
||||
- If you want ssh and vnc password also be setup, use this `https://huggingface.co/datasets/xlangai/ubuntu_osworld/resolve/main/aliyun_config.sh?download=true`.
|
||||
- Run the script and reboot your instance.
|
||||
- After rebooting, the instance will have the correct resolution and noVNC setup. You can connect to the instance via "http://<your_instance_public_ip>:5910/vnc.html" (make sure your security group allows port 5910).
|
||||
- Save the running instance as a new image. The new image will be used as the OSWorld image.
|
||||
|
|
@ -1,82 +0,0 @@
|
|||
# 阿里云ECS提供商配置指南
|
||||
|
||||
本指南介绍如何为OSWorld桌面环境配置和使用阿里云ECS。
|
||||
|
||||
## 配置流程
|
||||
|
||||
1. **阿里云账户**:您需要一个有效的阿里云账户,本脚本默认ECS通过按量付费方式拉起,需保证账户余额在100以上。
|
||||
2. **访问密钥**:在阿里云RAM访问控制控制台中创建AccessKey ID和AccessKey Secret,并授权ECS控制权限
|
||||
3. **VPC设置**:在目标地域创建VPC、交换机和安全组
|
||||
4. **自定义镜像**:创建OSWorld自定义镜像。
|
||||
5. 建议手动完成一次ECS创建流程后,记录所有需要的环境变量信息。
|
||||
|
||||
## 环境变量
|
||||
|
||||
在您的`.env`文件中设置以下环境变量:
|
||||
|
||||
```bash
|
||||
# 阿里云访问凭证
|
||||
ALIYUN_ACCESS_KEY_ID=your_access_key_id
|
||||
ALIYUN_ACCESS_KEY_SECRET=your_access_key_secret
|
||||
|
||||
# ECS配置信息
|
||||
ALIYUN_REGION=eu-central-1
|
||||
ALIYUN_IMAGE_ID=your_image_id
|
||||
ALIYUN_INSTANCE_TYPE=ecs.e-c1m2.large
|
||||
ALIYUN_VSWITCH_ID=vsw-xxxxxxxxx
|
||||
ALIYUN_SECURITY_GROUP_ID=sg-xxxxxxxxx
|
||||
```
|
||||
|
||||
## 所需阿里云资源
|
||||
|
||||
### 1. VPC和交换机
|
||||
- 在目标地域创建VPC
|
||||
- 在VPC内创建交换机
|
||||
- 确保交换机具有互联网访问能力以支持VNC连接
|
||||
|
||||
### 2. 安全组
|
||||
**⚠️ 重要提示**:请严格按照以下端口设置,以防止OSWorld任务因连接问题而失败:
|
||||
|
||||
#### 入方向规则(需要8条规则)
|
||||
|
||||
| 类型 | 协议 | 端口范围 | 源地址 | 描述 |
|
||||
|------|------|----------|--------|------|
|
||||
| SSH | TCP | 22 | 0.0.0.0/0 | SSH访问 |
|
||||
| HTTP | TCP | 80 | 172.31.0.0/16 | HTTP流量 |
|
||||
| 自定义TCP | TCP | 5000 | 172.31.0.0/16 | OSWorld后端服务 |
|
||||
| 自定义TCP | TCP | 5910 | 0.0.0.0/0 | NoVNC可视化端口 |
|
||||
| 自定义TCP | TCP | 8006 | 172.31.0.0/16 | VNC服务端口 |
|
||||
| 自定义TCP | TCP | 8080 | 172.31.0.0/16 | VLC服务端口 |
|
||||
| 自定义TCP | TCP | 8081 | 172.31.0.0/16 | 附加服务端口 |
|
||||
| 自定义TCP | TCP | 9222 | 172.31.0.0/16 | Chrome控制端口 |
|
||||
|
||||
#### 出方向规则(需要1条规则)
|
||||
|
||||
| 类型 | 协议 | 端口范围 | 目标地址 | 描述 |
|
||||
|------|------|----------|----------|------|
|
||||
| 全部流量 | 全部 | 全部 | 0.0.0.0/0 | 允许所有出站流量 |
|
||||
|
||||
### 3. 自定义镜像
|
||||
您需要为阿里云ECS创建自定义OSWorld镜像。请按照"为OSWorld创建自定义ECS镜像"部分的说明进行操作。
|
||||
|
||||
|
||||
## 为OSWorld创建自定义ECS镜像
|
||||
|
||||
本部分提供如何创建OSWorld桌面环境所需的自定义ECS镜像的指导。该过程包括设置带有桌面环境和VNC服务器的基础实例,然后从中创建自定义镜像。
|
||||
|
||||
### 分步镜像创建过程
|
||||
#### 步骤1:上传现有qcow2镜像到阿里云
|
||||
- 从`desktop_env/providers/docker/manager.py`中的链接下载提供的qcow2镜像:https://huggingface.co/datasets/xlangai/ubuntu_osworld/resolve/main/Ubuntu.qcow2.zip
|
||||
- 解压下载的文件并上传到阿里云对象存储服务(OSS)。确保OSS与您要启动ECS实例的目标地域在同一地域。
|
||||
- 在您的ECS控制台中,转到"镜像"页面,您将看到"导入镜像"按钮。点击它并按照说明从OSS导入qcow2镜像。
|
||||
- 导入完成后,您将在"镜像"列表中看到导入的镜像。
|
||||
|
||||
#### 步骤2:创建新镜像
|
||||
请注意,您在步骤1中创建的镜像分辨率与您想要用于OSWorld的分辨率(1920x1080)不同。我们需要自定义镜像以具有正确的分辨率并设置noVNC。
|
||||
- 转到"实例"选项卡,使用导入的镜像创建新实例。
|
||||
- 通过VNC连接到正在运行的实例。
|
||||
- 连接到实例后,请打开终端并下载此配置脚本:`https://gist.githubusercontent.com/qykong/bea58ff98f20057d3a69921276dd4553/raw/cd1a91a0840c4192d793f43cfb90553370343b08/config.sh`。
|
||||
- 如果您还想设置ssh和vnc密码,请使用此脚本 `https://huggingface.co/datasets/xlangai/ubuntu_osworld/resolve/main/aliyun_config.sh?download=true`。
|
||||
- 运行脚本并重启您的实例。
|
||||
- 重启后,实例将具有正确的分辨率和noVNC设置。您可以通过"http://<your_instance_public_ip>:5910/vnc.html"连接到实例(确保您的安全组允许端口5910)。
|
||||
- 将正在运行的实例保存为新镜像。新镜像将用作OSWorld镜像。
|
||||
|
|
@ -1,31 +0,0 @@
|
|||
import os
|
||||
|
||||
|
||||
# Default TTL minutes for instance auto-release (Aliyun-side)
|
||||
# Can be overridden via environment variable DEFAULT_TTL_MINUTES
|
||||
# ATTENTION: ECS requires TTL to be at least 30 minutes (if TTL > 0)
|
||||
MIN_TTL_MINUTES: int = 30
|
||||
|
||||
_ttl_env_str = os.getenv("DEFAULT_TTL_MINUTES", "60")
|
||||
try:
|
||||
_ttl_env_val = int(_ttl_env_str)
|
||||
except Exception:
|
||||
_ttl_env_val = 60
|
||||
|
||||
# If TTL is positive but less than Aliyun minimum, clamp to 30 minutes
|
||||
if _ttl_env_val > 0 and _ttl_env_val < MIN_TTL_MINUTES:
|
||||
DEFAULT_TTL_MINUTES: int = MIN_TTL_MINUTES
|
||||
else:
|
||||
DEFAULT_TTL_MINUTES: int = _ttl_env_val
|
||||
|
||||
# Master switch for TTL feature
|
||||
ENABLE_TTL: bool = os.getenv("ENABLE_TTL", "true").lower() == "true"
|
||||
|
||||
|
||||
def compute_ttl_seconds(ttl_minutes: int) -> int:
|
||||
try:
|
||||
return max(0, int(ttl_minutes) * 60)
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
|
||||
|
|
@ -1,325 +0,0 @@
|
|||
import os
|
||||
import logging
|
||||
import dotenv
|
||||
import time
|
||||
import signal
|
||||
import requests
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from alibabacloud_ecs20140526.client import Client as ECSClient
|
||||
from alibabacloud_tea_openapi import models as open_api_models
|
||||
from alibabacloud_ecs20140526 import models as ecs_models
|
||||
from alibabacloud_tea_util.client import Client as UtilClient
|
||||
from desktop_env.providers.base import VMManager
|
||||
from desktop_env.providers.aliyun.config import ENABLE_TTL, DEFAULT_TTL_MINUTES
|
||||
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
for env_name in [
|
||||
"ALIYUN_REGION",
|
||||
"ALIYUN_VSWITCH_ID",
|
||||
"ALIYUN_SECURITY_GROUP_ID",
|
||||
"ALIYUN_IMAGE_ID",
|
||||
"ALIYUN_ACCESS_KEY_ID",
|
||||
"ALIYUN_ACCESS_KEY_SECRET",
|
||||
"ALIYUN_INSTANCE_TYPE",
|
||||
]:
|
||||
if not os.getenv(env_name):
|
||||
raise EnvironmentError(f"{env_name} must be set in the environment variables.")
|
||||
|
||||
|
||||
logger = logging.getLogger("desktopenv.providers.aliyun.AliyunVMManager")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
ALIYUN_INSTANCE_TYPE = os.getenv("ALIYUN_INSTANCE_TYPE")
|
||||
ALIYUN_ACCESS_KEY_ID = os.getenv("ALIYUN_ACCESS_KEY_ID")
|
||||
ALIYUN_ACCESS_KEY_SECRET = os.getenv("ALIYUN_ACCESS_KEY_SECRET")
|
||||
ALIYUN_REGION = os.getenv("ALIYUN_REGION")
|
||||
ALIYUN_IMAGE_ID = os.getenv("ALIYUN_IMAGE_ID")
|
||||
ALIYUN_SECURITY_GROUP_ID = os.getenv("ALIYUN_SECURITY_GROUP_ID")
|
||||
ALIYUN_VSWITCH_ID = os.getenv("ALIYUN_VSWITCH_ID")
|
||||
ALIYUN_RESOURCE_GROUP_ID = os.getenv("ALIYUN_RESOURCE_GROUP_ID")
|
||||
|
||||
WAIT_DELAY = 20
|
||||
MAX_ATTEMPTS = 15
|
||||
|
||||
|
||||
def _allocate_vm(screen_size=(1920, 1080)):
|
||||
"""
|
||||
Allocate a new Aliyun ECS instance
|
||||
"""
|
||||
assert screen_size == (1920, 1080), "Only 1920x1080 screen size is supported"
|
||||
|
||||
config = open_api_models.Config(
|
||||
access_key_id=ALIYUN_ACCESS_KEY_ID,
|
||||
access_key_secret=ALIYUN_ACCESS_KEY_SECRET,
|
||||
region_id=ALIYUN_REGION,
|
||||
)
|
||||
client = ECSClient(config)
|
||||
instance_id = None
|
||||
original_sigint_handler = signal.getsignal(signal.SIGINT)
|
||||
original_sigterm_handler = signal.getsignal(signal.SIGTERM)
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
if instance_id:
|
||||
signal_name = "SIGINT" if sig == signal.SIGINT else "SIGTERM"
|
||||
logger.warning(
|
||||
f"Received {signal_name} signal, terminating instance {instance_id}..."
|
||||
)
|
||||
try:
|
||||
delete_request = ecs_models.DeleteInstancesRequest(
|
||||
region_id=ALIYUN_REGION,
|
||||
instance_ids=UtilClient.to_jsonstring([instance_id]),
|
||||
force=True,
|
||||
)
|
||||
client.delete_instances(delete_request)
|
||||
logger.info(
|
||||
f"Successfully terminated instance {instance_id} after {signal_name}."
|
||||
)
|
||||
except Exception as cleanup_error:
|
||||
logger.error(
|
||||
f"Failed to terminate instance {instance_id} after {signal_name}: {str(cleanup_error)}"
|
||||
)
|
||||
|
||||
# Restore original signal handlers
|
||||
signal.signal(signal.SIGINT, original_sigint_handler)
|
||||
signal.signal(signal.SIGTERM, original_sigterm_handler)
|
||||
|
||||
# Raise appropriate exception based on signal type
|
||||
if sig == signal.SIGINT:
|
||||
raise KeyboardInterrupt
|
||||
else:
|
||||
# For SIGTERM, exit gracefully
|
||||
import sys
|
||||
|
||||
sys.exit(0)
|
||||
|
||||
try:
|
||||
# Set up signal handlers for both SIGINT and SIGTERM
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
logger.info(
|
||||
f"Creating new ECS instance in region {ALIYUN_REGION} with image {ALIYUN_IMAGE_ID}"
|
||||
)
|
||||
|
||||
# TTL configuration
|
||||
ttl_enabled = ENABLE_TTL
|
||||
ttl_minutes = DEFAULT_TTL_MINUTES
|
||||
ttl_seconds = max(0, int(ttl_minutes) * 60)
|
||||
|
||||
# Aliyun constraints: at least 30 minutes in the future, ISO8601 UTC, seconds must be 00
|
||||
now_utc = datetime.now(timezone.utc)
|
||||
min_eta = now_utc + timedelta(minutes=30)
|
||||
raw_eta = now_utc + timedelta(seconds=ttl_seconds)
|
||||
effective_eta = raw_eta if raw_eta > min_eta else min_eta
|
||||
# round up to the next full minute, zero seconds
|
||||
effective_eta = (effective_eta + timedelta(seconds=59)).replace(second=0, microsecond=0)
|
||||
auto_release_str = effective_eta.strftime('%Y-%m-%dT%H:%M:%SZ')
|
||||
logger.info(
|
||||
f"TTL config: enabled={ttl_enabled}, minutes={ttl_minutes}, seconds={ttl_seconds}, ETA(UTC)={auto_release_str}"
|
||||
)
|
||||
|
||||
# Create instance request (attempt with auto_release_time first when TTL enabled)
|
||||
def _build_request(with_ttl: bool) -> ecs_models.RunInstancesRequest:
|
||||
kwargs = dict(
|
||||
region_id=ALIYUN_REGION,
|
||||
image_id=ALIYUN_IMAGE_ID,
|
||||
instance_type=ALIYUN_INSTANCE_TYPE,
|
||||
security_group_id=ALIYUN_SECURITY_GROUP_ID,
|
||||
v_switch_id=ALIYUN_VSWITCH_ID,
|
||||
instance_name=f"OSWorld-Desktop-{int(time.time())}",
|
||||
description="OSWorld Desktop Environment Instance",
|
||||
internet_max_bandwidth_out=10,
|
||||
internet_charge_type="PayByTraffic",
|
||||
instance_charge_type="PostPaid",
|
||||
system_disk=ecs_models.RunInstancesRequestSystemDisk(
|
||||
size="50",
|
||||
category="cloud_essd",
|
||||
),
|
||||
deletion_protection=False,
|
||||
)
|
||||
|
||||
if ALIYUN_RESOURCE_GROUP_ID:
|
||||
kwargs["resource_group_id"] = ALIYUN_RESOURCE_GROUP_ID
|
||||
|
||||
if with_ttl and ttl_enabled and ttl_seconds > 0:
|
||||
kwargs["auto_release_time"] = auto_release_str
|
||||
return ecs_models.RunInstancesRequest(**kwargs)
|
||||
|
||||
try:
|
||||
request = _build_request(with_ttl=True)
|
||||
response = client.run_instances(request)
|
||||
except Exception as create_err:
|
||||
# Retry without auto_release_time if creation-time TTL is rejected
|
||||
logger.warning(
|
||||
f"RunInstances with auto_release_time failed: {create_err}. Retrying without TTL field..."
|
||||
)
|
||||
request = _build_request(with_ttl=False)
|
||||
response = client.run_instances(request)
|
||||
instance_ids = response.body.instance_id_sets.instance_id_set
|
||||
|
||||
if not instance_ids:
|
||||
raise RuntimeError(
|
||||
"Failed to create ECS instance - no instance ID returned"
|
||||
)
|
||||
|
||||
instance_id = instance_ids[0]
|
||||
logger.info(f"ECS instance {instance_id} created successfully")
|
||||
|
||||
# Wait for the instance to be running
|
||||
logger.info(f"Waiting for instance {instance_id} to be running...")
|
||||
_wait_for_instance_running(client, instance_id)
|
||||
|
||||
logger.info(f"Instance {instance_id} is now running and ready")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.warning("VM allocation interrupted by user (SIGINT).")
|
||||
if instance_id:
|
||||
logger.info(f"Terminating instance {instance_id} due to interruption.")
|
||||
try:
|
||||
delete_request = ecs_models.DeleteInstancesRequest(
|
||||
region_id=ALIYUN_REGION,
|
||||
instance_ids=UtilClient.to_jsonstring([instance_id]),
|
||||
force=True,
|
||||
)
|
||||
client.delete_instances(delete_request)
|
||||
except Exception as cleanup_error:
|
||||
logger.error(
|
||||
f"Failed to cleanup instance {instance_id}: {str(cleanup_error)}"
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to allocate ECS instance: {str(e)}")
|
||||
if instance_id:
|
||||
logger.info(f"Terminating instance {instance_id} due to an error.")
|
||||
try:
|
||||
delete_request = ecs_models.DeleteInstancesRequest(
|
||||
region_id=ALIYUN_REGION,
|
||||
instance_ids=UtilClient.to_jsonstring([instance_id]),
|
||||
force=True,
|
||||
)
|
||||
client.delete_instances(delete_request)
|
||||
except Exception as cleanup_error:
|
||||
logger.error(
|
||||
f"Failed to cleanup instance {instance_id}: {str(cleanup_error)}"
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
# Restore original signal handlers
|
||||
signal.signal(signal.SIGINT, original_sigint_handler)
|
||||
signal.signal(signal.SIGTERM, original_sigterm_handler)
|
||||
|
||||
return instance_id
|
||||
|
||||
|
||||
def _wait_for_instance_running(
|
||||
client: ECSClient, instance_id: str, max_attempts: int = MAX_ATTEMPTS
|
||||
):
|
||||
"""Wait for instance to reach Running state"""
|
||||
for _ in range(max_attempts):
|
||||
try:
|
||||
req = ecs_models.DescribeInstancesRequest(
|
||||
region_id=ALIYUN_REGION,
|
||||
instance_ids=UtilClient.to_jsonstring([instance_id]),
|
||||
)
|
||||
response = client.describe_instances(req)
|
||||
|
||||
if response.body.instances.instance:
|
||||
instance = response.body.instances.instance[0]
|
||||
status = instance.status
|
||||
logger.info(f"Instance {instance_id} status: {status}")
|
||||
|
||||
if status == "Running":
|
||||
return
|
||||
elif status in ["Stopped", "Stopping"]:
|
||||
start_req = ecs_models.StartInstanceRequest(instance_id=instance_id)
|
||||
client.start_instance(start_req)
|
||||
logger.info(f"Started instance {instance_id}")
|
||||
|
||||
time.sleep(WAIT_DELAY)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error checking instance status: {e}")
|
||||
time.sleep(WAIT_DELAY)
|
||||
|
||||
raise TimeoutError(
|
||||
f"Instance {instance_id} did not reach Running state within {max_attempts * WAIT_DELAY} seconds"
|
||||
)
|
||||
|
||||
|
||||
def _wait_until_server_ready(public_ip: str):
|
||||
"""Wait until the server is ready"""
|
||||
for _ in range(MAX_ATTEMPTS):
|
||||
try:
|
||||
logger.info(f"Checking server status on {public_ip}...")
|
||||
response = requests.get(f"http://{public_ip}:5000/", timeout=2)
|
||||
if response.status_code == 404:
|
||||
logger.info(f"Server {public_ip} is ready")
|
||||
return
|
||||
except Exception:
|
||||
time.sleep(WAIT_DELAY)
|
||||
|
||||
raise TimeoutError(
|
||||
f"Server {public_ip} did not respond within {MAX_ATTEMPTS * WAIT_DELAY} seconds"
|
||||
)
|
||||
|
||||
|
||||
class AliyunVMManager(VMManager):
|
||||
"""
|
||||
Aliyun ECS VM Manager for managing virtual machines on Aliyun Cloud.
|
||||
|
||||
Aliyun ECS does not need to maintain a registry of VMs, as it can dynamically allocate and deallocate VMs.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.initialize_registry()
|
||||
|
||||
def initialize_registry(self, **kwargs):
|
||||
pass
|
||||
|
||||
def add_vm(self, vm_path, lock_needed=True, **kwargs):
|
||||
pass
|
||||
|
||||
def _add_vm(self, vm_path):
|
||||
pass
|
||||
|
||||
def delete_vm(self, vm_path, lock_needed=True, **kwargs):
|
||||
pass
|
||||
|
||||
def _delete_vm(self, vm_path):
|
||||
pass
|
||||
|
||||
def occupy_vm(self, vm_path, pid, lock_needed=True, **kwargs):
|
||||
pass
|
||||
|
||||
def _occupy_vm(self, vm_path, pid):
|
||||
pass
|
||||
|
||||
def check_and_clean(self, lock_needed=True, **kwargs):
|
||||
pass
|
||||
|
||||
def _check_and_clean(self):
|
||||
pass
|
||||
|
||||
def list_free_vms(self, lock_needed=True, **kwargs):
|
||||
pass
|
||||
|
||||
def _list_free_vms(self):
|
||||
pass
|
||||
|
||||
def get_vm_path(self, screen_size=(1920, 1080), **kwargs):
|
||||
"""Get a VM path (instance ID) for use"""
|
||||
logger.info(
|
||||
f"Allocating new ECS instance in region {ALIYUN_REGION} with screen size {screen_size}"
|
||||
)
|
||||
|
||||
try:
|
||||
instance_id = _allocate_vm(screen_size)
|
||||
logger.info(f"Successfully allocated instance {instance_id}")
|
||||
return instance_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to allocate instance: {str(e)}")
|
||||
raise
|
||||
|
|
@ -1,224 +0,0 @@
|
|||
import os
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from alibabacloud_ecs20140526.client import Client as ECSClient
|
||||
from alibabacloud_tea_openapi import models as open_api_models
|
||||
from alibabacloud_ecs20140526 import models as ecs_models
|
||||
from alibabacloud_tea_util.client import Client as UtilClient
|
||||
|
||||
from desktop_env.providers.base import Provider
|
||||
from desktop_env.providers.aliyun.manager import (
|
||||
_allocate_vm,
|
||||
_wait_for_instance_running,
|
||||
_wait_until_server_ready,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger("desktopenv.providers.aliyun.AliyunProvider")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class AliyunProvider(Provider):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.region = os.getenv("ALIYUN_REGION", "eu-central-1")
|
||||
self.client = self._create_client()
|
||||
# Whether to use private IP instead of public IP. Default: enabled.
|
||||
# Priority: explicit kwarg > env var ALIYUN_USE_PRIVATE_IP > default True
|
||||
env_use_private = os.getenv("ALIYUN_USE_PRIVATE_IP", "1").lower() in {"1", "true", "yes", "on"}
|
||||
kw_flag = kwargs.get("use_private_ip", None)
|
||||
self.use_private_ip = env_use_private if kw_flag is None else bool(kw_flag)
|
||||
|
||||
def _create_client(self) -> ECSClient:
|
||||
config = open_api_models.Config(
|
||||
access_key_id=os.getenv("ALIYUN_ACCESS_KEY_ID"),
|
||||
access_key_secret=os.getenv("ALIYUN_ACCESS_KEY_SECRET"),
|
||||
region_id=self.region,
|
||||
)
|
||||
return ECSClient(config)
|
||||
|
||||
def start_emulator(self, path_to_vm: str, headless: bool, *args, **kwargs):
|
||||
logger.info("Starting Aliyun ECS instance...")
|
||||
|
||||
try:
|
||||
# Check the current state of the instance
|
||||
response = self._describe_instance(path_to_vm)
|
||||
if not response.body.instances.instance:
|
||||
logger.error(f"Instance {path_to_vm} not found")
|
||||
return
|
||||
|
||||
instance = response.body.instances.instance[0]
|
||||
state = instance.status
|
||||
logger.info(f"Instance {path_to_vm} current state: {state}")
|
||||
|
||||
if state == "Running":
|
||||
# If the instance is already running, skip starting it
|
||||
logger.info(
|
||||
f"Instance {path_to_vm} is already running. Skipping start."
|
||||
)
|
||||
return
|
||||
|
||||
if state == "Stopped":
|
||||
# Start the instance if it's currently stopped
|
||||
req = ecs_models.StartInstanceRequest(instance_id=path_to_vm)
|
||||
self.client.start_instance(req)
|
||||
logger.info(f"Instance {path_to_vm} is starting...")
|
||||
|
||||
# Wait until the instance reaches 'Running' state
|
||||
_wait_for_instance_running(self.client, path_to_vm)
|
||||
logger.info(f"Instance {path_to_vm} is now running.")
|
||||
else:
|
||||
# For all other states (Pending, Starting, etc.), log a warning
|
||||
logger.warning(
|
||||
f"Instance {path_to_vm} is in state '{state}' and cannot be started."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to start the Aliyun ECS instance {path_to_vm}: {str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
def get_ip_address(self, path_to_vm: str) -> str:
|
||||
logger.info("Getting Aliyun ECS instance IP address...")
|
||||
|
||||
try:
|
||||
response = self._describe_instance(path_to_vm)
|
||||
if not response.body.instances.instance:
|
||||
logger.error(f"Instance {path_to_vm} not found")
|
||||
return ""
|
||||
|
||||
instance = response.body.instances.instance[0]
|
||||
|
||||
# Get private and public IP addresses
|
||||
private_ip = ""
|
||||
public_ip = ""
|
||||
|
||||
if hasattr(instance, "vpc_attributes") and instance.vpc_attributes:
|
||||
private_ip = (
|
||||
instance.vpc_attributes.private_ip_address.ip_address[0]
|
||||
if instance.vpc_attributes.private_ip_address.ip_address
|
||||
else ""
|
||||
)
|
||||
|
||||
if hasattr(instance, "public_ip_address") and instance.public_ip_address:
|
||||
public_ip = (
|
||||
instance.public_ip_address.ip_address[0]
|
||||
if instance.public_ip_address.ip_address
|
||||
else ""
|
||||
)
|
||||
|
||||
if hasattr(instance, "eip_address") and instance.eip_address:
|
||||
public_ip = instance.eip_address.ip_address or public_ip
|
||||
|
||||
# Select which IP to use based on configuration
|
||||
ip_to_use = private_ip if (self.use_private_ip and private_ip) else public_ip
|
||||
|
||||
if not ip_to_use:
|
||||
logger.warning("No usable IP address available (private/public both missing)")
|
||||
return ""
|
||||
|
||||
_wait_until_server_ready(ip_to_use)
|
||||
if public_ip:
|
||||
vnc_url = f"http://{public_ip}:5910/vnc.html"
|
||||
logger.info(f"🖥️ VNC Web Access URL: {vnc_url}")
|
||||
logger.info("=" * 80)
|
||||
logger.info(f"📡 Public IP: {public_ip}")
|
||||
logger.info(f"🏠 Private IP: {private_ip}")
|
||||
logger.info(f"🔧 Using IP: {'Private' if ip_to_use == private_ip else 'Public'} -> {ip_to_use}")
|
||||
logger.info("=" * 80)
|
||||
print(f"\n🌐 VNC Web Access URL: {vnc_url}")
|
||||
print(
|
||||
"📍 Please open the above address in the browser "
|
||||
"for remote desktop access\n"
|
||||
)
|
||||
|
||||
return ip_to_use
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to retrieve IP address for the instance {path_to_vm}: {str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
def save_state(self, path_to_vm: str, snapshot_name: str):
|
||||
logger.info("Saving Aliyun ECS instance state...")
|
||||
|
||||
try:
|
||||
req = ecs_models.CreateImageRequest(
|
||||
region_id=self.region,
|
||||
instance_id=path_to_vm,
|
||||
image_name=snapshot_name,
|
||||
description=f"Snapshot created at {datetime.now().isoformat()}",
|
||||
)
|
||||
response = self.client.create_image(req)
|
||||
image_id = response.body.image_id
|
||||
logger.info(
|
||||
f"Image {image_id} created successfully from instance {path_to_vm}."
|
||||
)
|
||||
return image_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to create image from the instance {path_to_vm}: {str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
def revert_to_snapshot(self, path_to_vm: str, snapshot_name: str):
|
||||
logger.info(
|
||||
f"Reverting Aliyun ECS instance to snapshot image: {snapshot_name}..."
|
||||
)
|
||||
|
||||
try:
|
||||
# Step 1: Retrieve the original instance details
|
||||
response = self._describe_instance(path_to_vm)
|
||||
if not response.body.instances.instance:
|
||||
logger.error(f"Instance {path_to_vm} not found")
|
||||
return
|
||||
# Step 2: Delete the old instance
|
||||
req = ecs_models.DeleteInstancesRequest(
|
||||
region_id=self.region, instance_id=[path_to_vm], force=True
|
||||
)
|
||||
self.client.delete_instances(req)
|
||||
logger.info(f"Old instance {path_to_vm} has been deleted.")
|
||||
|
||||
# Step 3: Launch a new instance from the snapshot image
|
||||
new_instance_id = _allocate_vm()
|
||||
logger.info(f"Instance {new_instance_id} is ready.")
|
||||
|
||||
# Get VNC access information
|
||||
self.get_ip_address(new_instance_id)
|
||||
|
||||
return new_instance_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to revert to snapshot {snapshot_name} for the instance {path_to_vm}: {str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
def stop_emulator(self, path_to_vm: str, region: str = None):
|
||||
logger.info(f"Stopping Aliyun ECS instance {path_to_vm}...")
|
||||
|
||||
try:
|
||||
req = ecs_models.DeleteInstancesRequest(
|
||||
region_id=self.region, instance_id=[path_to_vm], force=True
|
||||
)
|
||||
self.client.delete_instances(req)
|
||||
logger.info(f"Instance {path_to_vm} has been deleted.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to stop the Aliyun ECS instance {path_to_vm}: {str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
def _describe_instance(
|
||||
self, instance_id: str
|
||||
) -> ecs_models.DescribeInstancesResponse:
|
||||
"""Get instance details"""
|
||||
req = ecs_models.DescribeInstancesRequest(
|
||||
region_id=self.region, instance_ids=UtilClient.to_jsonstring([instance_id])
|
||||
)
|
||||
return self.client.describe_instances(req)
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
import os
|
||||
|
||||
|
||||
# Default TTL minutes for instance auto-termination (cloud-side scheduler)
|
||||
# Can be overridden via environment variable DEFAULT_TTL_MINUTES
|
||||
DEFAULT_TTL_MINUTES: int = int(os.getenv("DEFAULT_TTL_MINUTES", "180"))
|
||||
|
||||
# Master switch for TTL feature
|
||||
ENABLE_TTL: bool = os.getenv("ENABLE_TTL", "true").lower() == "true"
|
||||
|
||||
# EventBridge Scheduler role ARN for scheduling EC2 termination
|
||||
AWS_SCHEDULER_ROLE_ARN: str = os.getenv("AWS_SCHEDULER_ROLE_ARN", "").strip()
|
||||
|
||||
|
||||
def compute_ttl_seconds(ttl_minutes: int) -> int:
|
||||
try:
|
||||
return max(0, int(ttl_minutes) * 60)
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
|
||||
|
|
@ -1,16 +1,12 @@
|
|||
import os
|
||||
from filelock import FileLock
|
||||
import boto3
|
||||
import logging
|
||||
import dotenv
|
||||
import signal
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
# TTL configuration
|
||||
from desktop_env.providers.aws.config import ENABLE_TTL, DEFAULT_TTL_MINUTES, AWS_SCHEDULER_ROLE_ARN
|
||||
from desktop_env.providers.aws.scheduler_utils import schedule_instance_termination
|
||||
|
||||
|
||||
INSTANCE_TYPE = "t3.xlarge"
|
||||
INSTANCE_TYPE = "t3.medium"
|
||||
|
||||
# Load environment variables from .env file
|
||||
dotenv.load_dotenv()
|
||||
|
|
@ -40,13 +36,10 @@ DEFAULT_REGION = "us-east-1"
|
|||
# todo: public the AMI images
|
||||
IMAGE_ID_MAP = {
|
||||
"us-east-1": {
|
||||
(1920, 1080): "ami-0d23263edb96951d8",
|
||||
# For CoACT-1, uncomment to use the following AMI
|
||||
# (1920, 1080): "ami-0b505e9d0d99ba88c"
|
||||
(1920, 1080): "ami-0d23263edb96951d8"
|
||||
},
|
||||
"ap-east-1": {
|
||||
(1920, 1080): "ami-06850864d18fad836"
|
||||
# Please transfer AMI by yourself from AWS us-east-1 for CoACT-1
|
||||
(1920, 1080): "ami-0c092a5b8be4116f5"
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -96,20 +89,12 @@ def _allocate_vm(region=DEFAULT_REGION, screen_size=(1920, 1080)):
|
|||
if not os.getenv('AWS_SUBNET_ID'):
|
||||
raise ValueError("AWS_SUBNET_ID is not set in the environment variables.")
|
||||
|
||||
# TTL configuration (cloud-init removed; use cloud-side scheduler only)
|
||||
ttl_enabled = ENABLE_TTL
|
||||
ttl_minutes = DEFAULT_TTL_MINUTES
|
||||
ttl_seconds = max(0, int(ttl_minutes) * 60)
|
||||
eta_utc = datetime.now(timezone.utc) + timedelta(seconds=ttl_seconds)
|
||||
logger.info(f"TTL config: minutes={ttl_minutes}, seconds={ttl_seconds}, ETA(UTC)={eta_utc.isoformat()}")
|
||||
|
||||
run_instances_params = {
|
||||
"MaxCount": 1,
|
||||
"MinCount": 1,
|
||||
"ImageId": ami_id,
|
||||
"InstanceType": INSTANCE_TYPE,
|
||||
"EbsOptimized": True,
|
||||
"InstanceInitiatedShutdownBehavior": "terminate",
|
||||
"NetworkInterfaces": [
|
||||
{
|
||||
"SubnetId": os.getenv('AWS_SUBNET_ID'),
|
||||
|
|
@ -136,20 +121,12 @@ def _allocate_vm(region=DEFAULT_REGION, screen_size=(1920, 1080)):
|
|||
|
||||
response = ec2_client.run_instances(**run_instances_params)
|
||||
instance_id = response['Instances'][0]['InstanceId']
|
||||
|
||||
# Create TTL schedule immediately after instance is created, to survive early interruptions
|
||||
try:
|
||||
# Always attempt; helper resolves ARN via env or role name
|
||||
if ttl_enabled:
|
||||
schedule_instance_termination(region, instance_id, ttl_seconds, AWS_SCHEDULER_ROLE_ARN, logger)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create EventBridge Scheduler for {instance_id}: {e}")
|
||||
|
||||
|
||||
waiter = ec2_client.get_waiter('instance_running')
|
||||
logger.info(f"Waiting for instance {instance_id} to be running...")
|
||||
waiter.wait(InstanceIds=[instance_id])
|
||||
logger.info(f"Instance {instance_id} is ready.")
|
||||
|
||||
|
||||
try:
|
||||
instance_details = ec2_client.describe_instances(InstanceIds=[instance_id])
|
||||
instance = instance_details['Reservations'][0]['Instances'][0]
|
||||
|
|
|
|||
|
|
@ -2,14 +2,11 @@ import boto3
|
|||
from botocore.exceptions import ClientError
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from desktop_env.providers.base import Provider
|
||||
|
||||
# TTL configuration
|
||||
from desktop_env.providers.aws.config import ENABLE_TTL, DEFAULT_TTL_MINUTES, AWS_SCHEDULER_ROLE_ARN
|
||||
from desktop_env.providers.aws.scheduler_utils import schedule_instance_termination
|
||||
from desktop_env.providers.base import Provider
|
||||
from datetime import datetime
|
||||
import time
|
||||
|
||||
|
||||
logger = logging.getLogger("desktopenv.providers.aws.AWSProvider")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
|
@ -81,7 +78,6 @@ class AWSProvider(Provider):
|
|||
logger.warning("No public IP address available for VNC access")
|
||||
|
||||
return private_ip_address
|
||||
# return public_ip_address
|
||||
return '' # Return an empty string if no IP address is found
|
||||
except ClientError as e:
|
||||
logger.error(f"Failed to retrieve IP address for the instance {path_to_vm}: {str(e)}")
|
||||
|
|
@ -108,68 +104,23 @@ class AWSProvider(Provider):
|
|||
# Step 1: Retrieve the original instance details
|
||||
instance_details = ec2_client.describe_instances(InstanceIds=[path_to_vm])
|
||||
instance = instance_details['Reservations'][0]['Instances'][0]
|
||||
# Resolve security groups with fallbacks
|
||||
security_groups = [sg['GroupId'] for sg in instance.get('SecurityGroups', []) if 'GroupId' in sg]
|
||||
if not security_groups:
|
||||
env_sg = os.getenv('AWS_SECURITY_GROUP_ID')
|
||||
if env_sg:
|
||||
security_groups = [env_sg]
|
||||
logger.info("SecurityGroups missing on instance; using AWS_SECURITY_GROUP_ID from env")
|
||||
else:
|
||||
raise ValueError("No security groups found on instance and AWS_SECURITY_GROUP_ID not set")
|
||||
|
||||
# Resolve subnet with fallbacks
|
||||
subnet_id = instance.get('SubnetId')
|
||||
if not subnet_id:
|
||||
nis = instance.get('NetworkInterfaces', []) or []
|
||||
if nis and isinstance(nis, list):
|
||||
for ni in nis:
|
||||
if isinstance(ni, dict) and ni.get('SubnetId'):
|
||||
subnet_id = ni.get('SubnetId')
|
||||
break
|
||||
if not subnet_id:
|
||||
env_subnet = os.getenv('AWS_SUBNET_ID')
|
||||
if env_subnet:
|
||||
subnet_id = env_subnet
|
||||
logger.info("SubnetId missing on instance; using AWS_SUBNET_ID from env")
|
||||
else:
|
||||
raise ValueError("SubnetId not available on instance, NetworkInterfaces, or environment")
|
||||
|
||||
# Resolve instance type with fallbacks
|
||||
instance_type = instance.get('InstanceType') or os.getenv('AWS_INSTANCE_TYPE') or 't3.large'
|
||||
if instance.get('InstanceType') is None:
|
||||
logger.info(f"InstanceType missing on instance; using '{instance_type}' from env/default")
|
||||
security_groups = [sg['GroupId'] for sg in instance['SecurityGroups']]
|
||||
subnet_id = instance['SubnetId']
|
||||
instance_type = instance['InstanceType']
|
||||
|
||||
# Step 2: Terminate the old instance (skip if already terminated/shutting-down)
|
||||
state = (instance.get('State') or {}).get('Name')
|
||||
if state in ['shutting-down', 'terminated']:
|
||||
logger.info(f"Old instance {path_to_vm} is already in state '{state}', skipping termination.")
|
||||
else:
|
||||
try:
|
||||
ec2_client.terminate_instances(InstanceIds=[path_to_vm])
|
||||
logger.info(f"Old instance {path_to_vm} has been terminated.")
|
||||
except ClientError as e:
|
||||
error_code = getattr(getattr(e, 'response', {}), 'get', lambda *_: None)('Error', {}).get('Code') if hasattr(e, 'response') else None
|
||||
if error_code in ['InvalidInstanceID.NotFound', 'IncorrectInstanceState']:
|
||||
logger.info(f"Ignore termination error for {path_to_vm}: {error_code}")
|
||||
else:
|
||||
raise
|
||||
# Step 2: Terminate the old instance
|
||||
ec2_client.terminate_instances(InstanceIds=[path_to_vm])
|
||||
logger.info(f"Old instance {path_to_vm} has been terminated.")
|
||||
|
||||
# Step 3: Launch a new instance from the snapshot(AMI) with performance optimization
|
||||
logger.info(f"Launching a new instance from AMI {snapshot_name}...")
|
||||
|
||||
# TTL configuration follows the same env flags as allocation (centralized)
|
||||
enable_ttl = ENABLE_TTL
|
||||
default_ttl_minutes = DEFAULT_TTL_MINUTES
|
||||
ttl_seconds = max(0, default_ttl_minutes * 60)
|
||||
|
||||
run_instances_params = {
|
||||
"MaxCount": 1,
|
||||
"MinCount": 1,
|
||||
"ImageId": snapshot_name,
|
||||
"InstanceType": instance_type,
|
||||
"EbsOptimized": True,
|
||||
"InstanceInitiatedShutdownBehavior": "terminate",
|
||||
"NetworkInterfaces": [
|
||||
{
|
||||
"SubnetId": subnet_id,
|
||||
|
|
@ -199,40 +150,7 @@ class AWSProvider(Provider):
|
|||
ec2_client.get_waiter('instance_running').wait(InstanceIds=[new_instance_id])
|
||||
|
||||
logger.info(f"Instance {new_instance_id} is ready.")
|
||||
# Schedule cloud-side termination via EventBridge Scheduler (auto-resolve role ARN)
|
||||
try:
|
||||
if enable_ttl:
|
||||
schedule_instance_termination(self.region, new_instance_id, ttl_seconds, AWS_SCHEDULER_ROLE_ARN, logger)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create EventBridge Scheduler for {new_instance_id}: {e}")
|
||||
|
||||
# Schedule cloud-side termination via EventBridge Scheduler (same as allocation path)
|
||||
try:
|
||||
if enable_ttl and os.getenv('AWS_SCHEDULER_ROLE_ARN'):
|
||||
scheduler_client = boto3.client('scheduler', region_name=self.region)
|
||||
schedule_name = f"osworld-ttl-{new_instance_id}-{int(time.time())}"
|
||||
eta_scheduler = datetime.now(timezone.utc) + timedelta(seconds=ttl_seconds)
|
||||
schedule_expression = f"at({eta_scheduler.strftime('%Y-%m-%dT%H:%M:%S')})"
|
||||
target_arn = "arn:aws:scheduler:::aws-sdk:ec2:terminateInstances"
|
||||
input_payload = '{"InstanceIds":["' + new_instance_id + '"]}'
|
||||
scheduler_client.create_schedule(
|
||||
Name=schedule_name,
|
||||
ScheduleExpression=schedule_expression,
|
||||
FlexibleTimeWindow={"Mode": "OFF"},
|
||||
Target={
|
||||
"Arn": target_arn,
|
||||
"RoleArn": os.getenv('AWS_SCHEDULER_ROLE_ARN'),
|
||||
"Input": input_payload
|
||||
},
|
||||
State='ENABLED',
|
||||
Description=f"OSWorld TTL terminate for {new_instance_id}"
|
||||
)
|
||||
logger.info(f"Scheduled EC2 termination via EventBridge Scheduler for snapshot revert: name={schedule_name}, when={eta_scheduler.isoformat()} (UTC)")
|
||||
else:
|
||||
logger.info("TTL enabled but AWS_SCHEDULER_ROLE_ARN not set; skipping scheduler for snapshot revert.")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create EventBridge Scheduler for {new_instance_id}: {e}")
|
||||
|
||||
|
||||
try:
|
||||
instance_details = ec2_client.describe_instances(InstanceIds=[new_instance_id])
|
||||
instance = instance_details['Reservations'][0]['Instances'][0]
|
||||
|
|
|
|||
|
|
@ -1,153 +0,0 @@
|
|||
import os
|
||||
import time
|
||||
import json
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
|
||||
def _resolve_scheduler_role_arn(logger) -> str:
|
||||
# 1) Explicit env takes precedence
|
||||
role_arn = os.getenv('AWS_SCHEDULER_ROLE_ARN', '').strip()
|
||||
if role_arn:
|
||||
return role_arn
|
||||
|
||||
# 2) Derive from role name + account id
|
||||
role_name = os.getenv('AWS_SCHEDULER_ROLE_NAME', 'osworld-scheduler-ec2-terminate').strip()
|
||||
try:
|
||||
sts = boto3.client('sts')
|
||||
account_id = sts.get_caller_identity()['Account']
|
||||
derived_arn = f"arn:aws:iam::{account_id}:role/{role_name}"
|
||||
iam = boto3.client('iam')
|
||||
try:
|
||||
role = iam.get_role(RoleName=role_name)["Role"]
|
||||
except ClientError:
|
||||
auto_create = os.getenv('AWS_AUTO_CREATE_SCHEDULER_ROLE', 'true').lower() == 'true'
|
||||
if not auto_create:
|
||||
logger.warning(f"Scheduler role '{role_name}' not found and auto-create disabled.")
|
||||
return ''
|
||||
try:
|
||||
trust_policy = {
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Principal": {"Service": "scheduler.amazonaws.com"},
|
||||
"Action": "sts:AssumeRole"
|
||||
}
|
||||
]
|
||||
}
|
||||
iam.create_role(RoleName=role_name, AssumeRolePolicyDocument=json.dumps(trust_policy))
|
||||
role = iam.get_role(RoleName=role_name)["Role"]
|
||||
except ClientError as ce:
|
||||
# If another process created it, fetch again
|
||||
try:
|
||||
role = iam.get_role(RoleName=role_name)["Role"]
|
||||
except ClientError:
|
||||
logger.warning(f"Failed to auto-create scheduler role '{role_name}': {ce}")
|
||||
return ''
|
||||
|
||||
# Ensure trust policy allows scheduler.amazonaws.com
|
||||
assume_doc = role.get("AssumeRolePolicyDocument", {})
|
||||
principal_ok = False
|
||||
try:
|
||||
for stmt in assume_doc.get("Statement", []):
|
||||
principal = stmt.get("Principal", {})
|
||||
svc = principal.get("Service")
|
||||
if isinstance(svc, str) and svc == "scheduler.amazonaws.com":
|
||||
principal_ok = True
|
||||
break
|
||||
if isinstance(svc, list) and "scheduler.amazonaws.com" in svc:
|
||||
principal_ok = True
|
||||
break
|
||||
except Exception:
|
||||
principal_ok = False
|
||||
if not principal_ok:
|
||||
trust_policy = {
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Principal": {"Service": "scheduler.amazonaws.com"},
|
||||
"Action": "sts:AssumeRole"
|
||||
}
|
||||
]
|
||||
}
|
||||
iam.update_assume_role_policy(RoleName=role_name, PolicyDocument=json.dumps(trust_policy))
|
||||
|
||||
# Ensure minimal inline policy exists
|
||||
inline_name = f"{role_name}-inline"
|
||||
need_policy = False
|
||||
try:
|
||||
iam.get_role_policy(RoleName=role_name, PolicyName=inline_name)
|
||||
except ClientError:
|
||||
need_policy = True
|
||||
if need_policy:
|
||||
inline_policy = {
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Action": ["ec2:TerminateInstances", "ec2:DescribeInstances"],
|
||||
"Resource": "*"
|
||||
}
|
||||
]
|
||||
}
|
||||
iam.put_role_policy(RoleName=role_name, PolicyName=inline_name, PolicyDocument=json.dumps(inline_policy))
|
||||
|
||||
# Wait for IAM propagation
|
||||
time.sleep(8)
|
||||
logger.info(f"Derived AWS_SCHEDULER_ROLE_ARN={derived_arn} from role name '{role_name}'")
|
||||
return derived_arn
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to resolve Scheduler Role ARN: {e}")
|
||||
return ''
|
||||
|
||||
|
||||
def schedule_instance_termination(region: str, instance_id: str, ttl_seconds: int, role_arn: str, logger) -> None:
|
||||
if not role_arn:
|
||||
role_arn = _resolve_scheduler_role_arn(logger)
|
||||
if not role_arn:
|
||||
logger.info("Scheduler role ARN not available; skipping TTL schedule creation.")
|
||||
return
|
||||
scheduler_client = boto3.client('scheduler', region_name=region)
|
||||
schedule_name = f"osworld-ttl-{instance_id}-{int(time.time())}"
|
||||
eta_scheduler = datetime.now(timezone.utc) + timedelta(seconds=ttl_seconds)
|
||||
schedule_expression = f"at({eta_scheduler.strftime('%Y-%m-%dT%H:%M:%S')})"
|
||||
target_arn = "arn:aws:scheduler:::aws-sdk:ec2:terminateInstances"
|
||||
input_payload = '{"InstanceIds":["' + instance_id + '"]}'
|
||||
|
||||
# Retry to tolerate IAM eventual consistency
|
||||
last_err = None
|
||||
for attempt in range(1, 7): # ~ up to ~60s
|
||||
try:
|
||||
scheduler_client.create_schedule(
|
||||
Name=schedule_name,
|
||||
ScheduleExpression=schedule_expression,
|
||||
FlexibleTimeWindow={"Mode": "OFF"},
|
||||
ActionAfterCompletion='DELETE',
|
||||
Target={
|
||||
"Arn": target_arn,
|
||||
"RoleArn": role_arn,
|
||||
"Input": input_payload
|
||||
},
|
||||
State='ENABLED',
|
||||
Description=f"OSWorld TTL terminate for {instance_id}"
|
||||
)
|
||||
logger.info(f"Scheduled EC2 termination via EventBridge Scheduler: name={schedule_name}, when={eta_scheduler.isoformat()} (UTC)")
|
||||
last_err = None
|
||||
break
|
||||
except ClientError as e:
|
||||
last_err = e
|
||||
code = e.response.get('Error', {}).get('Code')
|
||||
msg = e.response.get('Error', {}).get('Message', '')
|
||||
if code == 'ValidationException' and 'must allow AWS EventBridge Scheduler to assume the role' in msg:
|
||||
time.sleep(10)
|
||||
continue
|
||||
else:
|
||||
raise
|
||||
if last_err is not None:
|
||||
# If we exhausted retries, re-raise to surface warning upstream
|
||||
raise last_err
|
||||
|
||||
|
||||
|
|
@ -29,11 +29,13 @@ UBUNTU_X86_URL = "https://huggingface.co/datasets/xlangai/ubuntu_osworld/resolve
|
|||
WINDOWS_X86_URL = "https://huggingface.co/datasets/xlangai/windows_osworld/resolve/main/Windows-x86.zip"
|
||||
|
||||
# Determine the platform and CPU architecture to decide the correct VM image to download
|
||||
# sometimes the system is 'Darwin' but the machine is x86-based.
|
||||
if platform.machine().lower() in ['amd64', 'x86_64']:
|
||||
URL = UBUNTU_X86_URL
|
||||
elif platform.system() == 'Darwin': # macOS
|
||||
if platform.system() == 'Darwin': # macOS
|
||||
# if os.uname().machine == 'arm64': # Apple Silicon
|
||||
URL = UBUNTU_ARM_URL
|
||||
# else:
|
||||
# url = UBUNTU_X86_URL
|
||||
elif platform.machine().lower() in ['amd64', 'x86_64']:
|
||||
URL = UBUNTU_X86_URL
|
||||
else:
|
||||
raise Exception("Unsupported platform or architecture")
|
||||
|
||||
|
|
@ -123,12 +125,12 @@ def _install_vm(vm_name, vms_dir, downloaded_file_name, os_type, original_vm_nam
|
|||
# Download the virtual machine image
|
||||
logger.info("Downloading the virtual machine image...")
|
||||
downloaded_size = 0
|
||||
# sometimes the system is 'Darwin' but the machine is x86-based.
|
||||
|
||||
if os_type == "Ubuntu":
|
||||
if platform.machine().lower() in ['amd64', 'x86_64']:
|
||||
URL = UBUNTU_X86_URL
|
||||
elif platform.system() == 'Darwin':
|
||||
if platform.system() == 'Darwin':
|
||||
URL = UBUNTU_ARM_URL
|
||||
elif platform.machine().lower() in ['amd64', 'x86_64']:
|
||||
URL = UBUNTU_X86_URL
|
||||
elif os_type == "Windows":
|
||||
if platform.machine().lower() in ['amd64', 'x86_64']:
|
||||
URL = WINDOWS_X86_URL
|
||||
|
|
|
|||
|
|
@ -1,72 +0,0 @@
|
|||
# 火山引擎ECS提供商配置指南
|
||||
|
||||
本指南介绍如何为OSWorld桌面环境配置和使用火山引擎ECS。
|
||||
|
||||
## 配置流程
|
||||
|
||||
1. **火山引擎账户**:您需要一个有效的火山引擎账户,本脚本默认ECS通过按量付费方式拉起,需保证账户余额在100以上。
|
||||
2. **访问密钥**:在火山引擎IAM控制台中创建AccessKey ID和SecretAccessKey,并授权ECS控制权限
|
||||
3. **VPC设置**:在目标地域创建VPC、子网和安全组
|
||||
4. **自定义镜像**:创建OSWorld自定义镜像
|
||||
5. 建议手动完成一次ECS创建流程后,记录所有需要的环境变量信息。
|
||||
|
||||
## 环境变量
|
||||
|
||||
在您的`.env`文件中设置以下环境变量:
|
||||
|
||||
```bash
|
||||
# 火山引擎访问凭证
|
||||
VOLCENGINE_ACCESS_KEY_ID=your_access_key_id
|
||||
VOLCENGINE_SECRET_ACCESS_KEY=your_secret_access_key
|
||||
|
||||
# ECS配置信息
|
||||
VOLCENGINE_REGION=ap-southeast-1
|
||||
VOLCENGINE_IMAGE_ID=image-xxxxxxxxx
|
||||
VOLCENGINE_INSTANCE_TYPE=ecs.e-c1m2.large
|
||||
VOLCENGINE_SUBNET_ID=subnet-xxxxxxxxx
|
||||
VOLCENGINE_SECURITY_GROUP_ID=sg-xxxxxxxxx
|
||||
VOLCENGINE_ZONE_ID=zone-xxxxxxxxx
|
||||
VOLCENGINE_DEFAULT_PASSWORD=your_default_password
|
||||
```
|
||||
|
||||
## 所需火山引擎资源
|
||||
|
||||
### 1. VPC和子网
|
||||
- 在目标地域创建VPC
|
||||
- 在VPC内创建子网
|
||||
- 确保子网具有互联网访问能力以支持VNC连接
|
||||
|
||||
### 2. 安全组
|
||||
**⚠️ 重要提示**:请严格按照以下端口设置,以防止OSWorld任务因连接问题而失败:
|
||||
|
||||
#### 入方向规则(需要8条规则)
|
||||
|
||||
| 类型 | 协议 | 端口范围 | 源地址 | 描述 |
|
||||
|------|------|----------|--------|------|
|
||||
| SSH | TCP | 22 | 0.0.0.0/0 | SSH访问 |
|
||||
| HTTP | TCP | 80 | 172.31.0.0/16 | HTTP流量 |
|
||||
| 自定义TCP | TCP | 5000 | 172.31.0.0/16 | OSWorld后端服务 |
|
||||
| 自定义TCP | TCP | 5910 | 0.0.0.0/0 | NoVNC可视化端口 |
|
||||
| 自定义TCP | TCP | 8006 | 172.31.0.0/16 | VNC服务端口 |
|
||||
| 自定义TCP | TCP | 8080 | 172.31.0.0/16 | VLC服务端口 |
|
||||
| 自定义TCP | TCP | 8081 | 172.31.0.0/16 | 附加服务端口 |
|
||||
| 自定义TCP | TCP | 9222 | 172.31.0.0/16 | Chrome控制端口 |
|
||||
|
||||
#### 出方向规则(需要1条规则)
|
||||
|
||||
| 类型 | 协议 | 端口范围 | 目标地址 | 描述 |
|
||||
|------|------|----------|----------|------|
|
||||
| 全部流量 | 全部 | 全部 | 0.0.0.0/0 | 允许所有出站流量 |
|
||||
|
||||
### 3. 自定义镜像
|
||||
您需要为火山引擎ECS创建自定义OSWorld镜像。请按照"为OSWorld创建自定义ECS镜像"部分的说明进行操作。
|
||||
|
||||
## 为OSWorld创建自定义ECS镜像
|
||||
|
||||
本部分提供如何创建OSWorld桌面环境所需的自定义ECS镜像的指导。该过程包括设置带有桌面环境和VNC服务器的基础实例,然后从中创建自定义镜像。
|
||||
|
||||
### 镜像创建过程
|
||||
- 从`desktop_env/providers/docker/manager.py`中的链接下载提供的qcow2镜像:https://huggingface.co/datasets/xlangai/ubuntu_osworld/resolve/main/Ubuntu.qcow2.zip
|
||||
- 解压下载的文件并上传到火山引擎对象存储服务(TOS)。确保TOS与您要启动ECS实例的目标地域在同一地域。
|
||||
- 在您的ECS控制台中,转到"镜像"页面,您将看到"导入镜像"按钮。点击它并按照说明从TOS导入qcow2镜像。
|
||||
- 导入完成后,您将在"镜像"列表中看到导入的镜像。
|
||||
|
|
@ -1,221 +0,0 @@
|
|||
import os
|
||||
import logging
|
||||
import signal
|
||||
import dotenv
|
||||
import time
|
||||
import volcenginesdkcore
|
||||
import volcenginesdkecs.models as ecs_models
|
||||
from volcenginesdkecs.api import ECSApi
|
||||
|
||||
from desktop_env.providers.base import VMManager
|
||||
|
||||
# Load environment variables from .env file
|
||||
dotenv.load_dotenv()
|
||||
|
||||
for env_name in [
|
||||
"VOLCENGINE_ACCESS_KEY_ID",
|
||||
"VOLCENGINE_SECRET_ACCESS_KEY",
|
||||
"VOLCENGINE_REGION",
|
||||
"VOLCENGINE_SUBNET_ID",
|
||||
"VOLCENGINE_SECURITY_GROUP_ID",
|
||||
"VOLCENGINE_INSTANCE_TYPE",
|
||||
"VOLCENGINE_IMAGE_ID",
|
||||
"VOLCENGINE_ZONE_ID",
|
||||
"VOLCENGINE_DEFAULT_PASSWORD",
|
||||
]:
|
||||
if not os.getenv(env_name):
|
||||
raise EnvironmentError(f"{env_name} must be set in the environment variables.")
|
||||
|
||||
logger = logging.getLogger("desktopenv.providers.volcengine.VolcengineVMManager")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
VOLCENGINE_ACCESS_KEY_ID = os.getenv("VOLCENGINE_ACCESS_KEY_ID")
|
||||
VOLCENGINE_SECRET_ACCESS_KEY = os.getenv("VOLCENGINE_SECRET_ACCESS_KEY")
|
||||
VOLCENGINE_REGION = os.getenv("VOLCENGINE_REGION")
|
||||
VOLCENGINE_SUBNET_ID = os.getenv("VOLCENGINE_SUBNET_ID")
|
||||
VOLCENGINE_SECURITY_GROUP_ID = os.getenv("VOLCENGINE_SECURITY_GROUP_ID")
|
||||
VOLCENGINE_INSTANCE_TYPE = os.getenv("VOLCENGINE_INSTANCE_TYPE")
|
||||
VOLCENGINE_IMAGE_ID = os.getenv("VOLCENGINE_IMAGE_ID")
|
||||
VOLCENGINE_ZONE_ID = os.getenv("VOLCENGINE_ZONE_ID")
|
||||
VOLCENGINE_DEFAULT_PASSWORD = os.getenv("VOLCENGINE_DEFAULT_PASSWORD")
|
||||
|
||||
def _allocate_vm(screen_size=(1920, 1080)):
|
||||
"""分配火山引擎虚拟机"""
|
||||
|
||||
# 初始化火山引擎客户端
|
||||
configuration = volcenginesdkcore.Configuration()
|
||||
configuration.region = VOLCENGINE_REGION
|
||||
configuration.ak = VOLCENGINE_ACCESS_KEY_ID
|
||||
configuration.sk = VOLCENGINE_SECRET_ACCESS_KEY
|
||||
configuration.client_side_validation = True
|
||||
# set default configuration
|
||||
volcenginesdkcore.Configuration.set_default(configuration)
|
||||
|
||||
# use global default configuration
|
||||
api_instance = ECSApi()
|
||||
|
||||
instance_id = None
|
||||
original_sigint_handler = signal.getsignal(signal.SIGINT)
|
||||
original_sigterm_handler = signal.getsignal(signal.SIGTERM)
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
if instance_id:
|
||||
signal_name = "SIGINT" if sig == signal.SIGINT else "SIGTERM"
|
||||
logger.warning(f"Received {signal_name} signal, terminating instance {instance_id}...")
|
||||
try:
|
||||
api_instance.delete_instance(ecs_models.DeleteInstanceRequest(
|
||||
instance_id=instance_id,
|
||||
))
|
||||
logger.info(f"Successfully terminated instance {instance_id} after {signal_name}.")
|
||||
except Exception as cleanup_error:
|
||||
logger.error(f"Failed to terminate instance {instance_id} after {signal_name}: {str(cleanup_error)}")
|
||||
|
||||
# Restore original signal handlers
|
||||
signal.signal(signal.SIGINT, original_sigint_handler)
|
||||
signal.signal(signal.SIGTERM, original_sigterm_handler)
|
||||
|
||||
if sig == signal.SIGINT:
|
||||
raise KeyboardInterrupt
|
||||
else:
|
||||
import sys
|
||||
sys.exit(0)
|
||||
|
||||
try:
|
||||
# Set up signal handlers
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
# 创建实例参数
|
||||
create_instance_params = ecs_models.RunInstancesRequest(
|
||||
image_id = VOLCENGINE_IMAGE_ID,
|
||||
instance_type = VOLCENGINE_INSTANCE_TYPE,
|
||||
network_interfaces=[ecs_models.NetworkInterfaceForRunInstancesInput(
|
||||
subnet_id=VOLCENGINE_SUBNET_ID,
|
||||
security_group_ids=[VOLCENGINE_SECURITY_GROUP_ID],
|
||||
)],
|
||||
eip_address=ecs_models.EipAddressForRunInstancesInput(
|
||||
bandwidth_mbps = 5,
|
||||
charge_type = "PayByTraffic",
|
||||
),
|
||||
instance_name = f"osworld-{os.getpid()}-{int(time.time())}",
|
||||
volumes=[ecs_models.VolumeForRunInstancesInput(
|
||||
volume_type="ESSD_PL0",
|
||||
size=30,
|
||||
)],
|
||||
zone_id=VOLCENGINE_ZONE_ID,
|
||||
password = VOLCENGINE_DEFAULT_PASSWORD, # 默认密码
|
||||
description = "OSWorld evaluation instance"
|
||||
)
|
||||
|
||||
# 创建实例
|
||||
response = api_instance.run_instances(create_instance_params)
|
||||
instance_id = response.instance_ids[0]
|
||||
|
||||
logger.info(f"Waiting for instance {instance_id} to be running...")
|
||||
|
||||
# 等待实例运行
|
||||
while True:
|
||||
instance_info = api_instance.describe_instances(ecs_models.DescribeInstancesRequest(
|
||||
instance_ids=[instance_id]
|
||||
))
|
||||
status = instance_info.instances[0].status
|
||||
if status == 'RUNNING':
|
||||
break
|
||||
elif status in ['STOPPED', 'ERROR']:
|
||||
raise Exception(f"Instance {instance_id} failed to start, status: {status}")
|
||||
time.sleep(5)
|
||||
|
||||
logger.info(f"Instance {instance_id} is ready.")
|
||||
|
||||
# 获取实例IP地址
|
||||
try:
|
||||
instance_info = api_instance.describe_instances(ecs_models.DescribeInstancesRequest(
|
||||
instance_ids=[instance_id]
|
||||
))
|
||||
print(instance_info)
|
||||
public_ip = instance_info.instances[0].eip_address.ip_address
|
||||
private_ip = instance_info.instances[0].network_interfaces[0].primary_ip_address
|
||||
|
||||
if public_ip:
|
||||
vnc_url = f"http://{public_ip}:5910/vnc.html"
|
||||
logger.info("="*80)
|
||||
logger.info(f"🖥️ VNC Web Access URL: {vnc_url}")
|
||||
logger.info(f"📡 Public IP: {public_ip}")
|
||||
logger.info(f"🏠 Private IP: {private_ip}")
|
||||
logger.info(f"🆔 Instance ID: {instance_id}")
|
||||
logger.info("="*80)
|
||||
print(f"\n🌐 VNC Web Access URL: {vnc_url}")
|
||||
print(f"📍 Please open the above address in the browser for remote desktop access\n")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get VNC address for instance {instance_id}: {e}")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.warning("VM allocation interrupted by user (SIGINT).")
|
||||
if instance_id:
|
||||
logger.info(f"Terminating instance {instance_id} due to interruption.")
|
||||
api_instance.delete_instance(ecs_models.DeleteInstanceRequest(
|
||||
instance_id=instance_id,
|
||||
))
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to allocate VM: {e}", exc_info=True)
|
||||
if instance_id:
|
||||
logger.info(f"Terminating instance {instance_id} due to an error.")
|
||||
api_instance.delete_instance(ecs_models.DeleteInstanceRequest(
|
||||
instance_id=instance_id,
|
||||
))
|
||||
raise
|
||||
finally:
|
||||
# Restore original signal handlers
|
||||
signal.signal(signal.SIGINT, original_sigint_handler)
|
||||
signal.signal(signal.SIGTERM, original_sigterm_handler)
|
||||
|
||||
return instance_id
|
||||
|
||||
|
||||
class VolcengineVMManager(VMManager):
|
||||
"""
|
||||
Volcengine VM Manager for managing virtual machines on Volcengine.
|
||||
|
||||
Volcengine does not need to maintain a registry of VMs, as it can dynamically allocate and deallocate VMs.
|
||||
"""
|
||||
def __init__(self, **kwargs):
|
||||
self.initialize_registry()
|
||||
|
||||
def initialize_registry(self, **kwargs):
|
||||
pass
|
||||
|
||||
def add_vm(self, vm_path, lock_needed=True, **kwargs):
|
||||
pass
|
||||
|
||||
def _add_vm(self, vm_path):
|
||||
pass
|
||||
|
||||
def delete_vm(self, vm_path, lock_needed=True, **kwargs):
|
||||
pass
|
||||
|
||||
def _delete_vm(self, vm_path):
|
||||
pass
|
||||
|
||||
def occupy_vm(self, vm_path, pid, lock_needed=True, **kwargs):
|
||||
pass
|
||||
|
||||
def _occupy_vm(self, vm_path, pid):
|
||||
pass
|
||||
|
||||
def check_and_clean(self, lock_needed=True, **kwargs):
|
||||
pass
|
||||
|
||||
def _check_and_clean(self):
|
||||
pass
|
||||
|
||||
def list_free_vms(self, lock_needed=True, **kwargs):
|
||||
pass
|
||||
|
||||
def _list_free_vms(self):
|
||||
pass
|
||||
|
||||
def get_vm_path(self, screen_size=(1920, 1080), **kwargs):
|
||||
logger.info("Allocating a new VM in region: {region}".format(region=VOLCENGINE_REGION))
|
||||
new_vm_path = _allocate_vm(screen_size=screen_size)
|
||||
return new_vm_path
|
||||
|
|
@ -1,188 +0,0 @@
|
|||
import os
|
||||
import time
|
||||
import logging
|
||||
import volcenginesdkcore
|
||||
import volcenginesdkautoscaling
|
||||
import volcenginesdkecs.models as ecs_models
|
||||
from volcenginesdkcore.rest import ApiException
|
||||
from volcenginesdkecs.api import ECSApi
|
||||
|
||||
from desktop_env.providers.base import Provider
|
||||
from desktop_env.providers.volcengine.manager import _allocate_vm
|
||||
|
||||
logger = logging.getLogger("desktopenv.providers.volcengine.VolcengineProvider")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
WAIT_DELAY = 15
|
||||
MAX_ATTEMPTS = 10
|
||||
|
||||
|
||||
class VolcengineProvider(Provider):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.region = os.getenv("VOLCENGINE_REGION", "eu-central-1")
|
||||
self.client = self._create_client()
|
||||
|
||||
def _create_client(self) -> ECSApi:
|
||||
configuration = volcenginesdkcore.Configuration()
|
||||
configuration.ak = os.getenv('VOLCENGINE_ACCESS_KEY_ID')
|
||||
configuration.sk = os.getenv('VOLCENGINE_SECRET_ACCESS_KEY')
|
||||
configuration.region = os.getenv('VOLCENGINE_REGION')
|
||||
configuration.client_side_validation = True
|
||||
# set default configuration
|
||||
volcenginesdkcore.Configuration.set_default(configuration)
|
||||
return ECSApi()
|
||||
|
||||
def start_emulator(self, path_to_vm: str, headless: bool, *args, **kwargs):
|
||||
logger.info("Starting Volcengine VM...")
|
||||
|
||||
try:
|
||||
# 检查实例状态
|
||||
instance_info = self.client.describe_instances(ecs_models.DescribeInstancesRequest(
|
||||
instance_ids=[path_to_vm]
|
||||
))
|
||||
status = instance_info.instances[0].status
|
||||
logger.info(f"Instance {path_to_vm} current status: {status}")
|
||||
|
||||
if status == 'RUNNING':
|
||||
logger.info(f"Instance {path_to_vm} is already running. Skipping start.")
|
||||
return
|
||||
|
||||
if status == 'STOPPED':
|
||||
# 启动实例
|
||||
self.client.start_instance(ecs_models.StartInstancesRequest(instance_ids=[path_to_vm]))
|
||||
logger.info(f"Instance {path_to_vm} is starting...")
|
||||
|
||||
# 等待实例运行
|
||||
for attempt in range(MAX_ATTEMPTS):
|
||||
time.sleep(WAIT_DELAY)
|
||||
instance_info = self.client.describe_instances(ecs_models.DescribeInstancesRequest(
|
||||
instance_ids=[path_to_vm]
|
||||
))
|
||||
status = instance_info.instances[0].status
|
||||
|
||||
if status == 'RUNNING':
|
||||
logger.info(f"Instance {path_to_vm} is now running.")
|
||||
break
|
||||
elif status == 'ERROR':
|
||||
raise Exception(f"Instance {path_to_vm} failed to start")
|
||||
elif attempt == MAX_ATTEMPTS - 1:
|
||||
raise Exception(f"Instance {path_to_vm} failed to start within timeout")
|
||||
else:
|
||||
logger.warning(f"Instance {path_to_vm} is in status '{status}' and cannot be started.")
|
||||
|
||||
except ApiException as e:
|
||||
logger.error(f"Failed to start the Volcengine VM {path_to_vm}: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_ip_address(self, path_to_vm: str) -> str:
|
||||
logger.info("Getting Volcengine VM IP address...")
|
||||
|
||||
try:
|
||||
instance_info = self.client.describe_instances(ecs_models.DescribeInstancesRequest(
|
||||
instance_ids=[path_to_vm]
|
||||
))
|
||||
|
||||
public_ip = instance_info.instances[0].eip_address.ip_address
|
||||
private_ip = instance_info.instances[0].network_interfaces[0].primary_ip_address
|
||||
|
||||
if public_ip:
|
||||
vnc_url = f"http://{public_ip}:5910/vnc.html"
|
||||
logger.info("=" * 80)
|
||||
logger.info(f"🖥️ VNC Web Access URL: {vnc_url}")
|
||||
logger.info(f"📡 Public IP: {public_ip}")
|
||||
logger.info(f"🏠 Private IP: {private_ip}")
|
||||
logger.info("=" * 80)
|
||||
print(f"\n🌐 VNC Web Access URL: {vnc_url}")
|
||||
print(f"📍 Please open the above address in the browser for remote desktop access\n")
|
||||
else:
|
||||
logger.warning("No public IP address available for VNC access")
|
||||
|
||||
return private_ip
|
||||
|
||||
except ApiException as e:
|
||||
logger.error(f"Failed to retrieve IP address for the instance {path_to_vm}: {str(e)}")
|
||||
raise
|
||||
|
||||
def save_state(self, path_to_vm: str, snapshot_name: str):
|
||||
logger.info("Saving Volcengine VM state...")
|
||||
|
||||
try:
|
||||
# 创建镜像
|
||||
response = self.client.create_image(ecs_models.CreateImageRequest(
|
||||
snapshot_id=snapshot_name,
|
||||
instance_id=path_to_vm,
|
||||
description=f"OSWorld snapshot: {snapshot_name}"
|
||||
))
|
||||
image_id = response['image_id']
|
||||
logger.info(f"Image {image_id} created successfully from instance {path_to_vm}.")
|
||||
return image_id
|
||||
except ApiException as e:
|
||||
logger.error(f"Failed to create image from the instance {path_to_vm}: {str(e)}")
|
||||
raise
|
||||
|
||||
def revert_to_snapshot(self, path_to_vm: str, snapshot_name: str):
|
||||
logger.info(f"Reverting Volcengine VM to snapshot: {snapshot_name}...")
|
||||
|
||||
try:
|
||||
# 删除原实例
|
||||
self.client.delete_instance(ecs_models.DeleteInstanceRequest(
|
||||
instance_id=path_to_vm,
|
||||
))
|
||||
logger.info(f"Old instance {path_to_vm} has been deleted.")
|
||||
|
||||
# 创建实例
|
||||
new_instance_id = _allocate_vm()
|
||||
|
||||
logger.info(f"New instance {new_instance_id} launched from image {snapshot_name}.")
|
||||
logger.info(f"Waiting for instance {new_instance_id} to be running...")
|
||||
|
||||
# 等待新实例运行
|
||||
while True:
|
||||
instance_info = self.client.describe_instances(ecs_models.DescribeInstancesRequest(
|
||||
instance_ids=[new_instance_id]
|
||||
))
|
||||
status = instance_info.instances[0].status
|
||||
if status == 'RUNNING':
|
||||
break
|
||||
elif status in ['STOPPED', 'ERROR']:
|
||||
raise Exception(f"New instance {new_instance_id} failed to start, status: {status}")
|
||||
time.sleep(5)
|
||||
|
||||
logger.info(f"Instance {new_instance_id} is ready.")
|
||||
|
||||
# 获取新实例的IP地址
|
||||
try:
|
||||
instance_info = self.client.describe_instances(ecs_models.DescribeInstancesRequest(
|
||||
instance_ids=[new_instance_id]
|
||||
))
|
||||
public_ip = instance_info.instances[0].eip_address.ip_address
|
||||
if public_ip:
|
||||
vnc_url = f"http://{public_ip}:5910/vnc.html"
|
||||
logger.info("=" * 80)
|
||||
logger.info(f"🖥️ New Instance VNC Web Access URL: {vnc_url}")
|
||||
logger.info(f"📡 Public IP: {public_ip}")
|
||||
logger.info(f"🆔 New Instance ID: {new_instance_id}")
|
||||
logger.info("=" * 80)
|
||||
print(f"\n🌐 New Instance VNC Web Access URL: {vnc_url}")
|
||||
print(f"📍 Please open the above address in the browser for remote desktop access\n")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get VNC address for new instance {new_instance_id}: {e}")
|
||||
|
||||
return new_instance_id
|
||||
|
||||
except ApiException as e:
|
||||
logger.error(f"Failed to revert to snapshot {snapshot_name} for the instance {path_to_vm}: {str(e)}")
|
||||
raise
|
||||
|
||||
def stop_emulator(self, path_to_vm, region=None):
|
||||
logger.info(f"Stopping Volcengine VM {path_to_vm}...")
|
||||
|
||||
try:
|
||||
self.client.delete_instance(ecs_models.DeleteInstanceRequest(
|
||||
instance_id=path_to_vm,
|
||||
))
|
||||
logger.info(f"Instance {path_to_vm} has been terminated.")
|
||||
except ApiException as e:
|
||||
logger.error(f"Failed to stop the Volcengine VM {path_to_vm}: {str(e)}")
|
||||
raise
|
||||
|
|
@ -1370,7 +1370,7 @@ def open_file():
|
|||
if window_found:
|
||||
return "File opened and window activated successfully"
|
||||
else:
|
||||
return f"Failed to find window for {file_name} within {TIMEOUT} seconds.", 500
|
||||
return f"Failed to find window for {file_name} within {timeout} seconds.", 500
|
||||
|
||||
except Exception as e:
|
||||
return f"Failed to open {path}. Error: {e}", 500
|
||||
|
|
@ -1568,230 +1568,5 @@ def end_recording():
|
|||
return abort(500, description=f"Recording failed. The output file is missing or empty. ffmpeg stderr: {error_output}")
|
||||
|
||||
|
||||
@app.route("/run_python", methods=['POST'])
|
||||
def run_python():
|
||||
data = request.json
|
||||
code = data.get('code', None)
|
||||
|
||||
if not code:
|
||||
return jsonify({'status': 'error', 'message': 'Code not supplied!'}), 400
|
||||
|
||||
# Create a temporary file to save the Python code
|
||||
import tempfile
|
||||
import uuid
|
||||
|
||||
# Generate unique filename
|
||||
temp_filename = f"/tmp/python_exec_{uuid.uuid4().hex}.py"
|
||||
|
||||
try:
|
||||
# Write code to temporary file
|
||||
with open(temp_filename, 'w') as f:
|
||||
f.write(code)
|
||||
|
||||
# Execute the file using subprocess to capture all output
|
||||
result = subprocess.run(
|
||||
['/usr/bin/python3', temp_filename],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
timeout=30 # 30 second timeout
|
||||
)
|
||||
|
||||
# Clean up the temporary file
|
||||
try:
|
||||
os.remove(temp_filename)
|
||||
except:
|
||||
pass # Ignore cleanup errors
|
||||
|
||||
# Prepare response
|
||||
output = result.stdout
|
||||
error_output = result.stderr
|
||||
|
||||
# Combine output and errors if both exist
|
||||
combined_message = output
|
||||
if error_output:
|
||||
combined_message += ('\n' + error_output) if output else error_output
|
||||
|
||||
# Determine status based on return code and errors
|
||||
if result.returncode != 0:
|
||||
status = 'error'
|
||||
if not error_output:
|
||||
# If no stderr but non-zero return code, add a generic error message
|
||||
error_output = f"Process exited with code {result.returncode}"
|
||||
combined_message = combined_message + '\n' + error_output if combined_message else error_output
|
||||
else:
|
||||
status = 'success'
|
||||
|
||||
return jsonify({
|
||||
'status': status,
|
||||
'message': combined_message,
|
||||
'need_more': False, # Not applicable for file execution
|
||||
'output': output, # stdout only
|
||||
'error': error_output, # stderr only
|
||||
'return_code': result.returncode
|
||||
})
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
# Clean up the temporary file on timeout
|
||||
try:
|
||||
os.remove(temp_filename)
|
||||
except:
|
||||
pass
|
||||
|
||||
return jsonify({
|
||||
'status': 'error',
|
||||
'message': 'Execution timeout: Code took too long to execute',
|
||||
'error': 'TimeoutExpired',
|
||||
'need_more': False,
|
||||
'output': None,
|
||||
}), 500
|
||||
|
||||
except Exception as e:
|
||||
# Clean up the temporary file on error
|
||||
try:
|
||||
os.remove(temp_filename)
|
||||
except:
|
||||
pass
|
||||
|
||||
# Capture the exception details
|
||||
return jsonify({
|
||||
'status': 'error',
|
||||
'message': f'Execution error: {str(e)}',
|
||||
'error': traceback.format_exc(),
|
||||
'need_more': False,
|
||||
'output': None,
|
||||
}), 500
|
||||
|
||||
|
||||
@app.route("/run_bash_script", methods=['POST'])
|
||||
def run_bash_script():
|
||||
data = request.json
|
||||
script = data.get('script', None)
|
||||
timeout = data.get('timeout', 100) # Default timeout of 30 seconds
|
||||
working_dir = data.get('working_dir', None)
|
||||
|
||||
if not script:
|
||||
return jsonify({
|
||||
'status': 'error',
|
||||
'output': 'Script not supplied!',
|
||||
'error': "", # Always empty as requested
|
||||
'returncode': -1
|
||||
}), 400
|
||||
|
||||
# Expand user directory if provided
|
||||
if working_dir:
|
||||
working_dir = os.path.expanduser(working_dir)
|
||||
if not os.path.exists(working_dir):
|
||||
return jsonify({
|
||||
'status': 'error',
|
||||
'output': f'Working directory does not exist: {working_dir}',
|
||||
'error': "", # Always empty as requested
|
||||
'returncode': -1
|
||||
}), 400
|
||||
|
||||
# Create a temporary script file
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.sh', delete=False) as tmp_file:
|
||||
if "#!/bin/bash" not in script:
|
||||
script = "#!/bin/bash\n\n" + script
|
||||
tmp_file.write(script)
|
||||
tmp_file_path = tmp_file.name
|
||||
|
||||
try:
|
||||
# Make the script executable
|
||||
os.chmod(tmp_file_path, 0o755)
|
||||
|
||||
# Execute the script
|
||||
if platform_name == "Windows":
|
||||
# On Windows, use Git Bash or WSL if available, otherwise cmd
|
||||
flags = subprocess.CREATE_NO_WINDOW
|
||||
# Try to use bash if available (Git Bash, WSL, etc.)
|
||||
result = subprocess.run(
|
||||
['bash', tmp_file_path],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT, # Merge stderr into stdout
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
cwd=working_dir,
|
||||
creationflags=flags,
|
||||
shell=False
|
||||
)
|
||||
else:
|
||||
# On Unix-like systems, use bash directly
|
||||
flags = 0
|
||||
result = subprocess.run(
|
||||
['/bin/bash', tmp_file_path],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT, # Merge stderr into stdout
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
cwd=working_dir,
|
||||
creationflags=flags,
|
||||
shell=False
|
||||
)
|
||||
|
||||
# Log the command execution for trajectory recording
|
||||
_append_event("BashScript",
|
||||
{"script": script, "output": result.stdout, "error": "", "returncode": result.returncode},
|
||||
ts=time.time())
|
||||
|
||||
return jsonify({
|
||||
'status': 'success' if result.returncode == 0 else 'error',
|
||||
'output': result.stdout, # Contains both stdout and stderr merged
|
||||
'error': "", # Always empty as requested
|
||||
'returncode': result.returncode
|
||||
})
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return jsonify({
|
||||
'status': 'error',
|
||||
'output': f'Script execution timed out after {timeout} seconds',
|
||||
'error': "", # Always empty as requested
|
||||
'returncode': -1
|
||||
}), 500
|
||||
except FileNotFoundError:
|
||||
# Bash not found, try with sh
|
||||
try:
|
||||
result = subprocess.run(
|
||||
['sh', tmp_file_path],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT, # Merge stderr into stdout
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
cwd=working_dir,
|
||||
shell=False
|
||||
)
|
||||
|
||||
_append_event("BashScript",
|
||||
{"script": script, "output": result.stdout, "error": "", "returncode": result.returncode},
|
||||
ts=time.time())
|
||||
|
||||
return jsonify({
|
||||
'status': 'success' if result.returncode == 0 else 'error',
|
||||
'output': result.stdout, # Contains both stdout and stderr merged
|
||||
'error': "", # Always empty as requested
|
||||
'returncode': result.returncode,
|
||||
})
|
||||
except Exception as e:
|
||||
return jsonify({
|
||||
'status': 'error',
|
||||
'output': f'Failed to execute script: {str(e)}',
|
||||
'error': "", # Always empty as requested
|
||||
'returncode': -1
|
||||
}), 500
|
||||
except Exception as e:
|
||||
return jsonify({
|
||||
'status': 'error',
|
||||
'output': f'Failed to execute script: {str(e)}',
|
||||
'error': "", # Always empty as requested
|
||||
'returncode': -1
|
||||
}), 500
|
||||
finally:
|
||||
# Clean up the temporary file
|
||||
try:
|
||||
os.unlink(tmp_file_path)
|
||||
except:
|
||||
pass
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(debug=True, host="0.0.0.0")
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
{
|
||||
"id": "35253b65-1c19-4304-8aa4-6884b8218fc0",
|
||||
"snapshot": "chrome",
|
||||
"instruction": "Hey, I need a quick way back to this site. Could you whip up a shortcut on my desktop for me using Chrome's built-in feature?",
|
||||
"instruction": "Hey, I need a quick way back to this site. Could you whip up a shortcut on my desktop for me?",
|
||||
"source": "https://www.laptopmag.com/articles/how-to-create-desktop-shortcuts-for-web-pages-using-chrome",
|
||||
"config": [
|
||||
{
|
||||
|
|
|
|||
|
|
@ -4,20 +4,6 @@
|
|||
"instruction": "I want Chrome to warn me whenever I visit a potentially harmful or unsafe website. Can you enable this safety feature?",
|
||||
"source": "https://www.quora.com/How-do-I-set-the-security-settings-for-the-Google-Chrome-browser-for-the-best-security#:~:text=Enable%20Safe%20Browsing:%20Chrome%20has%20a%20built%2Din,Security%20%3E%20Security%20%3E%20Enable%20Safe%20Browsing.",
|
||||
"config": [
|
||||
{
|
||||
"type": "execute",
|
||||
"parameters": {
|
||||
"command": "echo {CLIENT_PASSWORD} | sudo -S apt update -y && echo {CLIENT_PASSWORD} | sudo -S apt install jq -y",
|
||||
"shell": true
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "execute",
|
||||
"parameters": {
|
||||
"command": "mkdir -p /home/user/.config/google-chrome/Default && if [ ! -f /home/user/.config/google-chrome/Default/Preferences ]; then echo '{}' > /home/user/.config/google-chrome/Default/Preferences; fi && cd /home/user/.config/google-chrome/Default && jq '. + {\"safebrowsing\":{\"enabled\":false,\"enhanced\":false}}' Preferences > temp && mv temp Preferences",
|
||||
"shell": true
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "launch",
|
||||
"parameters": {
|
||||
|
|
@ -71,7 +57,7 @@
|
|||
],
|
||||
"func": "exact_match",
|
||||
"result": {
|
||||
"type": "enable_safe_browsing"
|
||||
"type": "enable_enhanced_safety_browsing"
|
||||
},
|
||||
"expected": {
|
||||
"type": "rule",
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@
|
|||
"chrome"
|
||||
],
|
||||
"evaluator": {
|
||||
"func": "is_expected_url_pattern_match",
|
||||
"func": "is_expected_active_tab",
|
||||
"result": {
|
||||
"type": "active_url_from_accessTree",
|
||||
"goto_prefix": "https://www."
|
||||
|
|
@ -51,9 +51,8 @@
|
|||
"expected": {
|
||||
"type": "rule",
|
||||
"rules": {
|
||||
"expected": [
|
||||
"^https://(www\\.)?dmv\\.virginia\\.gov/licenses-ids/license/applying/eligibility"
|
||||
]
|
||||
"type": "url",
|
||||
"url": "https://www.dmv.virginia.gov/licenses-ids/license/applying/eligibility"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
|
|
|||
|
|
@ -44,35 +44,37 @@
|
|||
],
|
||||
"evaluator": {
|
||||
"func": [
|
||||
"is_expected_url_pattern_match",
|
||||
"is_expected_url_pattern_match"
|
||||
"exact_match",
|
||||
"exact_match"
|
||||
],
|
||||
"conj": "or",
|
||||
"result": [
|
||||
{
|
||||
"type": "active_url_from_accessTree",
|
||||
"goto_prefix": "https://www."
|
||||
"type": "url_dashPart",
|
||||
"goto_prefix": "https://www.",
|
||||
"partIndex": -1,
|
||||
"needDeleteId": false,
|
||||
"returnType": "string"
|
||||
},
|
||||
{
|
||||
"type": "active_url_from_accessTree",
|
||||
"goto_prefix": "https://www."
|
||||
"type": "url_dashPart",
|
||||
"goto_prefix": "https://www.",
|
||||
"partIndex": -1,
|
||||
"needDeleteId": false,
|
||||
"returnType": "string"
|
||||
}
|
||||
],
|
||||
"expected": [
|
||||
{
|
||||
"type": "rule",
|
||||
"rules": {
|
||||
"expected": [
|
||||
"^https://(www\\.)?drugs\\.com/tamiflu\\.html#side-effects"
|
||||
]
|
||||
"expected": "tamiflu.html#side-effects"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "rule",
|
||||
"rules": {
|
||||
"expected": [
|
||||
"^https://(www\\.)?drugs\\.com/sfx/tamiflu-side-effects\\.html"
|
||||
]
|
||||
"expected": "tamiflu-side-effects.html"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@
|
|||
"type": "rule",
|
||||
"rules": {
|
||||
"expected": [
|
||||
"united\\.com/en/us/checked-bag-fee-calculator(/.*)?"
|
||||
"united.com/en/us/checked-bag-fee-calculator"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -82,8 +82,8 @@
|
|||
],
|
||||
"func": "check_palette_and_structure_sim",
|
||||
"expected": {
|
||||
"type": "cloud_file",
|
||||
"path": "https://huggingface.co/datasets/xlangai/ubuntu_osworld_file_cache/resolve/main/gimp/06ca5602-62ca-47f6-ad4f-da151cde54cc/computer.png",
|
||||
"type": "vm_file",
|
||||
"path": "/home/user/Desktop/computer.png",
|
||||
"dest": "computer.png"
|
||||
},
|
||||
"result": {
|
||||
|
|
|
|||
|
|
@ -11,6 +11,10 @@
|
|||
{
|
||||
"url": "https://huggingface.co/datasets/xlangai/ubuntu_osworld_file_cache/resolve/main/gimp/2a729ded-3296-423d-aec4-7dd55ed5fbb3/dog_with_background.png",
|
||||
"path": "/home/user/Desktop/dog_with_background.png"
|
||||
},
|
||||
{
|
||||
"url": "https://huggingface.co/datasets/xlangai/ubuntu_osworld_file_cache/resolve/main/gimp/2a729ded-3296-423d-aec4-7dd55ed5fbb3/dog_cutout_gold.png",
|
||||
"path": "/home/user/Desktop/dog_cutout_gold.png"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
@ -82,8 +86,8 @@
|
|||
],
|
||||
"func": "check_structure_sim",
|
||||
"expected": {
|
||||
"type": "cloud_file",
|
||||
"path": "https://huggingface.co/datasets/xlangai/ubuntu_osworld_file_cache/resolve/main/gimp/2a729ded-3296-423d-aec4-7dd55ed5fbb3/dog_cutout_gold.png",
|
||||
"type": "vm_file",
|
||||
"path": "/home/user/Desktop/dog_cutout_gold.png",
|
||||
"dest": "dog_cutout_gold.png"
|
||||
},
|
||||
"result": {
|
||||
|
|
|
|||
|
|
@ -82,8 +82,8 @@
|
|||
],
|
||||
"func": "check_saturation_increase_and_structure_sim",
|
||||
"expected": {
|
||||
"type": "cloud_file",
|
||||
"path": "https://huggingface.co/datasets/xlangai/ubuntu_osworld_file_cache/resolve/main/gimp/554785e9-4523-4e7a-b8e1-8016f565f56a/woman_sitting_by_the_tree2.png",
|
||||
"type": "vm_file",
|
||||
"path": "/home/user/Desktop/woman_sitting_by_the_tree2.png",
|
||||
"dest": "woman_sitting_by_the_tree2.png"
|
||||
},
|
||||
"result": {
|
||||
|
|
|
|||
|
|
@ -88,8 +88,8 @@
|
|||
],
|
||||
"func": "check_image_mirror",
|
||||
"expected": {
|
||||
"type": "cloud_file",
|
||||
"path": "https://huggingface.co/datasets/xlangai/ubuntu_osworld_file_cache/resolve/main/gimp/72f83cdc-bf76-4531-9a1b-eb893a13f8aa/berry.jpeg",
|
||||
"type": "vm_file",
|
||||
"path": "/home/user/Desktop/berry.png",
|
||||
"dest": "berry.png"
|
||||
},
|
||||
"result": {
|
||||
|
|
|
|||
|
|
@ -86,8 +86,8 @@
|
|||
],
|
||||
"func": "check_green_background",
|
||||
"expected": {
|
||||
"type": "cloud_file",
|
||||
"path": "https://huggingface.co/datasets/xlangai/ubuntu_osworld_file_cache/resolve/main/gimp/734d6579-c07d-47a8-9ae2-13339795476b/white_background_with_object.png",
|
||||
"type": "vm_file",
|
||||
"path": "/home/user/Desktop/white_background_with_object.png",
|
||||
"dest": "white_background_with_object.png"
|
||||
},
|
||||
"result": {
|
||||
|
|
|
|||
|
|
@ -32,8 +32,8 @@
|
|||
"evaluator": {
|
||||
"func": "check_file_exists_and_structure_sim",
|
||||
"expected": {
|
||||
"type": "cloud_file",
|
||||
"path": "https://huggingface.co/datasets/xlangai/ubuntu_osworld_file_cache/resolve/main/gimp/77b8ab4d-994f-43ac-8930-8ca087d7c4b4/The_Lost_River_Of_Dreams.jpg",
|
||||
"type": "vm_file",
|
||||
"path": "/home/user/Desktop/The_Lost_River_Of_Dreams.jpg",
|
||||
"dest": "The_Lost_River_Of_Dreams.jpg"
|
||||
},
|
||||
"result": {
|
||||
|
|
|
|||
|
|
@ -82,8 +82,8 @@
|
|||
],
|
||||
"func": "check_brightness_decrease_and_structure_sim",
|
||||
"expected": {
|
||||
"type": "cloud_file",
|
||||
"path": "https://huggingface.co/datasets/xlangai/ubuntu_osworld_file_cache/resolve/main/gimp/7a4deb26-d57d-4ea9-9a73-630f66a7b568/woman_sitting_by_the_tree.png",
|
||||
"type": "vm_file",
|
||||
"path": "/home/user/Desktop/woman_sitting_by_the_tree.png",
|
||||
"dest": "woman_sitting_by_the_tree.png"
|
||||
},
|
||||
"result": {
|
||||
|
|
|
|||
|
|
@ -104,8 +104,8 @@
|
|||
}
|
||||
},
|
||||
{
|
||||
"type": "cloud_file",
|
||||
"path": "https://huggingface.co/datasets/xlangai/ubuntu_osworld_file_cache/resolve/main/gimp/d16c99dc-2a1e-46f2-b350-d97c86c85c15/dog_with_background.png",
|
||||
"type": "vm_file",
|
||||
"path": "/home/user/Desktop/dog_with_background.png",
|
||||
"dest": "dog_with_background.png"
|
||||
}
|
||||
],
|
||||
|
|
|
|||
|
|
@ -88,8 +88,8 @@
|
|||
],
|
||||
"func": "check_contrast_increase_and_structure_sim",
|
||||
"expected": {
|
||||
"type": "cloud_file",
|
||||
"path": "https://huggingface.co/datasets/xlangai/ubuntu_osworld_file_cache/resolve/main/gimp/f723c744-e62c-4ae6-98d1-750d3cd7d79d/file_1X42_kOanL74vu_p6QdcZuiyzDQi3kA7F.jpg",
|
||||
"type": "vm_file",
|
||||
"path": "/home/user/Desktop/berries.png",
|
||||
"dest": "berries.png"
|
||||
},
|
||||
"result": {
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
{
|
||||
"id": "01b269ae-2111-4a07-81fd-3fcd711993b0",
|
||||
"snapshot": "libreoffice_calc",
|
||||
"instruction": "Fill all the blank cells in B1:E30 with the value in the cell above it. Finish the work and don't touch irrelevant regions, even if they are blank.",
|
||||
"instruction": "Fill all the blank cells in B1:E30 with the value in the cell above it.",
|
||||
"source": "https://www.youtube.com/shorts/VrUzPTIwQ04",
|
||||
"config": [
|
||||
{
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
{
|
||||
"id": "0bf05a7d-b28b-44d2-955a-50b41e24012a",
|
||||
"snapshot": "libreoffice_calc",
|
||||
"instruction": "I would like to copy all the numbers in the 'Old ID' column to the 'New 7 Digit Id' column, and pad them with zeros in front, to fill them up to seven digits. Finish the work and don't touch irrelevant regions, even if they are blank.",
|
||||
"instruction": "I would like to copy all the numbers in the 'Old ID' column to the 'New 7 Digit Id' column, and pad them with zeros in front, to fill them up to seven digits.",
|
||||
"source": "https://www.youtube.com/shorts/FPAQaDTS8VY",
|
||||
"config": [
|
||||
{
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
{
|
||||
"id": "1273e544-688f-496b-8d89-3e0f40aa0606",
|
||||
"snapshot": "libreoffice_calc",
|
||||
"instruction": "Copy the \"Revenue\" column along with the header to a new sheet named \"Sheet2\".",
|
||||
"instruction": "Copy the \"Revenue\" column along with the header to a new sheet.",
|
||||
"source": "SheetCopilot@123",
|
||||
"config": [
|
||||
{
|
||||
|
|
@ -82,4 +82,4 @@
|
|||
"proxy": false,
|
||||
"fixed_ip": false,
|
||||
"possibility_of_env_change": "low"
|
||||
}
|
||||
}
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
{
|
||||
"id": "21df9241-f8d7-4509-b7f1-37e501a823f7",
|
||||
"snapshot": "libreoffice_calc",
|
||||
"instruction": "Change the representation of column \"Parameter\" to show in Millions (M) in Column B and Billions (B) in Column C. The numbers should be rounded to one decimal place, and half should be rounded up. Then remember to place a white space between the digits and the unit.",
|
||||
"instruction": "Change the representation of column \"Parameter\" to show in Millions (M) in Column B and Billions (B) in Column C. Keep one decimal and place a white space between the digits and the unit.",
|
||||
"source": "https://www.youtube.com/watch?v=p5C4V_AO1UU",
|
||||
"config": [
|
||||
{
|
||||
|
|
@ -116,4 +116,4 @@
|
|||
"proxy": false,
|
||||
"fixed_ip": false,
|
||||
"possibility_of_env_change": "low"
|
||||
}
|
||||
}
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
{
|
||||
"id": "26a8440e-c166-4c50-aef4-bfb77314b46b",
|
||||
"snapshot": "libreoffice_calc",
|
||||
"instruction": "Create a table with two column headers (\"Month\" and \"Total\") in a new sheet named \"Sheet2\" to show the total sales for all months.",
|
||||
"instruction": "Create a table with two column headers (\"Month\" and \"Total\") in a new sheet to show the total sales for all months.",
|
||||
"source": "SheetCopilot@152",
|
||||
"config": [
|
||||
{
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
{
|
||||
"id": "30e3e107-1cfb-46ee-a755-2cd080d7ba6a",
|
||||
"snapshot": "libreoffice_calc",
|
||||
"instruction": "Please create a new sheet. Keep its sheet name as \"Sheet2\". Merge cells A1:C1 in the new sheet and write \"Demographic Profile\" with blue (#0000ff) fill and bold white text. Then I want to create three pivot tables to show the percentage of Sex, Civil Status, and Highest Educational Attainment. They should be stacked one by one in Sheet2, each separated with a blank line.",
|
||||
"instruction": "Please create a new sheet. Merge cells A1:C1 in the new sheet and write \"Demographic Profile\" with blue (#0000ff) fill and bold white text. Then I want to create three pivot tables to show the percentage of Sex, Civil Status , and Highest Educational Attainment. They should be stacked one by one in the new sheet, each separated with a blank line.",
|
||||
"source": "SheetCopilot@9",
|
||||
"config": [
|
||||
{
|
||||
|
|
@ -111,4 +111,4 @@
|
|||
"proxy": false,
|
||||
"fixed_ip": false,
|
||||
"possibility_of_env_change": "low"
|
||||
}
|
||||
}
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
{
|
||||
"id": "357ef137-7eeb-4c80-a3bb-0951f26a8aff",
|
||||
"snapshot": "libreoffice_calc",
|
||||
"instruction": "I have calculated the total work hours from the everday hours. And I have an hourly rate. Now I want to multiply the total hours with the hourly rate to get a total earned amount. However, I can't get a correct answer by directly multiply the two cells. Here the \"total hours\" is of time and \"hourly rate\" is just a number. How can I get the correct product of them? Help me fill in the cell the correct answer. Don't touch irrelevant blank regions.",
|
||||
"instruction": "I have calculated the total work hours from the everday hours. And I have an hourly rate. Now I want to multiply the total hours with the hourly rate to get a total earned amount. However, I can't get a correct answer by directly multiply the two cells. Here the \"total hours\" is of time and \"hourly rate\" is just a number. How can I get the correct product of them?",
|
||||
"source": "https://www.reddit.com/r/excel/comments/17zny8u/calculating_total_amount_earned_from_total_hours/",
|
||||
"config": [
|
||||
{
|
||||
|
|
@ -83,4 +83,4 @@
|
|||
"proxy": false,
|
||||
"fixed_ip": false,
|
||||
"possibility_of_env_change": "low"
|
||||
}
|
||||
}
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
{
|
||||
"id": "37608790-6147-45d0-9f20-1137bb35703d",
|
||||
"snapshot": "libreoffice_calc",
|
||||
"instruction": "The information are mixed in one field. Help me split them and fill in the columns of First Name, Last Name and Rank. Finish the work and don't touch the original data.",
|
||||
"instruction": "The information are mixed in one field. Help me split them and fill in the columns of First Name, Last Name and Rank",
|
||||
"source": "https://www.youtube.com/shorts/uzPo_CPCHH8",
|
||||
"config": [
|
||||
{
|
||||
|
|
@ -82,4 +82,4 @@
|
|||
"proxy": false,
|
||||
"fixed_ip": false,
|
||||
"possibility_of_env_change": "low"
|
||||
}
|
||||
}
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
{
|
||||
"id": "4e6fcf72-daf3-439f-a232-c434ce416af6",
|
||||
"snapshot": "libreoffice_calc",
|
||||
"instruction": "Please calculate the ages of the employees according to their birthday. Finish the work and don't touch irrelevant regions, even if they are blank.",
|
||||
"instruction": "Please calculate the ages of the employees according to their birthday.",
|
||||
"source": "https://www.youtube.com/shorts/0uxJccNCKcE",
|
||||
"config": [
|
||||
{
|
||||
|
|
@ -134,4 +134,4 @@
|
|||
"proxy": false,
|
||||
"fixed_ip": false,
|
||||
"possibility_of_env_change": "low"
|
||||
}
|
||||
}
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
{
|
||||
"id": "4f07fbe9-70de-4927-a4d5-bb28bc12c52c",
|
||||
"snapshot": "libreoffice_calc",
|
||||
"instruction": "Here I want to use the numerical value from a cell in the text. I can set its number of decimal digits to 2 in the original value cell but don't know how to fix it in the text as well. Please help me to do this. Finish the work and don't touch irrelevant regions, even if they are blank.",
|
||||
"instruction": "Here I want to use the numerical value from a cell in the text. I can set its number of decimal digits to 2 in the original value cell but don't know how to fix it in the text as well. Please help me to do this.",
|
||||
"source": "https://superuser.com/questions/1081048/libreoffice-calc-how-to-pad-number-to-fixed-decimals-when-used-within-formula",
|
||||
"config": [
|
||||
{
|
||||
|
|
@ -115,4 +115,4 @@
|
|||
"proxy": false,
|
||||
"fixed_ip": false,
|
||||
"possibility_of_env_change": "low"
|
||||
}
|
||||
}
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
{
|
||||
"id": "6054afcb-5bab-4702-90a0-b259b5d3217c",
|
||||
"snapshot": "libreoffice_calc",
|
||||
"instruction": "Some data are missed by now and are filled by 'N/A' temporarily. Please hide them in the table for now. Do not delete any cells and filter is not needed.",
|
||||
"instruction": "Some data are missed by now and are filled by 'N/A' temporarily. Please hide them in the table for now. Do not delete them and filter is no needed.",
|
||||
"source": "https://www.youtube.com/shorts/JTbZ8sRxkdU",
|
||||
"config": [
|
||||
{
|
||||
|
|
@ -90,4 +90,4 @@
|
|||
"proxy": false,
|
||||
"fixed_ip": false,
|
||||
"possibility_of_env_change": "low"
|
||||
}
|
||||
}
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
{
|
||||
"id": "7a4e4bc8-922c-4c84-865c-25ba34136be1",
|
||||
"snapshot": "libreoffice_calc",
|
||||
"instruction": "Reorder the columns to be \"Date\", \"First Name\", \"Last Name\", \"Order ID\", \"Sales\". Finish the work and don't touch irrelevant regions, even if they are blank.",
|
||||
"instruction": "Reorder the columns to be \"Date\", \"First Name\", \"Last Name\", \"Order ID\", \"Sales\"",
|
||||
"source": "https://www.youtube.com/shorts/bvUhr1AHs44",
|
||||
"config": [
|
||||
{
|
||||
|
|
@ -82,4 +82,4 @@
|
|||
"proxy": false,
|
||||
"fixed_ip": false,
|
||||
"possibility_of_env_change": "low"
|
||||
}
|
||||
}
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
{
|
||||
"id": "7efeb4b1-3d19-4762-b163-63328d66303b",
|
||||
"snapshot": "libreoffice_calc",
|
||||
"instruction": "Fill the Sequence Numbers as \"No. #\" in the \"Seq No.\" column. Finish the work and don't touch irrelevant regions, even if they are blank.",
|
||||
"instruction": "Fill the Sequence Numbers as \"No. #\" in the \"Seq No.\" column",
|
||||
"source": "https://www.youtube.com/shorts/4jzXfZNhfmk",
|
||||
"config": [
|
||||
{
|
||||
|
|
@ -82,4 +82,4 @@
|
|||
"proxy": false,
|
||||
"fixed_ip": false,
|
||||
"possibility_of_env_change": "low"
|
||||
}
|
||||
}
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
{
|
||||
"id": "8b1ce5f2-59d2-4dcc-b0b0-666a714b9a14",
|
||||
"snapshot": "libreoffice_calc",
|
||||
"instruction": "Given a partial calendar, please highlight all the weekends (Satureday & Sunday) by setting the cell background as red (#ff0000). Finish the work and don't touch irrelevant regions, even if they are blank.",
|
||||
"instruction": "Given a partial calendar, please highlight all the weekends (Satureday & Sunday) by setting the cell background as red (#ff0000).",
|
||||
"source": "https://www.youtube.com/shorts/Hbcwu6IQ1ns",
|
||||
"config": [
|
||||
{
|
||||
|
|
@ -90,4 +90,4 @@
|
|||
"proxy": false,
|
||||
"fixed_ip": false,
|
||||
"possibility_of_env_change": "low"
|
||||
}
|
||||
}
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
{
|
||||
"id": "a9f325aa-8c05-4e4f-8341-9e4358565f4f",
|
||||
"snapshot": "libreoffice_calc",
|
||||
"instruction": "I want to copy the movie titles in 'Garbage Movie Titles' column to the 'Clean Movie Titles' column. But please remove the adundant whitespaces and canonicalize the letter cases by capitalizing the first letter of each words and leave other letters as lower case. Finish the work and don't touch irrelevant regions, even if they are blank.",
|
||||
"instruction": "I want to copy the movie titles in 'Garbage Movie Titles' column to the 'Clean Movie Titles' column. But please remove the adundant whitespaces and canonicalize the letter cases by capitalizing the first letter of each words and leave other letters as lower case.",
|
||||
"source": "https://www.youtube.com/shorts/A0gmEBRKXWs",
|
||||
"config": [
|
||||
{
|
||||
|
|
@ -82,4 +82,4 @@
|
|||
"proxy": false,
|
||||
"fixed_ip": false,
|
||||
"possibility_of_env_change": "low"
|
||||
}
|
||||
}
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
{
|
||||
"id": "abed40dc-063f-4598-8ba5-9fe749c0615d",
|
||||
"snapshot": "libreoffice_calc",
|
||||
"instruction": "Check the names in column \"Names with duplicates\" and put the unique ones in column \"Unique Names\". Keep the original order of the first occurrences. Finish the work and don't touch irrelevant regions, even if they are blank.",
|
||||
"instruction": "Check the names in column \"Names with duplicates\" and put the unique ones in column \"Unique Names\". Keep the original order.",
|
||||
"source": "https://help.libreoffice.org/7.6/ro/text/scalc/guide/remove_duplicates.html?&DbPAR=SHARED&System=UNIX",
|
||||
"config": [
|
||||
{
|
||||
|
|
@ -82,4 +82,4 @@
|
|||
"proxy": false,
|
||||
"fixed_ip": false,
|
||||
"possibility_of_env_change": "low"
|
||||
}
|
||||
}
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
{
|
||||
"id": "d681960f-7bc3-4286-9913-a8812ba3261a",
|
||||
"snapshot": "libreoffice_calc",
|
||||
"instruction": "According to the scale table shown above, calculate and give each student a grade in the table below. Finish the work and don't touch irrelevant regions, even if they are blank.",
|
||||
"instruction": "According to the scale table shown above, calculate and give each student a grade in the table below",
|
||||
"source": "https://www.youtube.com/shorts/d7U1S_IsTVM",
|
||||
"config": [
|
||||
{
|
||||
|
|
@ -82,4 +82,4 @@
|
|||
"proxy": false,
|
||||
"fixed_ip": false,
|
||||
"possibility_of_env_change": "low"
|
||||
}
|
||||
}
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
{
|
||||
"id": "ecb0df7a-4e8d-4a03-b162-053391d3afaf",
|
||||
"snapshot": "libreoffice_calc",
|
||||
"instruction": "In the column \"Pass/Fail/Held\", one from the texts \"Pass\", \"Fail\", and \"Held\" should be filled. For convinience, enable data validation for the cells in this column so that the texts to fill can be directly selected from a drop down list. Finish the work and don't touch irrelevant regions, even if they are blank.",
|
||||
"instruction": "In the column \"Pass/Fail/Held\", one from the texts \"Pass\", \"Fail\", and \"Held\" should be filled. For convinience, enable data validation for the cells in this column so that the texts to fill can be directly selected from a drop down list.",
|
||||
"source": "https://www.youtube.com/shorts/tXOovKn0H68",
|
||||
"config": [
|
||||
{
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
{
|
||||
"id": "04578141-1d42-4146-b9cf-6fab4ce5fd74",
|
||||
"snapshot": "libreoffice_impress",
|
||||
"instruction": "Change the text color in the textboxes to on slide 1 yellow, red, and green, respectively, in top-to-bottom order. Use exactly these colors—no variations (e.g., no dark red, light green, etc.).",
|
||||
"instruction": "Color the first three textboxes on slide 1 yellow, red, and green, respectively, in top-to-bottom order. Use exactly these colors—no variations (e.g., no dark red, light green, etc.).",
|
||||
"source": "https://arxiv.org/pdf/2311.01767.pdf",
|
||||
"config": [
|
||||
{
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@
|
|||
"parameters": {
|
||||
"command": [
|
||||
"google-chrome",
|
||||
"--proxy-server=http://127.0.0.1:18888",
|
||||
"--remote-debugging-port=1337"
|
||||
]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@
|
|||
"rules": {
|
||||
"expected": [
|
||||
"Zoom Chrome Extension",
|
||||
"Speechify — Voice AI Assistant",
|
||||
"Speechify Text to Speech Voice Reader",
|
||||
"React Developer Tools",
|
||||
"Momentum",
|
||||
"Google Translate"
|
||||
|
|
|
|||
|
|
@ -40,8 +40,8 @@
|
|||
},
|
||||
"result": {
|
||||
"type": "vm_file",
|
||||
"path": "/home/user/essay_submission.zip",
|
||||
"dest": "essay_submission.zip"
|
||||
"path": "/home/user/Recruitment_and_retention_of_health_professionals_across_Europe.zip",
|
||||
"dest": "Recruitment_and_retention_of_health_professionals_across_Europe.zip"
|
||||
}
|
||||
},
|
||||
"proxy": false,
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
{
|
||||
"type": "execute",
|
||||
"parameters": {
|
||||
"command": "echo {CLIENT_PASSWORD} | sudo -S su - charles",
|
||||
"command": "echo password | sudo -S su - charles",
|
||||
"shell": true
|
||||
}
|
||||
},
|
||||
|
|
|
|||
|
|
@ -1,135 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Thread-safe results logging for OSWorld evaluations.
|
||||
Appends task completion results to results.json in real-time.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import fcntl
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
|
||||
def extract_domain_from_path(result_path: str) -> str:
|
||||
"""
|
||||
Extract domain/application from result directory path.
|
||||
Expected structure: results/{action_space}/{observation_type}/{model}/{domain}/{task_id}/
|
||||
"""
|
||||
path_parts = Path(result_path).parts
|
||||
if len(path_parts) >= 2:
|
||||
return path_parts[-2] # Second to last part should be domain
|
||||
return "unknown"
|
||||
|
||||
|
||||
def append_task_result(
|
||||
task_id: str,
|
||||
domain: str,
|
||||
score: float,
|
||||
result_dir: str,
|
||||
args: Any,
|
||||
error_message: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
Thread-safely append a task result to results.json.
|
||||
|
||||
Args:
|
||||
task_id: UUID of the task
|
||||
domain: Application domain (chrome, vlc, etc.)
|
||||
score: Task score (0.0 or 1.0)
|
||||
result_dir: Full path to the task result directory
|
||||
args: Command line arguments object
|
||||
error_message: Error message if task failed
|
||||
"""
|
||||
# Create result entry
|
||||
result_entry = {
|
||||
"application": domain,
|
||||
"task_id": task_id,
|
||||
"status": "error" if error_message else "success",
|
||||
"score": score,
|
||||
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
}
|
||||
|
||||
if error_message:
|
||||
result_entry["err_message"] = error_message
|
||||
|
||||
# Determine summary directory and results file path
|
||||
# Extract base result directory from args
|
||||
base_result_dir = Path(args.result_dir)
|
||||
summary_dir = base_result_dir / "summary"
|
||||
results_file = summary_dir / "results.json"
|
||||
|
||||
# Ensure summary directory exists
|
||||
summary_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Thread-safe JSON append with file locking
|
||||
try:
|
||||
with open(results_file, 'a+') as f:
|
||||
# Lock the file for exclusive access
|
||||
fcntl.flock(f.fileno(), fcntl.LOCK_EX)
|
||||
|
||||
try:
|
||||
# Move to beginning to read existing content
|
||||
f.seek(0)
|
||||
content = f.read().strip()
|
||||
|
||||
# Parse existing JSON array or create new one
|
||||
if content:
|
||||
try:
|
||||
existing_results = json.loads(content)
|
||||
if not isinstance(existing_results, list):
|
||||
existing_results = []
|
||||
except json.JSONDecodeError:
|
||||
existing_results = []
|
||||
else:
|
||||
existing_results = []
|
||||
|
||||
# Add new result
|
||||
existing_results.append(result_entry)
|
||||
|
||||
# Write back the complete JSON array
|
||||
f.seek(0)
|
||||
f.truncate()
|
||||
json.dump(existing_results, f, indent=2)
|
||||
f.write('\n') # Add newline for readability
|
||||
|
||||
finally:
|
||||
# Always unlock the file
|
||||
fcntl.flock(f.fileno(), fcntl.LOCK_UN)
|
||||
|
||||
print(f"📝 Logged result: {domain}/{task_id} -> {result_entry['status']} (score: {score})")
|
||||
|
||||
except Exception as e:
|
||||
# Don't let logging errors break the main evaluation
|
||||
print(f"⚠️ Failed to log result for {task_id}: {e}")
|
||||
|
||||
|
||||
def log_task_completion(example: Dict, result: float, result_dir: str, args: Any) -> None:
|
||||
"""
|
||||
Convenience wrapper for logging successful task completion.
|
||||
|
||||
Args:
|
||||
example: Task configuration dictionary
|
||||
result: Task score
|
||||
result_dir: Path to task result directory
|
||||
args: Command line arguments
|
||||
"""
|
||||
task_id = example.get('id', 'unknown')
|
||||
domain = extract_domain_from_path(result_dir)
|
||||
append_task_result(task_id, domain, result, result_dir, args)
|
||||
|
||||
|
||||
def log_task_error(example: Dict, error_msg: str, result_dir: str, args: Any) -> None:
|
||||
"""
|
||||
Convenience wrapper for logging task errors.
|
||||
|
||||
Args:
|
||||
example: Task configuration dictionary
|
||||
error_msg: Error message
|
||||
result_dir: Path to task result directory
|
||||
args: Command line arguments
|
||||
"""
|
||||
task_id = example.get('id', 'unknown')
|
||||
domain = extract_domain_from_path(result_dir)
|
||||
append_task_result(task_id, domain, 0.0, result_dir, args, error_msg)
|
||||
|
|
@ -4,22 +4,18 @@ import logging
|
|||
import os
|
||||
import time
|
||||
from wrapt_timeout_decorator import *
|
||||
from lib_results_logger import log_task_completion
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
|
||||
def run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir, scores):
|
||||
runtime_logger = setup_logger(example, example_result_dir)
|
||||
|
||||
# Reset environment first to get fresh VM IP
|
||||
env.reset(task_config=example)
|
||||
|
||||
# Reset agent with fresh VM IP (for snapshot reverts)
|
||||
try:
|
||||
agent.reset(runtime_logger, vm_ip=env.vm_ip)
|
||||
agent.reset(runtime_logger)
|
||||
except Exception as e:
|
||||
agent.reset(vm_ip=env.vm_ip)
|
||||
agent.reset()
|
||||
|
||||
env.reset(task_config=example)
|
||||
|
||||
time.sleep(60) # Wait for the environment to be ready
|
||||
obs = env._get_obs() # Get the initial observation
|
||||
|
|
@ -33,7 +29,7 @@ def run_single_example(agent, env, example, max_steps, instruction, args, exampl
|
|||
)
|
||||
for action in actions:
|
||||
# Capture the timestamp before executing the action
|
||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S%f")
|
||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
logger.info("Step %d: %s", step_idx + 1, action)
|
||||
obs, reward, done, info = env.step(action, args.sleep_after_execution)
|
||||
|
||||
|
|
@ -48,7 +44,6 @@ def run_single_example(agent, env, example, max_steps, instruction, args, exampl
|
|||
"step_num": step_idx + 1,
|
||||
"action_timestamp": action_timestamp,
|
||||
"action": action,
|
||||
"response": response,
|
||||
"reward": reward,
|
||||
"done": done,
|
||||
"info": info,
|
||||
|
|
@ -59,16 +54,11 @@ def run_single_example(agent, env, example, max_steps, instruction, args, exampl
|
|||
logger.info("The episode is done.")
|
||||
break
|
||||
step_idx += 1
|
||||
time.sleep(20) # Wait for the environment to settle
|
||||
result = env.evaluate()
|
||||
logger.info("Result: %.2f", result)
|
||||
scores.append(result)
|
||||
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{result}\n")
|
||||
|
||||
# Log task completion to results.json
|
||||
log_task_completion(example, result, example_result_dir, args)
|
||||
|
||||
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
|
||||
|
||||
|
||||
|
|
@ -105,67 +95,6 @@ def run_single_example_human(env, example, max_steps, instruction, args, example
|
|||
|
||||
|
||||
|
||||
def run_single_example_agi(agent, env, example, max_steps, instruction, args, example_result_dir, scores):
|
||||
runtime_logger = setup_logger(example, example_result_dir)
|
||||
agent.reset(runtime_logger)
|
||||
env.reset(task_config=example)
|
||||
time.sleep(60) # Wait for the environment to be ready
|
||||
obs = env._get_obs() # Get the initial observation
|
||||
done = False
|
||||
step_idx = 0
|
||||
env.controller.start_recording()
|
||||
while not done and step_idx < max_steps:
|
||||
response, actions = agent.predict(
|
||||
instruction,
|
||||
obs
|
||||
)
|
||||
|
||||
done = not response.get('state_correct', False)
|
||||
|
||||
for action in actions:
|
||||
# Capture the timestamp before executing the action
|
||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
logger.info("Step %d: %s", step_idx + 1, action)
|
||||
obs, reward, done, info, step_info = agent.step(action)
|
||||
|
||||
if not done:
|
||||
if not response.get('state_correct', False):
|
||||
done = True
|
||||
|
||||
logger.info("Reward: %.2f", reward)
|
||||
logger.info("Done: %s", done)
|
||||
# Save screenshot and trajectory information
|
||||
with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"),
|
||||
"wb") as _f:
|
||||
_f.write(obs['screenshot'])
|
||||
|
||||
# Remove pending checks if they exist which will cause issues with json serialization
|
||||
if action.get('pending_checks', None):
|
||||
del action['pending_checks']
|
||||
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||
f.write(json.dumps({
|
||||
"step_num": step_idx + 1,
|
||||
"action_timestamp": action_timestamp,
|
||||
"action": action,
|
||||
"reward": reward,
|
||||
"done": done,
|
||||
"info": info,
|
||||
"screenshot_file": f"step_{step_idx + 1}_{action_timestamp}.png"
|
||||
}))
|
||||
f.write("\n")
|
||||
if done:
|
||||
logger.info("The episode is done.")
|
||||
break
|
||||
step_idx += 1
|
||||
result = env.evaluate()
|
||||
logger.info("Result: %.2f", result)
|
||||
scores.append(result)
|
||||
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{result}\n")
|
||||
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
|
||||
|
||||
|
||||
def run_single_example_openaicua(agent, env, example, max_steps, instruction, args, example_result_dir, scores):
|
||||
runtime_logger = setup_logger(example, example_result_dir)
|
||||
agent.reset(runtime_logger)
|
||||
|
|
@ -256,189 +185,6 @@ def run_single_example_opencua(agent, env, example, max_steps, instruction, args
|
|||
"wb") as _f:
|
||||
_f.write(obs['screenshot'])
|
||||
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps({
|
||||
"step_num": step_idx + 1,
|
||||
"action": action,
|
||||
"natural_language_action": info_dict.get("action"),
|
||||
"action_timestamp": action_timestamp,
|
||||
"response": response,
|
||||
"reward": reward,
|
||||
"done": done,
|
||||
"info": info,
|
||||
"screenshot_file": f"step_{step_idx + 1}_{action_timestamp}.png"
|
||||
}, ensure_ascii=False))
|
||||
f.write("\n")
|
||||
if done:
|
||||
logger.info("The episode is done.")
|
||||
break
|
||||
step_idx += 1
|
||||
|
||||
time.sleep(20) # Wait for the environment to settle
|
||||
result = env.evaluate()
|
||||
logger.info("Result: %.2f", result)
|
||||
scores.append(result)
|
||||
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{result}\n")
|
||||
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
|
||||
|
||||
def run_single_example_autoglm(agent, env, example, max_steps, instruction, args, example_result_dir, scores):
|
||||
runtime_logger = setup_logger(example, example_result_dir)
|
||||
try:
|
||||
agent.reset(runtime_logger)
|
||||
except Exception as e:
|
||||
agent.reset()
|
||||
|
||||
env.reset(task_config=example)
|
||||
|
||||
time.sleep(60) # Wait for the environment to be ready
|
||||
obs = env._get_obs() # Get the initial observation
|
||||
done = False
|
||||
step_idx = 0
|
||||
env.controller.start_recording()
|
||||
while not done and step_idx < max_steps:
|
||||
response, actions = agent.predict(
|
||||
instruction,
|
||||
obs
|
||||
)
|
||||
for action in actions:
|
||||
# Capture the timestamp before executing the action
|
||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
logger.info("Step %d: %s", step_idx + 1, action)
|
||||
obs, reward, done, info = env.step(action, args.sleep_after_execution)
|
||||
|
||||
logger.info("Reward: %.2f", reward)
|
||||
logger.info("Done: %s", done)
|
||||
# Save screenshot and trajectory information
|
||||
with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"),
|
||||
"wb") as _f:
|
||||
_f.write(obs['screenshot'])
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||
f.write(json.dumps({
|
||||
"step_num": step_idx + 1,
|
||||
"action_timestamp": action_timestamp,
|
||||
"action": action,
|
||||
"response": response,
|
||||
"reward": reward,
|
||||
"done": done,
|
||||
"info": info,
|
||||
"screenshot_file": f"step_{step_idx + 1}_{action_timestamp}.png"
|
||||
}))
|
||||
f.write("\n")
|
||||
|
||||
if done:
|
||||
logger.info("The episode is done.")
|
||||
break
|
||||
|
||||
# Invalid Action
|
||||
if not actions:
|
||||
obs = env._get_obs() # update observation
|
||||
|
||||
step_idx += 1
|
||||
|
||||
if not done: # not completed the task yet
|
||||
env.action_history.append('FAIL')
|
||||
|
||||
result = env.evaluate()
|
||||
logger.info("Result: %.2f", result)
|
||||
scores.append(result)
|
||||
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{result}\n")
|
||||
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
|
||||
|
||||
def run_single_example_mano(agent, env, example, max_steps, instruction, args, example_result_dir, scores):
|
||||
runtime_logger = setup_logger(example, example_result_dir)
|
||||
agent.reset(runtime_logger)
|
||||
env.reset(task_config=example)
|
||||
time.sleep(60) # Wait for the environment to be ready
|
||||
obs = env._get_obs() # Get the initial observation
|
||||
done = False
|
||||
step_idx = 0
|
||||
env.controller.start_recording()
|
||||
|
||||
with open(os.path.join(example_result_dir, f"step_0.png"),
|
||||
"wb") as _f:
|
||||
_f.write(obs['screenshot'])
|
||||
while not done and step_idx < max_steps:
|
||||
response, actions = agent.predict(
|
||||
instruction,
|
||||
obs
|
||||
)
|
||||
if len(actions) > 1:
|
||||
if (("pyautogui.hotkey('shift')" in actions[0] or "pyautogui.hotkey('ctrl')" in actions[0])
|
||||
and "pyautogui.click" in actions[1]):
|
||||
hotkey_type = 'shift' if "shift" in actions[0] else 'ctrl'
|
||||
action = f"pyautogui.keyDown('{hotkey_type}')\n{actions[1]}\npyautogui.keyUp('{hotkey_type}')"
|
||||
actions = [action]
|
||||
|
||||
for action in actions:
|
||||
# Capture the timestamp before executing the action
|
||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
logger.info("Step %d: %s", step_idx + 1, action)
|
||||
obs, reward, done, info = env.step(action, args.sleep_after_execution)
|
||||
|
||||
logger.info("Reward: %.2f", reward)
|
||||
logger.info("Done: %s", done)
|
||||
# Save screenshot and trajectory information
|
||||
with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"),
|
||||
"wb") as _f:
|
||||
_f.write(obs['screenshot'])
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||
f.write(json.dumps({
|
||||
"step_num": step_idx + 1,
|
||||
"action_timestamp": action_timestamp,
|
||||
"action": action,
|
||||
"reward": reward,
|
||||
"done": done,
|
||||
"info": info,
|
||||
"screenshot_file": f"step_{step_idx + 1}_{action_timestamp}.png",
|
||||
"response":response
|
||||
}))
|
||||
f.write("\n")
|
||||
if done:
|
||||
logger.info("The episode is done.")
|
||||
break
|
||||
step_idx += 1
|
||||
result = env.evaluate()
|
||||
logger.info("Result: %.2f", result)
|
||||
scores.append(result)
|
||||
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{result}\n")
|
||||
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
|
||||
|
||||
def run_single_example_uipath(agent, env, example, max_steps, instruction, args, example_result_dir, scores):
|
||||
runtime_logger = setup_logger(example, example_result_dir)
|
||||
try:
|
||||
agent.reset(runtime_logger)
|
||||
except Exception as e:
|
||||
agent.reset()
|
||||
|
||||
env.reset(task_config=example)
|
||||
|
||||
time.sleep(60) # Wait for the environment to be ready
|
||||
obs = env._get_obs() # Get the initial observation
|
||||
done = False
|
||||
step_idx = 0
|
||||
env.controller.start_recording()
|
||||
while not done and step_idx < max_steps:
|
||||
response, actions = agent.predict(
|
||||
instruction,
|
||||
obs,
|
||||
args,
|
||||
step_idx
|
||||
)
|
||||
for action in actions:
|
||||
# Capture the timestamp before executing the action
|
||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
logger.info("Step %d: %s", step_idx + 1, action)
|
||||
obs, reward, done, info = env.step(action, args.sleep_after_execution)
|
||||
|
||||
logger.info("Reward: %.2f", reward)
|
||||
logger.info("Done: %s", done)
|
||||
# Save screenshot and trajectory information
|
||||
with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"),
|
||||
"wb") as _f:
|
||||
_f.write(obs['screenshot'])
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||
f.write(json.dumps({
|
||||
"step_num": step_idx + 1,
|
||||
|
|
@ -455,191 +201,10 @@ def run_single_example_uipath(agent, env, example, max_steps, instruction, args,
|
|||
logger.info("The episode is done.")
|
||||
break
|
||||
step_idx += 1
|
||||
|
||||
result = env.evaluate()
|
||||
logger.info("Result: %.2f", result)
|
||||
scores.append(result)
|
||||
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{result}\n")
|
||||
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
|
||||
|
||||
|
||||
from mm_agents.os_symphony.utils.common_utils import draw_coordinates
|
||||
from mm_agents.os_symphony.utils.process_context import set_current_result_dir
|
||||
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
def run_single_example_os_symphony(agent, env, example, max_steps, instruction, args, example_result_dir, scores):
|
||||
set_current_result_dir(example_result_dir)
|
||||
|
||||
agent.reset(result_dir=example_result_dir)
|
||||
env.reset(task_config=example)
|
||||
time.sleep(30) # Wait for the environment to be ready
|
||||
obs = env._get_obs() # Get the initial observation
|
||||
done = False
|
||||
step_idx = 0
|
||||
# env.controller.start_recording()
|
||||
start_time = time.time()
|
||||
|
||||
while not done and step_idx < max_steps:
|
||||
response, actions = agent.predict(
|
||||
instruction,
|
||||
obs,
|
||||
step_idx == max_steps - 1
|
||||
)
|
||||
for action in actions:
|
||||
# Save screenshot and trajectory information
|
||||
if "reflection" in response and response["reflection"].get("is_milestone"):
|
||||
img_name = f"step_{step_idx + 1}_milestone.png"
|
||||
else:
|
||||
img_name = f"step_{step_idx + 1}.png"
|
||||
|
||||
with open(os.path.join(example_result_dir, img_name),
|
||||
"wb") as _f:
|
||||
_f.write(obs['screenshot'])
|
||||
if "coordinates" in response and response["coordinates"]:
|
||||
draw_coordinates(
|
||||
image_bytes=obs['screenshot'],
|
||||
coordinates=response["coordinates"],
|
||||
save_path=os.path.join(example_result_dir, img_name[:-4] + "_draw.png")
|
||||
)
|
||||
|
||||
logger.info("Step %d: %s", step_idx + 1, action)
|
||||
obs, reward, done, info = env.step(action, args.sleep_after_execution)
|
||||
logger.info("Done: %s", done)
|
||||
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps({
|
||||
"instruction": instruction,
|
||||
"step_num": step_idx + 1,
|
||||
"action": action,
|
||||
"response": response,
|
||||
"done": done,
|
||||
"info": info,
|
||||
"screenshot_file": img_name
|
||||
}))
|
||||
f.write("\n")
|
||||
with open(os.path.join(example_result_dir, f"traj_{step_idx+1}.json"), "w", encoding="utf-8") as f:
|
||||
json.dump({
|
||||
"step_num": step_idx + 1,
|
||||
"action": action,
|
||||
"response": response,
|
||||
"done": done,
|
||||
"info": info,
|
||||
"screenshot_file": img_name
|
||||
}, f, indent=4, ensure_ascii=False)
|
||||
if done:
|
||||
logger.info("The episode is done.")
|
||||
time.sleep(60)
|
||||
break
|
||||
step_idx += 1
|
||||
end_time = time.time()
|
||||
result = float(env.evaluate())
|
||||
logger.info("Result: %.2f", result)
|
||||
scores.append(result)
|
||||
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{result}\n")
|
||||
|
||||
with open(os.path.join(example_result_dir, "time.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{end_time-start_time:.2f}\n")
|
||||
|
||||
|
||||
def run_single_example_evocua(agent, env, example, max_steps, instruction, args, example_result_dir, scores):
|
||||
"""
|
||||
Unified run function for EvoCUAAgent (supporting both S1 and S2 modes).
|
||||
"""
|
||||
runtime_logger = setup_logger(example, example_result_dir)
|
||||
|
||||
# Reset Environment
|
||||
env.reset(task_config=example)
|
||||
|
||||
# Reset Agent
|
||||
# Handle agent reset signature differences if any
|
||||
try:
|
||||
agent.reset(runtime_logger, vm_ip=env.vm_ip)
|
||||
except Exception:
|
||||
try:
|
||||
agent.reset(runtime_logger)
|
||||
except Exception:
|
||||
agent.reset()
|
||||
|
||||
time.sleep(60) # Wait for the environment to be ready
|
||||
obs = env._get_obs() # Get the initial observation
|
||||
done = False
|
||||
step_idx = 0
|
||||
|
||||
env.controller.start_recording()
|
||||
while not done and step_idx < max_steps:
|
||||
# EvoCUAAgent.predict unified signature: returns (response, actions)
|
||||
# It handles both modes internally.
|
||||
predict_res = agent.predict(instruction, obs)
|
||||
|
||||
# Check return signature logic
|
||||
if len(predict_res) == 3:
|
||||
# Compatibility with S1 original signature if agent was updated to match
|
||||
response, actions, info_dict = predict_res
|
||||
else:
|
||||
response, actions = predict_res
|
||||
info_dict = {}
|
||||
|
||||
logger.info(f"Step {step_idx + 1} Actions: {actions}")
|
||||
|
||||
# Break if no actions (fail-safe)
|
||||
if not actions or (len(actions) == 1 and (actions[0] == "" or "error" in actions[0].lower())):
|
||||
# Allow "FAIL" or "DONE" to process through execution loop if agent outputs them as actions
|
||||
if not (actions and actions[0] in ["FAIL", "DONE"]):
|
||||
logger.warning("No valid actions returned. Breaking loop.")
|
||||
break
|
||||
|
||||
for action in actions:
|
||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S%f")
|
||||
logger.info("Executing action: %s", action)
|
||||
|
||||
# Execute
|
||||
obs, reward, done, info = env.step(action, args.sleep_after_execution)
|
||||
|
||||
logger.info("Reward: %.2f", reward)
|
||||
logger.info("Done: %s", done)
|
||||
|
||||
# Save screenshot
|
||||
screenshot_file = f"step_{step_idx + 1}_{action_timestamp}.png"
|
||||
with open(os.path.join(example_result_dir, screenshot_file), "wb") as _f:
|
||||
_f.write(obs['screenshot'])
|
||||
|
||||
# Log Trajectory
|
||||
log_entry = {
|
||||
"step_num": step_idx + 1,
|
||||
"action_timestamp": action_timestamp,
|
||||
"action": action,
|
||||
"response": response,
|
||||
"reward": reward,
|
||||
"done": done,
|
||||
"info": info,
|
||||
"screenshot_file": screenshot_file
|
||||
}
|
||||
# Add natural language info if available (S1 style)
|
||||
if info_dict:
|
||||
log_entry["natural_language_action"] = info_dict.get("action")
|
||||
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(log_entry, ensure_ascii=False))
|
||||
f.write("\n")
|
||||
|
||||
if done:
|
||||
logger.info("The episode is done.")
|
||||
break
|
||||
|
||||
step_idx += 1
|
||||
|
||||
time.sleep(20) # Wait for environment to settle
|
||||
result = env.evaluate()
|
||||
logger.info("Result: %.2f", result)
|
||||
scores.append(result)
|
||||
|
||||
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{result}\n")
|
||||
|
||||
log_task_completion(example, result, example_result_dir, args)
|
||||
|
||||
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
|
||||
|
||||
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
|
||||
|
|
@ -1,73 +0,0 @@
|
|||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from wrapt_timeout_decorator import *
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
|
||||
def run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir, scores):
|
||||
runtime_logger = setup_logger(example, example_result_dir)
|
||||
try:
|
||||
agent.reset(runtime_logger)
|
||||
except:
|
||||
agent.reset()
|
||||
env.reset(task_config=example)
|
||||
time.sleep(60) # Wait for the environment to be ready
|
||||
obs = env._get_obs() # Get the initial observation
|
||||
|
||||
done = False
|
||||
step_idx = 0
|
||||
|
||||
# save the first step
|
||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
with open(os.path.join(example_result_dir, f"step_{step_idx}_{action_timestamp}.png"), "wb") as _f:
|
||||
_f.write(obs['screenshot'])
|
||||
|
||||
eval_flag = True
|
||||
env.controller.start_recording()
|
||||
while not done and step_idx < max_steps:
|
||||
global_state, action_code, step_status, reward, done = agent.step(instruction, env, args)
|
||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
|
||||
if step_status is False:
|
||||
eval_flag = False
|
||||
done = True
|
||||
reward = None
|
||||
else:
|
||||
obs = env._get_obs()
|
||||
with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"),
|
||||
"wb") as _f:
|
||||
_f.write(obs['screenshot'])
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||
f.write(json.dumps({
|
||||
"step_num": step_idx + 1,
|
||||
"step_status": step_status,
|
||||
"action_timestamp": action_timestamp,
|
||||
"action": action_code,
|
||||
"reward": reward,
|
||||
"done": done,
|
||||
"screenshot_file": f"step_{step_idx + 1}_{action_timestamp}.png",
|
||||
}))
|
||||
f.write("\n")
|
||||
if done:
|
||||
logger.info("The episode is done.")
|
||||
break
|
||||
step_idx += 1
|
||||
|
||||
if eval_flag:
|
||||
result = env.evaluate()
|
||||
logger.info("Result: %.2f", result)
|
||||
scores.append(result)
|
||||
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{result}\n")
|
||||
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
|
||||
|
||||
|
||||
def setup_logger(example, example_result_dir):
|
||||
runtime_logger = logging.getLogger(f"desktopenv.example.{example['id']}")
|
||||
runtime_logger.setLevel(logging.DEBUG)
|
||||
runtime_logger.addHandler(logging.FileHandler(os.path.join(example_result_dir, "runtime.log")))
|
||||
return runtime_logger
|
||||
|
|
@ -1,84 +0,0 @@
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from wrapt_timeout_decorator import *
|
||||
from mm_agents.os_symphony.utils.common_utils import draw_coordinates
|
||||
from mm_agents.os_symphony.utils.process_context import set_current_result_dir
|
||||
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
def run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir, scores):
|
||||
set_current_result_dir(example_result_dir)
|
||||
|
||||
agent.reset(result_dir=example_result_dir)
|
||||
env.reset(task_config=example)
|
||||
time.sleep(30) # Wait for the environment to be ready
|
||||
obs = env._get_obs() # Get the initial observation
|
||||
done = False
|
||||
step_idx = 0
|
||||
# env.controller.start_recording()
|
||||
start_time = time.time()
|
||||
|
||||
while not done and step_idx < max_steps:
|
||||
response, actions = agent.predict(
|
||||
instruction,
|
||||
obs,
|
||||
step_idx == max_steps - 1
|
||||
)
|
||||
for action in actions:
|
||||
# Save screenshot and trajectory information
|
||||
if "reflection" in response and response["reflection"].get("is_milestone"):
|
||||
img_name = f"step_{step_idx + 1}_milestone.png"
|
||||
else:
|
||||
img_name = f"step_{step_idx + 1}.png"
|
||||
|
||||
with open(os.path.join(example_result_dir, img_name),
|
||||
"wb") as _f:
|
||||
_f.write(obs['screenshot'])
|
||||
if "coordinates" in response and response["coordinates"]:
|
||||
draw_coordinates(
|
||||
image_bytes=obs['screenshot'],
|
||||
coordinates=response["coordinates"],
|
||||
save_path=os.path.join(example_result_dir, img_name[:-4] + "_draw.png")
|
||||
)
|
||||
|
||||
logger.info("Step %d: %s", step_idx + 1, action)
|
||||
obs, reward, done, info = env.step(action, args.sleep_after_execution)
|
||||
logger.info("Done: %s", done)
|
||||
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps({
|
||||
"instruction": instruction,
|
||||
"step_num": step_idx + 1,
|
||||
"action": action,
|
||||
"response": response,
|
||||
"done": done,
|
||||
"info": info,
|
||||
"screenshot_file": img_name
|
||||
}))
|
||||
f.write("\n")
|
||||
with open(os.path.join(example_result_dir, f"traj_{step_idx+1}.json"), "w", encoding="utf-8") as f:
|
||||
json.dump({
|
||||
"step_num": step_idx + 1,
|
||||
"action": action,
|
||||
"response": response,
|
||||
"done": done,
|
||||
"info": info,
|
||||
"screenshot_file": img_name
|
||||
}, f, indent=4, ensure_ascii=False)
|
||||
if done:
|
||||
logger.info("The episode is done.")
|
||||
time.sleep(60)
|
||||
break
|
||||
step_idx += 1
|
||||
end_time = time.time()
|
||||
result = float(env.evaluate())
|
||||
logger.info("Result: %.2f", result)
|
||||
scores.append(result)
|
||||
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{result}\n")
|
||||
|
||||
with open(os.path.join(example_result_dir, "time.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{end_time-start_time:.2f}\n")
|
||||
|
|
@ -1134,12 +1134,10 @@ class PromptAgent:
|
|||
|
||||
return actions
|
||||
|
||||
def reset(self, _logger=None, vm_ip=None, **kwargs):
|
||||
def reset(self, _logger=None):
|
||||
global logger
|
||||
logger = _logger if _logger is not None else logging.getLogger("desktopenv.agent")
|
||||
|
||||
self.vm_ip = vm_ip
|
||||
|
||||
self.thoughts = []
|
||||
self.actions = []
|
||||
self.observations = []
|
||||
self.observations = []
|
||||
|
|
|
|||
|
|
@ -1,219 +0,0 @@
|
|||
import base64
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, List, Tuple, Any, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger("desktopenv.agent")
|
||||
|
||||
|
||||
class Timer:
|
||||
"""Context manager for timing code blocks."""
|
||||
|
||||
def __enter__(self):
|
||||
self.start = time.time()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.duration = time.time() - self.start
|
||||
|
||||
|
||||
class AGIAgent:
|
||||
"""Agent that communicates with your private AGI server for decision-making."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env,
|
||||
server_url: str = "https://your-private-agi-endpoint", # Contact the authors for access to a private deployment endpoint.
|
||||
platform: str = "ubuntu",
|
||||
action_space: str = "pyautogui",
|
||||
observation_type: str = "screenshot",
|
||||
max_trajectory_length: int = 100,
|
||||
client_password: str = "",
|
||||
provider_name: str = "aws",
|
||||
screen_width: int = 1920,
|
||||
screen_height: int = 1080,
|
||||
timeout: int = 1800,
|
||||
):
|
||||
"""Initialize the AGI client.
|
||||
|
||||
Args:
|
||||
env: The desktop environment
|
||||
server_url: URL of your private AGI server
|
||||
"""
|
||||
self.env = env
|
||||
self.server_url = server_url.rstrip("/")
|
||||
self.platform = platform
|
||||
self.action_space = action_space
|
||||
self.observation_type = observation_type
|
||||
self.max_trajectory_length = max_trajectory_length
|
||||
self.client_password = client_password
|
||||
self.provider_name = provider_name
|
||||
self.screen_width = screen_width
|
||||
self.screen_height = screen_height
|
||||
|
||||
# Session management
|
||||
self.session_id: Optional[str] = None
|
||||
self.instruction: Optional[str] = None
|
||||
|
||||
# HTTP client
|
||||
self.client = httpx.Client(timeout=timeout)
|
||||
|
||||
# Tracking
|
||||
self.thoughts = []
|
||||
self.actions = []
|
||||
self.observations = []
|
||||
|
||||
logger.info(f"Initialized AGIAgent with server URL: {self.server_url}")
|
||||
|
||||
def reset(self, runtime_logger=None):
|
||||
"""Reset the agent and create a new session on the server.
|
||||
|
||||
Args:
|
||||
runtime_logger: Optional logger for runtime information
|
||||
"""
|
||||
global logger
|
||||
logger = runtime_logger if runtime_logger is not None else logging.getLogger("desktopenv.agent")
|
||||
|
||||
# Clear local state
|
||||
self.thoughts = []
|
||||
self.actions = []
|
||||
self.observations = []
|
||||
self.session_id = None
|
||||
|
||||
logger.info("AGIAgent reset complete")
|
||||
|
||||
def _create_session(self, instruction: str) -> str:
|
||||
"""Create a new session on the server.
|
||||
|
||||
Args:
|
||||
instruction: The task instruction
|
||||
|
||||
Returns:
|
||||
The session ID
|
||||
|
||||
Equivalent curl request:
|
||||
curl -X POST {server_url}/sessions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"task_description": "{instruction}"}'
|
||||
"""
|
||||
try:
|
||||
# print(f"Creating session with instruction: {instruction}")
|
||||
# print(f"Server URL: {self.server_url}")
|
||||
response = self.client.post(
|
||||
f"{self.server_url}/sessions",
|
||||
json={"task_description": instruction}
|
||||
)
|
||||
response.raise_for_status()
|
||||
session_id = response.json()["session_id"]
|
||||
logger.info(f"Created session: {session_id}")
|
||||
return session_id
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create session: {e}")
|
||||
raise
|
||||
|
||||
def predict(self, instruction: str, obs: Dict) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
|
||||
"""Predict the next action based on the current observation.
|
||||
|
||||
Args:
|
||||
instruction: The task instruction
|
||||
obs: Observation dictionary containing 'screenshot' key with image bytes
|
||||
|
||||
Returns:
|
||||
Tuple of (predict_info dict, list of action dicts)
|
||||
"""
|
||||
# Create session on first prediction
|
||||
if self.session_id is None:
|
||||
self.instruction = instruction
|
||||
self.session_id = self._create_session(instruction)
|
||||
|
||||
# input("Session created, press Enter to continue")
|
||||
|
||||
# Encode screenshot to base64
|
||||
screenshot_bytes = obs["screenshot"]
|
||||
screenshot_b64 = base64.b64encode(screenshot_bytes).decode("utf-8")
|
||||
|
||||
# Call the server
|
||||
with Timer() as model_timer:
|
||||
try:
|
||||
response = self.client.post(
|
||||
f"{self.server_url}/sessions/{self.session_id}/step",
|
||||
json={
|
||||
"screenshot_base64_png": screenshot_b64,
|
||||
"error": None # Could be populated from previous step errors
|
||||
}
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
parsed_action = result["parsed_response"]
|
||||
|
||||
logger.info(f"Server returned action: {parsed_action[:100]}...")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling server: {e}")
|
||||
raise
|
||||
|
||||
# Format response as expected by lib_run_single
|
||||
actions = [{
|
||||
"action_space": "pyautogui",
|
||||
"action": parsed_action,
|
||||
"pending_checks": [],
|
||||
"call_id": ""
|
||||
}]
|
||||
|
||||
# Check if task is complete or failed
|
||||
state_correct = parsed_action not in ["FAIL", "DONE"]
|
||||
|
||||
predict_info = {
|
||||
"model_usage": {
|
||||
"model_time": model_timer.duration,
|
||||
"prompt_tokens": 0, # Server doesn't expose these
|
||||
"completion_tokens": 0,
|
||||
},
|
||||
"messages": [], # Server manages conversation history
|
||||
"response": parsed_action,
|
||||
"state_correct": state_correct,
|
||||
}
|
||||
|
||||
return predict_info, actions
|
||||
|
||||
def step(self, action: Dict[str, Any]) -> Tuple[Dict, float, bool, Dict, Dict]:
|
||||
"""Execute an action in the environment.
|
||||
|
||||
Args:
|
||||
action: Action dictionary with 'action' key containing PyAutoGUI command
|
||||
|
||||
Returns:
|
||||
Tuple of (observation, reward, done, info, step_info)
|
||||
"""
|
||||
try:
|
||||
if not action:
|
||||
logger.warning("Empty action received, terminating episode")
|
||||
# Get observation without executing action
|
||||
obs = self.env._get_obs()
|
||||
return obs, 0.0, True, {}, {"step_time": 0.0, "action": action}
|
||||
|
||||
action_str = action.get("action", "")
|
||||
logger.info(f"Executing action: {action_str[:100]}...")
|
||||
|
||||
with Timer() as step_timer:
|
||||
# Execute the action directly (it's already a PyAutoGUI command string)
|
||||
obs, reward, terminated, info = self.env.step(action_str)
|
||||
|
||||
logger.debug(f"Action completed in {step_timer.duration:.2f}s")
|
||||
if terminated:
|
||||
logger.info("Environment signaled termination")
|
||||
|
||||
return obs, reward, terminated, info, {
|
||||
"step_time": step_timer.duration,
|
||||
"action": action
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Environment step failed: {str(e)}")
|
||||
raise
|
||||
|
||||
def close(self):
|
||||
"""Close the HTTP client."""
|
||||
self.client.close()
|
||||
|
|
@ -17,7 +17,7 @@ from anthropic.types.beta import (
|
|||
BetaMessageParam,
|
||||
BetaTextBlockParam,
|
||||
)
|
||||
from .utils import COMPUTER_USE_BETA_FLAG, PROMPT_CACHING_BETA_FLAG,SYSTEM_PROMPT, SYSTEM_PROMPT_WINDOWS, APIProvider, PROVIDER_TO_DEFAULT_MODEL_NAME, get_model_name
|
||||
from .utils import COMPUTER_USE_BETA_FLAG, PROMPT_CACHING_BETA_FLAG,SYSTEM_PROMPT, SYSTEM_PROMPT_WINDOWS, APIProvider, PROVIDER_TO_DEFAULT_MODEL_NAME
|
||||
from .utils import _response_to_params, _inject_prompt_caching, _maybe_filter_to_n_most_recent_images
|
||||
|
||||
import logging
|
||||
|
|
@ -30,18 +30,14 @@ API_RETRY_INTERVAL = 5
|
|||
class AnthropicAgent:
|
||||
def __init__(self,
|
||||
platform: str = "Ubuntu",
|
||||
model: str = "claude-sonnet-4-5-20250929",
|
||||
provider: APIProvider = APIProvider.ANTHROPIC,
|
||||
model: str = "claude-3-5-sonnet-20241022",
|
||||
provider: APIProvider = APIProvider.BEDROCK,
|
||||
max_tokens: int = 4096,
|
||||
api_key: str = os.environ.get("ANTHROPIC_API_KEY", None),
|
||||
system_prompt_suffix: str = "",
|
||||
only_n_most_recent_images: Optional[int] = 10,
|
||||
action_space: str = "claude_computer_use",
|
||||
screen_size: tuple[int, int] = (1920, 1080),
|
||||
no_thinking: bool = False,
|
||||
use_isp: bool = False,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
*args, **kwargs
|
||||
):
|
||||
self.platform = platform
|
||||
|
|
@ -56,24 +52,10 @@ class AnthropicAgent:
|
|||
self.only_n_most_recent_images = only_n_most_recent_images
|
||||
self.messages: list[BetaMessageParam] = []
|
||||
self.screen_size = screen_size
|
||||
self.no_thinking = no_thinking
|
||||
self.use_isp = use_isp
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
|
||||
self.resize_factor = (
|
||||
screen_size[0] / 1280, # Assuming 1280 is the base width
|
||||
screen_size[1] / 720 # Assuming 720 is the base height
|
||||
)
|
||||
|
||||
def _get_sampling_params(self):
|
||||
"""Get sampling parameters (temperature and/or top_p) - let API validate exclusivity"""
|
||||
params = {}
|
||||
if self.temperature is not None:
|
||||
params['temperature'] = self.temperature
|
||||
if self.top_p is not None:
|
||||
params['top_p'] = self.top_p
|
||||
return params
|
||||
|
||||
def add_tool_result(self, tool_call_id: str, result: str, screenshot: bytes = None):
|
||||
"""Add tool result to message history"""
|
||||
|
|
@ -102,21 +84,6 @@ class AnthropicAgent:
|
|||
"content": tool_result_content
|
||||
})
|
||||
|
||||
def _extract_raw_response_string(self, response) -> str:
|
||||
"""Extract and concatenate raw response content into a single string."""
|
||||
raw_response_str = ""
|
||||
if response.content:
|
||||
for block in response.content:
|
||||
if hasattr(block, 'text') and block.text:
|
||||
raw_response_str += f"[TEXT] {block.text}\n"
|
||||
elif hasattr(block, 'thinking') and block.thinking:
|
||||
raw_response_str += f"[THINKING] {block.thinking}\n"
|
||||
elif hasattr(block, 'name') and hasattr(block, 'input'):
|
||||
raw_response_str += f"[TOOL_USE] {block.name}: {block.input}\n"
|
||||
else:
|
||||
raw_response_str += f"[OTHER] {str(block)}\n"
|
||||
return raw_response_str.strip()
|
||||
|
||||
def parse_actions_from_tool_call(self, tool_call: Dict) -> str:
|
||||
result = ""
|
||||
function_args = (
|
||||
|
|
@ -134,7 +101,6 @@ class AnthropicAgent:
|
|||
|
||||
text = function_args.get("text")
|
||||
coordinate = function_args.get("coordinate")
|
||||
start_coordinate = function_args.get("start_coordinate")
|
||||
scroll_direction = function_args.get("scroll_direction")
|
||||
scroll_amount = function_args.get("scroll_amount")
|
||||
duration = function_args.get("duration")
|
||||
|
|
@ -145,11 +111,6 @@ class AnthropicAgent:
|
|||
int(coordinate[0] * self.resize_factor[0]),
|
||||
int(coordinate[1] * self.resize_factor[1])
|
||||
)
|
||||
if start_coordinate and self.resize_factor:
|
||||
start_coordinate = (
|
||||
int(start_coordinate[0] * self.resize_factor[0]),
|
||||
int(start_coordinate[1] * self.resize_factor[1])
|
||||
)
|
||||
|
||||
if action == "left_mouse_down":
|
||||
result += "pyautogui.mouseDown()\n"
|
||||
|
|
@ -184,16 +145,6 @@ class AnthropicAgent:
|
|||
)
|
||||
expected_outcome = f"Mouse moved to ({x},{y})."
|
||||
elif action == "left_click_drag":
|
||||
# If start_coordinate is provided, validate and move to start before dragging
|
||||
if start_coordinate:
|
||||
if not isinstance(start_coordinate, (list, tuple)) or len(start_coordinate) != 2:
|
||||
raise ValueError(f"{start_coordinate} must be a tuple of length 2")
|
||||
if not all(isinstance(i, int) for i in start_coordinate):
|
||||
raise ValueError(f"{start_coordinate} must be a tuple of ints")
|
||||
start_x, start_y = start_coordinate[0], start_coordinate[1]
|
||||
result += (
|
||||
f"pyautogui.moveTo({start_x}, {start_y}, duration={duration or 0.5})\n"
|
||||
)
|
||||
result += (
|
||||
f"pyautogui.dragTo({x}, {y}, duration={duration or 0.5})\n"
|
||||
)
|
||||
|
|
@ -227,23 +178,13 @@ class AnthropicAgent:
|
|||
result += (f"pyautogui.keyUp('{key}')\n")
|
||||
expected_outcome = f"Key {key} pressed."
|
||||
elif action == "type":
|
||||
for char in text:
|
||||
if char == '\n':
|
||||
result += "pyautogui.press('enter')\n"
|
||||
elif char == "'":
|
||||
result += 'pyautogui.press("\'")\n'
|
||||
elif char == '\\':
|
||||
result += "pyautogui.press('\\\\')\n"
|
||||
elif char == '"':
|
||||
result += "pyautogui.press('\"')\n"
|
||||
else:
|
||||
result += f"pyautogui.press('{char}')\n"
|
||||
result += (
|
||||
f"pyautogui.typewrite(\"\"\"{text}\"\"\", interval=0.01)\n"
|
||||
)
|
||||
expected_outcome = f"Text {text} written."
|
||||
|
||||
# Handle scroll actions
|
||||
elif action == "scroll":
|
||||
if text is not None:
|
||||
result += (f"pyautogui.keyDown('{text.lower()}')\n")
|
||||
if coordinate is None:
|
||||
if scroll_direction in ("up", "down"):
|
||||
result += (
|
||||
|
|
@ -264,18 +205,10 @@ class AnthropicAgent:
|
|||
result += (
|
||||
f"pyautogui.hscroll({scroll_amount if scroll_direction == 'right' else -scroll_amount}, {x}, {y})\n"
|
||||
)
|
||||
if text is not None:
|
||||
result += (f"pyautogui.keyUp('{text.lower()}')\n")
|
||||
expected_outcome = "Scroll action finished"
|
||||
|
||||
# Handle click actions
|
||||
elif action in ("left_click", "right_click", "double_click", "middle_click", "left_press", "triple_click"):
|
||||
# Handle modifier keys during click if specified
|
||||
if text:
|
||||
keys = text.split('+')
|
||||
for key in keys:
|
||||
key = key.strip().lower()
|
||||
result += f"pyautogui.keyDown('{key}')\n"
|
||||
if coordinate is not None:
|
||||
x, y = coordinate
|
||||
if action == "left_click":
|
||||
|
|
@ -308,12 +241,6 @@ class AnthropicAgent:
|
|||
result += ("pyautogui.mouseUp()\n")
|
||||
elif action == "triple_click":
|
||||
result += ("pyautogui.tripleClick()\n")
|
||||
# Release modifier keys after click
|
||||
if text:
|
||||
keys = text.split('+')
|
||||
for key in reversed(keys):
|
||||
key = key.strip().lower()
|
||||
result += f"pyautogui.keyUp('{key}')\n"
|
||||
expected_outcome = "Click action finished"
|
||||
|
||||
elif action == "wait":
|
||||
|
|
@ -330,12 +257,60 @@ class AnthropicAgent:
|
|||
expected_outcome = "Call user"
|
||||
elif action == "screenshot":
|
||||
result += "pyautogui.sleep(0.1)\n"
|
||||
expected_outcome = "Screenshot taken"
|
||||
expected_outcome = "Screenshot taken"
|
||||
else:
|
||||
raise ValueError(f"Invalid action: {action}")
|
||||
|
||||
return result
|
||||
|
||||
def _trim_history(self, max_rounds=4):
|
||||
|
||||
messages = self.messages
|
||||
if not messages or len(messages) <= 1:
|
||||
return
|
||||
|
||||
# 计算需要保留的最近轮次数
|
||||
actual_max_rounds = max_rounds * 2
|
||||
|
||||
# 如果消息数量不超过限制,不需要处理
|
||||
if len(messages) <= actual_max_rounds:
|
||||
return
|
||||
|
||||
# 保留前3条消息(初始消息)和最近的actual_max_rounds条消息 messages[0:1] + messages[-actual_max_rounds:]
|
||||
keep_messages = []
|
||||
|
||||
# 对于中间被删除的消息,只保留非图片内容
|
||||
for i in range(1, len(messages) - actual_max_rounds):
|
||||
old_message = messages[i]
|
||||
if old_message["role"] == "user" and "content" in old_message:
|
||||
# 过滤掉image类型的内容块,保留其他类型
|
||||
filtered_content = []
|
||||
for content_block in old_message["content"]:
|
||||
filtered_content_item = []
|
||||
if content_block.get("type") == "tool_result":
|
||||
for content_block_item in content_block["content"]:
|
||||
if content_block_item.get("type") != "image":
|
||||
filtered_content_item.append(content_block_item)
|
||||
filtered_content.append({
|
||||
"type": content_block.get("type"),
|
||||
"tool_use_id": content_block.get("tool_use_id"),
|
||||
"content": filtered_content_item
|
||||
})
|
||||
else:
|
||||
filtered_content.append(content_block)
|
||||
|
||||
# 如果过滤后还有内容,则保留这条消息
|
||||
if filtered_content:
|
||||
keep_messages.append({
|
||||
"role": old_message["role"],
|
||||
"content": filtered_content
|
||||
})
|
||||
else:
|
||||
# 非用户消息或没有content的消息直接保留
|
||||
keep_messages.append(old_message)
|
||||
|
||||
self.messages = messages[0:1] + keep_messages + messages[-actual_max_rounds:]
|
||||
|
||||
def predict(self, task_instruction: str, obs: Dict = None, system: Any = None):
|
||||
system = BetaTextBlockParam(
|
||||
type="text",
|
||||
|
|
@ -348,9 +323,6 @@ class AnthropicAgent:
|
|||
screenshot_bytes = obs["screenshot"]
|
||||
screenshot_image = Image.open(io.BytesIO(screenshot_bytes))
|
||||
|
||||
# Store original unresized screenshot for zoom processing
|
||||
obs["screenshot_original"] = screenshot_bytes
|
||||
|
||||
# Calculate new size based on resize factor
|
||||
new_width, new_height = 1280, 720
|
||||
|
||||
|
|
@ -382,45 +354,23 @@ class AnthropicAgent:
|
|||
]
|
||||
})
|
||||
|
||||
# Add tool_result for ALL tool_use blocks in the last message
|
||||
if self.messages:
|
||||
last_message_content = self.messages[-1]["content"]
|
||||
tool_use_blocks = [block for block in last_message_content if block.get("type") == "tool_use"]
|
||||
|
||||
for i, tool_block in enumerate(tool_use_blocks):
|
||||
tool_input = tool_block.get("input", {})
|
||||
action = tool_input.get("action")
|
||||
is_last_tool = i == len(tool_use_blocks) - 1
|
||||
|
||||
include_screenshot = None
|
||||
|
||||
if obs:
|
||||
if action == "screenshot":
|
||||
# Screenshot action always gets regular screenshot
|
||||
include_screenshot = obs.get("screenshot")
|
||||
elif is_last_tool:
|
||||
# Auto-screenshot: last tool gets regular screenshot (unless it's zoom, handled above)
|
||||
include_screenshot = obs.get("screenshot")
|
||||
|
||||
self.add_tool_result(
|
||||
tool_block["id"],
|
||||
f"Success",
|
||||
screenshot=include_screenshot
|
||||
)
|
||||
if self.messages and "tool_use" in [content_block["type"] for content_block in self.messages[-1]["content"]]:
|
||||
self.add_tool_result(
|
||||
self.messages[-1]["content"][-1]["id"],
|
||||
f"Success",
|
||||
screenshot=obs.get("screenshot") if obs else None
|
||||
)
|
||||
|
||||
enable_prompt_caching = False
|
||||
betas = [COMPUTER_USE_BETA_FLAG]
|
||||
|
||||
# Add interleaved thinking beta if ISP is requested
|
||||
if self.use_isp:
|
||||
betas.append("interleaved-thinking-2025-05-14")
|
||||
logger.info(f"Added interleaved thinking beta. Betas: {betas}")
|
||||
betas = ["computer-use-2025-01-24"]
|
||||
if self.model_name == "claude-3-7-sonnet-20250219" or self.model_name == "claude-4-opus-20250514" or self.model_name == "claude-4-sonnet-20250514":
|
||||
betas = ["computer-use-2025-01-24"]
|
||||
elif self.model_name == "claude-3-5-sonnet-20241022":
|
||||
betas = [COMPUTER_USE_BETA_FLAG]
|
||||
|
||||
image_truncation_threshold = 10
|
||||
if self.provider == APIProvider.ANTHROPIC:
|
||||
client = Anthropic(api_key=self.api_key, max_retries=4).with_options(
|
||||
default_headers={"anthropic-beta": COMPUTER_USE_BETA_FLAG}
|
||||
)
|
||||
client = Anthropic(api_key=self.api_key, max_retries=4)
|
||||
enable_prompt_caching = True
|
||||
elif self.provider == APIProvider.VERTEX:
|
||||
client = AnthropicVertex()
|
||||
|
|
@ -438,7 +388,7 @@ class AnthropicAgent:
|
|||
if enable_prompt_caching:
|
||||
betas.append(PROMPT_CACHING_BETA_FLAG)
|
||||
_inject_prompt_caching(self.messages)
|
||||
image_truncation_threshold = 20
|
||||
image_truncation_threshold = 50
|
||||
system["cache_control"] = {"type": "ephemeral"}
|
||||
|
||||
if self.only_n_most_recent_images:
|
||||
|
|
@ -448,128 +398,124 @@ class AnthropicAgent:
|
|||
min_removal_threshold=image_truncation_threshold,
|
||||
)
|
||||
|
||||
# Configure tool settings - use modern computer tool for all models
|
||||
tool_config = {
|
||||
'name': 'computer',
|
||||
'type': 'computer_20250124',
|
||||
'display_width_px': 1280,
|
||||
'display_height_px': 720,
|
||||
'display_number': 1
|
||||
}
|
||||
|
||||
tools = [
|
||||
tool_config,
|
||||
] if self.platform == 'Ubuntu' else [
|
||||
tool_config,
|
||||
]
|
||||
|
||||
# Configure thinking mode based on user preferences
|
||||
if self.no_thinking:
|
||||
# Disable thinking mode - omit the thinking parameter
|
||||
extra_body = {}
|
||||
actual_max_tokens = self.max_tokens # Use default when no thinking
|
||||
logger.info("Thinking mode: DISABLED")
|
||||
else:
|
||||
# Enable thinking mode (regular or interleaved)
|
||||
# Use consistent 2048 budget for both regular and ISP thinking
|
||||
budget_tokens = 2048
|
||||
|
||||
# For regular thinking: max_tokens > budget_tokens (API requirement)
|
||||
# For ISP: budget_tokens can exceed max_tokens (represents total across all thinking blocks)
|
||||
if self.max_tokens <= budget_tokens:
|
||||
required_max_tokens = budget_tokens + 500 # Give some headroom
|
||||
logger.warning(f"Regular thinking requires max_tokens > budget_tokens. Increasing max_tokens from {self.max_tokens} to {required_max_tokens}")
|
||||
actual_max_tokens = required_max_tokens
|
||||
else:
|
||||
actual_max_tokens = self.max_tokens
|
||||
|
||||
extra_body = {
|
||||
"thinking": {"type": "enabled", "budget_tokens": budget_tokens}
|
||||
}
|
||||
if self.use_isp:
|
||||
logger.info("Thinking mode: INTERLEAVED SCRATCHPAD (ISP)")
|
||||
else:
|
||||
logger.info("Thinking mode: REGULAR SCRATCHPAD")
|
||||
#self._trim_history(max_rounds=MAX_HISTORY)
|
||||
|
||||
try:
|
||||
if self.model_name == "claude-3-5-sonnet-20241022":
|
||||
tools = [
|
||||
{'name': 'computer', 'type': 'computer_20241022', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1},
|
||||
# {'type': 'bash_20241022', 'name': 'bash'},
|
||||
# {'name': 'str_replace_editor', 'type': 'text_editor_20241022'}
|
||||
] if self.platform == 'Ubuntu' else [
|
||||
{'name': 'computer', 'type': 'computer_20241022', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1},
|
||||
]
|
||||
elif self.model_name in ["claude-3-7-sonnet-20250219", "claude-4-opus-20250514", "claude-4-sonnet-20250514"]:
|
||||
tools = [
|
||||
{'name': 'computer', 'type': 'computer_20250124', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1},
|
||||
# {'type': 'bash_20250124', 'name': 'bash'},
|
||||
# {'name': 'str_replace_editor', 'type': 'text_editor_20250124'}
|
||||
] if self.platform == 'Ubuntu' else [
|
||||
{'name': 'computer', 'type': 'computer_20250124', 'display_width_px': 1280, 'display_height_px': 720, 'display_number': 1},
|
||||
]
|
||||
extra_body = {
|
||||
"thinking": {"type": "enabled", "budget_tokens": 1024}
|
||||
}
|
||||
response = None
|
||||
|
||||
for attempt in range(API_RETRY_TIMES):
|
||||
try:
|
||||
response = client.beta.messages.create(
|
||||
max_tokens=actual_max_tokens,
|
||||
messages=self.messages,
|
||||
model=get_model_name(self.provider, self.model_name),
|
||||
system=[system],
|
||||
tools=tools,
|
||||
betas=betas,
|
||||
extra_body=extra_body,
|
||||
**self._get_sampling_params()
|
||||
)
|
||||
|
||||
if self.model_name in ["claude-3-7-sonnet-20250219", "claude-4-opus-20250514", "claude-4-sonnet-20250514"]:
|
||||
response = client.beta.messages.create(
|
||||
max_tokens=self.max_tokens,
|
||||
messages=self.messages,
|
||||
model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name],
|
||||
system=[system],
|
||||
tools=tools,
|
||||
betas=betas,
|
||||
extra_body=extra_body
|
||||
)
|
||||
elif self.model_name == "claude-3-5-sonnet-20241022":
|
||||
response = client.beta.messages.create(
|
||||
max_tokens=self.max_tokens,
|
||||
messages=self.messages,
|
||||
model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name],
|
||||
system=[system],
|
||||
tools=tools,
|
||||
betas=betas,
|
||||
)
|
||||
logger.info(f"Response: {response}")
|
||||
break
|
||||
break # 成功则跳出重试循环
|
||||
except (APIError, APIStatusError, APIResponseValidationError) as e:
|
||||
error_msg = str(e)
|
||||
logger.warning(f"Anthropic API error (attempt {attempt+1}/{API_RETRY_TIMES}): {error_msg}")
|
||||
|
||||
# 检查是否是25MB限制错误
|
||||
if "25000000" in error_msg or "Member must have length less than or equal to" in error_msg:
|
||||
logger.warning("Detected 25MB limit error, automatically reducing image count")
|
||||
logger.warning("检测到25MB限制错误,自动裁剪图片数量")
|
||||
# 将图片数量减半
|
||||
current_image_count = self.only_n_most_recent_images
|
||||
new_image_count = max(1, current_image_count // 2) # Keep at least 1 image
|
||||
new_image_count = max(1, current_image_count // 2) # 至少保留1张图片
|
||||
self.only_n_most_recent_images = new_image_count
|
||||
|
||||
# 重新应用图片过滤
|
||||
_maybe_filter_to_n_most_recent_images(
|
||||
self.messages,
|
||||
new_image_count,
|
||||
min_removal_threshold=image_truncation_threshold,
|
||||
)
|
||||
logger.info(f"Image count reduced from {current_image_count} to {new_image_count}")
|
||||
logger.info(f"图片数量已从 {current_image_count} 减少到 {new_image_count}")
|
||||
|
||||
if attempt < API_RETRY_TIMES - 1:
|
||||
time.sleep(API_RETRY_INTERVAL)
|
||||
else:
|
||||
raise # All attempts failed, raise exception to enter existing except logic
|
||||
raise # 全部失败后抛出异常,进入原有except逻辑
|
||||
|
||||
except (APIError, APIStatusError, APIResponseValidationError) as e:
|
||||
logger.exception(f"Anthropic API error: {str(e)}")
|
||||
try:
|
||||
logger.warning("Retrying with backup API key...")
|
||||
|
||||
backup_client = Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY_BACKUP"), max_retries=4).with_options(
|
||||
default_headers={"anthropic-beta": COMPUTER_USE_BETA_FLAG}
|
||||
)
|
||||
response = backup_client.beta.messages.create(
|
||||
max_tokens=actual_max_tokens,
|
||||
messages=self.messages,
|
||||
model=get_model_name(self.provider, self.model_name),
|
||||
system=[system],
|
||||
tools=tools,
|
||||
betas=betas,
|
||||
extra_body=extra_body,
|
||||
**self._get_sampling_params()
|
||||
)
|
||||
|
||||
backup_client = Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY_BACKUP"), max_retries=4)
|
||||
if self.model_name in ["claude-3-7-sonnet-20250219", "claude-4-opus-20250514", "claude-4-sonnet-20250514"]:
|
||||
response = backup_client.beta.messages.create(
|
||||
max_tokens=self.max_tokens,
|
||||
messages=self.messages,
|
||||
model=PROVIDER_TO_DEFAULT_MODEL_NAME[APIProvider.ANTHROPIC, self.model_name],
|
||||
system=[system],
|
||||
tools=tools,
|
||||
betas=betas,
|
||||
extra_body=extra_body
|
||||
)
|
||||
elif self.model_name == "claude-3-5-sonnet-20241022":
|
||||
response = backup_client.beta.messages.create(
|
||||
max_tokens=self.max_tokens,
|
||||
messages=self.messages,
|
||||
model=PROVIDER_TO_DEFAULT_MODEL_NAME[APIProvider.ANTHROPIC, self.model_name],
|
||||
system=[system],
|
||||
tools=tools,
|
||||
betas=betas,
|
||||
)
|
||||
logger.info("Successfully used backup API key")
|
||||
except Exception as backup_e:
|
||||
backup_error_msg = str(backup_e)
|
||||
logger.exception(f"Backup API call also failed: {backup_error_msg}")
|
||||
|
||||
# Check if backup API also has 25MB limit error
|
||||
# 检查备用API是否也是25MB限制错误
|
||||
if "25000000" in backup_error_msg or "Member must have length less than or equal to" in backup_error_msg:
|
||||
logger.warning("Backup API also encountered 25MB limit error, further reducing image count")
|
||||
# Reduce image count by half again
|
||||
logger.warning("备用API也遇到25MB限制错误,进一步裁剪图片数量")
|
||||
# 将图片数量再减半
|
||||
current_image_count = self.only_n_most_recent_images
|
||||
new_image_count = max(1, current_image_count // 2) # Keep at least 1 image
|
||||
new_image_count = max(1, current_image_count // 2) # 至少保留1张图片
|
||||
self.only_n_most_recent_images = new_image_count
|
||||
|
||||
# Reapply image filtering
|
||||
# 重新应用图片过滤
|
||||
_maybe_filter_to_n_most_recent_images(
|
||||
self.messages,
|
||||
new_image_count,
|
||||
min_removal_threshold=image_truncation_threshold,
|
||||
)
|
||||
logger.info(f"Backup API image count reduced from {current_image_count} to {new_image_count}")
|
||||
logger.info(f"备用API图片数量已从 {current_image_count} 减少到 {new_image_count}")
|
||||
|
||||
return None, None
|
||||
|
||||
|
|
@ -577,16 +523,9 @@ class AnthropicAgent:
|
|||
logger.exception(f"Error in Anthropic API: {str(e)}")
|
||||
return None, None
|
||||
|
||||
if response is None:
|
||||
logger.error("Response is None after API call - this should not happen")
|
||||
return None, None
|
||||
|
||||
response_params = _response_to_params(response)
|
||||
logger.info(f"Received response params: {response_params}")
|
||||
|
||||
# Convert raw response to concatenated string for trajectory logging
|
||||
raw_response_str = self._extract_raw_response_string(response)
|
||||
|
||||
# Store response in message history
|
||||
self.messages.append({
|
||||
"role": "assistant",
|
||||
|
|
@ -605,8 +544,7 @@ class AnthropicAgent:
|
|||
"input": cast(dict[str, Any], content_block["input"]),
|
||||
"id": content_block["id"],
|
||||
"action_type": content_block.get("type"),
|
||||
"command": self.parse_actions_from_tool_call(content_block),
|
||||
"raw_response": raw_response_str # Add raw response to each action
|
||||
"command": self.parse_actions_from_tool_call(content_block)
|
||||
})
|
||||
elif content_block["type"] == "text":
|
||||
reasonings.append(content_block["text"])
|
||||
|
|
@ -614,45 +552,40 @@ class AnthropicAgent:
|
|||
reasonings = reasonings[0]
|
||||
else:
|
||||
reasonings = ""
|
||||
|
||||
# Check if the model indicated the task is infeasible
|
||||
if raw_response_str and "[INFEASIBLE]" in raw_response_str:
|
||||
logger.info("Detected [INFEASIBLE] pattern in response, triggering FAIL action")
|
||||
# Override actions with FAIL
|
||||
actions = [{
|
||||
"action_type": "FAIL",
|
||||
"raw_response": raw_response_str
|
||||
}]
|
||||
|
||||
logger.info(f"Received actions: {actions}")
|
||||
logger.info(f"Received reasonings: {reasonings}")
|
||||
if len(actions) == 0:
|
||||
actions = [{
|
||||
"action_type": "DONE",
|
||||
"raw_response": raw_response_str
|
||||
}]
|
||||
actions = ["DONE"]
|
||||
return reasonings, actions
|
||||
except Exception as e:
|
||||
logger.warning(f"parse_actions_from_tool_call parsing failed (attempt {parse_retry+1}/3), will retry API request: {e}")
|
||||
# Remove the recently appended assistant message to avoid polluting history
|
||||
logger.warning(f"parse_actions_from_tool_call解析失败(第{parse_retry+1}/3次),将重新请求API: {e}")
|
||||
# 删除刚刚append的assistant消息,避免污染history
|
||||
self.messages.pop()
|
||||
# Retry API request
|
||||
# 重新请求API
|
||||
response = None
|
||||
for attempt in range(API_RETRY_TIMES):
|
||||
try:
|
||||
response = client.beta.messages.create(
|
||||
max_tokens=actual_max_tokens,
|
||||
messages=self.messages,
|
||||
model=get_model_name(self.provider, self.model_name),
|
||||
system=[system],
|
||||
tools=tools,
|
||||
betas=betas,
|
||||
extra_body=extra_body,
|
||||
**self._get_sampling_params()
|
||||
)
|
||||
|
||||
if self.model_name in ["claude-3-7-sonnet-20250219", "claude-4-opus-20250514", "claude-4-sonnet-20250514"]:
|
||||
response = client.beta.messages.create(
|
||||
max_tokens=self.max_tokens,
|
||||
messages=self.messages,
|
||||
model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name],
|
||||
system=[system],
|
||||
tools=tools,
|
||||
betas=betas,
|
||||
extra_body=extra_body
|
||||
)
|
||||
elif self.model_name == "claude-3-5-sonnet-20241022":
|
||||
response = client.beta.messages.create(
|
||||
max_tokens=self.max_tokens,
|
||||
messages=self.messages,
|
||||
model=PROVIDER_TO_DEFAULT_MODEL_NAME[self.provider, self.model_name],
|
||||
system=[system],
|
||||
tools=tools,
|
||||
betas=betas,
|
||||
)
|
||||
logger.info(f"Response: {response}")
|
||||
break # Success, exit retry loop
|
||||
break # 成功则跳出重试循环
|
||||
except (APIError, APIStatusError, APIResponseValidationError) as e2:
|
||||
error_msg = str(e2)
|
||||
logger.warning(f"Anthropic API error (attempt {attempt+1}/{API_RETRY_TIMES}): {error_msg}")
|
||||
|
|
@ -662,20 +595,13 @@ class AnthropicAgent:
|
|||
raise
|
||||
response_params = _response_to_params(response)
|
||||
logger.info(f"Received response params: {response_params}")
|
||||
|
||||
# Update raw response string for retry case (will be used in next loop iteration)
|
||||
raw_response_str = self._extract_raw_response_string(response)
|
||||
|
||||
self.messages.append({
|
||||
"role": "assistant",
|
||||
"content": response_params
|
||||
})
|
||||
if parse_retry == max_parse_retry - 1:
|
||||
logger.error(f"parse_actions_from_tool_call parsing failed 3 times consecutively, terminating: {e}")
|
||||
actions = [{
|
||||
"action_type": "FAIL",
|
||||
"raw_response": f"Failed to parse actions from tool call after {max_parse_retry} attempts: {e}"
|
||||
}]
|
||||
logger.error(f"连续3次parse_actions_from_tool_call解析失败,终止: {e}")
|
||||
actions = ["FAIL"]
|
||||
return reasonings, actions
|
||||
def reset(self, _logger = None, *args, **kwargs):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ from datetime import datetime
|
|||
from .tools import ToolResult
|
||||
|
||||
|
||||
COMPUTER_USE_BETA_FLAG = "computer-use-2025-01-24"
|
||||
COMPUTER_USE_BETA_FLAG = "computer-use-2024-10-22"
|
||||
PROMPT_CACHING_BETA_FLAG = "prompt-caching-2024-07-31"
|
||||
|
||||
|
||||
|
|
@ -47,25 +47,12 @@ PROVIDER_TO_DEFAULT_MODEL_NAME: dict[(APIProvider, str), str] = {
|
|||
(APIProvider.ANTHROPIC, "claude-4-opus-20250514"): "claude-4-opus-20250514",
|
||||
(APIProvider.BEDROCK, "claude-4-opus-20250514"): "us.anthropic.claude-opus-4-20250514-v1:0",
|
||||
(APIProvider.VERTEX, "claude-4-opus-20250514"): "claude-4-opus-v1@20250514",
|
||||
# Add mapping for the alternative model name format
|
||||
(APIProvider.ANTHROPIC, "claude-opus-4-20250514"): "claude-opus-4-20250514",
|
||||
(APIProvider.ANTHROPIC, "claude-opus-4-1-20250805"): "claude-opus-4-1-20250805",
|
||||
(APIProvider.ANTHROPIC, "claude-4-sonnet-20250514"): "claude-4-sonnet-20250514",
|
||||
(APIProvider.ANTHROPIC, "claude-sonnet-4-20250514"): "claude-sonnet-4-20250514",
|
||||
(APIProvider.BEDROCK, "claude-4-sonnet-20250514"): "us.anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
(APIProvider.VERTEX, "claude-4-sonnet-20250514"): "claude-sonnet-4-v1@20250514",
|
||||
}
|
||||
|
||||
|
||||
def get_model_name(provider: APIProvider, model_name: str) -> str:
|
||||
"""
|
||||
Get the actual model name to use for API calls.
|
||||
|
||||
Simply returns the model name as-is for direct API usage.
|
||||
"""
|
||||
return model_name
|
||||
|
||||
|
||||
# This system prompt is optimized for the Docker environment in this repository and
|
||||
# specific tool combinations enabled.
|
||||
# We encourage modifying this system prompt to ensure the model has context for the
|
||||
|
|
@ -80,15 +67,8 @@ SYSTEM_PROMPT = f"""<SYSTEM_CAPABILITY>
|
|||
* When viewing a page it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available.
|
||||
* DO NOT ask users for clarification during task execution. DO NOT stop to request more information from users. Always take action using available tools.
|
||||
* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request.
|
||||
* TASK FEASIBILITY: You can declare a task infeasible at any point during execution - whether at the beginning after taking a screenshot, or later after attempting some actions and discovering barriers. Carefully evaluate whether the task is feasible given the current system state, available applications, and task requirements. If you determine that a task cannot be completed due to:
|
||||
- Missing required applications or dependencies that cannot be installed
|
||||
- Insufficient permissions or system limitations
|
||||
- Contradictory or impossible requirements
|
||||
- Any other fundamental barriers that make completion impossible
|
||||
Then you MUST output exactly "[INFEASIBLE]" (including the square brackets) anywhere in your response to trigger the fail action. The system will automatically detect this pattern and terminate the task appropriately.
|
||||
* The current date is {datetime.today().strftime('%A, %B %d, %Y')}.
|
||||
* Home directory of this Ubuntu system is '/home/user'.
|
||||
* If you need a password for sudo, the password of the computer is 'osworld-public-evaluation'.
|
||||
</SYSTEM_CAPABILITY>
|
||||
|
||||
<IMPORTANT>
|
||||
|
|
@ -102,7 +82,6 @@ SYSTEM_PROMPT_WINDOWS = f"""<SYSTEM_CAPABILITY>
|
|||
* The current date is {datetime.today().strftime('%A, %B %d, %Y')}.
|
||||
* Home directory of this Windows system is 'C:\\Users\\user'.
|
||||
* When you want to open some applications on Windows, please use Double Click on it instead of clicking once.
|
||||
* If you need a password for sudo, The password of the computer is 'osworld-public-evaluation'.
|
||||
</SYSTEM_CAPABILITY>"""
|
||||
|
||||
|
||||
|
|
@ -175,30 +154,21 @@ def _inject_prompt_caching(
|
|||
one cache breakpoint is left for tools/system prompt, to be shared across sessions
|
||||
"""
|
||||
|
||||
breakpoints_remaining = 2 # Use full budget for recent messages
|
||||
messages_processed = 0
|
||||
|
||||
breakpoints_remaining = 3
|
||||
for message in reversed(messages):
|
||||
if message["role"] == "user" and isinstance(
|
||||
content := message["content"], list
|
||||
):
|
||||
messages_processed += 1
|
||||
# Check if this message would fit within the remaining budget
|
||||
if breakpoints_remaining >= len(content):
|
||||
# We have enough budget, spend it and add cache_control
|
||||
breakpoints_remaining -= len(content)
|
||||
if breakpoints_remaining:
|
||||
breakpoints_remaining -= 1
|
||||
# Use type ignore to bypass TypedDict check until SDK types are updated
|
||||
content[-1]["cache_control"] = BetaCacheControlEphemeralParam( # type: ignore
|
||||
{"type": "ephemeral"}
|
||||
)
|
||||
else:
|
||||
# Check if this is the first message (contains image + text with task description)
|
||||
is_first_message = messages_processed == len([msg for msg in messages if msg["role"] == "user"])
|
||||
|
||||
if not is_first_message:
|
||||
# Not enough budget, remove any existing cache_control from this message
|
||||
content[-1].pop("cache_control", None)
|
||||
# Continue to clean up older messages that might have cache_control from previous turns
|
||||
content[-1].pop("cache_control", None)
|
||||
# we'll only every have one extra turn per loop
|
||||
break
|
||||
|
||||
|
||||
def _maybe_filter_to_n_most_recent_images(
|
||||
|
|
@ -250,105 +220,6 @@ def _maybe_filter_to_n_most_recent_images(
|
|||
tool_result["content"] = new_content
|
||||
|
||||
|
||||
def validate_model_support(model_name: str, api_key: str = None, temperature: float = None, top_p: float = None, no_thinking: bool = False, use_isp: bool = False) -> bool:
|
||||
"""
|
||||
Validate model support with the same API call pattern as the main agent.
|
||||
|
||||
Args:
|
||||
model_name: The model name to validate
|
||||
api_key: Optional API key, defaults to ANTHROPIC_API_KEY env var
|
||||
temperature: Optional temperature parameter for testing
|
||||
top_p: Optional top_p parameter for testing
|
||||
no_thinking: Disable thinking mode (matches AnthropicAgent)
|
||||
use_isp: Use interleaved scratchpad mode (matches AnthropicAgent)
|
||||
|
||||
Returns:
|
||||
True if model is supported and API call succeeds, False otherwise
|
||||
"""
|
||||
print(f"🔍 Validating model support: {model_name}")
|
||||
|
||||
try:
|
||||
from anthropic import Anthropic
|
||||
import os
|
||||
import time
|
||||
|
||||
# Same client setup as main agent but with manual retry (max_retries=1 for faster feedback)
|
||||
client = Anthropic(
|
||||
api_key=api_key or os.environ.get("ANTHROPIC_API_KEY"),
|
||||
max_retries=4
|
||||
).with_options(default_headers={"anthropic-beta": COMPUTER_USE_BETA_FLAG})
|
||||
|
||||
# Same message format as main agent - always use structured format with cache_control
|
||||
messages = [{"role": "user", "content": [{"type": "text", "text": "Respond with 'OK'", "cache_control": {"type": "ephemeral"}}]}]
|
||||
|
||||
# Same betas configuration as main agent
|
||||
betas = [COMPUTER_USE_BETA_FLAG]
|
||||
if use_isp:
|
||||
betas.append("interleaved-thinking-2025-05-14")
|
||||
|
||||
system = [{"type": "text", "text": "You are Claude. Respond with 'OK'."}]
|
||||
|
||||
# Same tools configuration as main agent - use modern computer tool for all models
|
||||
tools = [{"name": "computer", "type": "computer_20250124",
|
||||
"display_width_px": 1280, "display_height_px": 720, "display_number": 1}]
|
||||
|
||||
# Same thinking configuration as main agent
|
||||
max_tokens = 50 # Base validation max_tokens
|
||||
if no_thinking:
|
||||
extra_body = {}
|
||||
actual_max_tokens = max_tokens
|
||||
else:
|
||||
budget_tokens = 2048
|
||||
# Same logic as main agent: if max_tokens <= budget_tokens, increase it
|
||||
if max_tokens <= budget_tokens:
|
||||
actual_max_tokens = budget_tokens + 500
|
||||
else:
|
||||
actual_max_tokens = max_tokens
|
||||
extra_body = {
|
||||
"thinking": {"type": "enabled", "budget_tokens": budget_tokens}
|
||||
}
|
||||
|
||||
# Sampling parameters (same logic as main agent)
|
||||
sampling_params = {}
|
||||
if temperature is not None:
|
||||
sampling_params['temperature'] = temperature
|
||||
if top_p is not None:
|
||||
sampling_params['top_p'] = top_p
|
||||
|
||||
# Retry logic with 5 attempts, 5 second delays
|
||||
for attempt in range(5):
|
||||
try:
|
||||
# Same API call pattern as main agent
|
||||
client.beta.messages.create(
|
||||
max_tokens=actual_max_tokens,
|
||||
messages=messages,
|
||||
model=get_model_name(APIProvider.ANTHROPIC, model_name),
|
||||
system=system,
|
||||
tools=tools,
|
||||
betas=betas,
|
||||
extra_body=extra_body,
|
||||
**sampling_params
|
||||
)
|
||||
|
||||
print(f"✅ Model {model_name} validated successfully")
|
||||
return True
|
||||
except Exception as e:
|
||||
if attempt < 4: # Don't print error on final attempt
|
||||
print(f"🔄 Validation attempt {attempt + 1}/5 failed: {e}")
|
||||
print(f"⏳ Retrying in 5 seconds...")
|
||||
time.sleep(5)
|
||||
else:
|
||||
print(f"❌ All validation attempts failed. Final error: {e}")
|
||||
|
||||
return False
|
||||
|
||||
except ValueError:
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"❌ API validation setup failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _response_to_params(
|
||||
response: BetaMessage,
|
||||
) -> list[BetaContentBlockParam]:
|
||||
|
|
|
|||
|
|
@ -1,7 +0,0 @@
|
|||
"""
|
||||
AutoGLM agent implementation
|
||||
"""
|
||||
|
||||
from .main import AutoGLMAgent
|
||||
|
||||
__all__ = ["AutoGLMAgent"]
|
||||
|
|
@ -1,241 +0,0 @@
|
|||
import logging
|
||||
import re
|
||||
from base64 import b64encode
|
||||
from typing import Dict, List
|
||||
|
||||
from .prompt.accessibility_tree_handle import linearize_accessibility_tree, trim_accessibility_tree
|
||||
from .prompt.grounding_agent import GroundingAgent as Agent
|
||||
from .tools.package.google_chrome import BrowserTools
|
||||
from .prompt.procedural_memory import Prompt
|
||||
|
||||
logger = logging.getLogger("desktopenv.agent")
|
||||
|
||||
pure_text_settings = ["a11y_tree"]
|
||||
|
||||
|
||||
def parse_code_from_string(input_string):
|
||||
# input_string = "\n".join([line.strip() for line in input_string.split(';') if line.strip()])
|
||||
if input_string.strip() in ["WAIT", "DONE", "FAIL"]:
|
||||
return [input_string.strip()]
|
||||
|
||||
# This regular expression will match both ```code``` and ```python code```
|
||||
# and capture the `code` part. It uses a non-greedy match for the content inside.
|
||||
pattern = r"```(?:\w+\s+)?(.*?)```"
|
||||
# Find all non-overlapping matches in the string
|
||||
matches = re.findall(pattern, input_string, re.DOTALL)
|
||||
|
||||
# The regex above captures the content inside the triple backticks.
|
||||
# The `re.DOTALL` flag allows the dot `.` to match newline characters as well,
|
||||
# so the code inside backticks can span multiple lines.
|
||||
|
||||
# matches now contains all the captured code snippets
|
||||
|
||||
codes = []
|
||||
|
||||
for match in matches:
|
||||
match = match.strip()
|
||||
commands = ["WAIT", "DONE", "FAIL"] # fixme: updates this part when we have more commands
|
||||
|
||||
if match in commands:
|
||||
codes.append(match.strip())
|
||||
elif match.split("\n")[-1] in commands:
|
||||
if len(match.split("\n")) > 1:
|
||||
codes.append("\n".join(match.split("\n")[:-1]))
|
||||
codes.append(match.split("\n")[-1])
|
||||
else:
|
||||
codes.append(match)
|
||||
|
||||
return codes
|
||||
|
||||
|
||||
class AutoGLMAgent:
|
||||
def __init__(
|
||||
self,
|
||||
action_space="autoglm_computer_use",
|
||||
observation_type="a11y_tree",
|
||||
max_trajectory_length=3,
|
||||
a11y_tree_max_items=300,
|
||||
with_image: bool = False,
|
||||
client_password="password",
|
||||
gen_func=None,
|
||||
tool_in_sys_msg: bool = True,
|
||||
):
|
||||
self.action_space = action_space
|
||||
self.observation_type = observation_type
|
||||
assert action_space in ["autoglm_computer_use"], "Invalid action space"
|
||||
assert observation_type in ["a11y_tree"], "Invalid observation type"
|
||||
self.max_trajectory_length = max_trajectory_length
|
||||
self.a11y_tree_max_items = a11y_tree_max_items
|
||||
self.with_image = with_image
|
||||
self.client_password = client_password
|
||||
self.gen_func = gen_func
|
||||
self.tool_in_sys_msg = tool_in_sys_msg
|
||||
|
||||
self.tool_list = {
|
||||
"libreoffice_calc": "CalcTools",
|
||||
"libreoffice_impress": "ImpressTools",
|
||||
"libreoffice_writer": "WriterTools",
|
||||
"code": "CodeTools",
|
||||
"vlc": "VLCTools",
|
||||
"google_chrome": "BrowserTools",
|
||||
}
|
||||
self.contents = []
|
||||
|
||||
@property
|
||||
def turn_number(self):
|
||||
return len(self.contents)
|
||||
|
||||
def prepare(self, instruction: str, obs: Dict, history: List, last_result: str = "") -> List:
|
||||
"""
|
||||
Predict the next action(s) based on the current observation.
|
||||
"""
|
||||
if "exe_result" in obs and not last_result:
|
||||
last_result = obs["exe_result"]
|
||||
if self.contents:
|
||||
self.contents[-1]["exe_result"] = last_result
|
||||
|
||||
cur_app = obs["cur_app"]
|
||||
logger.info(f"current app is {cur_app}")
|
||||
|
||||
if cur_app:
|
||||
tool_name = cur_app.strip().lower().replace("-", "_")
|
||||
tool_name = tool_name if tool_name in self.tool_list.keys() else None
|
||||
else:
|
||||
tool_name = None
|
||||
|
||||
setup_prompt, func_def_prompt, note_prompt = Prompt.construct_procedural_memory(
|
||||
Agent, app_name=tool_name, client_password=self.client_password
|
||||
)
|
||||
if self.tool_in_sys_msg:
|
||||
system_message = setup_prompt + "\n\n" + func_def_prompt + "\n\n" + note_prompt
|
||||
else:
|
||||
system_message = setup_prompt + "\n\n" + note_prompt
|
||||
system_message += "\n\n**IMPORTANT** You are asked to complete the following task: {}".format(instruction)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_message,
|
||||
}
|
||||
]
|
||||
messages.extend(history)
|
||||
|
||||
if obs["apps"]:
|
||||
app_str = "Window ID App Name Title\n"
|
||||
for window_id, app in obs["apps"].items():
|
||||
app_str += f"{window_id} {app['app_name']} {app['title']}\n"
|
||||
else:
|
||||
app_str = "None"
|
||||
|
||||
last_result = last_result.strip() if last_result else "None"
|
||||
last_result = last_result[:2000] + "..." if len(last_result) > 2000 else last_result
|
||||
|
||||
tree = linearize_accessibility_tree(obs["accessibility_tree"], "Ubuntu")
|
||||
tree = trim_accessibility_tree(tree, 300)
|
||||
|
||||
app_info = obs["app_info"].strip() if obs["app_info"] else "None"
|
||||
app_info = app_info[:5000] + "..." if len(app_info) > 5000 else app_info
|
||||
|
||||
prompt = "* Apps: {}\n\n* Current App: {}\n\n* A11y Tree: {}\n\n* App Info: {}\n\n* Previous Action Result: {}".format(
|
||||
app_str.strip(),
|
||||
obs["cur_window_id"].strip() if obs["cur_window_id"] in app_str else "None",
|
||||
tree.strip(),
|
||||
app_info,
|
||||
last_result if last_result else "None",
|
||||
) + (
|
||||
"\n\n" + func_def_prompt if not self.tool_in_sys_msg else ""
|
||||
)
|
||||
|
||||
content = [{"type": "text", "text": prompt}]
|
||||
if self.with_image and obs.get('screenshot'):
|
||||
content.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{b64encode(obs['screenshot']).decode('utf-8')}",
|
||||
"detail": "high",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
messages.append({"role": "user", "content": content})
|
||||
|
||||
return messages
|
||||
|
||||
def execute(self, response, obs):
|
||||
try:
|
||||
actions = parse_code_from_string(response)
|
||||
action = actions[0]
|
||||
logger.info(f"The pesudo action is {action}")
|
||||
|
||||
if "Agent." in action:
|
||||
actions = [
|
||||
eval(action),
|
||||
]
|
||||
elif "BrowserTools." in action: # TODO: special check for BrowserTools
|
||||
actions = [
|
||||
eval(action),
|
||||
]
|
||||
else:
|
||||
actions = Agent.tool_commands(action, obs["cur_app"].strip().replace("-", "_").lower())
|
||||
logger.info(f"The grounded action is {actions[0]}")
|
||||
except Exception as e:
|
||||
print("Failed to parse action from response", e)
|
||||
actions = []
|
||||
|
||||
return actions
|
||||
|
||||
def format_history(self, max_turns=30):
|
||||
history = []
|
||||
for ix in range(self.turn_number):
|
||||
if ix == 0:
|
||||
env_input = "**Environment State (Omitted)**"
|
||||
else:
|
||||
env_input = (
|
||||
f"**Environment State (Omitted)**\nPrevious Action Result: {self.contents[ix - 1]['exe_result']}"
|
||||
)
|
||||
|
||||
env_input = env_input[:2000] + "..." if len(env_input) > 2000 else env_input
|
||||
response = (
|
||||
self.contents[ix]["response"][:1500] + "..."
|
||||
if len(self.contents[ix]["response"]) > 1500
|
||||
else self.contents[ix]["response"]
|
||||
)
|
||||
history.append({"role": "user", "content": [{"type": "text", "text": env_input}]})
|
||||
history.append({"role": "assistant", "content": [{"type": "text", "text": response}]})
|
||||
|
||||
return history[-max_turns * 2:]
|
||||
|
||||
def predict(self, instruction: str, obs: Dict) -> List:
|
||||
history = self.format_history()
|
||||
messages = self.prepare(instruction, obs, history)
|
||||
|
||||
assert self.gen_func is not None, "gen_func is not set"
|
||||
try:
|
||||
response = self.gen_func(messages)
|
||||
except Exception as e:
|
||||
logger.error("Failed to call gen_func, Error: " + str(e))
|
||||
response = ""
|
||||
|
||||
logger.info("RESPONSE: %s", response)
|
||||
|
||||
actions = self.execute(response, obs)
|
||||
|
||||
# update the contents
|
||||
self.contents.append(
|
||||
{
|
||||
"instruction": instruction,
|
||||
"index": len(self.contents),
|
||||
"response": response,
|
||||
"action": "Parse error" if not actions else actions[0],
|
||||
"exe_result": "Invalid action" if not actions else "",
|
||||
**obs,
|
||||
}
|
||||
)
|
||||
return response, actions
|
||||
|
||||
def reset(self, _logger=None):
|
||||
global logger
|
||||
logger = _logger if _logger is not None else logging.getLogger("desktopenv.aguvis_agent")
|
||||
|
||||
self.contents = []
|
||||
|
|
@ -1,329 +0,0 @@
|
|||
import io
|
||||
import re
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import List, Tuple
|
||||
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
from .deduplicate_node import filter_similar_nodes
|
||||
|
||||
attributes_ns_ubuntu = "https://accessibility.windows.example.org/ns/attributes"
|
||||
attributes_ns_windows = "https://accessibility.windows.example.org/ns/attributes"
|
||||
state_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/state"
|
||||
state_ns_windows = "https://accessibility.windows.example.org/ns/state"
|
||||
component_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/component"
|
||||
component_ns_windows = "https://accessibility.windows.example.org/ns/component"
|
||||
value_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/value"
|
||||
value_ns_windows = "https://accessibility.windows.example.org/ns/value"
|
||||
class_ns_windows = "https://accessibility.windows.example.org/ns/class"
|
||||
|
||||
|
||||
def find_leaf_nodes(xlm_file_str):
|
||||
if not xlm_file_str:
|
||||
return []
|
||||
|
||||
root = ET.fromstring(xlm_file_str)
|
||||
|
||||
# Recursive function to traverse the XML tree and collect leaf nodes
|
||||
def collect_leaf_nodes(node, leaf_nodes):
|
||||
# If the node has no children, it is a leaf node, add it to the list
|
||||
if not list(node):
|
||||
leaf_nodes.append(node)
|
||||
# If the node has children, recurse on each child
|
||||
for child in node:
|
||||
collect_leaf_nodes(child, leaf_nodes)
|
||||
|
||||
# List to hold all leaf nodes
|
||||
leaf_nodes = []
|
||||
collect_leaf_nodes(root, leaf_nodes)
|
||||
return leaf_nodes
|
||||
|
||||
|
||||
def judge_node(node: ET, platform="Ubuntu", check_image=False) -> bool:
|
||||
if platform == "Ubuntu":
|
||||
_state_ns = state_ns_ubuntu
|
||||
_component_ns = component_ns_ubuntu
|
||||
elif platform == "Windows":
|
||||
_state_ns = state_ns_windows
|
||||
_component_ns = component_ns_windows
|
||||
else:
|
||||
raise ValueError("Invalid platform, must be 'Ubuntu' or 'Windows'")
|
||||
|
||||
keeps: bool = (
|
||||
node.tag.startswith("document")
|
||||
or node.tag.endswith("item")
|
||||
or node.tag.endswith("button")
|
||||
or node.tag.endswith("heading")
|
||||
or node.tag.endswith("label")
|
||||
or node.tag.endswith("scrollbar")
|
||||
or node.tag.endswith("searchbox")
|
||||
or node.tag.endswith("textbox")
|
||||
or node.tag.endswith("link")
|
||||
or node.tag.endswith("tabelement")
|
||||
or node.tag.endswith("textfield")
|
||||
or node.tag.endswith("textarea")
|
||||
or node.tag.endswith("menu")
|
||||
or node.tag
|
||||
in {
|
||||
"alert",
|
||||
"canvas",
|
||||
"check-box",
|
||||
"combo-box",
|
||||
"entry",
|
||||
"icon",
|
||||
"image",
|
||||
"paragraph",
|
||||
"scroll-bar",
|
||||
"section",
|
||||
"slider",
|
||||
"static",
|
||||
"table-cell",
|
||||
"terminal",
|
||||
"text",
|
||||
"netuiribbontab",
|
||||
"start",
|
||||
"trayclockwclass",
|
||||
"traydummysearchcontrol",
|
||||
"uiimage",
|
||||
"uiproperty",
|
||||
"uiribboncommandbar",
|
||||
}
|
||||
)
|
||||
keeps = (
|
||||
keeps
|
||||
and (
|
||||
platform == "Ubuntu"
|
||||
and node.get("{{{:}}}showing".format(_state_ns), "false") == "true"
|
||||
and node.get("{{{:}}}visible".format(_state_ns), "false") == "true"
|
||||
or platform == "Windows"
|
||||
and node.get("{{{:}}}visible".format(_state_ns), "false") == "true"
|
||||
)
|
||||
and (
|
||||
node.get("name", "") != ""
|
||||
or node.text is not None
|
||||
and len(node.text) > 0
|
||||
or check_image
|
||||
and node.get("image", "false") == "true"
|
||||
)
|
||||
)
|
||||
# and (
|
||||
# node.get("{{{:}}}enabled".format(_state_ns), "false") == "true"
|
||||
# or node.get("{{{:}}}editable".format(_state_ns), "false") == "true"
|
||||
# or node.get("{{{:}}}expandable".format(_state_ns), "false") == "true"
|
||||
# or node.get("{{{:}}}checkable".format(_state_ns), "false") == "true"
|
||||
# ) \
|
||||
|
||||
coordinates: Tuple[int, int] = eval(node.get("{{{:}}}screencoord".format(_component_ns), "(-1, -1)"))
|
||||
sizes: Tuple[int, int] = eval(node.get("{{{:}}}size".format(_component_ns), "(-1, -1)"))
|
||||
keeps = keeps and coordinates[0] >= 0 and coordinates[1] >= 0 and sizes[0] > 0 and sizes[1] > 0
|
||||
return keeps
|
||||
|
||||
|
||||
def filter_nodes(root: ET, platform="Ubuntu", check_image=False):
|
||||
filtered_nodes = []
|
||||
|
||||
for node in root.iter():
|
||||
if judge_node(node, platform, check_image):
|
||||
filtered_nodes.append(node)
|
||||
|
||||
return filtered_nodes
|
||||
|
||||
|
||||
def draw_bounding_boxes(nodes, image_file_content, down_sampling_ratio=1.0, platform="Ubuntu"):
|
||||
|
||||
if platform == "Ubuntu":
|
||||
_state_ns = state_ns_ubuntu
|
||||
_component_ns = component_ns_ubuntu
|
||||
_value_ns = value_ns_ubuntu
|
||||
elif platform == "Windows":
|
||||
_state_ns = state_ns_windows
|
||||
_component_ns = component_ns_windows
|
||||
_value_ns = value_ns_windows
|
||||
else:
|
||||
raise ValueError("Invalid platform, must be 'Ubuntu' or 'Windows'")
|
||||
|
||||
# Load the screenshot image
|
||||
image_stream = io.BytesIO(image_file_content)
|
||||
image = Image.open(image_stream)
|
||||
if float(down_sampling_ratio) != 1.0:
|
||||
image = image.resize((int(image.size[0] * down_sampling_ratio), int(image.size[1] * down_sampling_ratio)))
|
||||
draw = ImageDraw.Draw(image)
|
||||
marks = []
|
||||
drew_nodes = []
|
||||
text_informations: List[str] = ["index\ttag\tname\ttext"]
|
||||
|
||||
try:
|
||||
# Adjust the path to the font file you have or use a default one
|
||||
font = ImageFont.truetype("arial.ttf", 15)
|
||||
except IOError:
|
||||
# Fallback to a basic font if the specified font can't be loaded
|
||||
font = ImageFont.load_default()
|
||||
|
||||
index = 1
|
||||
|
||||
# Loop over all the visible nodes and draw their bounding boxes
|
||||
for _node in nodes:
|
||||
coords_str = _node.attrib.get("{{{:}}}screencoord".format(_component_ns))
|
||||
size_str = _node.attrib.get("{{{:}}}size".format(_component_ns))
|
||||
|
||||
if coords_str and size_str:
|
||||
try:
|
||||
# Parse the coordinates and size from the strings
|
||||
coords = tuple(map(int, coords_str.strip("()").split(", ")))
|
||||
size = tuple(map(int, size_str.strip("()").split(", ")))
|
||||
|
||||
import copy
|
||||
|
||||
original_coords = copy.deepcopy(coords)
|
||||
original_size = copy.deepcopy(size)
|
||||
|
||||
if float(down_sampling_ratio) != 1.0:
|
||||
# Downsample the coordinates and size
|
||||
coords = tuple(int(coord * down_sampling_ratio) for coord in coords)
|
||||
size = tuple(int(s * down_sampling_ratio) for s in size)
|
||||
|
||||
# Check for negative sizes
|
||||
if size[0] <= 0 or size[1] <= 0:
|
||||
raise ValueError(f"Size must be positive, got: {size}")
|
||||
|
||||
# Calculate the bottom-right corner of the bounding box
|
||||
bottom_right = (coords[0] + size[0], coords[1] + size[1])
|
||||
|
||||
# Check that bottom_right > coords (x1 >= x0, y1 >= y0)
|
||||
if bottom_right[0] < coords[0] or bottom_right[1] < coords[1]:
|
||||
raise ValueError(f"Invalid coordinates or size, coords: {coords}, size: {size}")
|
||||
|
||||
# Check if the area only contains one color
|
||||
cropped_image = image.crop((*coords, *bottom_right))
|
||||
if len(set(list(cropped_image.getdata()))) == 1:
|
||||
continue
|
||||
|
||||
# Draw rectangle on image
|
||||
draw.rectangle([coords, bottom_right], outline="red", width=1)
|
||||
|
||||
# Draw index number at the bottom left of the bounding box with black background
|
||||
text_position = (coords[0], bottom_right[1]) # Adjust Y to be above the bottom right
|
||||
text_bbox: Tuple[int, int, int, int] = draw.textbbox(text_position, str(index), font=font, anchor="lb")
|
||||
# offset: int = bottom_right[1]-text_bbox[3]
|
||||
# text_bbox = (text_bbox[0], text_bbox[1]+offset, text_bbox[2], text_bbox[3]+offset)
|
||||
|
||||
# draw.rectangle([text_position, (text_position[0] + 25, text_position[1] + 18)], fill='black')
|
||||
draw.rectangle(text_bbox, fill="black")
|
||||
draw.text(text_position, str(index), font=font, anchor="lb", fill="white")
|
||||
|
||||
# each mark is an x, y, w, h tuple
|
||||
marks.append([original_coords[0], original_coords[1], original_size[0], original_size[1]])
|
||||
drew_nodes.append(_node)
|
||||
|
||||
if _node.text:
|
||||
node_text = _node.text if '"' not in _node.text else '"{:}"'.format(_node.text.replace('"', '""'))
|
||||
elif _node.get("{{{:}}}class".format(class_ns_windows), "").endswith("EditWrapper") and _node.get(
|
||||
"{{{:}}}value".format(_value_ns)
|
||||
):
|
||||
node_text = _node.get("{{{:}}}value".format(_value_ns), "")
|
||||
node_text = node_text if '"' not in node_text else '"{:}"'.format(node_text.replace('"', '""'))
|
||||
else:
|
||||
node_text = '""'
|
||||
text_information: str = "{:d}\t{:}\t{:}\t{:}".format(index, _node.tag, _node.get("name", ""), node_text)
|
||||
text_informations.append(text_information)
|
||||
|
||||
index += 1
|
||||
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
output_image_stream = io.BytesIO()
|
||||
image.save(output_image_stream, format="PNG")
|
||||
image_content = output_image_stream.getvalue()
|
||||
|
||||
return marks, drew_nodes, "\n".join(text_informations), image_content
|
||||
|
||||
|
||||
def print_nodes_with_indent(nodes, indent=0):
|
||||
for node in nodes:
|
||||
print(" " * indent, node.tag, node.attrib)
|
||||
print_nodes_with_indent(node, indent + 2)
|
||||
|
||||
|
||||
def find_active_applications(tree, state_ns):
|
||||
apps_with_active_tag = []
|
||||
for application in list(tree.getroot()):
|
||||
app_name = application.attrib.get("name")
|
||||
for frame in application:
|
||||
is_active = frame.attrib.get("{{{:}}}active".format(state_ns), "false")
|
||||
if is_active == "true":
|
||||
apps_with_active_tag.append(app_name)
|
||||
if apps_with_active_tag:
|
||||
to_keep = apps_with_active_tag + ["gnome-shell"]
|
||||
else:
|
||||
to_keep = ["gjs", "gnome-shell"]
|
||||
return to_keep
|
||||
|
||||
|
||||
def linearize_accessibility_tree(accessibility_tree, platform="Ubuntu"):
|
||||
if platform == "Ubuntu":
|
||||
_attributes_ns = attributes_ns_ubuntu
|
||||
_state_ns = state_ns_ubuntu
|
||||
_component_ns = component_ns_ubuntu
|
||||
_value_ns = value_ns_ubuntu
|
||||
elif platform == "Windows":
|
||||
_attributes_ns = attributes_ns_windows
|
||||
_state_ns = state_ns_windows
|
||||
_component_ns = component_ns_windows
|
||||
_value_ns = value_ns_windows
|
||||
else:
|
||||
raise ValueError("Invalid platform, must be 'Ubuntu' or 'Windows'")
|
||||
|
||||
try:
|
||||
tree = ET.ElementTree(ET.fromstring(accessibility_tree))
|
||||
keep_apps = find_active_applications(tree, _state_ns)
|
||||
|
||||
# Remove inactive applications
|
||||
for application in list(tree.getroot()):
|
||||
if application.get("name") not in keep_apps:
|
||||
tree.getroot().remove(application)
|
||||
|
||||
filtered_nodes = filter_nodes(tree.getroot(), platform, check_image=True)
|
||||
linearized_accessibility_tree = ["tag\ttext\tposition (center x & y)\tsize (w & h)"]
|
||||
|
||||
# Linearize the accessibility tree nodes into a table format
|
||||
for node in filtered_nodes:
|
||||
try:
|
||||
text = node.text if node.text is not None else ""
|
||||
text = text.strip()
|
||||
name = node.get("name", "").strip()
|
||||
if text == "":
|
||||
text = name
|
||||
elif name != "" and text != name:
|
||||
text = f"{name} ({text})"
|
||||
|
||||
text = text.replace("\n", "\\n")
|
||||
pos = node.get("{{{:}}}screencoord".format(_component_ns), "")
|
||||
size = node.get("{{{:}}}size".format(_component_ns), "")
|
||||
|
||||
x, y = re.match(f"\((\d+), (\d+)\)", pos).groups()
|
||||
w, h = re.match(f"\((\d+), (\d+)\)", size).groups()
|
||||
x_mid, y_mid = int(x) + int(w) // 2, int(y) + int(h) // 2
|
||||
|
||||
linearized_accessibility_tree.append(
|
||||
"{:}\t{:}\t{:}\t{:}".format(node.tag, text, f"({x_mid}, {y_mid})", size)
|
||||
)
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
# Filter out similar nodes
|
||||
linearized_accessibility_tree = filter_similar_nodes("\n".join(linearized_accessibility_tree))
|
||||
except Exception as e:
|
||||
print(f"Error in linearize_accessibility_tree: {e}")
|
||||
linearized_accessibility_tree = ""
|
||||
|
||||
return linearized_accessibility_tree
|
||||
|
||||
|
||||
def trim_accessibility_tree(linearized_accessibility_tree, max_items):
|
||||
lines = linearized_accessibility_tree.strip().split("\n")
|
||||
if len(lines) > max_items:
|
||||
lines = lines[:max_items]
|
||||
linearized_accessibility_tree = "\n".join(lines)
|
||||
linearized_accessibility_tree += "\n..."
|
||||
return linearized_accessibility_tree
|
||||
|
|
@ -1,100 +0,0 @@
|
|||
import re
|
||||
|
||||
|
||||
def parse_line(line):
|
||||
# 解析格式,如:label Google Chrome (191, 13) (104, 17)
|
||||
pattern = r"^(\S+)\s+(.+?)\s+\((\d+), (\d+)\)\s+\((\d+), (\d+)\)"
|
||||
m = re.match(pattern, line)
|
||||
if not m:
|
||||
return None
|
||||
node_type, text, cx, cy, w, h = m.groups()
|
||||
cx, cy, w, h = map(int, (cx, cy, w, h))
|
||||
# bounding box as (x1, y1, x2, y2)
|
||||
x1 = cx - w // 2
|
||||
y1 = cy - h // 2
|
||||
x2 = x1 + w
|
||||
y2 = y1 + h
|
||||
return {
|
||||
"type": node_type,
|
||||
"text": text.strip(),
|
||||
"bbox": (x1, y1, x2, y2),
|
||||
"center": (cx, cy),
|
||||
"size": (w, h),
|
||||
"raw": line,
|
||||
}
|
||||
|
||||
|
||||
def iou(box1, box2):
|
||||
# box: (x1, y1, x2, y2)
|
||||
xi1 = max(box1[0], box2[0])
|
||||
yi1 = max(box1[1], box2[1])
|
||||
xi2 = min(box1[2], box2[2])
|
||||
yi2 = min(box1[3], box2[3])
|
||||
inter_width = max(0, xi2 - xi1)
|
||||
inter_height = max(0, yi2 - yi1)
|
||||
inter_area = inter_width * inter_height
|
||||
area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
||||
area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
||||
union = area1 + area2 - inter_area
|
||||
if union == 0:
|
||||
return 0
|
||||
return inter_area / union
|
||||
|
||||
|
||||
def norm_text(s):
|
||||
# 归一化文本:小写、去空格等
|
||||
return re.sub(r"\s+", "", s.lower())
|
||||
|
||||
|
||||
def text_similarity(a, b):
|
||||
# 简单判定:完全一致为1,否则0
|
||||
na, nb = norm_text(a), norm_text(b)
|
||||
if na == nb:
|
||||
return 1.0
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def filter_similar_nodes(linearized_accessibility_tree):
|
||||
lines = [ln for ln in linearized_accessibility_tree.split("\n") if ln.strip()]
|
||||
# parse all nodes
|
||||
nodes = []
|
||||
for ln in lines:
|
||||
node = parse_line(ln)
|
||||
if node:
|
||||
nodes.append(node)
|
||||
else:
|
||||
# 解析不了的保留
|
||||
nodes.append({"raw": ln, "invalid": True})
|
||||
filtered = []
|
||||
removed = [False] * len(nodes)
|
||||
# 阈值可自行调整
|
||||
IOU_THRESH = 0.2
|
||||
TEXT_THRESH = 0.9
|
||||
for i, ni in enumerate(nodes):
|
||||
if ni.get("invalid"):
|
||||
filtered.append(ni["raw"])
|
||||
continue
|
||||
if removed[i]:
|
||||
continue
|
||||
for j in range(i + 1, len(nodes)):
|
||||
nj = nodes[j]
|
||||
if nj.get("invalid"):
|
||||
continue
|
||||
iou_val = iou(ni["bbox"], nj["bbox"])
|
||||
text_sim = text_similarity(ni["text"], nj["text"])
|
||||
if iou_val > IOU_THRESH and text_sim > TEXT_THRESH:
|
||||
# 二者极其相似,移除后者
|
||||
removed[j] = True
|
||||
# print(f"移除: {nj['raw']} (与 {ni['raw']} 相似度高)")
|
||||
# 保留未被标记为移除的
|
||||
if not removed[i]:
|
||||
filtered.append(ni["raw"])
|
||||
return "\n".join(filtered)
|
||||
|
||||
|
||||
# 示例用法
|
||||
if __name__ == "__main__":
|
||||
linearized_accessibility_tree = "tag\ttext\tposition (center x & y)\tsize (w & h)\nicon\t\t(1853, 1001)\t(64, 64)\nlabel\tHome\t(1853, 1045)\t(40, 17)\nlabel\tActivities\t(49, 13)\t(63, 17)\ntext\tActivities\t(49, 13)\t(63, 17)\nlabel\tApr 17 17∶04\t(995, 13)\t(117, 27)\ntext\tApr 17 17∶04\t(995, 13)\t(87, 18)\nmenu\tSystem\t(1867, 13)\t(106, 27)\npush-button\tGoogle Chrome\t(35, 65)\t(70, 64)\npush-button\tThunderbird Mail\t(35, 133)\t(70, 64)\npush-button\tVisual Studio Code\t(35, 201)\t(70, 64)\npush-button\tVLC media player\t(35, 269)\t(70, 64)\npush-button\tLibreOffice Writer\t(35, 337)\t(70, 64)\npush-button\tLibreOffice Calc\t(35, 405)\t(70, 64)\npush-button\tLibreOffice Impress\t(35, 473)\t(70, 64)\npush-button\tGNU Image Manipulation Program\t(35, 541)\t(70, 64)\npush-button\tFiles\t(35, 609)\t(70, 64)\npush-button\tUbuntu Software\t(35, 677)\t(70, 64)\npush-button\tHelp\t(35, 745)\t(70, 64)\npush-button\tTrash\t(35, 816)\t(70, 64)\ntoggle-button\tShow Applications\t(35, 1045)\t(70, 70)"
|
||||
result = filter_similar_nodes(linearized_accessibility_tree)
|
||||
print(result)
|
||||
|
|
@ -1,259 +0,0 @@
|
|||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger("desktopenv.agent")
|
||||
|
||||
|
||||
def agent_action(func):
|
||||
func.is_agent_action = True
|
||||
return func
|
||||
|
||||
|
||||
switch_window_code = """import subprocess;
|
||||
import pyautogui;
|
||||
pyautogui.press('escape');
|
||||
time.sleep(0.5);
|
||||
subprocess.run(['wmctrl', '-ia', 'WINDOW_ID'])
|
||||
subprocess.run(['wmctrl', '-ir', 'WINDOW_ID', '-b', 'add,maximized_vert,maximized_horz'])
|
||||
print('Switch to WINDOW_ID')"""
|
||||
|
||||
launch_app_commands = {
|
||||
# Web Browser
|
||||
"chrome": "google-chrome --remote-debugging-port=1337",
|
||||
# File Manager
|
||||
"files": "nautilus",
|
||||
# Terminal
|
||||
"terminal": 'export DBUS_SESSION_BUS_ADDRESS="unix:path=/run/user/1000/bus" && gnome-terminal',
|
||||
# Utilities
|
||||
"gedit": "gedit",
|
||||
# Office
|
||||
"libreoffice writer": "libreoffice --writer",
|
||||
"libreoffice calc": "libreoffice --calc",
|
||||
"libreoffice impress": "libreoffice --impress",
|
||||
# System
|
||||
"settings": 'export DBUS_SESSION_BUS_ADDRESS="unix:path=/run/user/1000/bus" && gnome-control-center',
|
||||
# Multimedia
|
||||
"vlc": "vlc",
|
||||
"gimp": "gimp",
|
||||
# IDE
|
||||
"vs code": "code",
|
||||
# Email
|
||||
"thunderbird": "thunderbird",
|
||||
}
|
||||
|
||||
|
||||
class GroundingAgent:
|
||||
|
||||
tool_list = {
|
||||
"libreoffice_calc": "CalcTools",
|
||||
"libreoffice_impress": "ImpressTools",
|
||||
"libreoffice_writer": "WriterTools",
|
||||
"code": "CodeTools",
|
||||
"vlc": "VLCTools",
|
||||
"google_chrome": "BrowserTools",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def tool_commands(cls, code: str, tool_name: str):
|
||||
command = f"from {tool_name} import *; "
|
||||
command += code
|
||||
|
||||
tool_class = cls.tool_list[tool_name]
|
||||
command += f"; {tool_class}.print_result()"
|
||||
|
||||
return [
|
||||
command,
|
||||
]
|
||||
|
||||
@classmethod
|
||||
@agent_action
|
||||
def click(
|
||||
cls,
|
||||
coordinates: List,
|
||||
num_clicks: int = 1,
|
||||
button_type: str = "left",
|
||||
):
|
||||
"""
|
||||
Click on the element.
|
||||
|
||||
Args:
|
||||
coordinates (List): [x, y], Coordinates of the element to click on
|
||||
num_clicks (int): number of times to click the element
|
||||
button_type (str): which mouse button to press can be "left", "middle", or "right"
|
||||
"""
|
||||
command = ""
|
||||
x, y = coordinates
|
||||
command += f"""pyautogui.click({x}, {y}, clicks={num_clicks}, button={repr(button_type)}); print("Click Success")""" # TODO: 最大化窗口需要一次调用
|
||||
return command
|
||||
|
||||
@classmethod
|
||||
@agent_action
|
||||
def type(
|
||||
cls,
|
||||
coordinates: Optional[List] = None,
|
||||
text: str = "",
|
||||
overwrite: bool = False,
|
||||
enter: bool = False,
|
||||
):
|
||||
"""
|
||||
Type text into the element.
|
||||
|
||||
Args:
|
||||
coordinates (List): [x, y] Coordinates of the element to type into. If not provided, typing will start at the current cursor location.
|
||||
text (str): the text to type
|
||||
overwrite (bool): Assign it to True if the text should overwrite the existing text, otherwise assign it to False. Using this argument clears all text in an element.
|
||||
enter (bool): Assign it to True if the enter key should be pressed after typing the text, otherwise assign it to False.
|
||||
"""
|
||||
|
||||
command = ""
|
||||
|
||||
if coordinates is not None:
|
||||
# Start typing at the center of the element
|
||||
x, y = coordinates
|
||||
command += f"pyautogui.click({x}, {y}); "
|
||||
|
||||
if overwrite:
|
||||
command += f"pyautogui.hotkey('ctrl', 'a'); pyautogui.press('backspace'); "
|
||||
|
||||
command += f"pyautogui.write({repr(text)}); "
|
||||
|
||||
if enter:
|
||||
command += "pyautogui.press('enter'); "
|
||||
|
||||
command += "print('Type Success')"
|
||||
|
||||
return command
|
||||
|
||||
@classmethod
|
||||
@agent_action
|
||||
def drag_and_drop(cls, drag_from_coordinates: List, drop_on_coordinates: List):
|
||||
"""
|
||||
Drag element1 and drop it on element2.
|
||||
|
||||
Args:
|
||||
drag_from_coordinates (List): [x, y] Coordinates of element to drag
|
||||
drop_on_coordinates (List): [x, y] Coordinates of element to drop on
|
||||
"""
|
||||
x1, y1 = drag_from_coordinates
|
||||
x2, y2 = drop_on_coordinates
|
||||
|
||||
command = f"pyautogui.moveTo({x1}, {y1}); "
|
||||
# TODO: specified duration?
|
||||
command += f"pyautogui.dragTo({x2}, {y2}, duration=1.); pyautogui.mouseUp(); "
|
||||
|
||||
command += "print('Drag and Drop Success')"
|
||||
|
||||
return command
|
||||
|
||||
@classmethod
|
||||
@agent_action
|
||||
def scroll(cls, coordinates: List, direction: str):
|
||||
"""
|
||||
Scroll the element in the specified direction.
|
||||
|
||||
Args:
|
||||
coordinates (List): [x, y] Coordinates of the element to scroll in
|
||||
direction (str): the direction to scroll can be "up" or "down".
|
||||
"""
|
||||
x, y = coordinates
|
||||
amount = 100 if direction == "up" else -100
|
||||
return f"import pyautogui; pyautogui.moveTo({x}, {y}); pyautogui.scroll({amount}); print('Scroll Success')"
|
||||
|
||||
@classmethod
|
||||
@agent_action
|
||||
def open_app(cls, app_name: str):
|
||||
"""
|
||||
Open a specified application.
|
||||
|
||||
App List:
|
||||
- chrome
|
||||
- files
|
||||
- terminal
|
||||
- gedit
|
||||
- libreoffice writer
|
||||
- libreoffice calc
|
||||
- libreoffice impress
|
||||
- vs code
|
||||
- vlc
|
||||
- gimp
|
||||
- settings
|
||||
- thunderbird
|
||||
|
||||
Args:
|
||||
app_name (str): Name of the application to open
|
||||
"""
|
||||
|
||||
app_name = app_name.lower().strip()
|
||||
|
||||
if app_name not in launch_app_commands:
|
||||
command = f"print(f'{app_name} is not supported or recognized')"
|
||||
else:
|
||||
command = {
|
||||
"action_type": "OPEN_APP",
|
||||
"parameters": {"launch_app_command": launch_app_commands[app_name], "app_name": app_name},
|
||||
}
|
||||
|
||||
return command
|
||||
|
||||
@classmethod
|
||||
@agent_action
|
||||
def switch_window(cls, window_id: str):
|
||||
"""
|
||||
Switch to the window with the given window id.
|
||||
|
||||
Args:
|
||||
window_id (str): the window id to switch to from the provided list of open windows
|
||||
"""
|
||||
return switch_window_code.replace("WINDOW_ID", window_id)
|
||||
|
||||
@classmethod
|
||||
@agent_action
|
||||
def hotkey(cls, keys: List):
|
||||
"""
|
||||
Press a hotkey combination.
|
||||
|
||||
Args:
|
||||
keys (List): the keys to press in combination in a list format (e.g. ['ctrl', 'c'] for copy, ['prtsc'] for screenshot)
|
||||
"""
|
||||
# add quotes around the keys
|
||||
keys = [f"'{key}'" for key in keys]
|
||||
key_str = ", ".join(keys).replace("'", "\\'")
|
||||
return f"import pyautogui; pyautogui.hotkey({', '.join(keys)}); print(f'Press Hotkey: {key_str}')"
|
||||
|
||||
@classmethod
|
||||
@agent_action
|
||||
def quote(cls, content: str):
|
||||
"""
|
||||
Quoting information from the current page for memory. Only you can see the quoted content.
|
||||
|
||||
Args:
|
||||
content (str): text summarized or copied from the page for later operation.
|
||||
"""
|
||||
return f'''print("""{content}""")'''
|
||||
|
||||
@classmethod
|
||||
@agent_action
|
||||
def wait(cls):
|
||||
"""
|
||||
Wait for a while.
|
||||
|
||||
"""
|
||||
return "WAIT"
|
||||
|
||||
@classmethod
|
||||
@agent_action
|
||||
def exit(cls, success: bool):
|
||||
"""
|
||||
End the current task.
|
||||
|
||||
Args:
|
||||
success (bool): True if successfully finish a task, otherwise set it False
|
||||
"""
|
||||
if success:
|
||||
return "DONE"
|
||||
else:
|
||||
return "FAIL"
|
||||
|
|
@ -1,202 +0,0 @@
|
|||
import inspect
|
||||
import json
|
||||
import os
|
||||
import textwrap
|
||||
|
||||
current_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
def generate_func(json_data):
|
||||
# 收集所有类名和它们的函数
|
||||
class_funcs = {}
|
||||
no_class_funcs = []
|
||||
cls_name = ""
|
||||
|
||||
for item in json_data:
|
||||
if item["type"] == "function":
|
||||
func = item["function"]
|
||||
func_parts = func["name"].split(".")
|
||||
|
||||
if len(func_parts) == 2:
|
||||
class_name, func_name = func_parts
|
||||
if class_name not in class_funcs:
|
||||
class_funcs[class_name] = []
|
||||
class_funcs[class_name].append(item)
|
||||
else:
|
||||
no_class_funcs.append(item)
|
||||
|
||||
code = ""
|
||||
|
||||
# 生成有类的函数
|
||||
for class_name, funcs in class_funcs.items():
|
||||
code += f"class {class_name}:\n"
|
||||
cls_name = class_name
|
||||
for item in funcs:
|
||||
func = item["function"]
|
||||
func_name = func["name"].split(".")[-1]
|
||||
description = func["description"]
|
||||
params = func["parameters"]["properties"]
|
||||
required = func["parameters"].get("required", [])
|
||||
|
||||
# 构建参数列表
|
||||
param_list = ["cls"]
|
||||
# 首先添加必需参数
|
||||
for param_name in required:
|
||||
param_list.append(f"{param_name}")
|
||||
# 然后添加可选参数
|
||||
for param_name in params:
|
||||
if param_name not in required:
|
||||
param_list.append(f"{param_name}") # 可选参数默认值设为None
|
||||
|
||||
# 构建函数定义
|
||||
func_def = f" def {func_name}({', '.join(param_list)}):\n"
|
||||
|
||||
# 构建文档字符串
|
||||
docstring = f' """\n {description}\n\n Args:\n'
|
||||
if len(param_list) == 1: # 只有cls参数
|
||||
docstring += " None\n"
|
||||
else:
|
||||
# 首先记录必需参数
|
||||
for param_name in required:
|
||||
param_type = params[param_name]["type"]
|
||||
param_desc = params[param_name].get("description", "")
|
||||
docstring += f" {param_name} ({param_type}): {param_desc}\n"
|
||||
# 然后记录可选参数
|
||||
for param_name in params:
|
||||
if param_name not in required:
|
||||
param_type = params[param_name]["type"]
|
||||
param_desc = params[param_name].get("description", "")
|
||||
docstring += f" {param_name} ({param_type}, optional): {param_desc}\n"
|
||||
|
||||
docstring += ' """\n'
|
||||
|
||||
code += func_def + docstring + "\n"
|
||||
|
||||
code += "\n"
|
||||
|
||||
# 生成没有类的函数
|
||||
for item in no_class_funcs:
|
||||
func = item["function"]
|
||||
func_name = func["name"]
|
||||
description = func["description"]
|
||||
params = func["parameters"]["properties"]
|
||||
required = func["parameters"].get("required", [])
|
||||
|
||||
# 构建参数列表
|
||||
param_list = []
|
||||
# 首先添加必需参数
|
||||
for param_name in required:
|
||||
param_list.append(f"{param_name}")
|
||||
# 然后添加可选参数
|
||||
for param_name in params:
|
||||
if param_name not in required:
|
||||
param_list.append(f"{param_name}")
|
||||
|
||||
# 构建函数定义
|
||||
func_def = f"def {func_name}({', '.join(param_list)}):\n"
|
||||
|
||||
# 构建文档字符串
|
||||
docstring = f' """\n {description}\n\n Args:\n'
|
||||
if not param_list:
|
||||
docstring += " None\n"
|
||||
else:
|
||||
# 首先记录必需参数
|
||||
for param_name in required:
|
||||
param_type = params[param_name]["type"]
|
||||
param_desc = params[param_name].get("description", "")
|
||||
docstring += f" {param_name} ({param_type}): {param_desc}\n"
|
||||
# 然后记录可选参数
|
||||
for param_name in params:
|
||||
if param_name not in required:
|
||||
param_type = params[param_name]["type"]
|
||||
param_desc = params[param_name].get("description", "")
|
||||
docstring += f" {param_name} ({param_type}, optional): {param_desc}\n"
|
||||
|
||||
docstring += ' """\n'
|
||||
|
||||
code += func_def + docstring + "\n"
|
||||
|
||||
return code.strip(), cls_name
|
||||
|
||||
|
||||
setup_prompt = """You are an agent which follow my instruction and perform desktop computer tasks as instructed.
|
||||
You have good knowledge of computer and good internet connection and assume your code will run on a computer for controlling the mouse and keyboard.
|
||||
For each step, you will get an observation of the desktop by 1) screenshot; 2) current application name; 3) accessibility tree, which is based on AT-SPI library; 4) application info; 5) last action result.
|
||||
You should first generate a plan for completing the task, confirm the previous results, reflect on the current status, then generate operations to complete the task in python-style pseudo code using the predefined functions.
|
||||
|
||||
Your output should STRICTLY follow the format:
|
||||
<think>
|
||||
{**YOUR-PLAN-AND-THINKING**}
|
||||
</think>
|
||||
```python
|
||||
{**ONE-LINE-OF-CODE**}
|
||||
```"""
|
||||
|
||||
func_def_tool_template = """You will be provided access to the following methods to interact with the UI:
|
||||
1. class Agent, a grounding agent which provides basic action space to interact with desktop.
|
||||
2. class {tool_class_name}, which provides tools to interact with the current application {app_name}.
|
||||
|
||||
Here are the defination of the classes:
|
||||
```python
|
||||
{class_content}
|
||||
```"""
|
||||
|
||||
func_def_template = """You will be provided access to the following methods to interact with the UI:
|
||||
|
||||
```python
|
||||
{class_content}
|
||||
```"""
|
||||
|
||||
note_prompt = """* Note:
|
||||
- Your code should be wrapped in ```python```, and your plan and thinking should be wrapped in <think></think>.
|
||||
- Only **ONE-LINE-OF-CODE** at a time.
|
||||
- Each code block is context independent, and variables from the previous round cannot be used in the next round.
|
||||
- Do not put anything other than python code in ```python```.
|
||||
- You **can only use the above methods to interact with the UI**, do not invent new methods.
|
||||
- Return with `Agent.exit(success=True)` immediately after the task is completed.
|
||||
- If you think cannot complete the task, **DO NOT keep repeating actions, just return with `Agent.exit(success=False)`.**
|
||||
- The computer's environment is Linux, e.g., Desktop path is '/home/user/Desktop'
|
||||
- My computer's password is '{client_password}', feel free to use it when you need sudo rights"""
|
||||
|
||||
|
||||
class Prompt:
|
||||
@staticmethod
|
||||
def construct_procedural_memory(agent_class, app_name=None, client_password="password"):
|
||||
agent_class_content = "Class Agent:"
|
||||
for attr_name in dir(agent_class):
|
||||
attr = getattr(agent_class, attr_name)
|
||||
if callable(attr) and hasattr(attr, "is_agent_action"):
|
||||
# Use inspect to get the full function signature
|
||||
signature = inspect.signature(attr)
|
||||
agent_class_content += f"""
|
||||
def {attr_name}{signature}:
|
||||
'''{attr.__doc__}'''
|
||||
"""
|
||||
|
||||
if app_name is not None:
|
||||
tool_path = os.path.join(current_dir, "tools", "apis", f"{app_name.lower()}.json")
|
||||
with open(tool_path, "r") as f:
|
||||
json_data = json.load(f)
|
||||
|
||||
tool_class_content, tool_class_name = generate_func(json_data)
|
||||
|
||||
agent_class_content += "\n\n{}".format(tool_class_content)
|
||||
func_def_prompt = func_def_tool_template.format(
|
||||
class_content=agent_class_content.strip(),
|
||||
tool_class_name=tool_class_name,
|
||||
app_name=app_name,
|
||||
client_password=client_password,
|
||||
)
|
||||
else:
|
||||
func_def_prompt = func_def_template.format(class_content=agent_class_content.strip())
|
||||
note_prompt_formatted = note_prompt.format(client_password=client_password)
|
||||
|
||||
# procedural_memory = f"{setup_prompt}\n\n{func_def_prompt}\n\n{note_prompt}".strip()
|
||||
# return procedural_memory
|
||||
return setup_prompt, func_def_prompt, note_prompt_formatted
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from grounding_agent import GroundingAgent
|
||||
|
||||
print(Prompt.construct_procedural_memory(GroundingAgent, "vlc"))
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
from .func import generate_func
|
||||
|
||||
__all__ = ["generate_func"]
|
||||
|
|
@ -1,260 +0,0 @@
|
|||
[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "CodeTools.launch_vscode",
|
||||
"description": "Launches Visual Studio Code with the specified file path or directory",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "The file path or directory to open in VS Code"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"path"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "CodeTools.compare_files",
|
||||
"description": "Compares two files in VSCode",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file1": {
|
||||
"type": "string",
|
||||
"description": "The path to the first file"
|
||||
},
|
||||
"file2": {
|
||||
"type": "string",
|
||||
"description": "The path to the second file"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"file1",
|
||||
"file2"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "CodeTools.add_folder",
|
||||
"description": "Adds a folder to the last active window in VSCode",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"folder": {
|
||||
"type": "string",
|
||||
"description": "The folder path to add"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"folder"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "CodeTools.goto_file",
|
||||
"description": "Opens a file at a specific line and character position",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "The file path to open"
|
||||
},
|
||||
"line": {
|
||||
"type": "integer",
|
||||
"description": "The line number to navigate to",
|
||||
"default": 1
|
||||
},
|
||||
"character": {
|
||||
"type": "integer",
|
||||
"description": "The character position to navigate to",
|
||||
"default": 1
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"file_path"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "CodeTools.perform_merge",
|
||||
"description": "Perform a three-way merge",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path1": {
|
||||
"type": "string",
|
||||
"description": "The path to the first version file"
|
||||
},
|
||||
"path2": {
|
||||
"type": "string",
|
||||
"description": "The path to the second version file"
|
||||
},
|
||||
"base": {
|
||||
"type": "string",
|
||||
"description": "The path to the base version file"
|
||||
},
|
||||
"result": {
|
||||
"type": "string",
|
||||
"description": "The path to save the merged result"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"path1",
|
||||
"path2",
|
||||
"base",
|
||||
"result"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "CodeTools.remove_folder",
|
||||
"description": "Removes a folder from the last active window in VSCode",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"folder": {
|
||||
"type": "string",
|
||||
"description": "The folder path to remove"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"folder"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "CodeTools.install_extension",
|
||||
"description": "Installs an extension or updates it in VSCode",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"extension_id": {
|
||||
"type": "string",
|
||||
"description": "The identifier of the extension"
|
||||
},
|
||||
"pre_release": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to install the pre-release version",
|
||||
"default": false
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"extension_id"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "CodeTools.uninstall_extension",
|
||||
"description": "Uninstalls an extension from VSCode",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"extension_id": {
|
||||
"type": "string",
|
||||
"description": "The identifier of the extension"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"extension_id"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "CodeTools.list_extensions",
|
||||
"description": "Lists installed extensions in VSCode",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"show_versions": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to show extension versions",
|
||||
"default": false
|
||||
},
|
||||
"category": {
|
||||
"type": "string",
|
||||
"description": "The category to filter extensions by"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "CodeTools.update_extensions",
|
||||
"description": "Updates all installed extensions in VSCode to the latest version",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "CodeTools.disable_extension",
|
||||
"description": "Disables a specific extension for the next instance of VSCode",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"extension_id": {
|
||||
"type": "string",
|
||||
"description": "The identifier of the extension"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"extension_id"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "CodeTools.toggle_sync",
|
||||
"description": "Toggles synchronization on or off in VSCode",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"state": {
|
||||
"type": "string",
|
||||
"description": "The state to set ('on' or 'off')",
|
||||
"enum": ["on", "off"]
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"state"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
|
@ -1,117 +0,0 @@
|
|||
def generate_func(json_data):
|
||||
# 收集所有类名和它们的函数
|
||||
class_funcs = {}
|
||||
no_class_funcs = []
|
||||
|
||||
for item in json_data:
|
||||
if item["type"] == "function":
|
||||
func = item["function"]
|
||||
func_parts = func["name"].split(".")
|
||||
|
||||
if len(func_parts) == 2:
|
||||
class_name, func_name = func_parts
|
||||
if class_name not in class_funcs:
|
||||
class_funcs[class_name] = []
|
||||
class_funcs[class_name].append(item)
|
||||
else:
|
||||
no_class_funcs.append(item)
|
||||
|
||||
code = ""
|
||||
|
||||
# 生成有类的函数
|
||||
for class_name, funcs in class_funcs.items():
|
||||
code += f"class {class_name}:\n"
|
||||
for item in funcs:
|
||||
func = item["function"]
|
||||
func_name = func["name"].split(".")[-1]
|
||||
description = func["description"]
|
||||
params = func["parameters"]["properties"]
|
||||
required = func["parameters"].get("required", [])
|
||||
|
||||
# 构建参数列表
|
||||
param_list = ["cls"]
|
||||
# 首先添加必需参数
|
||||
for param_name in required:
|
||||
param_list.append(f"{param_name}")
|
||||
# 然后添加可选参数
|
||||
for param_name in params:
|
||||
if param_name not in required:
|
||||
param_list.append(f"{param_name}") # 可选参数默认值设为None
|
||||
|
||||
# 构建函数定义
|
||||
func_def = f" def {func_name}({', '.join(param_list)}):\n"
|
||||
|
||||
# 构建文档字符串
|
||||
docstring = f' """\n {description}\n\n Args:\n'
|
||||
if len(param_list) == 1: # 只有cls参数
|
||||
docstring += " None\n"
|
||||
else:
|
||||
# 首先记录必需参数
|
||||
for param_name in required:
|
||||
param_type = params[param_name]["type"]
|
||||
param_desc = params[param_name].get("description", "")
|
||||
docstring += f" {param_name} ({param_type}): {param_desc}\n"
|
||||
# 然后记录可选参数
|
||||
for param_name in params:
|
||||
if param_name not in required:
|
||||
param_type = params[param_name]["type"]
|
||||
param_desc = params[param_name].get("description", "")
|
||||
docstring += f" {param_name} ({param_type}, optional): {param_desc}\n"
|
||||
|
||||
docstring += ' """\n'
|
||||
|
||||
code += func_def + docstring + "\n"
|
||||
|
||||
code += "\n"
|
||||
|
||||
# 生成没有类的函数
|
||||
for item in no_class_funcs:
|
||||
func = item["function"]
|
||||
func_name = func["name"]
|
||||
description = func["description"]
|
||||
params = func["parameters"]["properties"]
|
||||
required = func["parameters"].get("required", [])
|
||||
|
||||
# 构建参数列表
|
||||
param_list = []
|
||||
# 首先添加必需参数
|
||||
for param_name in required:
|
||||
param_list.append(f"{param_name}")
|
||||
# 然后添加可选参数
|
||||
for param_name in params:
|
||||
if param_name not in required:
|
||||
param_list.append(f"{param_name}")
|
||||
|
||||
# 构建函数定义
|
||||
func_def = f"def {func_name}({', '.join(param_list)}):\n"
|
||||
|
||||
# 构建文档字符串
|
||||
docstring = f' """\n {description}\n\n Args:\n'
|
||||
if not param_list:
|
||||
docstring += " None\n"
|
||||
else:
|
||||
# 首先记录必需参数
|
||||
for param_name in required:
|
||||
param_type = params[param_name]["type"]
|
||||
param_desc = params[param_name].get("description", "")
|
||||
docstring += f" {param_name} ({param_type}): {param_desc}\n"
|
||||
# 然后记录可选参数
|
||||
for param_name in params:
|
||||
if param_name not in required:
|
||||
param_type = params[param_name]["type"]
|
||||
param_desc = params[param_name].get("description", "")
|
||||
docstring += f" {param_name} ({param_type}, optional): {param_desc}\n"
|
||||
|
||||
docstring += ' """\n'
|
||||
|
||||
code += func_def + docstring + "\n"
|
||||
|
||||
return code.strip()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import json
|
||||
|
||||
with open("libreoffice_calc.json", "r") as f:
|
||||
json_data = json.load(f)
|
||||
print(generate_func(json_data))
|
||||
|
|
@ -1,134 +0,0 @@
|
|||
[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "BrowserTools.open_profile_settings",
|
||||
"description": "Opens the profile settings page in the browser.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "BrowserTools.open_password_settings",
|
||||
"description": "Opens the password/autofill settings page in the browser.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "BrowserTools.open_privacy_settings",
|
||||
"description": "Opens the privacy settings page in the browser.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "BrowserTools.open_appearance_settings",
|
||||
"description": "Opens the appearance settings page in the browser.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "BrowserTools.open_search_engine_settings",
|
||||
"description": "Opens the search engine settings page in the browser.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "BrowserTools.bring_back_last_tab",
|
||||
"description": "Restores the last-closed tab in the browser (equivalent to Ctrl+Shift+T).",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "BrowserTools.print",
|
||||
"description": "Opens the print dialog for the current browser page (equivalent to Ctrl+P).",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "BrowserTools.delete_browsing_data",
|
||||
"description": "Opens the 'Clear browsing data' dialog in the browser (equivalent to Ctrl+Shift+Del).",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "BrowserTools.open_extensions",
|
||||
"description": "Opens the extensions management page in the browser.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "BrowserTools.bookmark_page",
|
||||
"description": "Bookmarks the current page in the browser (equivalent to Ctrl+D).",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "BrowserTools.open_bookmarks",
|
||||
"description": "Opens the bookmarks page in the browser.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
|
@ -1,634 +0,0 @@
|
|||
[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "CalcTools.get_workbook_info",
|
||||
"description": "Get workbook information, including file path, file name, sheets and active sheet.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "CalcTools.save",
|
||||
"description": "Save the current workbook to its current location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "CalcTools.get_column_data",
|
||||
"description": "Get all data from the specified column.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"column_name": {
|
||||
"type": "string",
|
||||
"description": "Name of the column to read (e.g. 'A', 'B', etc.)"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"column_name"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "CalcTools.switch_active_sheet",
|
||||
"description": "Switch to the specified sheet and make it active. Creates new sheet if it doesn't exist.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"sheet_name": {
|
||||
"type": "string",
|
||||
"description": "Name of the sheet to switch to or create"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"sheet_name"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "CalcTools.set_column_values",
|
||||
"description": "Set values to the specified column, cannot be used to set formulas.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"column_name": {
|
||||
"type": "string",
|
||||
"description": "Name of the column (e.g. 'A', 'B', etc.) to write to"
|
||||
},
|
||||
"data": {
|
||||
"type": "array",
|
||||
"description": "List of values to write to the column"
|
||||
},
|
||||
"start_index": {
|
||||
"type": "integer",
|
||||
"description": "The index of the first row to write to, default is 2 (skip the first row)"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"column_name",
|
||||
"data"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "CalcTools.highlight_range",
|
||||
"description": "Highlight the specified range with the specified color.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"range_str": {
|
||||
"type": "string",
|
||||
"description": "Range to highlight, in the format of 'A1:B10'"
|
||||
},
|
||||
"color": {
|
||||
"type": "integer",
|
||||
"description": "Color to highlight with, default is 0xFF0000 (red)"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"range_str"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "CalcTools.transpose_range",
|
||||
"description": "Transpose the specified range and paste it to the target cell.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"source_range": {
|
||||
"type": "string",
|
||||
"description": "Range to transpose, in the format of 'A1:B10'"
|
||||
},
|
||||
"target_cell": {
|
||||
"type": "string",
|
||||
"description": "Target cell to paste the transposed data, in the format of 'A1'"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"source_range",
|
||||
"target_cell"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "CalcTools.export_to_csv",
|
||||
"description": "Export the current document to a CSV file with the same path and name as the original file.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "CalcTools.sort_column",
|
||||
"description": "Sort the data in the specified column in ascending or descending order.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"column_name": {
|
||||
"type": "string",
|
||||
"description": "The name of the column to sort (e.g. 'A', 'B', etc.)"
|
||||
},
|
||||
"ascending": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to sort in ascending order (default True)"
|
||||
},
|
||||
"start_index": {
|
||||
"type": "integer",
|
||||
"description": "The index of the first row to sort, default is 2 (skip the first row)"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"column_name"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "CalcTools.set_validation_list",
|
||||
"description": "Set a validation list for the specified column.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"column_name": {
|
||||
"type": "string",
|
||||
"description": "The name of the column (e.g. 'A', 'B', etc.) to set the validation list for"
|
||||
},
|
||||
"values": {
|
||||
"type": "array",
|
||||
"description": "The list of values to use for the validation list"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"column_name",
|
||||
"values"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "CalcTools.hide_row_data",
|
||||
"description": "Hide rows that contain the specified value.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"value": {
|
||||
"type": "string",
|
||||
"description": "The value to hide rows for, default is 'N/A'"
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "CalcTools.reorder_columns",
|
||||
"description": "Reorder the columns in the sheet according to the specified order.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"column_order": {
|
||||
"type": "array",
|
||||
"description": "A list of column names in the desired order (e.g. ['A', 'B', 'C'])"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"column_order"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "CalcTools.create_pivot_table",
|
||||
"description": "Create a pivot table in the active worksheet based on data from the source sheet.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"source_sheet": {
|
||||
"type": "string",
|
||||
"description": "Name of the source sheet containing the data"
|
||||
},
|
||||
"table_name": {
|
||||
"type": "string",
|
||||
"description": "Name for the new pivot table"
|
||||
},
|
||||
"row_fields": {
|
||||
"type": "array",
|
||||
"description": "List of fields to use as row labels (e.g. ['A', 'B', 'C'])"
|
||||
},
|
||||
"col_fields": {
|
||||
"type": "array",
|
||||
"description": "List of fields to use as column labels (e.g. ['A', 'B', 'C'])"
|
||||
},
|
||||
"value_fields": {
|
||||
"type": "array",
|
||||
"description": "List of fields to use as values (e.g. ['A', 'B', 'C'])"
|
||||
},
|
||||
"aggregation_function": {
|
||||
"type": "string",
|
||||
"description": "Aggregation function to use (sum, count, average, min, max), default is 'sum'"
|
||||
},
|
||||
"target_cell": {
|
||||
"type": "string",
|
||||
"description": "Target cell for the pivot table, default is 'A1'"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"source_sheet",
|
||||
"table_name",
|
||||
"value_fields"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "CalcTools.merge_cells",
|
||||
"description": "Merge cells in the specified range.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"range_str": {
|
||||
"type": "string",
|
||||
"description": "Range of cells to merge, in format 'A1:B10'"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"range_str"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "CalcTools.set_cell_value",
|
||||
"description": "Set a value to a specific cell in the active worksheet.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"cell": {
|
||||
"type": "string",
|
||||
"description": "Cell reference (e.g., 'A1')"
|
||||
},
|
||||
"value": {
|
||||
"type": "string",
|
||||
"description": "Value to set in the cell"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"cell",
|
||||
"value"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "CalcTools.format_range",
|
||||
"description": "Apply formatting to the specified range in the active worksheet.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"range_str": {
|
||||
"type": "string",
|
||||
"description": "Range to format, in the format of 'A1:B10'"
|
||||
},
|
||||
"background_color": {
|
||||
"type": "string",
|
||||
"description": "Background color in hex format (e.g., '#0000ff')"
|
||||
},
|
||||
"font_color": {
|
||||
"type": "string",
|
||||
"description": "Font color in hex format (e.g., '#ffffff')"
|
||||
},
|
||||
"bold": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to make the text bold"
|
||||
},
|
||||
"alignment": {
|
||||
"type": "string",
|
||||
"description": "Text alignment (left, center, right)"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"range_str"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "CalcTools.create_chart",
|
||||
"description": "Create a chart in the active worksheet based on the specified data range.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"chart_type": {
|
||||
"type": "string",
|
||||
"description": "Type of chart (bar, column, line, pie, scatter, area)"
|
||||
},
|
||||
"data_range": {
|
||||
"type": "string",
|
||||
"description": "Range containing the data for the chart, in the format of 'A1:B10'"
|
||||
},
|
||||
"title": {
|
||||
"type": "string",
|
||||
"description": "Title for the chart"
|
||||
},
|
||||
"x_axis_title": {
|
||||
"type": "string",
|
||||
"description": "Title for the X axis"
|
||||
},
|
||||
"y_axis_title": {
|
||||
"type": "string",
|
||||
"description": "Title for the Y axis"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"chart_type",
|
||||
"data_range"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "CalcTools.freeze_panes",
|
||||
"description": "Freeze rows and/or columns in the active worksheet.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"rows": {
|
||||
"type": "integer",
|
||||
"description": "Number of rows to freeze from the top"
|
||||
},
|
||||
"columns": {
|
||||
"type": "integer",
|
||||
"description": "Number of columns to freeze from the left"
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "CalcTools.rename_sheet",
|
||||
"description": "Rename a worksheet in the workbook.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"old_name": {
|
||||
"type": "string",
|
||||
"description": "Current name of the worksheet"
|
||||
},
|
||||
"new_name": {
|
||||
"type": "string",
|
||||
"description": "New name for the worksheet"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"old_name",
|
||||
"new_name"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "CalcTools.copy_sheet",
|
||||
"description": "Create a copy of an existing worksheet in the workbook.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"source_sheet": {
|
||||
"type": "string",
|
||||
"description": "Name of the worksheet to copy"
|
||||
},
|
||||
"new_sheet_name": {
|
||||
"type": "string",
|
||||
"description": "Name for the new worksheet copy (optional)"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"source_sheet"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "CalcTools.reorder_sheets",
|
||||
"description": "Change the order of worksheets in the workbook.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"sheet_name": {
|
||||
"type": "string",
|
||||
"description": "Name of the worksheet to move"
|
||||
},
|
||||
"position": {
|
||||
"type": "integer",
|
||||
"description": "New position index (0-based) for the worksheet"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"sheet_name",
|
||||
"position"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "CalcTools.set_chart_legend_position",
|
||||
"description": "Set the position of the legend in a chart.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"position": {
|
||||
"type": "string",
|
||||
"description": "Position of the legend (top, bottom, left, right, none)"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"position"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "CalcTools.set_number_format",
|
||||
"description": "Apply a specific number format to a range of cells.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"range_str": {
|
||||
"type": "string",
|
||||
"description": "Range to format, in the format of 'A1:B10'"
|
||||
},
|
||||
"format_type": {
|
||||
"type": "string",
|
||||
"description": "Type of number format (general, number, currency, accounting, date, time, percentage, fraction, scientific, text)"
|
||||
},
|
||||
"decimal_places": {
|
||||
"type": "integer",
|
||||
"description": "Number of decimal places to display (optional)"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"range_str",
|
||||
"format_type"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "CalcTools.adjust_column_width",
|
||||
"description": "Adjust the width of specified columns.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"columns": {
|
||||
"type": "string",
|
||||
"description": "Column range to adjust (e.g., 'A:C')"
|
||||
},
|
||||
"width": {
|
||||
"type": "number",
|
||||
"description": "Width to set (in characters)"
|
||||
},
|
||||
"autofit": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to autofit columns to content"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"columns"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "CalcTools.adjust_row_height",
|
||||
"description": "Adjust the height of specified rows.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"rows": {
|
||||
"type": "string",
|
||||
"description": "Row range to adjust (e.g., '1:10')"
|
||||
},
|
||||
"height": {
|
||||
"type": "number",
|
||||
"description": "Height to set (in points)"
|
||||
},
|
||||
"autofit": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to autofit rows to content"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"rows"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "CalcTools.export_to_pdf",
|
||||
"description": "Export the current document or specified sheets to PDF.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "Path where to save the PDF file, default is the same path as the original file"
|
||||
},
|
||||
"sheets": {
|
||||
"type": "array",
|
||||
"description": "List of sheet names to include in PDF, default is all sheets"
|
||||
},
|
||||
"open_after_export": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to open the PDF after export, default is False"
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "CalcTools.set_zoom_level",
|
||||
"description": "Adjust the zoom level of the current worksheet.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"zoom_percentage": {
|
||||
"type": "integer",
|
||||
"description": "Zoom level as a percentage (e.g., 75 for 75%, 100 for normal size, 150 for zoomed in). Valid range is typically 10-400."
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"zoom_percentage"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
|
@ -1,569 +0,0 @@
|
|||
[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "ImpressTools.save",
|
||||
"description": "Save the current presentation to its current location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "ImpressTools.go_to_slide",
|
||||
"description": "Navigates to a specific slide in the presentation based on its index.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"slide_index": {
|
||||
"type": "integer",
|
||||
"description": "The index of the slide to navigate to (1-based indexing)"
|
||||
}
|
||||
},
|
||||
"required": ["slide_index"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "ImpressTools.get_slide_count",
|
||||
"description": "Gets the total number of slides in the current presentation.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "ImpressTools.duplicate_slide",
|
||||
"description": "Creates a duplicate of a specific slide and places it at the end of the presentation.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"slide_index": {
|
||||
"type": "integer",
|
||||
"description": "The index of the slide to duplicate (1-based indexing)"
|
||||
}
|
||||
},
|
||||
"required": ["slide_index"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "ImpressTools.set_slide_font",
|
||||
"description": "Sets the font style for all text elements in a specific slide.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"slide_index": {
|
||||
"type": "integer",
|
||||
"description": "The index of the slide to modify (1-based indexing)"
|
||||
},
|
||||
"font_name": {
|
||||
"type": "string",
|
||||
"description": "The name of the font to apply (e.g., 'Arial', 'Times New Roman', 'Calibri')"
|
||||
}
|
||||
},
|
||||
"required": ["slide_index", "font_name"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "ImpressTools.write_text",
|
||||
"description": "writes text to a specific textbox on a slide",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The text content of the note to add"
|
||||
},
|
||||
"page_index": {
|
||||
"type": "integer",
|
||||
"description": "The index of the slide to add a note to (1-based indexing)"
|
||||
},
|
||||
"box_index": {
|
||||
"type": "integer",
|
||||
"description": "The index of the textbox to modify (0-based indexing)"
|
||||
},
|
||||
"bold": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to make the text bold, default is false"
|
||||
},
|
||||
"italic": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to make the text italic, default is false"
|
||||
},
|
||||
"size": {
|
||||
"type": "integer",
|
||||
"description": "The size of the text. If None, uses the box's current font size."
|
||||
},
|
||||
"append": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to append the text, default is False. If you want to observe some formats(like a bullet at the beginning) or keep the original text, you should set up it."
|
||||
}
|
||||
},
|
||||
"required": ["content", "page_index", "box_index"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "ImpressTools.set_style",
|
||||
"description": "Sets the style properties for the specified textbox on a slide.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"slide_index": {
|
||||
"type": "integer",
|
||||
"description": "The index of the slide to modify (1-based indexing)"
|
||||
},
|
||||
"box_index": {
|
||||
"type": "integer",
|
||||
"description": "The index of the textbox to modify (0-based indexing)"
|
||||
},
|
||||
"bold": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to make the title text bold"
|
||||
},
|
||||
"italic": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to make the title text italic"
|
||||
},
|
||||
"underline": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to underline the title text"
|
||||
}
|
||||
},
|
||||
"required": ["slide_index", "box_index"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "ImpressTools.configure_auto_save",
|
||||
"description": "Enables or disables auto-save functionality for the current document and sets the auto-save interval.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"enabled": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to enable (true) or disable (false) auto-save"
|
||||
},
|
||||
"interval_minutes": {
|
||||
"type": "number",
|
||||
"description": "The interval in minutes between auto-saves (minimum 1 minute)"
|
||||
}
|
||||
},
|
||||
"required": ["enabled", "interval_minutes"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "ImpressTools.set_background_color",
|
||||
"description": "Sets the background color for the specified textbox on a slide.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"slide_index": {
|
||||
"type": "integer",
|
||||
"description": "The index of the slide containing the textbox (1-based indexing)"
|
||||
},
|
||||
"box_index": {
|
||||
"type": "integer",
|
||||
"description": "The index of the textbox to modify (0-based indexing)"
|
||||
},
|
||||
"color": {
|
||||
"type": "string",
|
||||
"description": "The color to apply to the textbox (e.g., 'red', 'green', 'blue', 'yellow', or hex color code)"
|
||||
}
|
||||
},
|
||||
"required": ["slide_index", "box_index", "color"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "ImpressTools.set_text_color",
|
||||
"description": "Sets the text color for the specified textbox on a slide.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"slide_index": {
|
||||
"type": "integer",
|
||||
"description": "The index of the slide to modify (1-based indexing)"
|
||||
},
|
||||
"box_index": {
|
||||
"type": "integer",
|
||||
"description": "The index of the textbox to modify (0-based indexing)"
|
||||
},
|
||||
"color": {
|
||||
"type": "string",
|
||||
"description": "The color to apply to the title text (e.g., 'red', 'green', 'blue', 'black', or hex color code)"
|
||||
}
|
||||
},
|
||||
"required": ["slide_index", "box_index", "color"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "ImpressTools.delete_content",
|
||||
"description": "Deletes the specified textbox from a slide.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"slide_index": {
|
||||
"type": "integer",
|
||||
"description": "The index of the slide to modify (1-based indexing)"
|
||||
},
|
||||
"box_index": {
|
||||
"type": "integer",
|
||||
"description": "The index of the textbox to modify (0-based indexing)"
|
||||
}
|
||||
},
|
||||
"required": ["slide_index", "box_index"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "ImpressTools.set_slide_orientation",
|
||||
"description": "Changes the orientation of slides in the presentation between portrait (upright) and landscape (sideways).",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"orientation": {
|
||||
"type": "string",
|
||||
"description": "The desired orientation for the slides",
|
||||
"enum": ["portrait", "landscape"]
|
||||
}
|
||||
},
|
||||
"required": ["orientation"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "ImpressTools.position_box",
|
||||
"description": "Positions a textbox or image on a slide at a specific location or predefined position.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"slide_index": {
|
||||
"type": "integer",
|
||||
"description": "The index of the slide containing the box (1-based indexing)"
|
||||
},
|
||||
"box_index": {
|
||||
"type": "integer",
|
||||
"description": "The index of the box to position (0-based indexing)"
|
||||
},
|
||||
"position": {
|
||||
"type": "string",
|
||||
"description": "Predefined position on the slide (left, right, center, top, bottom)",
|
||||
"enum": [
|
||||
"left",
|
||||
"right",
|
||||
"center",
|
||||
"top",
|
||||
"bottom",
|
||||
"top-left",
|
||||
"top-right",
|
||||
"bottom-left",
|
||||
"bottom-right"
|
||||
]
|
||||
}
|
||||
},
|
||||
"required": ["slide_index", "box_index", "position"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "ImpressTools.insert_file",
|
||||
"description": "Inserts a video or audio file into the current or specified slide in the presentation.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "The full path to the file to be inserted"
|
||||
},
|
||||
"slide_index": {
|
||||
"type": "integer",
|
||||
"description": "The index of the slide to insert the file into (1-based indexing). If not provided, inserts into the current slide."
|
||||
},
|
||||
"position": {
|
||||
"type": "object",
|
||||
"description": "The position coordinates for the file",
|
||||
"properties": {
|
||||
"x": {
|
||||
"type": "number",
|
||||
"description": "The x-coordinate (horizontal position) as a percentage of slide width"
|
||||
},
|
||||
"y": {
|
||||
"type": "number",
|
||||
"description": "The y-coordinate (vertical position) as a percentage of slide height"
|
||||
}
|
||||
}
|
||||
},
|
||||
"size": {
|
||||
"type": "object",
|
||||
"description": "The size dimensions for the file",
|
||||
"properties": {
|
||||
"width": {
|
||||
"type": "number",
|
||||
"description": "The width as a percentage of slide width"
|
||||
},
|
||||
"height": {
|
||||
"type": "number",
|
||||
"description": "The height as a percentage of slide height"
|
||||
}
|
||||
}
|
||||
},
|
||||
"autoplay": {
|
||||
"type": "boolean",
|
||||
"description": "Whether the video or audio should automatically play when the slide is shown"
|
||||
}
|
||||
},
|
||||
"required": ["file_path"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "ImpressTools.set_slide_background",
|
||||
"description": "Sets the background color or image for a specific slide or all slides.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"slide_index": {
|
||||
"type": "integer",
|
||||
"description": "The index of the slide to modify (1-based indexing). If not provided, applies to all slides."
|
||||
},
|
||||
"color": {
|
||||
"type": "string",
|
||||
"description": "The background color to apply (e.g., 'red', 'green', 'blue', or hex color code)"
|
||||
},
|
||||
"image_path": {
|
||||
"type": "string",
|
||||
"description": "Path to an image file to use as background. If provided, overrides color."
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "ImpressTools.save_as",
|
||||
"description": "Saves the current document to a specified location with a given filename.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "The full path where the file should be saved, including the filename and extension"
|
||||
},
|
||||
"overwrite": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to overwrite the file if it already exists (default: false)"
|
||||
}
|
||||
},
|
||||
"required": ["file_path"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "ImpressTools.insert_image",
|
||||
"description": "Inserts an image to a specific slide in the presentation.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"slide_index": {
|
||||
"type": "integer",
|
||||
"description": "The index of the slide to add the image to (1-based indexing)"
|
||||
},
|
||||
"image_path": {
|
||||
"type": "string",
|
||||
"description": "The full path to the image file to be added"
|
||||
},
|
||||
"width": {
|
||||
"type": "number",
|
||||
"description": "The width of the image in centimeters"
|
||||
},
|
||||
"height": {
|
||||
"type": "number",
|
||||
"description": "The height of the image in centimeters"
|
||||
},
|
||||
"position": {
|
||||
"type": "object",
|
||||
"description": "The position coordinates for the image",
|
||||
"properties": {
|
||||
"x": {
|
||||
"type": "number",
|
||||
"description": "The x-coordinate (horizontal position) as a percentage of slide width"
|
||||
},
|
||||
"y": {
|
||||
"type": "number",
|
||||
"description": "The y-coordinate (vertical position) as a percentage of slide height"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["slide_index", "image_path"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "ImpressTools.configure_display_settings",
|
||||
"description": "Configures the display settings for LibreOffice Impress presentations, including monitor usage and presenter view options.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"use_presenter_view": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to use presenter view (showing current and next slide on one screen). Set to false to disable presenter view."
|
||||
},
|
||||
"primary_monitor_only": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to use only the primary monitor for the presentation. Set to true to use only one screen."
|
||||
},
|
||||
"monitor_for_presentation": {
|
||||
"type": "integer",
|
||||
"description": "Specify which monitor to use for the presentation (1 for primary monitor, 2 for secondary monitor, etc.)"
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "ImpressTools.set_slide_number_color",
|
||||
"description": "Sets the color of the slide number in the presentation.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"color": {
|
||||
"type": "string",
|
||||
"description": "The color to apply to slide numbers (e.g., 'red', 'green', 'blue', 'black', or hex color code)"
|
||||
}
|
||||
},
|
||||
"required": ["color"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "ImpressTools.set_text_strikethrough",
|
||||
"description": "Applies or removes strike-through formatting to specific text content in a slide.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"slide_index": {
|
||||
"type": "integer",
|
||||
"description": "The index of the slide containing the text (1-based indexing)"
|
||||
},
|
||||
"box_index": {
|
||||
"type": "integer",
|
||||
"description": "The index of the textbox containing the text (0-based indexing)"
|
||||
},
|
||||
"line_numbers": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "integer"
|
||||
},
|
||||
"description": "The line numbers to apply strike-through formatting to (1-based indexing)"
|
||||
},
|
||||
"apply": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to apply (true) or remove (false) strike-through formatting"
|
||||
}
|
||||
},
|
||||
"required": ["slide_index", "box_index", "line_numbers", "apply"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "ImpressTools.set_textbox_alignment",
|
||||
"description": "Sets the text alignment for the specified textbox on a slide.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"slide_index": {
|
||||
"type": "integer",
|
||||
"description": "The index of the slide to modify (1-based indexing)"
|
||||
},
|
||||
"box_index": {
|
||||
"type": "integer",
|
||||
"description": "The index of the textbox to modify (0-based indexing)"
|
||||
},
|
||||
"alignment": {
|
||||
"type": "string",
|
||||
"description": "The text alignment to apply to the title",
|
||||
"enum": ["left", "center", "right", "justify"]
|
||||
}
|
||||
},
|
||||
"required": ["slide_index", "box_index", "alignment"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "ImpressTools.export_to_image",
|
||||
"description": "Exports the current presentation or a specific slide to an image file format (PNG, JPEG, etc.).",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "The full path where the image file should be saved, including the filename and extension"
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"description": "The image format to export to (e.g., 'png', 'jpeg', 'gif')",
|
||||
"enum": ["png", "jpeg", "jpg", "gif", "bmp", "tiff"]
|
||||
},
|
||||
"slide_index": {
|
||||
"type": "integer",
|
||||
"description": "The index of the specific slide to export (1-based indexing). If not provided, exports the entire presentation as a series of images."
|
||||
}
|
||||
},
|
||||
"required": ["file_path", "format"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
|
@ -1,412 +0,0 @@
|
|||
[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "WriterTools.save",
|
||||
"description": "Save the current document to its current location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "WriterTools.write_text",
|
||||
"description": "Write text at the current cursor position in the document.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "The text to write at the current cursor position."
|
||||
},
|
||||
"bold": {
|
||||
"type": "boolean",
|
||||
"description": "Optional. Whether to write the text in bold."
|
||||
},
|
||||
"italic": {
|
||||
"type": "boolean",
|
||||
"description": "Optional. Whether to write the text in italic."
|
||||
},
|
||||
"size": {
|
||||
"type": "number",
|
||||
"description": "Optional. The size of the text."
|
||||
}
|
||||
},
|
||||
"required": ["text"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "WriterTools.set_color",
|
||||
"description": "Changes the color of matched text in the document for specified paragraphs.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "The pattern to match in the document, should be a regular expression"
|
||||
},
|
||||
"color": {
|
||||
"type": "number",
|
||||
"description": "The color to apply, should be a hex color code, like 0x000000 for black"
|
||||
},
|
||||
"paragraph_indices": {
|
||||
"type": "array",
|
||||
"description": "Optional. Indices of paragraphs to modify (0-based indexing). If not provided, applies to all paragraphs."
|
||||
}
|
||||
},
|
||||
"required": ["pattern", "color"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "WriterTools.find_and_replace",
|
||||
"description": "Finds all occurrences of a specified text pattern and replaces them with another text in the document.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "The pattern to match in the document, should be a regular expression"
|
||||
},
|
||||
"replacement": {
|
||||
"type": "string",
|
||||
"description": "The text to replace the found text with."
|
||||
},
|
||||
"paragraph_indices": {
|
||||
"type": "array",
|
||||
"description": "Optional. Indices of paragraphs to modify (0-based indexing). If not provided, applies to all paragraphs."
|
||||
}
|
||||
},
|
||||
"required": ["pattern", "replacement"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "WriterTools.set_font",
|
||||
"description": "Changes the font of text in the document or specified paragraphs.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"font_name": {
|
||||
"type": "string",
|
||||
"description": "The name of the font to apply (e.g., 'Times New Roman', 'Arial', 'Calibri')"
|
||||
},
|
||||
"paragraph_indices": {
|
||||
"type": "array",
|
||||
"description": "Optional. Indices of paragraphs to modify (0-based indexing). If not provided, applies to all paragraphs."
|
||||
}
|
||||
},
|
||||
"required": ["font_name"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "WriterTools.set_line_spacing",
|
||||
"description": "Sets the line spacing for specified paragraphs in the document.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"spacing_value": {
|
||||
"type": "number",
|
||||
"description": "The line spacing value to apply (1.0 for single spacing, 2.0 for double spacing, etc.)."
|
||||
},
|
||||
"paragraph_indices": {
|
||||
"type": "array",
|
||||
"description": "Optional. Indices of paragraphs to modify (0-based indexing). If not provided, applies to all paragraphs."
|
||||
}
|
||||
},
|
||||
"required": ["spacing_value"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "WriterTools.remove_highlighting",
|
||||
"description": "Removes highlighting from text in the document for specified paragraphs.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"paragraph_indices": {
|
||||
"type": "array",
|
||||
"description": "Optional. Indices of paragraphs to modify (0-based indexing). If not provided, applies to all paragraphs."
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "WriterTools.find_highlighted_text",
|
||||
"description": "Finds all text in the document that has a specific highlight color applied to it.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"highlight_color": {
|
||||
"type": "string",
|
||||
"description": "The highlight color to search for. Can be a color name (e.g., 'yellow', 'green') or hex code."
|
||||
}
|
||||
},
|
||||
"required": ["highlight_color"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "WriterTools.insert_formula_at_cursor",
|
||||
"description": "Inserts a formula at the current cursor position in the document.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"formula": {
|
||||
"type": "string",
|
||||
"description": "The formula to insert at the current cursor position."
|
||||
}
|
||||
},
|
||||
"required": ["formula"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "WriterTools.insert_image_at_cursor",
|
||||
"description": "Inserts an image at the current cursor position in the document.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image_path": {
|
||||
"type": "string",
|
||||
"description": "Full path to the image file to insert"
|
||||
},
|
||||
"width": {
|
||||
"type": "integer",
|
||||
"description": "Optional. Width to display the image in pixels. If not specified, uses the original image width."
|
||||
},
|
||||
"height": {
|
||||
"type": "integer",
|
||||
"description": "Optional. Height to display the image in pixels. If not specified, uses the original image height."
|
||||
}
|
||||
},
|
||||
"required": ["image_path"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "WriterTools.set_strikethrough",
|
||||
"description": "Sets the strikethrough formatting for specified text in the document.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "The pattern to match in the document, should be a regular expression"
|
||||
},
|
||||
"paragraph_indices": {
|
||||
"type": "array",
|
||||
"description": "Optional. Indices of paragraphs to modify (0-based indexing). If not provided, applies to all paragraphs."
|
||||
}
|
||||
},
|
||||
"required": ["pattern"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "WriterTools.set_font_size",
|
||||
"description": "Changes the font size of specified text in the document.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"font_size": {
|
||||
"type": "number",
|
||||
"description": "The font size to apply (in points)."
|
||||
},
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "The pattern to match in the document, should be a regular expression"
|
||||
},
|
||||
"paragraph_indices": {
|
||||
"type": "array",
|
||||
"description": "Optional. Indices of paragraphs to modify (0-based indexing). If not provided, applies to all paragraphs."
|
||||
}
|
||||
},
|
||||
"required": ["font_size", "pattern"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "WriterTools.export_to_pdf",
|
||||
"description": "Exports the current document to PDF format.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"output_path": {
|
||||
"type": "string",
|
||||
"description": "Optional. The full path where the PDF should be saved. If not provided, uses the same location as the original document with .pdf extension."
|
||||
},
|
||||
"output_filename": {
|
||||
"type": "string",
|
||||
"description": "Optional. The filename to use for the PDF. If not provided, uses the original document's filename with .pdf extension."
|
||||
},
|
||||
"include_comments": {
|
||||
"type": "boolean",
|
||||
"description": "Optional. Whether to include comments in the exported PDF. Defaults to false."
|
||||
},
|
||||
"quality": {
|
||||
"type": "string",
|
||||
"description": "Optional. The quality of the PDF export ('standard', 'high', 'print'). Defaults to 'standard'."
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "WriterTools.set_paragraph_alignment",
|
||||
"description": "Sets the text alignment for specified paragraphs in the document.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"alignment": {
|
||||
"type": "string",
|
||||
"description": "The alignment to apply ('left', 'center', 'right', 'justify')."
|
||||
},
|
||||
"paragraph_indices": {
|
||||
"type": "array",
|
||||
"description": "Optional. Indices of paragraphs to modify (0-based indexing). If not provided, applies to all paragraphs."
|
||||
}
|
||||
},
|
||||
"required": ["alignment"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "WriterTools.capitalize_words",
|
||||
"description": "Capitalizes the first letter of each word for specified paragraphs in the document.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"paragraph_indices": {
|
||||
"type": "array",
|
||||
"description": "Optional. Indices of paragraphs to modify (0-based indexing). If not provided, applies to all paragraphs."
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "WriterTools.set_default_font",
|
||||
"description": "Sets the default font for new text in the document without changing existing text.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"font_name": {
|
||||
"type": "string",
|
||||
"description": "The name of the font to set as default (e.g., 'Times New Roman', 'Arial', 'Calibri')"
|
||||
},
|
||||
"font_size": {
|
||||
"type": "number",
|
||||
"description": "Optional. The default font size in points."
|
||||
}
|
||||
},
|
||||
"required": ["font_name"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "WriterTools.add_page_numbers",
|
||||
"description": "Adds page numbers to the document at the specified position.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"position": {
|
||||
"type": "string",
|
||||
"description": "Position of the page numbers ('bottom_left', 'bottom_center', 'bottom_right', 'top_left', 'top_center', 'top_right')"
|
||||
},
|
||||
"start_number": {
|
||||
"type": "integer",
|
||||
"description": "Optional. The starting page number. Defaults to 1."
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"description": "Optional. Format of the page numbers (e.g., '1', 'Page 1', '1 of N'). Defaults to simple number format."
|
||||
}
|
||||
},
|
||||
"required": ["position"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "WriterTools.insert_page_break",
|
||||
"description": "Inserts a page break at the current cursor position, creating a new blank page after the current one.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"position": {
|
||||
"type": "string",
|
||||
"description": "Optional. Specifies where to insert the page break: 'at_cursor' for current cursor position, 'end_of_document' for end of document. Defaults to 'at_cursor'."
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "WriterTools.change_text_case",
|
||||
"description": "Changes the case of text in the document or a specified selection.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"case_type": {
|
||||
"type": "string",
|
||||
"description": "The type of case conversion to apply ('lowercase', 'uppercase')."
|
||||
},
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "The pattern to match in the document, should be a regular expression"
|
||||
},
|
||||
"paragraph_indices": {
|
||||
"type": "array",
|
||||
"description": "Optional. Indices of paragraphs to modify (0-based indexing). If not provided, applies to all paragraphs."
|
||||
}
|
||||
},
|
||||
"required": ["case_type", "pattern"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
|
@ -1,171 +0,0 @@
|
|||
[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "VLCTools.get_playlist",
|
||||
"description": "Gets the current VLC playlist with track information including name, URI and duration.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "VLCTools.play",
|
||||
"description": "Starts playing the current media in VLC player.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "VLCTools.pause",
|
||||
"description": "Pauses the currently playing media in VLC player.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "VLCTools.next",
|
||||
"description": "Switches to the next media item in the VLC playlist.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "VLCTools.previous",
|
||||
"description": "Switches to the previous media item in the VLC playlist.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "VLCTools.add_to_playlist",
|
||||
"description": "Adds a media file to the VLC playlist.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"uri": {
|
||||
"type": "string",
|
||||
"description": "The URI of the media file to add to the playlist, start with 'file://' or 'https://'"
|
||||
}
|
||||
},
|
||||
"required": ["uri"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "VLCTools.get_current_time",
|
||||
"description": "Gets the current playback time position of the playing media in seconds.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "VLCTools.get_media_duration",
|
||||
"description": "Gets the total duration of the currently playing media file in seconds.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "VLCTools.toggle_fullscreen",
|
||||
"description": "Toggles fullscreen mode for the currently playing video in the media player. If the video is not in fullscreen mode, it will be expanded to fill the entire screen. If it's already in fullscreen mode, it will return to windowed mode.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"enable": {
|
||||
"type": "boolean",
|
||||
"description": "Optional parameter to explicitly set fullscreen mode. If true, forces fullscreen mode. If false, exits fullscreen mode. If not provided, the current state is toggled."
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "VLCTools.get_settings",
|
||||
"description": "Gets the current settings of the VLC player.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "VLCTools.set_settings",
|
||||
"description": "Sets the settings for the VLC player.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"field": {
|
||||
"type": "string",
|
||||
"description": "The name of the setting to set. i.e. input-record-path: the path to the recording folder, qt-bgcone: disable/enable splash cone icon (in 0/1), qt-max-volume: set max volume (in number), qt-minimal-view: hide/show bottom toolbar (in 0/1), global-key-play-pause: disable/enable play&pause key (in 0/1)"
|
||||
},
|
||||
"value": {
|
||||
"type": "string",
|
||||
"description": "The value to set for the specified setting, set 0/1 for boolean values"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"field",
|
||||
"value"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "VLCTools.get_media_files",
|
||||
"description": "Gets the media files for the specified path.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "The path to the media files"
|
||||
},
|
||||
"suffix": {
|
||||
"type": "array",
|
||||
"description": "The suffix of the media files, default is ['mp4', 'avi', 'mkv', 'mov', 'mp3', 'm4a', 'wav']"
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
|
@ -1,260 +0,0 @@
|
|||
import json
|
||||
import os
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class CodeTools:
|
||||
ret = ""
|
||||
|
||||
@classmethod
|
||||
def print_result(cls):
|
||||
"""打印执行结果"""
|
||||
print(cls.ret)
|
||||
|
||||
@classmethod
|
||||
def launch_vscode(cls, path):
|
||||
"""
|
||||
Launches Visual Studio Code with the specified file path or directory.
|
||||
在存在的窗口中打开一个文件或目录。
|
||||
|
||||
Args:
|
||||
path (str): 文件路径或目录。
|
||||
"""
|
||||
try:
|
||||
subprocess.run(["code", "-r", path], check=True)
|
||||
cls.ret = "Successfully launched VS Code"
|
||||
except subprocess.CalledProcessError as e:
|
||||
cls.ret = f"Error launching VS Code: {e}"
|
||||
except Exception as e:
|
||||
cls.ret = f"Unexpected error: {e}"
|
||||
|
||||
return cls.ret
|
||||
|
||||
@classmethod
|
||||
def env_info(cls):
|
||||
cls.ret = "None"
|
||||
|
||||
@classmethod
|
||||
def compare_files(cls, file1, file2):
|
||||
"""
|
||||
Compares two files in VSCode.
|
||||
在VSCode中比较两个文件。
|
||||
|
||||
Args:
|
||||
file1 (str): 第一个文件的路径。
|
||||
file2 (str): 第二个文件的路径。
|
||||
"""
|
||||
try:
|
||||
# 获取compare结果
|
||||
subprocess.run(["code", "-d", file1, file2], check=True)
|
||||
cls.ret = "The compared files are opened in VSCode"
|
||||
except subprocess.CalledProcessError as e:
|
||||
cls.ret = f"Error comparing files: {e}"
|
||||
except Exception as e:
|
||||
cls.ret = f"Unexpected error: {e}"
|
||||
|
||||
return cls.ret
|
||||
|
||||
@classmethod
|
||||
def add_folder(cls, folder):
|
||||
"""
|
||||
Adds a folder to the last active window in VSCode.
|
||||
向VSCode的最后一个活动窗口添加文件夹。
|
||||
|
||||
Args:
|
||||
folder (str): 文件夹路径。
|
||||
"""
|
||||
try:
|
||||
subprocess.run(["code", "-a", folder], check=True)
|
||||
cls.ret = "Successfully added folder"
|
||||
except subprocess.CalledProcessError as e:
|
||||
cls.ret = f"Error adding folder: {e}"
|
||||
except Exception as e:
|
||||
cls.ret = f"Unexpected error: {e}"
|
||||
|
||||
return cls.ret
|
||||
|
||||
@classmethod
|
||||
def goto_file(cls, file_path, line=1, character=1):
|
||||
"""
|
||||
Opens a file at a specific line and character position.
|
||||
在特定行和字符的位置打开文件。
|
||||
|
||||
Args:
|
||||
file_path (str): 文件路径。
|
||||
line (int): 行号。
|
||||
character (int): 字符位置。
|
||||
"""
|
||||
try:
|
||||
command = f"{file_path}:{line}:{character}"
|
||||
subprocess.run(["code", "-g", command], check=True)
|
||||
cls.ret = "Successfully opened file, line: {}, character: {}".format(line, character)
|
||||
except subprocess.CalledProcessError as e:
|
||||
cls.ret = f"Error going to file: {e}"
|
||||
except Exception as e:
|
||||
cls.ret = f"Unexpected error: {e}"
|
||||
|
||||
return cls.ret
|
||||
|
||||
@classmethod
|
||||
def perform_merge(cls, path1, path2, base, result):
|
||||
"""
|
||||
Perform a three-way merge.
|
||||
执行三方合并。
|
||||
|
||||
Args:
|
||||
path1 (str): 第一版本文件路径。
|
||||
path2 (str): 第二版本文件路径。
|
||||
base (str): 基础版本文件路径。
|
||||
result (str): 结果文件的保存路径。
|
||||
"""
|
||||
try:
|
||||
subprocess.run(["code", "-m", path1, path2, base, result], check=True)
|
||||
cls.ret = "Successfully performed merge"
|
||||
except subprocess.CalledProcessError as e:
|
||||
cls.ret = f"Error performing merge: {e}"
|
||||
except Exception as e:
|
||||
cls.ret = f"Unexpected error: {e}"
|
||||
|
||||
return cls.ret
|
||||
|
||||
@classmethod
|
||||
def remove_folder(cls, folder):
|
||||
"""
|
||||
Removes a folder from the last active window in VSCode.
|
||||
在VSCode的最后一个活动窗口中移除文件夹。
|
||||
|
||||
Args:
|
||||
folder (str): 文件夹路径。
|
||||
"""
|
||||
try:
|
||||
subprocess.run(["code", "--remove", folder], check=True)
|
||||
cls.ret = "Successfully removed folder"
|
||||
except subprocess.CalledProcessError as e:
|
||||
cls.ret = f"Error removing folder: {e}"
|
||||
except Exception as e:
|
||||
cls.ret = f"Unexpected error: {e}"
|
||||
|
||||
return cls.ret
|
||||
|
||||
@classmethod
|
||||
def install_extension(cls, extension_id, pre_release=False):
|
||||
"""
|
||||
Installs an extension or updates it in VSCode.
|
||||
安装或更新VSCode中的扩展。
|
||||
|
||||
Args:
|
||||
extension_id (str): 扩展的标识符。
|
||||
pre_release (bool): 是否安装预发布版本。
|
||||
"""
|
||||
try:
|
||||
command = ["code", "--install-extension", extension_id]
|
||||
if pre_release:
|
||||
command.append("--pre-release")
|
||||
subprocess.run(command, check=True)
|
||||
cls.ret = "Successfully installed extension"
|
||||
except subprocess.CalledProcessError as e:
|
||||
cls.ret = f"Error installing extension: {e}"
|
||||
except Exception as e:
|
||||
cls.ret = f"Unexpected error: {e}"
|
||||
|
||||
return cls.ret
|
||||
|
||||
@classmethod
|
||||
def uninstall_extension(cls, extension_id):
|
||||
"""
|
||||
Uninstalls an extension from VSCode.
|
||||
从VSCode中卸载扩展。
|
||||
|
||||
Args:
|
||||
extension_id (str): 扩展的标识符。
|
||||
"""
|
||||
try:
|
||||
subprocess.run(["code", "--uninstall-extension", extension_id], check=True)
|
||||
cls.ret = "Successfully uninstalled extension"
|
||||
except subprocess.CalledProcessError as e:
|
||||
cls.ret = f"Error uninstalling extension: {e}"
|
||||
except Exception as e:
|
||||
cls.ret = f"Unexpected error: {e}"
|
||||
|
||||
return cls.ret
|
||||
|
||||
@classmethod
|
||||
def list_extensions(cls, show_versions=False, category=None):
|
||||
"""
|
||||
Lists installed extensions in VSCode.
|
||||
列出VSCode中安装的扩展。
|
||||
|
||||
Args:
|
||||
show_versions (bool): 是否显示扩展的版本。
|
||||
category (str): 按类别筛选扩展。
|
||||
"""
|
||||
try:
|
||||
command = ["code", "--list-extensions"]
|
||||
if show_versions:
|
||||
command.append("--show-versions")
|
||||
if category:
|
||||
command.extend(["--category", category])
|
||||
cls.ret = subprocess.run(command, check=True, capture_output=True, text=True).stdout
|
||||
except subprocess.CalledProcessError as e:
|
||||
cls.ret = f"Error listing extensions: {e}"
|
||||
except Exception as e:
|
||||
cls.ret = f"Unexpected error: {e}"
|
||||
|
||||
return cls.ret
|
||||
|
||||
@classmethod
|
||||
def update_extensions(cls):
|
||||
"""
|
||||
Updates all installed extensions in VSCode to the latest version.
|
||||
更新VSCode中所有安装的扩展到最新版本。
|
||||
"""
|
||||
try:
|
||||
subprocess.run(["code", "--update-extensions"], check=True)
|
||||
cls.ret = "Successfully updated extensions"
|
||||
except subprocess.CalledProcessError as e:
|
||||
cls.ret = f"Error updating extensions: {e}"
|
||||
except Exception as e:
|
||||
cls.ret = f"Unexpected error: {e}"
|
||||
|
||||
return cls.ret
|
||||
|
||||
@classmethod
|
||||
def disable_extension(cls, extension_id):
|
||||
"""
|
||||
Disables a specific extension for the next instance of VSCode.
|
||||
禁用在下一个VSCode窗口中的指定扩展。
|
||||
|
||||
Args:
|
||||
extension_id (str): 扩展的标识符。
|
||||
"""
|
||||
try:
|
||||
subprocess.run(["code", "--disable-extension", extension_id], check=True)
|
||||
cls.ret = "Successfully disabled extension"
|
||||
except subprocess.CalledProcessError as e:
|
||||
cls.ret = f"Error disabling extension: {e}"
|
||||
except Exception as e:
|
||||
cls.ret = f"Unexpected error: {e}"
|
||||
|
||||
return cls.ret
|
||||
|
||||
@classmethod
|
||||
def toggle_sync(cls, state):
|
||||
"""
|
||||
Toggles synchronization on or off in VSCode.
|
||||
在VSCode中开启或关闭同步。
|
||||
|
||||
Args:
|
||||
state (str): 'on' 或 'off' 表示开启或关闭。
|
||||
"""
|
||||
try:
|
||||
command = ["code", "--sync", state]
|
||||
subprocess.run(command, check=True)
|
||||
cls.ret = "Successfully toggled sync"
|
||||
except subprocess.CalledProcessError as e:
|
||||
cls.ret = f"Error toggling sync: {e}"
|
||||
except Exception as e:
|
||||
cls.ret = f"Unexpected error: {e}"
|
||||
|
||||
return cls.ret
|
||||
|
|
@ -1,107 +0,0 @@
|
|||
class BrowserTools:
|
||||
ret = ""
|
||||
|
||||
@classmethod
|
||||
def print_result(cls):
|
||||
print(cls.ret)
|
||||
|
||||
@classmethod
|
||||
def env_info(cls):
|
||||
cls.ret = "None"
|
||||
|
||||
# @classmethod
|
||||
# def show_all_tabs(cls):
|
||||
# cls.ret = "Browser not found"
|
||||
# for attempt in range(3):
|
||||
# with sync_playwright() as p:
|
||||
# try:
|
||||
# browser = p.chromium.connect_over_cdp(cls.remote_debugging_url)
|
||||
# if not browser:
|
||||
# continue
|
||||
# context = browser.contexts[0]
|
||||
# # 获取所有窗口名称
|
||||
# cls.ret = 'Browser Tabs: '
|
||||
# for idx, page in enumerate(context.pages):
|
||||
# cls.ret += f"{idx}. {page.title()} ({page.url})" + '\n'
|
||||
# return cls.ret
|
||||
# except TimeoutError:
|
||||
# cls.ret = 'Failed to get browser tabs'
|
||||
# return None
|
||||
# return None
|
||||
|
||||
@classmethod
|
||||
def open_profile_settings(cls):
|
||||
"""
|
||||
Open the profile settings page in the browser.
|
||||
"""
|
||||
return {"action_type": "OPEN_CHROME_TAB", "parameters": {"urls_to_open": ["chrome://settings/people"]}}
|
||||
|
||||
@classmethod
|
||||
def open_password_settings(cls):
|
||||
"""
|
||||
Open the password settings page in the browser.
|
||||
"""
|
||||
return {"action_type": "OPEN_CHROME_TAB", "parameters": {"urls_to_open": ["chrome://settings/autofill"]}}
|
||||
|
||||
@classmethod
|
||||
def open_privacy_settings(cls):
|
||||
"""
|
||||
Open the privacy settings page in the browser.
|
||||
"""
|
||||
return {"action_type": "OPEN_CHROME_TAB", "parameters": {"urls_to_open": ["chrome://settings/privacy"]}}
|
||||
|
||||
@classmethod
|
||||
def open_appearance_settings(cls):
|
||||
"""
|
||||
Open the appearance settings page in the browser.
|
||||
"""
|
||||
return {"action_type": "OPEN_CHROME_TAB", "parameters": {"urls_to_open": ["chrome://settings/appearance"]}}
|
||||
|
||||
@classmethod
|
||||
def open_search_engine_settings(cls):
|
||||
"""
|
||||
Open the search engine settings page in the browser.
|
||||
"""
|
||||
return {"action_type": "OPEN_CHROME_TAB", "parameters": {"urls_to_open": ["chrome://settings/search"]}}
|
||||
|
||||
@classmethod
|
||||
def bring_back_last_tab(cls):
|
||||
"""
|
||||
Bring back the last tab in the browser.
|
||||
"""
|
||||
return f"import pyautogui; pyautogui.hotkey('ctrl', 'shift', 't'); print('Brought back last tab')"
|
||||
|
||||
@classmethod
|
||||
def print(cls):
|
||||
"""
|
||||
Open the print option in current page.
|
||||
"""
|
||||
return f"import pyautogui; pyautogui.hotkey('ctrl', 'p'); print('Opened print option')"
|
||||
|
||||
@classmethod
|
||||
def delete_browsing_data(cls):
|
||||
"""
|
||||
Delete browsing data in the browser.
|
||||
"""
|
||||
return f"import pyautogui; pyautogui.hotkey('ctrl', 'shift', 'del'); print('Deleted browsing data')"
|
||||
|
||||
@classmethod
|
||||
def open_extensions(cls):
|
||||
"""
|
||||
open the extensions page in the browser.
|
||||
"""
|
||||
return {"action_type": "OPEN_CHROME_TAB", "parameters": {"urls_to_open": ["chrome://extensions"]}}
|
||||
|
||||
@classmethod
|
||||
def bookmark_page(cls):
|
||||
"""
|
||||
Bookmark the current page in the browser.
|
||||
"""
|
||||
return f"import pyautogui; pyautogui.hotkey('ctrl', 'd'); print('Bookmarked page')"
|
||||
|
||||
@classmethod
|
||||
def open_bookmarks(cls):
|
||||
"""
|
||||
Open the bookmarks page in the browser.
|
||||
"""
|
||||
return {"action_type": "OPEN_CHROME_TAB", "parameters": {"urls_to_open": ["chrome://bookmarks"]}}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
|
@ -1,753 +0,0 @@
|
|||
import os
|
||||
import re
|
||||
|
||||
import uno
|
||||
from com.sun.star.awt.FontSlant import ITALIC, NONE, OBLIQUE
|
||||
from com.sun.star.awt.FontWeight import BOLD, NORMAL
|
||||
from com.sun.star.beans import PropertyValue
|
||||
from com.sun.star.style.ParagraphAdjust import CENTER, LEFT, RIGHT
|
||||
from com.sun.star.text.ControlCharacter import PARAGRAPH_BREAK
|
||||
from com.sun.star.text.TextContentAnchorType import AS_CHARACTER
|
||||
|
||||
|
||||
class WriterTools:
|
||||
localContext = uno.getComponentContext()
|
||||
resolver = localContext.ServiceManager.createInstanceWithContext("com.sun.star.bridge.UnoUrlResolver", localContext)
|
||||
ctx = resolver.resolve("uno:socket,host=localhost,port=2002;urp;StarOffice.ComponentContext")
|
||||
desktop = ctx.ServiceManager.createInstanceWithContext("com.sun.star.frame.Desktop", ctx)
|
||||
doc = desktop.getCurrentComponent()
|
||||
text = doc.Text
|
||||
cursor = text.createTextCursor()
|
||||
ret = ""
|
||||
|
||||
@classmethod
|
||||
def close_other_window(cls):
|
||||
"""关闭除当前文档外的所有文档"""
|
||||
components = cls.desktop.getComponents().createEnumeration()
|
||||
current_url = cls.doc.getURL()
|
||||
while components.hasMoreElements():
|
||||
doc = components.nextElement()
|
||||
if doc.getURL() != current_url:
|
||||
doc.close(True)
|
||||
|
||||
@classmethod
|
||||
def save(cls):
|
||||
"""保存文档到当前位置"""
|
||||
try:
|
||||
if cls.doc.hasLocation():
|
||||
cls.doc.store()
|
||||
else:
|
||||
raise Exception("文档没有保存位置,请使用另存为功能")
|
||||
return True
|
||||
except Exception as e:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def maximize_window(cls):
|
||||
"""
|
||||
将窗口设置为工作区最大尺寸
|
||||
使用工作区域大小(考虑任务栏等)
|
||||
"""
|
||||
window = cls.doc.getCurrentController().getFrame().getContainerWindow()
|
||||
toolkit = window.getToolkit()
|
||||
device = toolkit.createScreenCompatibleDevice(0, 0)
|
||||
workarea = toolkit.getWorkArea()
|
||||
window.setPosSize(workarea.X, workarea.Y, workarea.Width, workarea.Height, 15)
|
||||
|
||||
@classmethod
|
||||
def print_result(cls):
|
||||
print(cls.ret)
|
||||
|
||||
@classmethod
|
||||
def write_text(cls, text, bold=False, italic=False, size=None):
|
||||
"""写入文本"""
|
||||
cls.cursor.CharWeight = 150 if bold else 100
|
||||
cls.cursor.CharPosture = ITALIC if italic else NONE
|
||||
if size:
|
||||
cls.cursor.CharHeight = size
|
||||
cls.text.insertString(cls.cursor, text, False)
|
||||
cls.ret = "Success"
|
||||
|
||||
@classmethod
|
||||
def get_paragraphs(cls, start_index=0, count=None):
|
||||
"""Retrieves paragraphs from the document as a list."""
|
||||
text = cls.doc.getText()
|
||||
paragraphs = text.createEnumeration()
|
||||
paragraph_list = []
|
||||
while paragraphs.hasMoreElements():
|
||||
paragraph = paragraphs.nextElement()
|
||||
if paragraph.supportsService("com.sun.star.text.Paragraph"):
|
||||
paragraph_list.append(paragraph.getString())
|
||||
if start_index < 0:
|
||||
start_index = 0
|
||||
elif start_index >= len(paragraph_list):
|
||||
cls.ret = []
|
||||
if count is not None:
|
||||
end_index = min(start_index + count, len(paragraph_list))
|
||||
cls.ret = paragraph_list[start_index:end_index]
|
||||
else:
|
||||
cls.ret = paragraph_list[start_index:]
|
||||
return cls.ret
|
||||
|
||||
@classmethod
|
||||
def env_info(cls):
|
||||
paras = cls.get_paragraphs()
|
||||
para_str = ""
|
||||
for i, para in enumerate(paras):
|
||||
para = para[:500] + "..." if len(para) > 500 else para
|
||||
para_str += "Paragraph " + str(i) + ": " + para.strip() + "\n"
|
||||
cls.ret = para_str
|
||||
return cls.ret
|
||||
|
||||
@classmethod
|
||||
def set_color(cls, pattern, color, paragraph_indices=None):
|
||||
"""
|
||||
Changes the color of matched text in the document for specified paragraphs.
|
||||
|
||||
Args:
|
||||
pattern (str): Regular expression pattern to match text
|
||||
color (int): Hex color code (e.g., 0x000000 for black)
|
||||
paragraph_indices (list, optional): List of paragraph indices to modify (0-based).
|
||||
If None, applies to all paragraphs.
|
||||
"""
|
||||
try:
|
||||
enum = cls.doc.Text.createEnumeration()
|
||||
paragraphs = []
|
||||
while enum.hasMoreElements():
|
||||
paragraphs.append(enum.nextElement())
|
||||
if not paragraph_indices:
|
||||
paragraphs_to_process = range(len(paragraphs))
|
||||
else:
|
||||
paragraphs_to_process = paragraph_indices
|
||||
regex = re.compile(pattern)
|
||||
for idx in paragraphs_to_process:
|
||||
if idx < 0 or idx >= len(paragraphs):
|
||||
continue
|
||||
paragraph = paragraphs[idx]
|
||||
if not paragraph.supportsService("com.sun.star.text.Paragraph"):
|
||||
continue
|
||||
para_text = paragraph.getString()
|
||||
matches = regex.finditer(para_text)
|
||||
for match in matches:
|
||||
para_cursor = cls.text.createTextCursorByRange(paragraph.getStart())
|
||||
para_cursor.goRight(match.start(), False)
|
||||
para_cursor.goRight(match.end() - match.start(), True)
|
||||
para_cursor.CharColor = color
|
||||
cls.ret = "Success"
|
||||
return True
|
||||
except Exception as e:
|
||||
cls.ret = f"Error: {str(e)}"
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def find_and_replace(cls, pattern, replacement, paragraph_indices=None):
|
||||
"""
|
||||
Finds all occurrences of a specified text pattern and replaces them with another text in the document.
|
||||
|
||||
Args:
|
||||
pattern (str): The pattern to match in the document, should be a regular expression
|
||||
replacement (str): The text to replace the found text with
|
||||
paragraph_indices (list, optional): Indices of paragraphs to modify (0-based indexing)
|
||||
|
||||
Returns:
|
||||
str: Success message with number of replacements made
|
||||
"""
|
||||
try:
|
||||
enum = cls.doc.Text.createEnumeration()
|
||||
paragraphs = []
|
||||
while enum.hasMoreElements():
|
||||
paragraphs.append(enum.nextElement())
|
||||
total_replacements = 0
|
||||
if not paragraph_indices:
|
||||
paragraphs_to_process = list(range(len(paragraphs)))
|
||||
else:
|
||||
paragraphs_to_process = [i for i in paragraph_indices if 0 <= i < len(paragraphs)]
|
||||
regex = re.compile(pattern)
|
||||
for idx in paragraphs_to_process:
|
||||
if idx >= len(paragraphs):
|
||||
continue
|
||||
paragraph = paragraphs[idx]
|
||||
if paragraph.supportsService("com.sun.star.text.Paragraph"):
|
||||
text_content = paragraph.getString()
|
||||
new_text, count = regex.subn(replacement, text_content)
|
||||
if count > 0:
|
||||
paragraph.setString(new_text)
|
||||
total_replacements += count
|
||||
cls.ret = f"Successfully made {total_replacements} replacements"
|
||||
return cls.ret
|
||||
except Exception as e:
|
||||
cls.ret = f"Error during find and replace: {str(e)}"
|
||||
return cls.ret
|
||||
|
||||
@classmethod
|
||||
def set_font(cls, font_name, paragraph_indices=None):
|
||||
"""
|
||||
Changes the font of text in the document or specified paragraphs.
|
||||
|
||||
Args:
|
||||
font_name (str): The name of the font to apply (e.g., 'Times New Roman', 'Arial', 'Calibri')
|
||||
paragraph_indices (list, optional): Indices of paragraphs to modify (0-based indexing).
|
||||
If not provided, applies to all paragraphs.
|
||||
"""
|
||||
try:
|
||||
text = cls.doc.getText()
|
||||
enum = text.createEnumeration()
|
||||
paragraphs = []
|
||||
while enum.hasMoreElements():
|
||||
paragraphs.append(enum.nextElement())
|
||||
if not paragraph_indices:
|
||||
paragraph_indices = range(len(paragraphs))
|
||||
for idx in paragraph_indices:
|
||||
if 0 <= idx < len(paragraphs):
|
||||
paragraph = paragraphs[idx]
|
||||
cursor = text.createTextCursorByRange(paragraph)
|
||||
cursor.CharFontName = font_name
|
||||
cls.ret = "Success"
|
||||
return True
|
||||
except Exception as e:
|
||||
cls.ret = f"Error: {str(e)}"
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def set_line_spacing(cls, spacing_value, paragraph_indices=None):
|
||||
"""
|
||||
Sets the line spacing for specified paragraphs in the document.
|
||||
|
||||
Args:
|
||||
spacing_value (float): The line spacing value to apply (1.0 for single spacing, 2.0 for double spacing, etc.)
|
||||
paragraph_indices (list, optional): Indices of paragraphs to modify (0-based indexing).
|
||||
If not provided, applies to all paragraphs.
|
||||
"""
|
||||
try:
|
||||
text = cls.doc.getText()
|
||||
paragraph_enum = text.createEnumeration()
|
||||
line_spacing_value = int(spacing_value * 100)
|
||||
current_index = 0
|
||||
|
||||
while paragraph_enum.hasMoreElements():
|
||||
paragraph = paragraph_enum.nextElement()
|
||||
|
||||
if not paragraph_indices or current_index in paragraph_indices:
|
||||
line_spacing = uno.createUnoStruct("com.sun.star.style.LineSpacing")
|
||||
line_spacing.Mode = 0
|
||||
line_spacing.Height = line_spacing_value
|
||||
paragraph.ParaLineSpacing = line_spacing
|
||||
|
||||
if paragraph.String.strip():
|
||||
current_index += 1
|
||||
|
||||
cls.ret = "Success"
|
||||
return True
|
||||
except Exception as e:
|
||||
cls.ret = f"Error: {str(e)}"
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def remove_highlighting(cls, paragraph_indices=None):
|
||||
"""
|
||||
Removes ALL highlighting from text in the document for specified paragraphs.
|
||||
|
||||
Args:
|
||||
paragraph_indices (list, optional): Indices of paragraphs to modify (0-based indexing).
|
||||
If not provided, applies to all paragraphs.
|
||||
|
||||
Returns:
|
||||
str: Success message or error message
|
||||
"""
|
||||
try:
|
||||
text = cls.doc.getText()
|
||||
paragraphs = text.createEnumeration()
|
||||
target_indices = set(paragraph_indices) if paragraph_indices else None
|
||||
current_index = 0
|
||||
|
||||
while paragraphs.hasMoreElements():
|
||||
paragraph = paragraphs.nextElement()
|
||||
if target_indices is None or current_index in target_indices:
|
||||
if paragraph.supportsService("com.sun.star.text.Paragraph"):
|
||||
para_cursor = text.createTextCursorByRange(paragraph)
|
||||
# Remove all highlighting by setting back color to -1
|
||||
para_cursor.CharBackColor = -1
|
||||
|
||||
# Additional cleanup for individual text portions (optional)
|
||||
text_portions = paragraph.createEnumeration()
|
||||
while text_portions.hasMoreElements():
|
||||
text_portion = text_portions.nextElement()
|
||||
if hasattr(text_portion, "CharBackColor"):
|
||||
portion_cursor = text.createTextCursorByRange(text_portion)
|
||||
portion_cursor.CharBackColor = -1
|
||||
current_index += 1
|
||||
|
||||
cls.ret = "Successfully removed all highlighting"
|
||||
return cls.ret
|
||||
except Exception as e:
|
||||
cls.ret = f"Error removing highlighting: {str(e)}"
|
||||
return cls.ret
|
||||
|
||||
@classmethod
|
||||
def find_highlighted_text(cls, highlight_color):
|
||||
"""
|
||||
Finds all text in the document that has a specific highlight color applied to it.
|
||||
|
||||
Args:
|
||||
highlight_color (str): The highlight color to search for. Can be a color name (e.g., 'yellow', 'green') or hex code.
|
||||
|
||||
Returns:
|
||||
list: A list of strings containing all text segments with the specified highlight color.
|
||||
"""
|
||||
color_map = {
|
||||
"yellow": 16776960,
|
||||
"green": 65280,
|
||||
"blue": 255,
|
||||
"red": 16711680,
|
||||
"cyan": 65535,
|
||||
"magenta": 16711935,
|
||||
"black": 0,
|
||||
"white": 16777215,
|
||||
"gray": 8421504,
|
||||
"lightgray": 12632256,
|
||||
}
|
||||
target_color = None
|
||||
if highlight_color.lower() in color_map:
|
||||
target_color = color_map[highlight_color.lower()]
|
||||
elif highlight_color.startswith("#") and len(highlight_color) == 7:
|
||||
try:
|
||||
hex_color = highlight_color[1:]
|
||||
r = int(hex_color[0:2], 16)
|
||||
g = int(hex_color[2:4], 16)
|
||||
b = int(hex_color[4:6], 16)
|
||||
target_color = (r << 16) + (g << 8) + b
|
||||
except ValueError:
|
||||
cls.ret = f"Invalid hex color format: {highlight_color}"
|
||||
return []
|
||||
else:
|
||||
cls.ret = f"Unsupported color format: {highlight_color}"
|
||||
return []
|
||||
highlighted_text = []
|
||||
text = cls.doc.getText()
|
||||
enum_paragraphs = text.createEnumeration()
|
||||
while enum_paragraphs.hasMoreElements():
|
||||
paragraph = enum_paragraphs.nextElement()
|
||||
if paragraph.supportsService("com.sun.star.text.Paragraph"):
|
||||
enum_portions = paragraph.createEnumeration()
|
||||
while enum_portions.hasMoreElements():
|
||||
text_portion = enum_portions.nextElement()
|
||||
if hasattr(text_portion, "CharBackColor") and text_portion.CharBackColor == target_color:
|
||||
if text_portion.getString().strip():
|
||||
highlighted_text.append(text_portion.getString())
|
||||
cls.ret = f"Found {len(highlighted_text)} text segments with highlight color {highlight_color}"
|
||||
return highlighted_text
|
||||
|
||||
@classmethod
|
||||
def insert_formula_at_cursor(cls, formula):
|
||||
"""
|
||||
Inserts a formula at the current cursor position in the document.
|
||||
|
||||
Args:
|
||||
formula (str): The formula to insert at the current cursor position.
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
embedded_obj = cls.doc.createInstance("com.sun.star.text.TextEmbeddedObject")
|
||||
embedded_obj.setPropertyValue("CLSID", "078B7ABA-54FC-457F-8551-6147e776a997")
|
||||
embedded_obj.setPropertyValue("AnchorType", AS_CHARACTER)
|
||||
cls.text.insertTextContent(cls.cursor, embedded_obj, False)
|
||||
math_obj = embedded_obj.getEmbeddedObject()
|
||||
math_obj.Formula = formula
|
||||
cls.ret = "Formula inserted successfully"
|
||||
return True
|
||||
except Exception as e:
|
||||
cls.ret = f"Error inserting formula: {str(e)}"
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def insert_image_at_cursor(cls, image_path, width=None, height=None):
|
||||
"""
|
||||
Inserts an image at the current cursor position in the document.
|
||||
|
||||
Args:
|
||||
image_path (str): Full path to the image file to insert
|
||||
width (int, optional): Width to display the image in pixels
|
||||
height (int, optional): Height to display the image in pixels
|
||||
|
||||
Returns:
|
||||
str: Success message or error message
|
||||
"""
|
||||
try:
|
||||
if image_path.startswith("~"):
|
||||
image_path = os.path.expanduser(image_path)
|
||||
if not os.path.exists(image_path):
|
||||
cls.ret = f"Error: Image file not found at {image_path}"
|
||||
return cls.ret
|
||||
image_path = os.path.abspath(image_path)
|
||||
if os.name == "nt":
|
||||
file_url = "file:///" + image_path.replace("\\", "/")
|
||||
else:
|
||||
file_url = "file://" + image_path
|
||||
graphic = cls.doc.createInstance("com.sun.star.text.GraphicObject")
|
||||
graphic.GraphicURL = file_url
|
||||
graphic.AnchorType = AS_CHARACTER
|
||||
if width is not None:
|
||||
graphic.Width = width * 100
|
||||
if height is not None:
|
||||
graphic.Height = height * 100
|
||||
cls.text.insertTextContent(cls.cursor, graphic, False)
|
||||
cls.ret = "Success: Image inserted"
|
||||
return cls.ret
|
||||
except Exception as e:
|
||||
cls.ret = f"Error: {str(e)}"
|
||||
return cls.ret
|
||||
|
||||
@classmethod
|
||||
def set_strikethrough(cls, pattern, paragraph_indices=None):
|
||||
"""
|
||||
Sets the strikethrough formatting for text matching the specified pattern in the document.
|
||||
|
||||
Args:
|
||||
pattern (str): The regular expression pattern to match in the document
|
||||
paragraph_indices (list, optional): Indices of paragraphs to modify (0-based indexing).
|
||||
If not provided, applies to all paragraphs.
|
||||
|
||||
Returns:
|
||||
str: Success message or error information
|
||||
"""
|
||||
try:
|
||||
paragraphs = cls.doc.getText().createEnumeration()
|
||||
para_index = 0
|
||||
found_matches = 0
|
||||
while paragraphs.hasMoreElements():
|
||||
paragraph = paragraphs.nextElement()
|
||||
if paragraph.supportsService("com.sun.star.text.Paragraph"):
|
||||
if paragraph_indices and para_index not in paragraph_indices:
|
||||
para_index += 1
|
||||
continue
|
||||
para_text = paragraph.getString()
|
||||
matches = list(re.finditer(pattern, para_text))
|
||||
for match in matches:
|
||||
text_range = paragraph.getStart()
|
||||
cursor = cls.doc.getText().createTextCursorByRange(text_range)
|
||||
cursor.goRight(match.start(), False)
|
||||
cursor.goRight(match.end() - match.start(), True)
|
||||
cursor.CharStrikeout = 1
|
||||
found_matches += 1
|
||||
para_index += 1
|
||||
cls.ret = f"Successfully applied strikethrough to {found_matches} matches of pattern: {pattern}"
|
||||
return cls.ret
|
||||
except Exception as e:
|
||||
cls.ret = f"Error applying strikethrough: {str(e)}"
|
||||
return cls.ret
|
||||
|
||||
@classmethod
|
||||
def set_font_size(cls, font_size, pattern, paragraph_indices=None):
|
||||
"""
|
||||
Changes the font size of specified text in the document.
|
||||
|
||||
Args:
|
||||
font_size (float): The font size to apply (in points).
|
||||
pattern (str): The pattern to match in the document, should be a regular expression.
|
||||
paragraph_indices (list, optional): Indices of paragraphs to modify (0-based indexing).
|
||||
If not provided, applies to all paragraphs.
|
||||
|
||||
Returns:
|
||||
str: Result message indicating success or failure.
|
||||
"""
|
||||
try:
|
||||
regex = re.compile(pattern)
|
||||
paragraphs = cls.doc.getText().createEnumeration()
|
||||
current_index = 0
|
||||
while paragraphs.hasMoreElements():
|
||||
paragraph = paragraphs.nextElement()
|
||||
if paragraph_indices and current_index not in paragraph_indices:
|
||||
current_index += 1
|
||||
continue
|
||||
if paragraph.supportsService("com.sun.star.text.Paragraph"):
|
||||
para_cursor = cls.text.createTextCursorByRange(paragraph)
|
||||
para_text = paragraph.getString()
|
||||
matches = list(regex.finditer(para_text))
|
||||
for match in reversed(matches):
|
||||
start_pos = match.start()
|
||||
end_pos = match.end()
|
||||
para_cursor.gotoStart(False)
|
||||
para_cursor.goRight(start_pos, False)
|
||||
para_cursor.goRight(end_pos - start_pos, True)
|
||||
para_cursor.CharHeight = font_size
|
||||
current_index += 1
|
||||
cls.ret = f"Successfully changed font size to {font_size} for text matching '{pattern}'"
|
||||
return cls.ret
|
||||
except Exception as e:
|
||||
cls.ret = f"Error changing font size: {str(e)}"
|
||||
return cls.ret
|
||||
|
||||
@classmethod
|
||||
def export_to_pdf(cls, output_path=None, output_filename=None, include_comments=False, quality="standard"):
|
||||
"""
|
||||
Exports the current document to PDF format.
|
||||
|
||||
Args:
|
||||
output_path (str, optional): The full path where the PDF should be saved.
|
||||
If not provided, uses the same location as the original document.
|
||||
output_filename (str, optional): The filename to use for the PDF.
|
||||
If not provided, uses the original document's filename with .pdf extension.
|
||||
include_comments (bool, optional): Whether to include comments in the exported PDF.
|
||||
Defaults to False.
|
||||
quality (str, optional): The quality of the PDF export ('standard', 'high', 'print').
|
||||
Defaults to 'standard'.
|
||||
|
||||
Returns:
|
||||
str: Path to the exported PDF file or error message
|
||||
"""
|
||||
try:
|
||||
doc_url = cls.doc.getURL()
|
||||
if not doc_url and not output_path:
|
||||
return "Error: Document has not been saved and no output path provided"
|
||||
if doc_url:
|
||||
doc_path = uno.fileUrlToSystemPath(os.path.dirname(doc_url))
|
||||
doc_filename = os.path.basename(doc_url)
|
||||
doc_name = os.path.splitext(doc_filename)[0]
|
||||
else:
|
||||
doc_path = ""
|
||||
doc_name = "export"
|
||||
final_path = output_path if output_path else doc_path
|
||||
final_filename = output_filename if output_filename else f"{doc_name}.pdf"
|
||||
if not final_filename.lower().endswith(".pdf"):
|
||||
final_filename += ".pdf"
|
||||
full_output_path = os.path.join(final_path, final_filename)
|
||||
output_url = uno.systemPathToFileUrl(full_output_path)
|
||||
export_props = []
|
||||
if quality == "high":
|
||||
export_props.append(PropertyValue(Name="SelectPdfVersion", Value=1))
|
||||
elif quality == "print":
|
||||
export_props.append(PropertyValue(Name="SelectPdfVersion", Value=2))
|
||||
else:
|
||||
export_props.append(PropertyValue(Name="SelectPdfVersion", Value=0))
|
||||
export_props.append(PropertyValue(Name="ExportNotes", Value=include_comments))
|
||||
export_props.extend(
|
||||
[
|
||||
PropertyValue(Name="FilterName", Value="writer_pdf_Export"),
|
||||
PropertyValue(Name="Overwrite", Value=True),
|
||||
]
|
||||
)
|
||||
cls.doc.storeToURL(output_url, tuple(export_props))
|
||||
cls.ret = f"PDF exported to: {full_output_path}"
|
||||
return full_output_path
|
||||
except Exception as e:
|
||||
cls.ret = f"Error exporting to PDF: {str(e)}"
|
||||
return cls.ret
|
||||
|
||||
@classmethod
|
||||
def set_paragraph_alignment(cls, alignment, paragraph_indices=None):
|
||||
"""
|
||||
Sets the text alignment for specified paragraphs in the document.
|
||||
|
||||
Args:
|
||||
alignment (str): The alignment to apply ('left', 'center', 'right', 'justify').
|
||||
paragraph_indices (list, optional): Indices of paragraphs to modify (0-based indexing).
|
||||
If not provided, applies to all paragraphs.
|
||||
|
||||
Returns:
|
||||
str: Success message or error message
|
||||
"""
|
||||
try:
|
||||
alignment_map = {"left": LEFT, "center": CENTER, "right": RIGHT, "justify": 3}
|
||||
if alignment.lower() not in alignment_map:
|
||||
cls.ret = f"Error: Invalid alignment '{alignment}'. Use 'left', 'center', 'right', or 'justify'."
|
||||
return cls.ret
|
||||
alignment_value = alignment_map[alignment.lower()]
|
||||
text = cls.doc.getText()
|
||||
paragraph_enum = text.createEnumeration()
|
||||
paragraphs = []
|
||||
while paragraph_enum.hasMoreElements():
|
||||
paragraph = paragraph_enum.nextElement()
|
||||
if paragraph.supportsService("com.sun.star.text.Paragraph"):
|
||||
paragraphs.append(paragraph)
|
||||
if paragraph_indices:
|
||||
valid_indices = [i for i in paragraph_indices if 0 <= i < len(paragraphs)]
|
||||
if len(valid_indices) != len(paragraph_indices):
|
||||
cls.ret = f"Warning: Some paragraph indices were out of range (0-{len(paragraphs) - 1})"
|
||||
for idx in valid_indices:
|
||||
paragraphs[idx].ParaAdjust = alignment_value
|
||||
else:
|
||||
for paragraph in paragraphs:
|
||||
paragraph.ParaAdjust = alignment_value
|
||||
cls.ret = f"Successfully applied '{alignment}' alignment to paragraphs"
|
||||
return cls.ret
|
||||
except Exception as e:
|
||||
cls.ret = f"Error setting paragraph alignment: {str(e)}"
|
||||
return cls.ret
|
||||
|
||||
@classmethod
|
||||
def capitalize_words(cls, paragraph_indices=None):
|
||||
"""
|
||||
Capitalizes the first letter of each word for specified paragraphs in the document.
|
||||
|
||||
Args:
|
||||
paragraph_indices (list, optional): Indices of paragraphs to modify (0-based indexing).
|
||||
If not provided, applies to all paragraphs.
|
||||
|
||||
Returns:
|
||||
str: Success message or error message
|
||||
"""
|
||||
try:
|
||||
text = cls.doc.getText()
|
||||
enum = text.createEnumeration()
|
||||
paragraphs = []
|
||||
while enum.hasMoreElements():
|
||||
paragraph = enum.nextElement()
|
||||
if paragraph.supportsService("com.sun.star.text.Paragraph"):
|
||||
paragraphs.append(paragraph)
|
||||
if not paragraph_indices:
|
||||
target_paragraphs = list(range(len(paragraphs)))
|
||||
else:
|
||||
target_paragraphs = paragraph_indices
|
||||
valid_indices = [idx for idx in target_paragraphs if 0 <= idx < len(paragraphs)]
|
||||
for idx in valid_indices:
|
||||
paragraph = paragraphs[idx]
|
||||
text_content = paragraph.getString()
|
||||
if not text_content.strip():
|
||||
continue
|
||||
capitalized_text = " ".join(word.capitalize() if word else "" for word in text_content.split(" "))
|
||||
para_cursor = text.createTextCursorByRange(paragraph.getStart())
|
||||
para_cursor.gotoRange(paragraph.getEnd(), True)
|
||||
para_cursor.setString(capitalized_text)
|
||||
cls.ret = f"Successfully capitalized words in {len(valid_indices)} paragraphs"
|
||||
return cls.ret
|
||||
except Exception as e:
|
||||
cls.ret = f"Error capitalizing words: {str(e)}"
|
||||
return cls.ret
|
||||
|
||||
@classmethod
|
||||
def set_default_font(cls, font_name, font_size=None):
|
||||
"""
|
||||
Sets the default font for new text in the document without changing existing text.
|
||||
|
||||
Args:
|
||||
font_name (str): The name of the font to set as default (e.g., 'Times New Roman', 'Arial', 'Calibri')
|
||||
font_size (float, optional): The default font size in points.
|
||||
|
||||
Returns:
|
||||
str: Success message or error message
|
||||
"""
|
||||
try:
|
||||
style_families = cls.doc.getStyleFamilies()
|
||||
paragraph_styles = style_families.getByName("ParagraphStyles")
|
||||
default_style_names = ["Default", "Standard", "Normal"]
|
||||
standard_style = None
|
||||
for style_name in default_style_names:
|
||||
if paragraph_styles.hasByName(style_name):
|
||||
standard_style = paragraph_styles.getByName(style_name)
|
||||
break
|
||||
if standard_style is None:
|
||||
style_names = paragraph_styles.getElementNames()
|
||||
if style_names:
|
||||
standard_style = paragraph_styles.getByName(style_names[0])
|
||||
else:
|
||||
raise Exception("Could not find default paragraph style")
|
||||
standard_style.setPropertyValue("CharFontName", font_name)
|
||||
standard_style.setPropertyValue("CharFontNameAsian", font_name)
|
||||
standard_style.setPropertyValue("CharFontNameComplex", font_name)
|
||||
if font_size is not None:
|
||||
standard_style.setPropertyValue("CharHeight", float(font_size))
|
||||
standard_style.setPropertyValue("CharHeightAsian", float(font_size))
|
||||
standard_style.setPropertyValue("CharHeightComplex", float(font_size))
|
||||
cls.cursor.setPropertyValue("CharFontName", font_name)
|
||||
cls.cursor.setPropertyValue("CharFontNameAsian", font_name)
|
||||
cls.cursor.setPropertyValue("CharFontNameComplex", font_name)
|
||||
if font_size is not None:
|
||||
cls.cursor.setPropertyValue("CharHeight", float(font_size))
|
||||
cls.cursor.setPropertyValue("CharHeightAsian", float(font_size))
|
||||
cls.cursor.setPropertyValue("CharHeightComplex", float(font_size))
|
||||
cls.ret = f"Default font set to '{font_name}'" + (f" with size {font_size}pt" if font_size else "")
|
||||
return cls.ret
|
||||
except Exception as e:
|
||||
cls.ret = f"Error setting default font: {str(e)}"
|
||||
return cls.ret
|
||||
|
||||
@classmethod
|
||||
def add_page_numbers(cls, position, start_number=1, format=None):
|
||||
"""
|
||||
Adds page numbers to the document at the specified position.
|
||||
|
||||
Args:
|
||||
position (str): Position of the page numbers ('bottom_left', 'bottom_center', 'bottom_right',
|
||||
'top_left', 'top_center', 'top_right')
|
||||
start_number (int, optional): The starting page number. Defaults to 1.
|
||||
format (str, optional): Format of the page numbers (e.g., '1', 'Page 1', '1 of N').
|
||||
Defaults to simple number format.
|
||||
|
||||
Returns:
|
||||
str: Success message or error message
|
||||
"""
|
||||
try:
|
||||
page_styles = cls.doc.StyleFamilies.getByName("PageStyles")
|
||||
default_style = page_styles.getByName("Standard")
|
||||
try:
|
||||
default_style.setPropertyValue("PageNumberOffset", start_number)
|
||||
except:
|
||||
pass
|
||||
if position.startswith("top"):
|
||||
default_style.HeaderIsOn = True
|
||||
target = default_style.HeaderText
|
||||
else:
|
||||
default_style.FooterIsOn = True
|
||||
target = default_style.FooterText
|
||||
cursor = target.createTextCursor()
|
||||
cursor.gotoStart(False)
|
||||
cursor.gotoEnd(True)
|
||||
cursor.setString("")
|
||||
cursor.gotoStart(False)
|
||||
if position.endswith("_left"):
|
||||
cursor.ParaAdjust = LEFT
|
||||
elif position.endswith("_center"):
|
||||
cursor.ParaAdjust = CENTER
|
||||
elif position.endswith("_right"):
|
||||
cursor.ParaAdjust = RIGHT
|
||||
if not format or format == "1":
|
||||
page_number = cls.doc.createInstance("com.sun.star.text.TextField.PageNumber")
|
||||
page_number.NumberingType = 4
|
||||
target.insertTextContent(cursor, page_number, False)
|
||||
elif format == "Page 1" or "Page" in format and "of" not in format:
|
||||
target.insertString(cursor, "Page ", False)
|
||||
page_number = cls.doc.createInstance("com.sun.star.text.TextField.PageNumber")
|
||||
page_number.NumberingType = 4
|
||||
target.insertTextContent(cursor, page_number, False)
|
||||
elif format == "1 of N" or format == "Page {page} of {total}" or "of" in format:
|
||||
if "Page" in format:
|
||||
target.insertString(cursor, "Page ", False)
|
||||
page_number = cls.doc.createInstance("com.sun.star.text.TextField.PageNumber")
|
||||
page_number.NumberingType = 4
|
||||
target.insertTextContent(cursor, page_number, False)
|
||||
target.insertString(cursor, " of ", False)
|
||||
page_count = cls.doc.createInstance("com.sun.star.text.TextField.PageCount")
|
||||
page_count.NumberingType = 4
|
||||
target.insertTextContent(cursor, page_count, False)
|
||||
else:
|
||||
page_number = cls.doc.createInstance("com.sun.star.text.TextField.PageNumber")
|
||||
page_number.NumberingType = 4
|
||||
target.insertTextContent(cursor, page_number, False)
|
||||
cls.ret = "Successfully added page numbers"
|
||||
return cls.ret
|
||||
except Exception as e:
|
||||
cls.ret = f"Error adding page numbers: {str(e)}"
|
||||
return cls.ret
|
||||
|
||||
@classmethod
|
||||
def insert_page_break(cls, position="at_cursor"):
|
||||
"""
|
||||
Inserts a page break at the specified position.
|
||||
|
||||
Args:
|
||||
position (str): Where to insert the page break: 'at_cursor' for current cursor position,
|
||||
'end_of_document' for end of document. Defaults to 'at_cursor'.
|
||||
"""
|
||||
try:
|
||||
if position == "end_of_document":
|
||||
cls.cursor.gotoEnd(False)
|
||||
cls.text.insertControlCharacter(cls.cursor, PARAGRAPH_BREAK, False)
|
||||
cls.cursor.gotoStartOfParagraph(True)
|
||||
cls.cursor.BreakType = uno.Enum("com.sun.star.style.BreakType", "PAGE_BEFORE")
|
||||
cls.ret = "Page break inserted successfully"
|
||||
return True
|
||||
except Exception as e:
|
||||
cls.ret = f"Error inserting page break: {str(e)}"
|
||||
return False
|
||||
|
|
@ -1,233 +0,0 @@
|
|||
import json
|
||||
import os
|
||||
import re
|
||||
import xml.etree.ElementTree as ET
|
||||
from pathlib import Path
|
||||
from urllib.parse import quote
|
||||
|
||||
import requests
|
||||
from requests.auth import HTTPBasicAuth
|
||||
|
||||
|
||||
class VLCTools:
|
||||
host = "localhost"
|
||||
port = 8080
|
||||
base_url = f"http://{host}:{port}/requests"
|
||||
password = "password"
|
||||
auth = HTTPBasicAuth("", password)
|
||||
ret = ""
|
||||
|
||||
@classmethod
|
||||
def print_result(cls):
|
||||
print(cls.ret)
|
||||
|
||||
@classmethod
|
||||
def _make_request(cls, endpoint, params=None):
|
||||
url = f"{cls.base_url}/{endpoint}"
|
||||
try:
|
||||
response = requests.get(url, params=params, auth=cls.auth)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
except requests.exceptions.RequestException as e:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _get_status(cls):
|
||||
response = cls._make_request("status.xml")
|
||||
if response:
|
||||
return ET.fromstring(response.content)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def env_info(cls):
|
||||
cls.ret = "None"
|
||||
|
||||
@classmethod
|
||||
def get_playlist(cls):
|
||||
response = cls._make_request("playlist.xml")
|
||||
if response:
|
||||
info = ET.fromstring(response.content)
|
||||
playlist_node = info.find('.//node[@name="Playlist"]')
|
||||
if playlist_node is not None:
|
||||
playlist_items = []
|
||||
for leaf in playlist_node.findall("leaf"):
|
||||
item = {"name": leaf.get("name"), "uri": leaf.get("uri"), "duration": leaf.get("duration") + "s"}
|
||||
playlist_items.append(item)
|
||||
cls.ret = f"Playlist: {playlist_items}"
|
||||
return cls.ret
|
||||
cls.ret = "Error getting playlist"
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def play(cls):
|
||||
response = cls._make_request("status.xml", {"command": "pl_play"})
|
||||
if response:
|
||||
cls.ret = "Start playing the media"
|
||||
return cls.ret
|
||||
cls.ret = "Error playing the media"
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def pause(cls):
|
||||
response = cls._make_request("status.xml", {"command": "pl_pause"})
|
||||
if response:
|
||||
cls.ret = "Pause the media"
|
||||
return cls.ret
|
||||
cls.ret = "Error pausing the media"
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def next(cls):
|
||||
response = cls._make_request("status.xml", {"command": "pl_next"})
|
||||
if response:
|
||||
cls.ret = "Switch to next media"
|
||||
return cls.ret
|
||||
cls.ret = "Error switching to next media"
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def previous(cls):
|
||||
response = cls._make_request("status.xml", {"command": "pl_previous"})
|
||||
if response:
|
||||
cls.ret = "Switch to previous media"
|
||||
return cls.ret
|
||||
cls.ret = "Error switching to previous media"
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def add_to_playlist(cls, uri):
|
||||
if uri.startswith("http"):
|
||||
encoded_uri = uri
|
||||
else:
|
||||
encoded_uri = "file://" + quote(uri.replace("file://", ""))
|
||||
|
||||
response = cls._make_request("status.xml", {"command": "in_play", "input": encoded_uri})
|
||||
if response:
|
||||
cls.ret = f"Add {uri} to playlist"
|
||||
return cls.ret
|
||||
cls.ret = f"Error adding {uri} to playlist"
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_current_time(cls):
|
||||
status = cls._get_status()
|
||||
if status is not None:
|
||||
time = status.find("time")
|
||||
cls.ret = int(time.text) if time is not None else None
|
||||
return cls.ret
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_media_duration(cls):
|
||||
status = cls._get_status()
|
||||
if status is not None:
|
||||
length = status.find("length")
|
||||
if length is not None:
|
||||
cls.ret = f"Media duration: {length.text} seconds"
|
||||
return cls.ret
|
||||
cls.ret = "Error getting media duration"
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_settings(cls):
|
||||
settings = {}
|
||||
with open(Path.home() / ".config/vlc/vlcrc", "r") as f:
|
||||
for line in f:
|
||||
if line:
|
||||
try:
|
||||
key, value = line.split("=")
|
||||
if key.strip().startswith("#"):
|
||||
continue
|
||||
settings[key.strip()] = value.strip()
|
||||
except:
|
||||
continue
|
||||
cls.ret = json.dumps(settings, indent=4, ensure_ascii=False)
|
||||
return cls.ret
|
||||
|
||||
@classmethod
|
||||
def set_settings(cls, field, value):
|
||||
with open(Path.home() / ".config/vlc/vlcrc", "r") as rf:
|
||||
settings = rf.read()
|
||||
|
||||
# 正则表达式匹配settings中的field项并替换
|
||||
pattern = re.compile(r"#? *" + re.escape(field) + r"=.*")
|
||||
# 判断是否存在field项
|
||||
if pattern.search(settings):
|
||||
settings = pattern.sub(f"{field}={value}", settings)
|
||||
else:
|
||||
settings += f"{field}={value}\n"
|
||||
|
||||
with open(Path.home() / ".config/vlc/vlcrc", "w") as wf:
|
||||
wf.write(settings)
|
||||
|
||||
cls.ret = f"Set {field} to {value}"
|
||||
return cls.ret
|
||||
|
||||
@classmethod
|
||||
def toggle_fullscreen(cls, enable=None):
|
||||
"""
|
||||
Toggle fullscreen mode or set it explicitly based on the enable parameter.
|
||||
|
||||
Args:
|
||||
enable (bool, optional): If provided, explicitly set fullscreen mode (True for fullscreen, False for windowed)
|
||||
|
||||
Returns:
|
||||
str: Success or error message
|
||||
"""
|
||||
if enable is not None:
|
||||
command = "fullscreen" if enable else "fullscreen off"
|
||||
else:
|
||||
command = "fullscreen"
|
||||
response = cls._make_request("status.xml", {"command": command})
|
||||
if response:
|
||||
action = "enabled" if enable is True else "disabled" if enable is False else "toggled"
|
||||
cls.ret = f"Fullscreen mode {action}"
|
||||
return cls.ret
|
||||
cls.ret = "Error changing fullscreen mode"
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_media_files(cls, path, suffix=None):
|
||||
"""
|
||||
Gets the media files for the specified path.
|
||||
|
||||
Args:
|
||||
path (str): The path to the media files
|
||||
suffix (List[str], optional): The suffix of the media files.
|
||||
Defaults to ['mp4', 'avi', 'mkv', 'mov', 'mp3', 'm4a', 'wav']
|
||||
"""
|
||||
# Set default suffix if not provided
|
||||
if suffix is None:
|
||||
suffix = ["mp4", "avi", "mkv", "mov", "mp3", "m4a", "wav"]
|
||||
|
||||
# Validate path
|
||||
if not path:
|
||||
cls.ret = "Path cannot be empty"
|
||||
return None
|
||||
|
||||
if not os.path.exists(path):
|
||||
cls.ret = f"Path not found: {path}"
|
||||
return None
|
||||
|
||||
# Initialize result list
|
||||
media_files = []
|
||||
|
||||
# Convert suffix list to lowercase for case-insensitive comparison
|
||||
suffix = [s.lower() for s in suffix]
|
||||
|
||||
# Walk through directory
|
||||
try:
|
||||
for root, _, files in os.walk(path):
|
||||
for file in files:
|
||||
# Check if file extension matches any of the specified suffixes
|
||||
if any(file.lower().endswith(f".{s}") for s in suffix):
|
||||
# Add full path of the file to results
|
||||
full_path = os.path.join(root, file)
|
||||
media_files.append(full_path)
|
||||
|
||||
except Exception as e:
|
||||
cls.ret = f"Error while scanning directory: {str(e)}"
|
||||
return None
|
||||
|
||||
cls.ret = media_files
|
||||
return cls.ret
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
"""
|
||||
AutoGLM agent implementation
|
||||
"""
|
||||
|
||||
from .main import AutoGLMAgent
|
||||
|
||||
__all__ = ["AutoGLMAgent"]
|
||||
|
|
@ -1,265 +0,0 @@
|
|||
import logging
|
||||
import re
|
||||
from base64 import b64encode
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
from typing import Dict, List
|
||||
|
||||
from .prompt.accessibility_tree_handle import linearize_accessibility_tree, trim_accessibility_tree
|
||||
from .prompt.grounding_agent import GroundingAgent as Agent
|
||||
from .tools.package.google_chrome import BrowserTools
|
||||
from .prompt.procedural_memory import Prompt
|
||||
|
||||
logger = logging.getLogger("desktopenv.agent")
|
||||
|
||||
pure_text_settings = ["a11y_tree"]
|
||||
|
||||
def resize_image(image, w, h):
|
||||
img = Image.open(BytesIO(image))
|
||||
# resize to max_pixel_num max_pixels
|
||||
img = img.resize((w, h))
|
||||
buf = BytesIO()
|
||||
img.save(buf, format='PNG') # 指定保存格式,比如 PNG、JPEG
|
||||
img_bytes = buf.getvalue() # 得到 bytes 数据
|
||||
return img_bytes
|
||||
|
||||
def parse_code_from_string(input_string):
|
||||
# input_string = "\n".join([line.strip() for line in input_string.split(';') if line.strip()])
|
||||
if input_string.strip() in ["WAIT", "DONE", "FAIL"]:
|
||||
return [input_string.strip()]
|
||||
|
||||
# This regular expression will match both ```code``` and ```python code```
|
||||
# and capture the `code` part. It uses a non-greedy match for the content inside.
|
||||
pattern = r"```(?:\w+\s+)?(.*?)```"
|
||||
# Find all non-overlapping matches in the string
|
||||
matches = re.findall(pattern, input_string, re.DOTALL)
|
||||
|
||||
# The regex above captures the content inside the triple backticks.
|
||||
# The `re.DOTALL` flag allows the dot `.` to match newline characters as well,
|
||||
# so the code inside backticks can span multiple lines.
|
||||
|
||||
# matches now contains all the captured code snippets
|
||||
|
||||
codes = []
|
||||
|
||||
for match in matches:
|
||||
match = match.strip()
|
||||
commands = ["WAIT", "DONE", "FAIL"] # fixme: updates this part when we have more commands
|
||||
|
||||
if match in commands:
|
||||
codes.append(match.strip())
|
||||
elif match.split("\n")[-1] in commands:
|
||||
if len(match.split("\n")) > 1:
|
||||
codes.append("\n".join(match.split("\n")[:-1]))
|
||||
codes.append(match.split("\n")[-1])
|
||||
else:
|
||||
codes.append(match)
|
||||
|
||||
return codes
|
||||
|
||||
|
||||
class AutoGLMAgent:
|
||||
def __init__(
|
||||
self,
|
||||
action_space="autoglm_computer_use",
|
||||
observation_type="a11y_tree",
|
||||
max_trajectory_length=3,
|
||||
a11y_tree_max_items=300,
|
||||
with_image: bool = True,
|
||||
screen_size = (1920, 1080),
|
||||
image_size=(1920, 1080),
|
||||
with_atree: bool = False,
|
||||
glm41v_format: bool = True,
|
||||
relative_coordinate: bool = True,
|
||||
client_password="password",
|
||||
gen_func=None,
|
||||
tool_in_sys_msg: bool = True,
|
||||
):
|
||||
self.action_space = action_space
|
||||
self.observation_type = observation_type
|
||||
assert action_space in ["autoglm_computer_use"], "Invalid action space"
|
||||
assert observation_type in ["a11y_tree"], "Invalid observation type"
|
||||
self.max_trajectory_length = max_trajectory_length
|
||||
self.a11y_tree_max_items = a11y_tree_max_items
|
||||
self.with_image = with_image
|
||||
self.screen_size = screen_size
|
||||
self.image_size = image_size
|
||||
self.with_atree = with_atree
|
||||
self.glm41v_format = glm41v_format
|
||||
self.relative_coordinate = relative_coordinate
|
||||
self.client_password = client_password
|
||||
self.gen_func = gen_func
|
||||
self.tool_in_sys_msg = tool_in_sys_msg
|
||||
|
||||
self.tool_list = {
|
||||
"libreoffice_calc": "CalcTools",
|
||||
"libreoffice_impress": "ImpressTools",
|
||||
"libreoffice_writer": "WriterTools",
|
||||
"code": "CodeTools",
|
||||
"vlc": "VLCTools",
|
||||
"google_chrome": "BrowserTools",
|
||||
}
|
||||
|
||||
Agent.relative_coordinate = relative_coordinate
|
||||
|
||||
self.contents = []
|
||||
|
||||
@property
|
||||
def turn_number(self):
|
||||
return len(self.contents)
|
||||
|
||||
def prepare(self, instruction: str, obs: Dict, history: List, last_result: str = "") -> List:
|
||||
"""
|
||||
Predict the next action(s) based on the current observation.
|
||||
"""
|
||||
if "exe_result" in obs and not last_result:
|
||||
last_result = obs["exe_result"]
|
||||
if self.contents:
|
||||
self.contents[-1]["exe_result"] = last_result
|
||||
|
||||
cur_app = obs["cur_app"]
|
||||
logger.info(f"current app is {cur_app}")
|
||||
|
||||
if cur_app:
|
||||
tool_name = cur_app.strip().lower().replace("-", "_")
|
||||
tool_name = tool_name if tool_name in self.tool_list.keys() else None
|
||||
else:
|
||||
tool_name = None
|
||||
|
||||
setup_prompt, func_def_prompt, note_prompt = Prompt.construct_procedural_memory(
|
||||
Agent, app_name=tool_name, client_password=self.client_password, with_image=self.with_image, with_atree=self.with_atree, relative_coordinate=self.relative_coordinate, glm41v_format=self.glm41v_format
|
||||
)
|
||||
if self.tool_in_sys_msg:
|
||||
system_message = setup_prompt + "\n\n" + func_def_prompt + "\n\n" + note_prompt
|
||||
else:
|
||||
system_message = setup_prompt + "\n\n" + note_prompt
|
||||
system_message += "\n\n**IMPORTANT** You are asked to complete the following task: {}".format(instruction)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_message,
|
||||
}
|
||||
]
|
||||
messages.extend(history)
|
||||
|
||||
if obs["apps"]:
|
||||
app_str = "Window ID App Name Title\n"
|
||||
for window_id, app in obs["apps"].items():
|
||||
app_str += f"{window_id} {app['app_name']} {app['title']}\n"
|
||||
else:
|
||||
app_str = "None"
|
||||
|
||||
last_result = last_result.strip() if last_result else "None"
|
||||
last_result = last_result[:2000] + "..." if len(last_result) > 2000 else last_result
|
||||
|
||||
tree = linearize_accessibility_tree(obs["accessibility_tree"], "Ubuntu")
|
||||
tree = trim_accessibility_tree(tree, 300)
|
||||
|
||||
app_info = obs["app_info"].strip() if obs["app_info"] else "None"
|
||||
app_info = app_info[:5000] + "..." if len(app_info) > 5000 else app_info
|
||||
|
||||
prompt = "* Apps: {}\n\n* Current App: {}{}\n\n* App Info: {}\n\n* Previous Action Result: {}".format(
|
||||
app_str.strip(),
|
||||
obs["cur_window_id"].strip() if obs["cur_window_id"] in app_str else "None",
|
||||
'\n\n* A11y Tree: {}'.format(tree.strip()) if self.with_atree else "",
|
||||
app_info,
|
||||
last_result if last_result else "None",
|
||||
) + (
|
||||
"\n\n" + func_def_prompt if not self.tool_in_sys_msg else ""
|
||||
)
|
||||
|
||||
content = [{"type": "text", "text": prompt}]
|
||||
if self.with_image and obs.get('screenshot'):
|
||||
screenshot = resize_image(obs['screenshot'], self.image_size[0], self.image_size[1])
|
||||
content = [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{b64encode(screenshot).decode('utf-8')}",
|
||||
"detail": "high",
|
||||
},
|
||||
}
|
||||
] + content
|
||||
|
||||
messages.append({"role": "user", "content": content})
|
||||
|
||||
return messages
|
||||
|
||||
def execute(self, response, obs):
|
||||
try:
|
||||
actions = parse_code_from_string(response)
|
||||
action = actions[0]
|
||||
logger.info(f"The pesudo action is {action}")
|
||||
|
||||
if "Agent." in action:
|
||||
actions = [
|
||||
eval(action),
|
||||
]
|
||||
elif "BrowserTools." in action: # TODO: special check for BrowserTools
|
||||
actions = [
|
||||
eval(action),
|
||||
]
|
||||
else:
|
||||
actions = Agent.tool_commands(action, obs["cur_app"].strip().replace("-", "_").lower())
|
||||
logger.info(f"The grounded action is {actions[0]}")
|
||||
except Exception as e:
|
||||
print("Failed to parse action from response", e)
|
||||
actions = []
|
||||
|
||||
return actions
|
||||
|
||||
def format_history(self, max_turns=30):
|
||||
history = []
|
||||
for ix in range(self.turn_number):
|
||||
if ix == 0:
|
||||
env_input = "**Environment State (Omitted)**"
|
||||
else:
|
||||
env_input = (
|
||||
f"**Environment State (Omitted)**\nPrevious Action Result: {self.contents[ix - 1]['exe_result']}"
|
||||
)
|
||||
|
||||
env_input = env_input[:2000] + "..." if len(env_input) > 2000 else env_input
|
||||
response = (
|
||||
self.contents[ix]["response"][:1500] + "..."
|
||||
if len(self.contents[ix]["response"]) > 1500
|
||||
else self.contents[ix]["response"]
|
||||
)
|
||||
history.append({"role": "user", "content": [{"type": "text", "text": env_input}]})
|
||||
history.append({"role": "assistant", "content": [{"type": "text", "text": response}]})
|
||||
|
||||
return history[-max_turns * 2:]
|
||||
|
||||
def predict(self, instruction: str, obs: Dict) -> List:
|
||||
history = self.format_history()
|
||||
messages = self.prepare(instruction, obs, history)
|
||||
|
||||
assert self.gen_func is not None, "gen_func is not set"
|
||||
try:
|
||||
response = self.gen_func(messages)
|
||||
except Exception as e:
|
||||
logger.error("Failed to call gen_func, Error: " + str(e))
|
||||
response = ""
|
||||
|
||||
logger.info("RESPONSE: %s", response)
|
||||
|
||||
actions = self.execute(response, obs)
|
||||
|
||||
# update the contents
|
||||
self.contents.append(
|
||||
{
|
||||
"instruction": instruction,
|
||||
"index": len(self.contents),
|
||||
"response": response,
|
||||
"action": "Parse error" if not actions else actions[0],
|
||||
"exe_result": "Invalid action" if not actions else "",
|
||||
**obs,
|
||||
}
|
||||
)
|
||||
return response, actions
|
||||
|
||||
def reset(self, _logger=None):
|
||||
global logger
|
||||
logger = _logger if _logger is not None else logging.getLogger("desktopenv.aguvis_agent")
|
||||
|
||||
self.contents = []
|
||||
|
|
@ -1,329 +0,0 @@
|
|||
import io
|
||||
import re
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import List, Tuple
|
||||
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
from .deduplicate_node import filter_similar_nodes
|
||||
|
||||
attributes_ns_ubuntu = "https://accessibility.windows.example.org/ns/attributes"
|
||||
attributes_ns_windows = "https://accessibility.windows.example.org/ns/attributes"
|
||||
state_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/state"
|
||||
state_ns_windows = "https://accessibility.windows.example.org/ns/state"
|
||||
component_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/component"
|
||||
component_ns_windows = "https://accessibility.windows.example.org/ns/component"
|
||||
value_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/value"
|
||||
value_ns_windows = "https://accessibility.windows.example.org/ns/value"
|
||||
class_ns_windows = "https://accessibility.windows.example.org/ns/class"
|
||||
|
||||
|
||||
def find_leaf_nodes(xlm_file_str):
|
||||
if not xlm_file_str:
|
||||
return []
|
||||
|
||||
root = ET.fromstring(xlm_file_str)
|
||||
|
||||
# Recursive function to traverse the XML tree and collect leaf nodes
|
||||
def collect_leaf_nodes(node, leaf_nodes):
|
||||
# If the node has no children, it is a leaf node, add it to the list
|
||||
if not list(node):
|
||||
leaf_nodes.append(node)
|
||||
# If the node has children, recurse on each child
|
||||
for child in node:
|
||||
collect_leaf_nodes(child, leaf_nodes)
|
||||
|
||||
# List to hold all leaf nodes
|
||||
leaf_nodes = []
|
||||
collect_leaf_nodes(root, leaf_nodes)
|
||||
return leaf_nodes
|
||||
|
||||
|
||||
def judge_node(node: ET, platform="Ubuntu", check_image=False) -> bool:
|
||||
if platform == "Ubuntu":
|
||||
_state_ns = state_ns_ubuntu
|
||||
_component_ns = component_ns_ubuntu
|
||||
elif platform == "Windows":
|
||||
_state_ns = state_ns_windows
|
||||
_component_ns = component_ns_windows
|
||||
else:
|
||||
raise ValueError("Invalid platform, must be 'Ubuntu' or 'Windows'")
|
||||
|
||||
keeps: bool = (
|
||||
node.tag.startswith("document")
|
||||
or node.tag.endswith("item")
|
||||
or node.tag.endswith("button")
|
||||
or node.tag.endswith("heading")
|
||||
or node.tag.endswith("label")
|
||||
or node.tag.endswith("scrollbar")
|
||||
or node.tag.endswith("searchbox")
|
||||
or node.tag.endswith("textbox")
|
||||
or node.tag.endswith("link")
|
||||
or node.tag.endswith("tabelement")
|
||||
or node.tag.endswith("textfield")
|
||||
or node.tag.endswith("textarea")
|
||||
or node.tag.endswith("menu")
|
||||
or node.tag
|
||||
in {
|
||||
"alert",
|
||||
"canvas",
|
||||
"check-box",
|
||||
"combo-box",
|
||||
"entry",
|
||||
"icon",
|
||||
"image",
|
||||
"paragraph",
|
||||
"scroll-bar",
|
||||
"section",
|
||||
"slider",
|
||||
"static",
|
||||
"table-cell",
|
||||
"terminal",
|
||||
"text",
|
||||
"netuiribbontab",
|
||||
"start",
|
||||
"trayclockwclass",
|
||||
"traydummysearchcontrol",
|
||||
"uiimage",
|
||||
"uiproperty",
|
||||
"uiribboncommandbar",
|
||||
}
|
||||
)
|
||||
keeps = (
|
||||
keeps
|
||||
and (
|
||||
platform == "Ubuntu"
|
||||
and node.get("{{{:}}}showing".format(_state_ns), "false") == "true"
|
||||
and node.get("{{{:}}}visible".format(_state_ns), "false") == "true"
|
||||
or platform == "Windows"
|
||||
and node.get("{{{:}}}visible".format(_state_ns), "false") == "true"
|
||||
)
|
||||
and (
|
||||
node.get("name", "") != ""
|
||||
or node.text is not None
|
||||
and len(node.text) > 0
|
||||
or check_image
|
||||
and node.get("image", "false") == "true"
|
||||
)
|
||||
)
|
||||
# and (
|
||||
# node.get("{{{:}}}enabled".format(_state_ns), "false") == "true"
|
||||
# or node.get("{{{:}}}editable".format(_state_ns), "false") == "true"
|
||||
# or node.get("{{{:}}}expandable".format(_state_ns), "false") == "true"
|
||||
# or node.get("{{{:}}}checkable".format(_state_ns), "false") == "true"
|
||||
# ) \
|
||||
|
||||
coordinates: Tuple[int, int] = eval(node.get("{{{:}}}screencoord".format(_component_ns), "(-1, -1)"))
|
||||
sizes: Tuple[int, int] = eval(node.get("{{{:}}}size".format(_component_ns), "(-1, -1)"))
|
||||
keeps = keeps and coordinates[0] >= 0 and coordinates[1] >= 0 and sizes[0] > 0 and sizes[1] > 0
|
||||
return keeps
|
||||
|
||||
|
||||
def filter_nodes(root: ET, platform="Ubuntu", check_image=False):
|
||||
filtered_nodes = []
|
||||
|
||||
for node in root.iter():
|
||||
if judge_node(node, platform, check_image):
|
||||
filtered_nodes.append(node)
|
||||
|
||||
return filtered_nodes
|
||||
|
||||
|
||||
def draw_bounding_boxes(nodes, image_file_content, down_sampling_ratio=1.0, platform="Ubuntu"):
|
||||
|
||||
if platform == "Ubuntu":
|
||||
_state_ns = state_ns_ubuntu
|
||||
_component_ns = component_ns_ubuntu
|
||||
_value_ns = value_ns_ubuntu
|
||||
elif platform == "Windows":
|
||||
_state_ns = state_ns_windows
|
||||
_component_ns = component_ns_windows
|
||||
_value_ns = value_ns_windows
|
||||
else:
|
||||
raise ValueError("Invalid platform, must be 'Ubuntu' or 'Windows'")
|
||||
|
||||
# Load the screenshot image
|
||||
image_stream = io.BytesIO(image_file_content)
|
||||
image = Image.open(image_stream)
|
||||
if float(down_sampling_ratio) != 1.0:
|
||||
image = image.resize((int(image.size[0] * down_sampling_ratio), int(image.size[1] * down_sampling_ratio)))
|
||||
draw = ImageDraw.Draw(image)
|
||||
marks = []
|
||||
drew_nodes = []
|
||||
text_informations: List[str] = ["index\ttag\tname\ttext"]
|
||||
|
||||
try:
|
||||
# Adjust the path to the font file you have or use a default one
|
||||
font = ImageFont.truetype("arial.ttf", 15)
|
||||
except IOError:
|
||||
# Fallback to a basic font if the specified font can't be loaded
|
||||
font = ImageFont.load_default()
|
||||
|
||||
index = 1
|
||||
|
||||
# Loop over all the visible nodes and draw their bounding boxes
|
||||
for _node in nodes:
|
||||
coords_str = _node.attrib.get("{{{:}}}screencoord".format(_component_ns))
|
||||
size_str = _node.attrib.get("{{{:}}}size".format(_component_ns))
|
||||
|
||||
if coords_str and size_str:
|
||||
try:
|
||||
# Parse the coordinates and size from the strings
|
||||
coords = tuple(map(int, coords_str.strip("()").split(", ")))
|
||||
size = tuple(map(int, size_str.strip("()").split(", ")))
|
||||
|
||||
import copy
|
||||
|
||||
original_coords = copy.deepcopy(coords)
|
||||
original_size = copy.deepcopy(size)
|
||||
|
||||
if float(down_sampling_ratio) != 1.0:
|
||||
# Downsample the coordinates and size
|
||||
coords = tuple(int(coord * down_sampling_ratio) for coord in coords)
|
||||
size = tuple(int(s * down_sampling_ratio) for s in size)
|
||||
|
||||
# Check for negative sizes
|
||||
if size[0] <= 0 or size[1] <= 0:
|
||||
raise ValueError(f"Size must be positive, got: {size}")
|
||||
|
||||
# Calculate the bottom-right corner of the bounding box
|
||||
bottom_right = (coords[0] + size[0], coords[1] + size[1])
|
||||
|
||||
# Check that bottom_right > coords (x1 >= x0, y1 >= y0)
|
||||
if bottom_right[0] < coords[0] or bottom_right[1] < coords[1]:
|
||||
raise ValueError(f"Invalid coordinates or size, coords: {coords}, size: {size}")
|
||||
|
||||
# Check if the area only contains one color
|
||||
cropped_image = image.crop((*coords, *bottom_right))
|
||||
if len(set(list(cropped_image.getdata()))) == 1:
|
||||
continue
|
||||
|
||||
# Draw rectangle on image
|
||||
draw.rectangle([coords, bottom_right], outline="red", width=1)
|
||||
|
||||
# Draw index number at the bottom left of the bounding box with black background
|
||||
text_position = (coords[0], bottom_right[1]) # Adjust Y to be above the bottom right
|
||||
text_bbox: Tuple[int, int, int, int] = draw.textbbox(text_position, str(index), font=font, anchor="lb")
|
||||
# offset: int = bottom_right[1]-text_bbox[3]
|
||||
# text_bbox = (text_bbox[0], text_bbox[1]+offset, text_bbox[2], text_bbox[3]+offset)
|
||||
|
||||
# draw.rectangle([text_position, (text_position[0] + 25, text_position[1] + 18)], fill='black')
|
||||
draw.rectangle(text_bbox, fill="black")
|
||||
draw.text(text_position, str(index), font=font, anchor="lb", fill="white")
|
||||
|
||||
# each mark is an x, y, w, h tuple
|
||||
marks.append([original_coords[0], original_coords[1], original_size[0], original_size[1]])
|
||||
drew_nodes.append(_node)
|
||||
|
||||
if _node.text:
|
||||
node_text = _node.text if '"' not in _node.text else '"{:}"'.format(_node.text.replace('"', '""'))
|
||||
elif _node.get("{{{:}}}class".format(class_ns_windows), "").endswith("EditWrapper") and _node.get(
|
||||
"{{{:}}}value".format(_value_ns)
|
||||
):
|
||||
node_text = _node.get("{{{:}}}value".format(_value_ns), "")
|
||||
node_text = node_text if '"' not in node_text else '"{:}"'.format(node_text.replace('"', '""'))
|
||||
else:
|
||||
node_text = '""'
|
||||
text_information: str = "{:d}\t{:}\t{:}\t{:}".format(index, _node.tag, _node.get("name", ""), node_text)
|
||||
text_informations.append(text_information)
|
||||
|
||||
index += 1
|
||||
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
output_image_stream = io.BytesIO()
|
||||
image.save(output_image_stream, format="PNG")
|
||||
image_content = output_image_stream.getvalue()
|
||||
|
||||
return marks, drew_nodes, "\n".join(text_informations), image_content
|
||||
|
||||
|
||||
def print_nodes_with_indent(nodes, indent=0):
|
||||
for node in nodes:
|
||||
print(" " * indent, node.tag, node.attrib)
|
||||
print_nodes_with_indent(node, indent + 2)
|
||||
|
||||
|
||||
def find_active_applications(tree, state_ns):
|
||||
apps_with_active_tag = []
|
||||
for application in list(tree.getroot()):
|
||||
app_name = application.attrib.get("name")
|
||||
for frame in application:
|
||||
is_active = frame.attrib.get("{{{:}}}active".format(state_ns), "false")
|
||||
if is_active == "true":
|
||||
apps_with_active_tag.append(app_name)
|
||||
if apps_with_active_tag:
|
||||
to_keep = apps_with_active_tag + ["gnome-shell"]
|
||||
else:
|
||||
to_keep = ["gjs", "gnome-shell"]
|
||||
return to_keep
|
||||
|
||||
|
||||
def linearize_accessibility_tree(accessibility_tree, platform="Ubuntu"):
|
||||
if platform == "Ubuntu":
|
||||
_attributes_ns = attributes_ns_ubuntu
|
||||
_state_ns = state_ns_ubuntu
|
||||
_component_ns = component_ns_ubuntu
|
||||
_value_ns = value_ns_ubuntu
|
||||
elif platform == "Windows":
|
||||
_attributes_ns = attributes_ns_windows
|
||||
_state_ns = state_ns_windows
|
||||
_component_ns = component_ns_windows
|
||||
_value_ns = value_ns_windows
|
||||
else:
|
||||
raise ValueError("Invalid platform, must be 'Ubuntu' or 'Windows'")
|
||||
|
||||
try:
|
||||
tree = ET.ElementTree(ET.fromstring(accessibility_tree))
|
||||
keep_apps = find_active_applications(tree, _state_ns)
|
||||
|
||||
# Remove inactive applications
|
||||
for application in list(tree.getroot()):
|
||||
if application.get("name") not in keep_apps:
|
||||
tree.getroot().remove(application)
|
||||
|
||||
filtered_nodes = filter_nodes(tree.getroot(), platform, check_image=True)
|
||||
linearized_accessibility_tree = ["tag\ttext\tposition (center x & y)\tsize (w & h)"]
|
||||
|
||||
# Linearize the accessibility tree nodes into a table format
|
||||
for node in filtered_nodes:
|
||||
try:
|
||||
text = node.text if node.text is not None else ""
|
||||
text = text.strip()
|
||||
name = node.get("name", "").strip()
|
||||
if text == "":
|
||||
text = name
|
||||
elif name != "" and text != name:
|
||||
text = f"{name} ({text})"
|
||||
|
||||
text = text.replace("\n", "\\n")
|
||||
pos = node.get("{{{:}}}screencoord".format(_component_ns), "")
|
||||
size = node.get("{{{:}}}size".format(_component_ns), "")
|
||||
|
||||
x, y = re.match(f"\((\d+), (\d+)\)", pos).groups()
|
||||
w, h = re.match(f"\((\d+), (\d+)\)", size).groups()
|
||||
x_mid, y_mid = int(x) + int(w) // 2, int(y) + int(h) // 2
|
||||
|
||||
linearized_accessibility_tree.append(
|
||||
"{:}\t{:}\t{:}\t{:}".format(node.tag, text, f"({x_mid}, {y_mid})", size)
|
||||
)
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
# Filter out similar nodes
|
||||
linearized_accessibility_tree = filter_similar_nodes("\n".join(linearized_accessibility_tree))
|
||||
except Exception as e:
|
||||
print(f"Error in linearize_accessibility_tree: {e}")
|
||||
linearized_accessibility_tree = ""
|
||||
|
||||
return linearized_accessibility_tree
|
||||
|
||||
|
||||
def trim_accessibility_tree(linearized_accessibility_tree, max_items):
|
||||
lines = linearized_accessibility_tree.strip().split("\n")
|
||||
if len(lines) > max_items:
|
||||
lines = lines[:max_items]
|
||||
linearized_accessibility_tree = "\n".join(lines)
|
||||
linearized_accessibility_tree += "\n..."
|
||||
return linearized_accessibility_tree
|
||||
|
|
@ -1,100 +0,0 @@
|
|||
import re
|
||||
|
||||
|
||||
def parse_line(line):
|
||||
# 解析格式,如:label Google Chrome (191, 13) (104, 17)
|
||||
pattern = r"^(\S+)\s+(.+?)\s+\((\d+), (\d+)\)\s+\((\d+), (\d+)\)"
|
||||
m = re.match(pattern, line)
|
||||
if not m:
|
||||
return None
|
||||
node_type, text, cx, cy, w, h = m.groups()
|
||||
cx, cy, w, h = map(int, (cx, cy, w, h))
|
||||
# bounding box as (x1, y1, x2, y2)
|
||||
x1 = cx - w // 2
|
||||
y1 = cy - h // 2
|
||||
x2 = x1 + w
|
||||
y2 = y1 + h
|
||||
return {
|
||||
"type": node_type,
|
||||
"text": text.strip(),
|
||||
"bbox": (x1, y1, x2, y2),
|
||||
"center": (cx, cy),
|
||||
"size": (w, h),
|
||||
"raw": line,
|
||||
}
|
||||
|
||||
|
||||
def iou(box1, box2):
|
||||
# box: (x1, y1, x2, y2)
|
||||
xi1 = max(box1[0], box2[0])
|
||||
yi1 = max(box1[1], box2[1])
|
||||
xi2 = min(box1[2], box2[2])
|
||||
yi2 = min(box1[3], box2[3])
|
||||
inter_width = max(0, xi2 - xi1)
|
||||
inter_height = max(0, yi2 - yi1)
|
||||
inter_area = inter_width * inter_height
|
||||
area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
||||
area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
||||
union = area1 + area2 - inter_area
|
||||
if union == 0:
|
||||
return 0
|
||||
return inter_area / union
|
||||
|
||||
|
||||
def norm_text(s):
|
||||
# 归一化文本:小写、去空格等
|
||||
return re.sub(r"\s+", "", s.lower())
|
||||
|
||||
|
||||
def text_similarity(a, b):
|
||||
# 简单判定:完全一致为1,否则0
|
||||
na, nb = norm_text(a), norm_text(b)
|
||||
if na == nb:
|
||||
return 1.0
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def filter_similar_nodes(linearized_accessibility_tree):
|
||||
lines = [ln for ln in linearized_accessibility_tree.split("\n") if ln.strip()]
|
||||
# parse all nodes
|
||||
nodes = []
|
||||
for ln in lines:
|
||||
node = parse_line(ln)
|
||||
if node:
|
||||
nodes.append(node)
|
||||
else:
|
||||
# 解析不了的保留
|
||||
nodes.append({"raw": ln, "invalid": True})
|
||||
filtered = []
|
||||
removed = [False] * len(nodes)
|
||||
# 阈值可自行调整
|
||||
IOU_THRESH = 0.2
|
||||
TEXT_THRESH = 0.9
|
||||
for i, ni in enumerate(nodes):
|
||||
if ni.get("invalid"):
|
||||
filtered.append(ni["raw"])
|
||||
continue
|
||||
if removed[i]:
|
||||
continue
|
||||
for j in range(i + 1, len(nodes)):
|
||||
nj = nodes[j]
|
||||
if nj.get("invalid"):
|
||||
continue
|
||||
iou_val = iou(ni["bbox"], nj["bbox"])
|
||||
text_sim = text_similarity(ni["text"], nj["text"])
|
||||
if iou_val > IOU_THRESH and text_sim > TEXT_THRESH:
|
||||
# 二者极其相似,移除后者
|
||||
removed[j] = True
|
||||
# print(f"移除: {nj['raw']} (与 {ni['raw']} 相似度高)")
|
||||
# 保留未被标记为移除的
|
||||
if not removed[i]:
|
||||
filtered.append(ni["raw"])
|
||||
return "\n".join(filtered)
|
||||
|
||||
|
||||
# 示例用法
|
||||
if __name__ == "__main__":
|
||||
linearized_accessibility_tree = "tag\ttext\tposition (center x & y)\tsize (w & h)\nicon\t\t(1853, 1001)\t(64, 64)\nlabel\tHome\t(1853, 1045)\t(40, 17)\nlabel\tActivities\t(49, 13)\t(63, 17)\ntext\tActivities\t(49, 13)\t(63, 17)\nlabel\tApr 17 17∶04\t(995, 13)\t(117, 27)\ntext\tApr 17 17∶04\t(995, 13)\t(87, 18)\nmenu\tSystem\t(1867, 13)\t(106, 27)\npush-button\tGoogle Chrome\t(35, 65)\t(70, 64)\npush-button\tThunderbird Mail\t(35, 133)\t(70, 64)\npush-button\tVisual Studio Code\t(35, 201)\t(70, 64)\npush-button\tVLC media player\t(35, 269)\t(70, 64)\npush-button\tLibreOffice Writer\t(35, 337)\t(70, 64)\npush-button\tLibreOffice Calc\t(35, 405)\t(70, 64)\npush-button\tLibreOffice Impress\t(35, 473)\t(70, 64)\npush-button\tGNU Image Manipulation Program\t(35, 541)\t(70, 64)\npush-button\tFiles\t(35, 609)\t(70, 64)\npush-button\tUbuntu Software\t(35, 677)\t(70, 64)\npush-button\tHelp\t(35, 745)\t(70, 64)\npush-button\tTrash\t(35, 816)\t(70, 64)\ntoggle-button\tShow Applications\t(35, 1045)\t(70, 70)"
|
||||
result = filter_similar_nodes(linearized_accessibility_tree)
|
||||
print(result)
|
||||
|
|
@ -1,260 +0,0 @@
|
|||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger("desktopenv.agent")
|
||||
|
||||
|
||||
def agent_action(func):
|
||||
func.is_agent_action = True
|
||||
return func
|
||||
|
||||
|
||||
switch_window_code = """import subprocess;
|
||||
import pyautogui;
|
||||
pyautogui.press('escape');
|
||||
time.sleep(0.5);
|
||||
subprocess.run(['wmctrl', '-ia', 'WINDOW_ID'])
|
||||
subprocess.run(['wmctrl', '-ir', 'WINDOW_ID', '-b', 'add,maximized_vert,maximized_horz'])
|
||||
print('Switch to WINDOW_ID')"""
|
||||
|
||||
launch_app_commands = {
|
||||
# Web Browser
|
||||
"chrome": "google-chrome --remote-debugging-port=1337",
|
||||
# File Manager
|
||||
"files": "nautilus",
|
||||
# Terminal
|
||||
"terminal": 'export DBUS_SESSION_BUS_ADDRESS="unix:path=/run/user/1000/bus" && gnome-terminal',
|
||||
# Utilities
|
||||
"gedit": "gedit",
|
||||
# Office
|
||||
"libreoffice writer": "libreoffice --writer",
|
||||
"libreoffice calc": "libreoffice --calc",
|
||||
"libreoffice impress": "libreoffice --impress",
|
||||
# System
|
||||
"settings": 'export DBUS_SESSION_BUS_ADDRESS="unix:path=/run/user/1000/bus" && gnome-control-center',
|
||||
# Multimedia
|
||||
"vlc": "vlc",
|
||||
"gimp": "gimp",
|
||||
# IDE
|
||||
"vs code": "code",
|
||||
# Email
|
||||
"thunderbird": "thunderbird",
|
||||
}
|
||||
|
||||
|
||||
class GroundingAgent:
|
||||
|
||||
tool_list = {
|
||||
"libreoffice_calc": "CalcTools",
|
||||
"libreoffice_impress": "ImpressTools",
|
||||
"libreoffice_writer": "WriterTools",
|
||||
"code": "CodeTools",
|
||||
"vlc": "VLCTools",
|
||||
"google_chrome": "BrowserTools",
|
||||
}
|
||||
|
||||
relative_coordinate = True # whether the coordinates are relative (0-1000) or absolute (e.g. 1920x1080)
|
||||
|
||||
@classmethod
|
||||
def tool_commands(cls, code: str, tool_name: str):
|
||||
command = f"from {tool_name} import *; "
|
||||
command += code
|
||||
|
||||
tool_class = cls.tool_list[tool_name]
|
||||
command += f"; {tool_class}.print_result()"
|
||||
|
||||
return [
|
||||
command,
|
||||
]
|
||||
|
||||
@classmethod
|
||||
@agent_action
|
||||
def click(
|
||||
cls,
|
||||
coordinate: List,
|
||||
num_clicks: int = 1,
|
||||
button_type: str = "left",
|
||||
):
|
||||
"""
|
||||
Click on the element
|
||||
|
||||
Args:
|
||||
coordinate (List): [x, y], coordinate of the element to click on
|
||||
num_clicks (int): number of times to click the element
|
||||
button_type (str): which mouse button to press ("left", "middle", or "right")
|
||||
"""
|
||||
command = ""
|
||||
x, y = coordinate
|
||||
if cls.relative_coordinate:
|
||||
x, y = round(x * 1920 / 1000), round(y * 1080 / 1000)
|
||||
command += f"""pyautogui.click({x}, {y}, clicks={num_clicks}, button={repr(button_type)}); print("Click Success")""" # TODO: 最大化窗口需要一次调用
|
||||
return command
|
||||
|
||||
@classmethod
|
||||
@agent_action
|
||||
def type(
|
||||
cls,
|
||||
coordinate: Optional[List] = None,
|
||||
text: str = "",
|
||||
overwrite: bool = False,
|
||||
enter: bool = False,
|
||||
):
|
||||
"""
|
||||
Type text into the element
|
||||
|
||||
Args:
|
||||
coordinate (List): [x, y], coordinate of the element to type into. If None, typing starts at current cursor location
|
||||
text (str): the text to type
|
||||
overwrite (bool): True to overwrite existing text, False otherwise
|
||||
enter (bool): True to press enter after typing, False otherwise
|
||||
"""
|
||||
|
||||
command = ""
|
||||
|
||||
if coordinate is not None:
|
||||
# Start typing at the center of the element
|
||||
x, y = coordinate
|
||||
if cls.relative_coordinate:
|
||||
x, y = round(x * 1920 / 1000), round(y * 1080 / 1000)
|
||||
command += f"pyautogui.click({x}, {y}); "
|
||||
|
||||
if overwrite:
|
||||
command += f"pyautogui.hotkey('ctrl', 'a'); pyautogui.press('backspace'); "
|
||||
|
||||
command += f"pyautogui.write({repr(text)}); "
|
||||
|
||||
if enter:
|
||||
command += "pyautogui.press('enter'); "
|
||||
|
||||
command += "print('Type Success')"
|
||||
|
||||
return command
|
||||
|
||||
@classmethod
|
||||
@agent_action
|
||||
def drag_and_drop(cls, drag_from_coordinate: List, drop_on_coordinate: List):
|
||||
"""
|
||||
Drag element1 and drop it on element2
|
||||
|
||||
Args:
|
||||
drag_from_coordinate (List): [x, y], coordinate of element to drag
|
||||
drop_on_coordinate (List): [x, y], coordinate of element to drop on
|
||||
"""
|
||||
x1, y1 = drag_from_coordinate
|
||||
if cls.relative_coordinate:
|
||||
x1, y1 = round(x1 * 1920 / 1000), round(y1 * 1080 / 1000)
|
||||
x2, y2 = drop_on_coordinate
|
||||
if cls.relative_coordinate:
|
||||
x2, y2 = round(x2 * 1920 / 1000), round(y2 * 1080 / 1000)
|
||||
|
||||
command = f"pyautogui.moveTo({x1}, {y1}); "
|
||||
# TODO: specified duration?
|
||||
command += f"pyautogui.dragTo({x2}, {y2}, duration=1.); pyautogui.mouseUp(); "
|
||||
|
||||
command += "print('Drag and Drop Success')"
|
||||
|
||||
return command
|
||||
|
||||
@classmethod
|
||||
@agent_action
|
||||
def scroll(cls, coordinate: List, direction: str):
|
||||
"""
|
||||
Scroll the element in the specified direction
|
||||
|
||||
Args:
|
||||
coordinate (List): [x, y], coordinate of the element to scroll in
|
||||
direction (str): the direction to scroll ("up" or "down")
|
||||
"""
|
||||
x, y = coordinate
|
||||
if cls.relative_coordinate:
|
||||
x, y = round(x * 1920 / 1000), round(y * 1080 / 1000)
|
||||
amount = 100 if direction == "up" else -100
|
||||
return f"import pyautogui; pyautogui.moveTo({x}, {y}); pyautogui.scroll({amount}); print('Scroll Success')"
|
||||
|
||||
@classmethod
|
||||
@agent_action
|
||||
def open_app(cls, app_name: str):
|
||||
"""
|
||||
Open a specified application
|
||||
|
||||
Supported apps: chrome, files, terminal, gedit, libreoffice writer,
|
||||
libreoffice calc, libreoffice impress, vs code, vlc, gimp, settings, thunderbird
|
||||
|
||||
Args:
|
||||
app_name (str): name of the application to open
|
||||
"""
|
||||
|
||||
app_name = app_name.lower().strip()
|
||||
|
||||
if app_name not in launch_app_commands:
|
||||
command = f"print(f'{app_name} is not supported or recognized')"
|
||||
else:
|
||||
command = {
|
||||
"action_type": "OPEN_APP",
|
||||
"parameters": {"launch_app_command": launch_app_commands[app_name], "app_name": app_name},
|
||||
}
|
||||
|
||||
return command
|
||||
|
||||
@classmethod
|
||||
@agent_action
|
||||
def switch_window(cls, window_id: str):
|
||||
"""
|
||||
Switch to the window with the given window id
|
||||
|
||||
Args:
|
||||
window_id (str): the window id to switch to from the provided list of open windows
|
||||
"""
|
||||
return switch_window_code.replace("WINDOW_ID", window_id)
|
||||
|
||||
@classmethod
|
||||
@agent_action
|
||||
def hotkey(cls, keys: List):
|
||||
"""
|
||||
Press a hotkey combination
|
||||
|
||||
Args:
|
||||
keys (List): the keys to press in combination (e.g. ['ctrl', 'c'] for copy, ['prtsc'] for screenshot)
|
||||
"""
|
||||
# add quotes around the keys
|
||||
keys = [f"'{key}'" for key in keys]
|
||||
key_str = ", ".join(keys).replace("'", "\\'")
|
||||
return f"import pyautogui; pyautogui.hotkey({', '.join(keys)}); print(f'Press Hotkey: {key_str}')"
|
||||
|
||||
@classmethod
|
||||
@agent_action
|
||||
def quote(cls, content: str):
|
||||
"""
|
||||
Quote information from the current page for memory
|
||||
|
||||
Args:
|
||||
content (str): text summarized or copied from the page for later operation
|
||||
"""
|
||||
return f'''print("""{content}""")'''
|
||||
|
||||
@classmethod
|
||||
@agent_action
|
||||
def wait(cls):
|
||||
"""
|
||||
Wait for a while
|
||||
|
||||
"""
|
||||
return "WAIT"
|
||||
|
||||
@classmethod
|
||||
@agent_action
|
||||
def exit(cls, success: bool):
|
||||
"""
|
||||
End the current task
|
||||
|
||||
Args:
|
||||
success (bool): True if successfully finish a task, False otherwise
|
||||
"""
|
||||
if success:
|
||||
return "DONE"
|
||||
else:
|
||||
return "FAIL"
|
||||
|
|
@ -1,194 +0,0 @@
|
|||
import inspect
|
||||
import json
|
||||
import os
|
||||
import textwrap
|
||||
|
||||
current_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
def generate_func(json_data):
|
||||
# 收集所有类名和它们的函数
|
||||
class_funcs = {}
|
||||
no_class_funcs = []
|
||||
cls_name = ""
|
||||
|
||||
for item in json_data:
|
||||
if item["type"] == "function":
|
||||
func = item["function"]
|
||||
func_parts = func["name"].split(".")
|
||||
|
||||
if len(func_parts) == 2:
|
||||
class_name, func_name = func_parts
|
||||
if class_name not in class_funcs:
|
||||
class_funcs[class_name] = []
|
||||
class_funcs[class_name].append(item)
|
||||
else:
|
||||
no_class_funcs.append(item)
|
||||
|
||||
code = ""
|
||||
|
||||
# 生成有类的函数
|
||||
for class_name, funcs in class_funcs.items():
|
||||
code += f"class {class_name}:\n"
|
||||
cls_name = class_name
|
||||
for item in funcs:
|
||||
func = item["function"]
|
||||
func_name = func["name"].split(".")[-1]
|
||||
description = func["description"]
|
||||
params = func["parameters"]["properties"]
|
||||
required = func["parameters"].get("required", [])
|
||||
|
||||
# 构建参数列表
|
||||
param_list = ["cls"]
|
||||
# 首先添加必需参数
|
||||
for param_name in required:
|
||||
param_list.append(f"{param_name}")
|
||||
# 然后添加可选参数
|
||||
for param_name in params:
|
||||
if param_name not in required:
|
||||
param_list.append(f"{param_name}") # 可选参数默认值设为None
|
||||
|
||||
# 构建函数定义
|
||||
func_def = f" def {func_name}({', '.join(param_list)}):\n"
|
||||
|
||||
# 构建文档字符串
|
||||
docstring = f' """\n {description}\n\n Args:\n'
|
||||
if len(param_list) == 1: # 只有cls参数
|
||||
docstring += " None\n"
|
||||
else:
|
||||
# 首先记录必需参数
|
||||
for param_name in required:
|
||||
param_type = params[param_name]["type"]
|
||||
param_desc = params[param_name].get("description", "")
|
||||
docstring += f" {param_name} ({param_type}): {param_desc}\n"
|
||||
# 然后记录可选参数
|
||||
for param_name in params:
|
||||
if param_name not in required:
|
||||
param_type = params[param_name]["type"]
|
||||
param_desc = params[param_name].get("description", "")
|
||||
docstring += f" {param_name} ({param_type}, optional): {param_desc}\n"
|
||||
|
||||
docstring += ' """\n'
|
||||
|
||||
code += func_def + docstring + "\n"
|
||||
|
||||
code += "\n"
|
||||
|
||||
# 生成没有类的函数
|
||||
for item in no_class_funcs:
|
||||
func = item["function"]
|
||||
func_name = func["name"]
|
||||
description = func["description"]
|
||||
params = func["parameters"]["properties"]
|
||||
required = func["parameters"].get("required", [])
|
||||
|
||||
# 构建参数列表
|
||||
param_list = []
|
||||
# 首先添加必需参数
|
||||
for param_name in required:
|
||||
param_list.append(f"{param_name}")
|
||||
# 然后添加可选参数
|
||||
for param_name in params:
|
||||
if param_name not in required:
|
||||
param_list.append(f"{param_name}")
|
||||
|
||||
# 构建函数定义
|
||||
func_def = f"def {func_name}({', '.join(param_list)}):\n"
|
||||
|
||||
# 构建文档字符串
|
||||
docstring = f' """\n {description}\n\n Args:\n'
|
||||
if not param_list:
|
||||
docstring += " None\n"
|
||||
else:
|
||||
# 首先记录必需参数
|
||||
for param_name in required:
|
||||
param_type = params[param_name]["type"]
|
||||
param_desc = params[param_name].get("description", "")
|
||||
docstring += f" {param_name} ({param_type}): {param_desc}\n"
|
||||
# 然后记录可选参数
|
||||
for param_name in params:
|
||||
if param_name not in required:
|
||||
param_type = params[param_name]["type"]
|
||||
param_desc = params[param_name].get("description", "")
|
||||
docstring += f" {param_name} ({param_type}, optional): {param_desc}\n"
|
||||
|
||||
docstring += ' """\n'
|
||||
|
||||
code += func_def + docstring + "\n"
|
||||
|
||||
return code.strip(), cls_name
|
||||
|
||||
|
||||
setup_prompt = """You are a GUI operation agent. You will be given a task and your action history, with current observation ({observation_list}). You should help me control the computer, output the best action step by step to accomplish the task.
|
||||
You should first generate a plan, reflect on the current observation, then generate actions to complete the task in python-style pseudo code using the predefined functions.
|
||||
|
||||
* Output Format:
|
||||
{format_hint}"""
|
||||
|
||||
func_def_template = """* Available Functions:
|
||||
```python
|
||||
{class_content}
|
||||
```"""
|
||||
|
||||
note_prompt = """* Note:
|
||||
- Your code should only be wrapped in ```python```.
|
||||
- Only **ONE-LINE-OF-CODE** at a time.
|
||||
- Each code block is context independent, and variables from the previous round cannot be used in the next round.
|
||||
{relative_coordinate_hint}- Return with `Agent.exit(success=True)` immediately after the task is completed.
|
||||
- The computer's environment is Linux, e.g., Desktop path is '/home/user/Desktop'
|
||||
- My computer's password is '{client_password}', feel free to use it when you need sudo rights"""
|
||||
|
||||
|
||||
class Prompt:
|
||||
@staticmethod
|
||||
def construct_procedural_memory(agent_class, app_name=None, client_password="password", with_image=True, with_atree=False, relative_coordinate=True, glm41v_format=True):
|
||||
agent_class_content = "Class Agent:"
|
||||
for attr_name in dir(agent_class):
|
||||
attr = getattr(agent_class, attr_name)
|
||||
if callable(attr) and hasattr(attr, "is_agent_action"):
|
||||
# Use inspect to get the full function signature
|
||||
signature = inspect.signature(attr)
|
||||
agent_class_content += f"""
|
||||
def {attr_name}{signature}:
|
||||
'''{attr.__doc__}'''
|
||||
"""
|
||||
|
||||
if app_name is not None:
|
||||
tool_path = os.path.join(current_dir, "tools", "apis", f"{app_name.lower()}.json")
|
||||
with open(tool_path, "r") as f:
|
||||
json_data = json.load(f)
|
||||
|
||||
tool_class_content, tool_class_name = generate_func(json_data)
|
||||
|
||||
agent_class_content += "\n\n{}".format(tool_class_content)
|
||||
|
||||
func_def_prompt = func_def_template.format(class_content=agent_class_content.strip())
|
||||
|
||||
# --- dynamic observation list ---
|
||||
obs_items = []
|
||||
if with_image:
|
||||
obs_items.append("screenshot")
|
||||
obs_items.append("current app name")
|
||||
if with_atree:
|
||||
obs_items.append("a11y tree (based on AT-SPI library)")
|
||||
obs_items.append("app info")
|
||||
obs_items.append("last action result")
|
||||
observation_list = ", ".join(obs_items)
|
||||
|
||||
setup_prompt_formatted = setup_prompt.format(
|
||||
observation_list=observation_list,
|
||||
format_hint="<think>\n{**YOUR-PLAN-AND-THINKING**}</think>\n<answer>```python\n{**ONE-LINE-OF-CODE**}\n```</answer>" if glm41v_format else "<think>\n{**YOUR-PLAN-AND-THINKING**}\n</think>\n```python\n{**ONE-LINE-OF-CODE**}\n```"
|
||||
)
|
||||
|
||||
note_prompt_formatted = note_prompt.format(
|
||||
relative_coordinate_hint="- The coordinate [x, y] should be normalized to 0-1000, which usually should be the center of a specific target element.\n" if relative_coordinate else "",
|
||||
client_password=client_password
|
||||
)
|
||||
|
||||
return setup_prompt_formatted, func_def_prompt, note_prompt_formatted
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from grounding_agent import GroundingAgent
|
||||
|
||||
print(Prompt.construct_procedural_memory(GroundingAgent, "vlc"))
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue