This commit is contained in:
Timothy Jaeryang Baek 2025-08-22 17:19:57 +04:00
parent 72b25ab78b
commit 37a3de0703
2 changed files with 37 additions and 24 deletions

View File

@ -144,6 +144,17 @@ def upload_file(
metadata: Optional[dict | str] = Form(None), metadata: Optional[dict | str] = Form(None),
process: bool = Query(True), process: bool = Query(True),
user=Depends(get_verified_user), user=Depends(get_verified_user),
):
return upload_file_handler(request, file, metadata, process, user, background_tasks)
def upload_file_handler(
request: Request,
file: UploadFile = File(...),
metadata: Optional[dict | str] = Form(None),
process: bool = Query(True),
user=Depends(get_verified_user),
background_tasks: Optional[BackgroundTasks] = None,
): ):
log.info(f"file.content_type: {file.content_type}") log.info(f"file.content_type: {file.content_type}")
@ -214,16 +225,27 @@ def upload_file(
) )
if process: if process:
background_tasks.add_task( if background_tasks:
process_uploaded_file, background_tasks.add_task(
request, process_uploaded_file,
file, request,
file_path, file,
file_item, file_path,
file_metadata, file_item,
user, file_metadata,
) user,
return {"status": True, **file_item.model_dump()} )
return {"status": True, **file_item.model_dump()}
else:
process_uploaded_file(
request,
file,
file_path,
file_item,
file_metadata,
user,
)
return {"status": True, **file_item.model_dump()}
else: else:
if file_item: if file_item:
return file_item return file_item

View File

@ -16,13 +16,12 @@ from fastapi import (
HTTPException, HTTPException,
Request, Request,
UploadFile, UploadFile,
BackgroundTasks,
) )
from open_webui.config import CACHE_DIR from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS
from open_webui.routers.files import upload_file from open_webui.routers.files import upload_file_handler
from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.images.comfyui import ( from open_webui.utils.images.comfyui import (
ComfyUIGenerateImageForm, ComfyUIGenerateImageForm,
@ -468,7 +467,7 @@ def load_url_image_data(url, headers=None):
return None return None
def upload_image(request, background_tasks, image_data, content_type, metadata, user): def upload_image(request, image_data, content_type, metadata, user):
image_format = mimetypes.guess_extension(content_type) image_format = mimetypes.guess_extension(content_type)
file = UploadFile( file = UploadFile(
file=io.BytesIO(image_data), file=io.BytesIO(image_data),
@ -477,9 +476,8 @@ def upload_image(request, background_tasks, image_data, content_type, metadata,
"content-type": content_type, "content-type": content_type,
}, },
) )
file_item = upload_file( file_item = upload_file_handler(
request, request,
background_tasks,
file=file, file=file,
metadata=metadata, metadata=metadata,
process=False, process=False,
@ -492,7 +490,6 @@ def upload_image(request, background_tasks, image_data, content_type, metadata,
@router.post("/generations") @router.post("/generations")
async def image_generations( async def image_generations(
request: Request, request: Request,
background_tasks: BackgroundTasks,
form_data: GenerateImageForm, form_data: GenerateImageForm,
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
@ -566,9 +563,7 @@ async def image_generations(
else: else:
image_data, content_type = load_b64_image_data(image["b64_json"]) image_data, content_type = load_b64_image_data(image["b64_json"])
url = upload_image( url = upload_image(request, image_data, content_type, data, user)
request, background_tasks, image_data, content_type, data, user
)
images.append({"url": url}) images.append({"url": url})
return images return images
@ -602,9 +597,7 @@ async def image_generations(
image_data, content_type = load_b64_image_data( image_data, content_type = load_b64_image_data(
image["bytesBase64Encoded"] image["bytesBase64Encoded"]
) )
url = upload_image( url = upload_image(request, image_data, content_type, data, user)
request, background_tasks, image_data, content_type, data, user
)
images.append({"url": url}) images.append({"url": url})
return images return images
@ -655,7 +648,6 @@ async def image_generations(
image_data, content_type = load_url_image_data(image["url"], headers) image_data, content_type = load_url_image_data(image["url"], headers)
url = upload_image( url = upload_image(
request, request,
background_tasks,
image_data, image_data,
content_type, content_type,
form_data.model_dump(exclude_none=True), form_data.model_dump(exclude_none=True),
@ -709,7 +701,6 @@ async def image_generations(
image_data, content_type = load_b64_image_data(image) image_data, content_type = load_b64_image_data(image)
url = upload_image( url = upload_image(
request, request,
background_tasks,
image_data, image_data,
content_type, content_type,
{**data, "info": res["info"]}, {**data, "info": res["info"]},