open-webui/backend/open_webui/routers/functions.py

504 lines
15 KiB
Python
Raw Normal View History

2024-08-28 06:10:27 +08:00
import os
2025-05-27 03:52:22 +08:00
import re
import logging
2025-05-27 03:52:22 +08:00
import aiohttp
2024-08-28 06:10:27 +08:00
from pathlib import Path
from typing import Optional
2024-06-20 15:37:02 +08:00
2024-12-10 16:54:13 +08:00
from open_webui.models.functions import (
2024-06-20 15:37:02 +08:00
FunctionForm,
FunctionModel,
FunctionResponse,
2024-08-28 06:10:27 +08:00
Functions,
2024-06-20 15:37:02 +08:00
)
from open_webui.utils.plugin import (
load_function_module_by_id,
replace_imports,
get_function_module_from_cache,
)
from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES
2024-08-28 06:10:27 +08:00
from fastapi import APIRouter, Depends, HTTPException, Request, status
2024-12-09 08:01:56 +08:00
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.env import SRC_LOG_LEVELS
2025-05-27 03:52:22 +08:00
from pydantic import BaseModel, HttpUrl
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
2024-06-20 15:37:02 +08:00
router = APIRouter()
############################
# GetFunctions
############################
2024-08-14 20:46:31 +08:00
@router.get("/", response_model=list[FunctionResponse])
2024-06-20 15:37:02 +08:00
async def get_functions(user=Depends(get_verified_user)):
return Functions.get_functions()
############################
# ExportFunctions
############################
2024-08-14 20:46:31 +08:00
@router.get("/export", response_model=list[FunctionModel])
2024-06-20 15:37:02 +08:00
async def get_functions(user=Depends(get_admin_user)):
return Functions.get_functions()
2025-05-27 03:52:22 +08:00
############################
# LoadFunctionFromLink
############################
class LoadUrlForm(BaseModel):
url: HttpUrl
def github_url_to_raw_url(url: str) -> str:
# Handle 'tree' (folder) URLs (add main.py at the end)
m1 = re.match(r"https://github\.com/([^/]+)/([^/]+)/tree/([^/]+)/(.*)", url)
if m1:
org, repo, branch, path = m1.groups()
return f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path.rstrip('/')}/main.py"
# Handle 'blob' (file) URLs
m2 = re.match(r"https://github\.com/([^/]+)/([^/]+)/blob/([^/]+)/(.*)", url)
if m2:
org, repo, branch, path = m2.groups()
return (
f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path}"
)
# No match; return as-is
return url
@router.post("/load/url", response_model=Optional[dict])
async def load_function_from_url(
request: Request, form_data: LoadUrlForm, user=Depends(get_admin_user)
):
2025-05-27 04:56:59 +08:00
# NOTE: This is NOT a SSRF vulnerability:
# This endpoint is admin-only (see get_admin_user), meant for *trusted* internal use,
# and does NOT accept untrusted user input. Access is enforced by authentication.
2025-05-27 03:52:22 +08:00
url = str(form_data.url)
if not url:
raise HTTPException(status_code=400, detail="Please enter a valid URL")
url = github_url_to_raw_url(url)
url_parts = url.rstrip("/").split("/")
file_name = url_parts[-1]
function_name = (
file_name[:-3]
if (
file_name.endswith(".py")
and (not file_name.startswith(("main.py", "index.py", "__init__.py")))
)
else url_parts[-2] if len(url_parts) > 1 else "function"
)
try:
async with aiohttp.ClientSession() as session:
async with session.get(
url, headers={"Content-Type": "application/json"}
) as resp:
if resp.status != 200:
raise HTTPException(
status_code=resp.status, detail="Failed to fetch the function"
)
data = await resp.text()
if not data:
raise HTTPException(
status_code=400, detail="No data received from the URL"
)
return {
"name": function_name,
"content": data,
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error importing function: {e}")
2025-05-25 03:39:19 +08:00
############################
# SyncFunctions
############################
class SyncFunctionsForm(FunctionForm):
functions: list[FunctionModel] = []
@router.post("/sync", response_model=Optional[FunctionModel])
async def sync_functions(
request: Request, form_data: SyncFunctionsForm, user=Depends(get_admin_user)
):
return Functions.sync_functions(user.id, form_data.functions)
2024-06-20 15:37:02 +08:00
############################
# CreateNewFunction
############################
@router.post("/create", response_model=Optional[FunctionResponse])
async def create_new_function(
request: Request, form_data: FunctionForm, user=Depends(get_admin_user)
):
if not form_data.id.isidentifier():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Only alphanumeric characters and underscores are allowed in the id",
)
form_data.id = form_data.id.lower()
function = Functions.get_function_by_id(form_data.id)
2024-08-14 20:39:53 +08:00
if function is None:
2024-06-20 15:37:02 +08:00
try:
2024-09-05 01:57:41 +08:00
form_data.content = replace_imports(form_data.content)
2024-06-24 11:31:40 +08:00
function_module, function_type, frontmatter = load_function_module_by_id(
2024-09-05 01:55:20 +08:00
form_data.id,
content=form_data.content,
2024-06-24 11:31:40 +08:00
)
form_data.meta.manifest = frontmatter
2024-06-20 15:37:02 +08:00
FUNCTIONS = request.app.state.FUNCTIONS
FUNCTIONS[form_data.id] = function_module
2024-06-20 15:54:58 +08:00
function = Functions.insert_new_function(user.id, function_type, form_data)
2024-06-20 15:37:02 +08:00
function_cache_dir = CACHE_DIR / "functions" / form_data.id
2024-06-20 15:37:02 +08:00
function_cache_dir.mkdir(parents=True, exist_ok=True)
if function:
return function
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error creating function"),
)
except Exception as e:
log.exception(f"Failed to create a new function: {e}")
2024-06-20 15:37:02 +08:00
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.ID_TAKEN,
)
############################
# GetFunctionById
############################
@router.get("/id/{id}", response_model=Optional[FunctionModel])
async def get_function_by_id(id: str, user=Depends(get_admin_user)):
function = Functions.get_function_by_id(id)
if function:
return function
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
2024-06-24 11:31:40 +08:00
############################
# ToggleFunctionById
############################
@router.post("/id/{id}/toggle", response_model=Optional[FunctionModel])
async def toggle_function_by_id(id: str, user=Depends(get_admin_user)):
function = Functions.get_function_by_id(id)
if function:
function = Functions.update_function_by_id(
id, {"is_active": not function.is_active}
)
if function:
return function
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error updating function"),
)
else:
raise HTTPException(
2024-06-28 04:04:12 +08:00
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# ToggleGlobalById
############################
@router.post("/id/{id}/toggle/global", response_model=Optional[FunctionModel])
async def toggle_global_by_id(id: str, user=Depends(get_admin_user)):
function = Functions.get_function_by_id(id)
if function:
function = Functions.update_function_by_id(
id, {"is_global": not function.is_global}
)
if function:
return function
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error updating function"),
)
else:
raise HTTPException(
2024-06-24 11:31:40 +08:00
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# UpdateFunctionById
############################
@router.post("/id/{id}/update", response_model=Optional[FunctionModel])
async def update_function_by_id(
request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user)
):
try:
2024-09-05 01:57:41 +08:00
form_data.content = replace_imports(form_data.content)
2024-09-05 01:55:20 +08:00
function_module, function_type, frontmatter = load_function_module_by_id(
id, content=form_data.content
)
2024-06-24 11:31:40 +08:00
form_data.meta.manifest = frontmatter
FUNCTIONS = request.app.state.FUNCTIONS
FUNCTIONS[id] = function_module
updated = {**form_data.model_dump(exclude={"id"}), "type": function_type}
log.debug(updated)
2024-06-24 11:31:40 +08:00
function = Functions.update_function_by_id(id, updated)
if function:
return function
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error updating function"),
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
############################
# DeleteFunctionById
############################
@router.delete("/id/{id}/delete", response_model=bool)
async def delete_function_by_id(
request: Request, id: str, user=Depends(get_admin_user)
):
result = Functions.delete_function_by_id(id)
if result:
FUNCTIONS = request.app.state.FUNCTIONS
if id in FUNCTIONS:
del FUNCTIONS[id]
return result
2024-06-24 09:34:42 +08:00
############################
# GetFunctionValves
############################
@router.get("/id/{id}/valves", response_model=Optional[dict])
async def get_function_valves_by_id(id: str, user=Depends(get_admin_user)):
function = Functions.get_function_by_id(id)
if function:
try:
2024-06-24 10:18:13 +08:00
valves = Functions.get_function_valves_by_id(id)
return valves
2024-06-24 09:34:42 +08:00
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
2024-06-24 10:02:27 +08:00
# GetFunctionValvesSpec
############################
@router.get("/id/{id}/valves/spec", response_model=Optional[dict])
async def get_function_valves_spec_by_id(
request: Request, id: str, user=Depends(get_admin_user)
):
function = Functions.get_function_by_id(id)
if function:
2025-05-29 06:36:33 +08:00
function_module, function_type, frontmatter = get_function_module_from_cache(
request, id
)
2024-06-24 10:02:27 +08:00
if hasattr(function_module, "Valves"):
Valves = function_module.Valves
return Valves.schema()
return None
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# UpdateFunctionValves
2024-06-24 09:34:42 +08:00
############################
@router.post("/id/{id}/valves/update", response_model=Optional[dict])
2024-06-24 10:02:27 +08:00
async def update_function_valves_by_id(
request: Request, id: str, form_data: dict, user=Depends(get_admin_user)
2024-06-24 09:34:42 +08:00
):
function = Functions.get_function_by_id(id)
if function:
2025-05-29 06:36:33 +08:00
function_module, function_type, frontmatter = get_function_module_from_cache(
request, id
)
2024-06-24 10:02:27 +08:00
if hasattr(function_module, "Valves"):
Valves = function_module.Valves
try:
2024-06-24 10:05:56 +08:00
form_data = {k: v for k, v in form_data.items() if v is not None}
2024-06-24 10:02:27 +08:00
valves = Valves(**form_data)
Functions.update_function_valves_by_id(id, valves.model_dump())
return valves.model_dump()
except Exception as e:
log.exception(f"Error updating function values by id {id}: {e}")
2024-06-24 10:02:27 +08:00
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
else:
2024-06-24 09:34:42 +08:00
raise HTTPException(
2024-06-24 10:02:27 +08:00
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
2024-06-24 09:34:42 +08:00
)
2024-06-24 10:02:27 +08:00
2024-06-24 09:34:42 +08:00
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
2024-06-23 02:26:33 +08:00
############################
# FunctionUserValves
############################
@router.get("/id/{id}/valves/user", response_model=Optional[dict])
async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user)):
function = Functions.get_function_by_id(id)
if function:
try:
user_valves = Functions.get_user_valves_by_id_and_user_id(id, user.id)
return user_valves
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
@router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
async def get_function_user_valves_spec_by_id(
request: Request, id: str, user=Depends(get_verified_user)
):
2024-06-23 03:08:32 +08:00
function = Functions.get_function_by_id(id)
2024-06-23 02:26:33 +08:00
if function:
2025-05-29 06:36:33 +08:00
function_module, function_type, frontmatter = get_function_module_from_cache(
request, id
)
2024-06-23 02:26:33 +08:00
if hasattr(function_module, "UserValves"):
UserValves = function_module.UserValves
return UserValves.schema()
return None
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
@router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
async def update_function_user_valves_by_id(
request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
):
2024-06-23 03:08:32 +08:00
function = Functions.get_function_by_id(id)
2024-06-23 02:26:33 +08:00
if function:
2025-05-29 06:36:33 +08:00
function_module, function_type, frontmatter = get_function_module_from_cache(
request, id
)
2024-06-23 02:26:33 +08:00
if hasattr(function_module, "UserValves"):
UserValves = function_module.UserValves
try:
2024-06-24 10:05:56 +08:00
form_data = {k: v for k, v in form_data.items() if v is not None}
2024-06-23 02:26:33 +08:00
user_valves = UserValves(**form_data)
Functions.update_user_valves_by_id_and_user_id(
id, user.id, user_valves.model_dump()
)
return user_valves.model_dump()
except Exception as e:
log.exception(f"Error updating function user valves by id {id}: {e}")
2024-06-23 02:26:33 +08:00
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)