fix: image generation

This commit is contained in:
Timothy Jaeryang Baek 2025-08-22 16:58:25 +04:00
parent fbff4e19de
commit 72b25ab78b
1 changed files with 25 additions and 5 deletions

View File

@ -10,7 +10,15 @@ from typing import Optional
from urllib.parse import quote from urllib.parse import quote
import requests import requests
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile from fastapi import (
APIRouter,
Depends,
HTTPException,
Request,
UploadFile,
BackgroundTasks,
)
from open_webui.config import CACHE_DIR from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS
@ -460,7 +468,7 @@ def load_url_image_data(url, headers=None):
return None return None
def upload_image(request, image_data, content_type, metadata, user): def upload_image(request, background_tasks, image_data, content_type, metadata, user):
image_format = mimetypes.guess_extension(content_type) image_format = mimetypes.guess_extension(content_type)
file = UploadFile( file = UploadFile(
file=io.BytesIO(image_data), file=io.BytesIO(image_data),
@ -470,7 +478,12 @@ def upload_image(request, image_data, content_type, metadata, user):
}, },
) )
file_item = upload_file( file_item = upload_file(
request, file=file, metadata=metadata, process=False, user=user request,
background_tasks,
file=file,
metadata=metadata,
process=False,
user=user,
) )
url = request.app.url_path_for("get_file_content_by_id", id=file_item.id) url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
return url return url
@ -479,6 +492,7 @@ def upload_image(request, image_data, content_type, metadata, user):
@router.post("/generations") @router.post("/generations")
async def image_generations( async def image_generations(
request: Request, request: Request,
background_tasks: BackgroundTasks,
form_data: GenerateImageForm, form_data: GenerateImageForm,
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
@ -552,7 +566,9 @@ async def image_generations(
else: else:
image_data, content_type = load_b64_image_data(image["b64_json"]) image_data, content_type = load_b64_image_data(image["b64_json"])
url = upload_image(request, image_data, content_type, data, user) url = upload_image(
request, background_tasks, image_data, content_type, data, user
)
images.append({"url": url}) images.append({"url": url})
return images return images
@ -586,7 +602,9 @@ async def image_generations(
image_data, content_type = load_b64_image_data( image_data, content_type = load_b64_image_data(
image["bytesBase64Encoded"] image["bytesBase64Encoded"]
) )
url = upload_image(request, image_data, content_type, data, user) url = upload_image(
request, background_tasks, image_data, content_type, data, user
)
images.append({"url": url}) images.append({"url": url})
return images return images
@ -637,6 +655,7 @@ async def image_generations(
image_data, content_type = load_url_image_data(image["url"], headers) image_data, content_type = load_url_image_data(image["url"], headers)
url = upload_image( url = upload_image(
request, request,
background_tasks,
image_data, image_data,
content_type, content_type,
form_data.model_dump(exclude_none=True), form_data.model_dump(exclude_none=True),
@ -690,6 +709,7 @@ async def image_generations(
image_data, content_type = load_b64_image_data(image) image_data, content_type = load_b64_image_data(image)
url = upload_image( url = upload_image(
request, request,
background_tasks,
image_data, image_data,
content_type, content_type,
{**data, "info": res["info"]}, {**data, "info": res["info"]},