149 lines
		
	
	
		
			3.8 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			149 lines
		
	
	
		
			3.8 KiB
		
	
	
	
		
			Python
		
	
	
	
| import os
 | |
| import requests
 | |
| from fastapi import (
 | |
|     FastAPI,
 | |
|     Request,
 | |
|     Depends,
 | |
|     HTTPException,
 | |
|     status,
 | |
|     UploadFile,
 | |
|     File,
 | |
|     Form,
 | |
| )
 | |
| from fastapi.middleware.cors import CORSMiddleware
 | |
| from faster_whisper import WhisperModel
 | |
| 
 | |
| from constants import ERROR_MESSAGES
 | |
| from utils.utils import (
 | |
|     get_current_user,
 | |
|     get_admin_user,
 | |
| )
 | |
| from utils.misc import calculate_sha256
 | |
| from typing import Optional
 | |
| from pydantic import BaseModel
 | |
| from config import AUTOMATIC1111_BASE_URL
 | |
| 
 | |
| app = FastAPI()
 | |
| app.add_middleware(
 | |
|     CORSMiddleware,
 | |
|     allow_origins=["*"],
 | |
|     allow_credentials=True,
 | |
|     allow_methods=["*"],
 | |
|     allow_headers=["*"],
 | |
| )
 | |
| 
 | |
| app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
 | |
| app.state.ENABLED = False
 | |
| 
 | |
| 
 | |
| @app.get("/enabled", response_model=bool)
 | |
| async def get_enable_status(request: Request, user=Depends(get_admin_user)):
 | |
|     return app.state.ENABLED
 | |
| 
 | |
| 
 | |
| @app.get("/enabled/toggle", response_model=bool)
 | |
| async def toggle_enabled(request: Request, user=Depends(get_admin_user)):
 | |
|     app.state.ENABLED = not app.state.ENABLED
 | |
|     return app.state.ENABLED
 | |
| 
 | |
| 
 | |
| class UrlUpdateForm(BaseModel):
 | |
|     url: str
 | |
| 
 | |
| 
 | |
| @app.get("/url")
 | |
| async def get_openai_url(user=Depends(get_admin_user)):
 | |
|     return {"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL}
 | |
| 
 | |
| 
 | |
| @app.post("/url/update")
 | |
| async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
 | |
|     try:
 | |
|         r = requests.head(form_data.url)
 | |
|         if r.ok:
 | |
|             app.state.AUTOMATIC1111_BASE_URL = form_data.url.strip("/")
 | |
|         return {
 | |
|             "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
 | |
|             "status": True,
 | |
|         }
 | |
|     except Exception as e:
 | |
|         raise HTTPException(status_code=r.status_code, detail=ERROR_MESSAGES.DEFAULT(e))
 | |
| 
 | |
| 
 | |
| @app.get("/models")
 | |
| def get_models(user=Depends(get_current_user)):
 | |
|     try:
 | |
|         r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models")
 | |
|         models = r.json()
 | |
|         return models
 | |
|     except Exception as e:
 | |
|         raise HTTPException(status_code=r.status_code, detail=ERROR_MESSAGES.DEFAULT(e))
 | |
| 
 | |
| 
 | |
| @app.get("/models/default")
 | |
| async def get_default_model(user=Depends(get_admin_user)):
 | |
|     try:
 | |
|         r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
 | |
|         options = r.json()
 | |
| 
 | |
|         return {"model": options["sd_model_checkpoint"]}
 | |
|     except Exception as e:
 | |
|         raise HTTPException(status_code=r.status_code, detail=ERROR_MESSAGES.DEFAULT(e))
 | |
| 
 | |
| 
 | |
| class UpdateModelForm(BaseModel):
 | |
|     model: str
 | |
| 
 | |
| 
 | |
| def set_model_handler(model: str):
 | |
|     r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
 | |
|     options = r.json()
 | |
| 
 | |
|     if model != options["sd_model_checkpoint"]:
 | |
|         options["sd_model_checkpoint"] = model
 | |
|         r = requests.post(
 | |
|             url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options
 | |
|         )
 | |
| 
 | |
|     return options
 | |
| 
 | |
| 
 | |
| @app.post("/models/default/update")
 | |
| def update_default_model(
 | |
|     form_data: UpdateModelForm,
 | |
|     user=Depends(get_current_user),
 | |
| ):
 | |
|     return set_model_handler(form_data.model)
 | |
| 
 | |
| 
 | |
| class GenerateImageForm(BaseModel):
 | |
|     model: Optional[str] = None
 | |
|     prompt: str
 | |
|     n: int = 1
 | |
|     size: str = "512x512"
 | |
|     negative_prompt: Optional[str] = None
 | |
| 
 | |
| 
 | |
| @app.post("/generations")
 | |
| def generate_image(
 | |
|     form_data: GenerateImageForm,
 | |
|     user=Depends(get_current_user),
 | |
| ):
 | |
|     if form_data.model:
 | |
|         set_model_handler(form_data.model)
 | |
| 
 | |
|     width, height = tuple(map(int, form_data.size.split("x")))
 | |
| 
 | |
|     r = requests.get(
 | |
|         url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
 | |
|         json={
 | |
|             "prompt": form_data.prompt,
 | |
|             "negative_prompt": form_data.negative_prompt,
 | |
|             "batch_size": form_data.n,
 | |
|             "width": width,
 | |
|             "height": height,
 | |
|         },
 | |
|     )
 | |
| 
 | |
|     return r.json()
 |