open-webui/backend/open_webui/routers/models.py

324 lines
9.2 KiB
Python

from typing import Optional
import io
import base64
import json
import asyncio
import logging
from open_webui.models.models import (
ModelForm,
ModelModel,
ModelResponse,
ModelUserResponse,
Models,
)
from pydantic import BaseModel
from open_webui.constants import ERROR_MESSAGES
from fastapi import (
APIRouter,
Depends,
HTTPException,
Request,
status,
Response,
)
from fastapi.responses import FileResponse, StreamingResponse
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_access, has_permission
from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL, STATIC_DIR
log = logging.getLogger(__name__)
router = APIRouter()
###########################
# GetModels
###########################
@router.get("/", response_model=list[ModelUserResponse])
async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)):
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
return Models.get_models()
else:
return Models.get_models_by_user_id(user.id)
###########################
# GetBaseModels
###########################
@router.get("/base", response_model=list[ModelResponse])
async def get_base_models(user=Depends(get_admin_user)):
return Models.get_base_models()
############################
# CreateNewModel
############################
@router.post("/create", response_model=Optional[ModelModel])
async def create_new_model(
request: Request,
form_data: ModelForm,
user=Depends(get_verified_user),
):
if user.role != "admin" and not has_permission(
user.id, "workspace.models", request.app.state.config.USER_PERMISSIONS
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
model = Models.get_model_by_id(form_data.id)
if model:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.MODEL_ID_TAKEN,
)
else:
model = Models.insert_new_model(form_data, user.id)
if model:
return model
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.DEFAULT(),
)
############################
# ExportModels
############################
@router.get("/export", response_model=list[ModelModel])
async def export_models(user=Depends(get_admin_user)):
return Models.get_models()
############################
# ImportModels
############################
class ModelsImportForm(BaseModel):
models: list[dict]
@router.post("/import", response_model=bool)
async def import_models(
user: str = Depends(get_admin_user), form_data: ModelsImportForm = (...)
):
try:
data = form_data.models
if isinstance(data, list):
for model_data in data:
# Here, you can add logic to validate model_data if needed
model_id = model_data.get("id")
if model_id:
existing_model = Models.get_model_by_id(model_id)
if existing_model:
# Update existing model
model_data["meta"] = model_data.get("meta", {})
model_data["params"] = model_data.get("params", {})
updated_model = ModelForm(
**{**existing_model.model_dump(), **model_data}
)
Models.update_model_by_id(model_id, updated_model)
else:
# Insert new model
model_data["meta"] = model_data.get("meta", {})
model_data["params"] = model_data.get("params", {})
new_model = ModelForm(**model_data)
Models.insert_new_model(user_id=user.id, form_data=new_model)
return True
else:
raise HTTPException(status_code=400, detail="Invalid JSON format")
except Exception as e:
log.exception(e)
raise HTTPException(status_code=500, detail=str(e))
############################
# SyncModels
############################
class SyncModelsForm(BaseModel):
models: list[ModelModel] = []
@router.post("/sync", response_model=list[ModelModel])
async def sync_models(
request: Request, form_data: SyncModelsForm, user=Depends(get_admin_user)
):
return Models.sync_models(user.id, form_data.models)
###########################
# GetModelById
###########################
# Note: We're not using the typical url path param here, but instead using a query parameter to allow '/' in the id
@router.get("/model", response_model=Optional[ModelResponse])
async def get_model_by_id(id: str, user=Depends(get_verified_user)):
model = Models.get_model_by_id(id)
if model:
if (
(user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL)
or model.user_id == user.id
or has_access(user.id, "read", model.access_control)
):
return model
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
###########################
# GetModelById
###########################
@router.get("/model/profile/image")
async def get_model_profile_image(id: str, user=Depends(get_verified_user)):
model = Models.get_model_by_id(id)
if model:
if model.meta.profile_image_url:
if model.meta.profile_image_url.startswith("http"):
return Response(
status_code=status.HTTP_302_FOUND,
headers={"Location": model.meta.profile_image_url},
)
elif model.meta.profile_image_url.startswith("data:image"):
try:
header, base64_data = model.meta.profile_image_url.split(",", 1)
image_data = base64.b64decode(base64_data)
image_buffer = io.BytesIO(image_data)
return StreamingResponse(
image_buffer,
media_type="image/png",
headers={"Content-Disposition": "inline; filename=image.png"},
)
except Exception as e:
pass
return FileResponse(f"{STATIC_DIR}/favicon.png")
else:
return FileResponse(f"{STATIC_DIR}/favicon.png")
############################
# ToggleModelById
############################
@router.post("/model/toggle", response_model=Optional[ModelResponse])
async def toggle_model_by_id(id: str, user=Depends(get_verified_user)):
model = Models.get_model_by_id(id)
if model:
if (
user.role == "admin"
or model.user_id == user.id
or has_access(user.id, "write", model.access_control)
):
model = Models.toggle_model_by_id(id)
if model:
return model
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error updating function"),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# UpdateModelById
############################
@router.post("/model/update", response_model=Optional[ModelModel])
async def update_model_by_id(
id: str,
form_data: ModelForm,
user=Depends(get_verified_user),
):
model = Models.get_model_by_id(id)
if not model:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if (
model.user_id != user.id
and not has_access(user.id, "write", model.access_control)
and user.role != "admin"
):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
model = Models.update_model_by_id(id, form_data)
return model
############################
# DeleteModelById
############################
@router.delete("/model/delete", response_model=bool)
async def delete_model_by_id(id: str, user=Depends(get_verified_user)):
model = Models.get_model_by_id(id)
if not model:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if (
user.role != "admin"
and model.user_id != user.id
and not has_access(user.id, "write", model.access_control)
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
result = Models.delete_model_by_id(id)
return result
@router.delete("/delete/all", response_model=bool)
async def delete_all_models(user=Depends(get_admin_user)):
result = Models.delete_all_models()
return result