| 
									
										
										
										
											2024-03-22 14:45:00 +08:00
										 |  |  | from fastapi import ( | 
					
						
							|  |  |  |     FastAPI, | 
					
						
							|  |  |  |     Request, | 
					
						
							|  |  |  |     Response, | 
					
						
							|  |  |  |     HTTPException, | 
					
						
							|  |  |  |     Depends, | 
					
						
							|  |  |  |     status, | 
					
						
							|  |  |  |     UploadFile, | 
					
						
							|  |  |  |     File, | 
					
						
							|  |  |  |     BackgroundTasks, | 
					
						
							|  |  |  | ) | 
					
						
							| 
									
										
										
										
											2024-01-05 05:06:31 +08:00
										 |  |  | from fastapi.middleware.cors import CORSMiddleware | 
					
						
							|  |  |  | from fastapi.responses import StreamingResponse | 
					
						
							| 
									
										
										
										
											2024-01-06 09:16:35 +08:00
										 |  |  | from fastapi.concurrency import run_in_threadpool | 
					
						
							| 
									
										
										
										
											2023-11-15 08:28:51 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-05 17:41:22 +08:00
										 |  |  | from pydantic import BaseModel, ConfigDict | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-22 14:45:00 +08:00
										 |  |  | import os | 
					
						
							| 
									
										
										
										
											2024-03-22 15:55:59 +08:00
										 |  |  | import copy | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  | import random | 
					
						
							| 
									
										
										
										
											2023-11-15 08:28:51 +08:00
										 |  |  | import requests | 
					
						
							|  |  |  | import json | 
					
						
							| 
									
										
										
										
											2024-01-18 11:19:44 +08:00
										 |  |  | import uuid | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  | import aiohttp | 
					
						
							|  |  |  | import asyncio | 
					
						
							| 
									
										
										
										
											2024-03-21 07:11:36 +08:00
										 |  |  | import logging | 
					
						
							| 
									
										
										
										
											2024-03-22 14:45:00 +08:00
										 |  |  | from urllib.parse import urlparse | 
					
						
							|  |  |  | from typing import Optional, List, Union | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-15 08:28:51 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-19 08:47:12 +08:00
										 |  |  | from apps.web.models.users import Users | 
					
						
							|  |  |  | from constants import ERROR_MESSAGES | 
					
						
							| 
									
										
										
										
											2024-02-09 08:05:01 +08:00
										 |  |  | from utils.utils import decode_token, get_current_user, get_admin_user | 
					
						
							| 
									
										
										
										
											2023-11-15 08:28:51 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-22 14:45:00 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-31 16:13:39 +08:00
										 |  |  | from config import ( | 
					
						
							|  |  |  |     SRC_LOG_LEVELS, | 
					
						
							|  |  |  |     OLLAMA_BASE_URLS, | 
					
						
							|  |  |  |     MODEL_FILTER_ENABLED, | 
					
						
							|  |  |  |     MODEL_FILTER_LIST, | 
					
						
							|  |  |  |     UPLOAD_DIR, | 
					
						
							|  |  |  | ) | 
					
						
							| 
									
										
										
										
											2024-03-25 07:04:03 +08:00
										 |  |  | from utils.misc import calculate_sha256 | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-21 07:11:36 +08:00
										 |  |  | log = logging.getLogger(__name__) | 
					
						
							|  |  |  | log.setLevel(SRC_LOG_LEVELS["OLLAMA"]) | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-05 05:06:31 +08:00
										 |  |  | app = FastAPI() | 
					
						
							|  |  |  | app.add_middleware( | 
					
						
							|  |  |  |     CORSMiddleware, | 
					
						
							|  |  |  |     allow_origins=["*"], | 
					
						
							|  |  |  |     allow_credentials=True, | 
					
						
							|  |  |  |     allow_methods=["*"], | 
					
						
							|  |  |  |     allow_headers=["*"], | 
					
						
							|  |  |  | ) | 
					
						
							| 
									
										
										
										
											2023-11-15 08:28:51 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-10 13:19:20 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-10 13:47:01 +08:00
										 |  |  | app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED | 
					
						
							|  |  |  | app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST | 
					
						
							| 
									
										
										
										
											2024-03-10 13:19:20 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-07 03:44:00 +08:00
										 |  |  | app.state.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  | app.state.MODELS = {} | 
					
						
							| 
									
										
										
										
											2023-11-15 08:28:51 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-18 11:19:44 +08:00
										 |  |  | REQUEST_POOL = [] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-05 17:21:50 +08:00
										 |  |  | # TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances. | 
					
						
							|  |  |  | # Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin, | 
					
						
							|  |  |  | # least connections, or least response time for better resource utilization and performance optimization. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  | @app.middleware("http") | 
					
						
							|  |  |  | async def check_url(request: Request, call_next): | 
					
						
							|  |  |  |     if len(app.state.MODELS) == 0: | 
					
						
							|  |  |  |         await get_all_models() | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     response = await call_next(request) | 
					
						
							|  |  |  |     return response | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-02 17:03:13 +08:00
										 |  |  | @app.head("/") | 
					
						
							|  |  |  | @app.get("/") | 
					
						
							|  |  |  | async def get_status(): | 
					
						
							|  |  |  |     return {"status": True} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  | @app.get("/urls") | 
					
						
							|  |  |  | async def get_ollama_api_urls(user=Depends(get_admin_user)): | 
					
						
							|  |  |  |     return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS} | 
					
						
							| 
									
										
										
										
											2024-01-05 05:06:31 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-15 08:28:51 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-05 05:06:31 +08:00
										 |  |  | class UrlUpdateForm(BaseModel): | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |     urls: List[str] | 
					
						
							| 
									
										
										
										
											2024-01-05 05:06:31 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  | @app.post("/urls/update") | 
					
						
							| 
									
										
										
										
											2024-02-17 15:30:38 +08:00
										 |  |  | async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)): | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |     app.state.OLLAMA_BASE_URLS = form_data.urls | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-21 07:11:36 +08:00
										 |  |  |     log.info(f"app.state.OLLAMA_BASE_URLS: {app.state.OLLAMA_BASE_URLS}") | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |     return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS} | 
					
						
							| 
									
										
										
										
											2024-01-05 17:25:34 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-18 11:19:44 +08:00
										 |  |  | @app.get("/cancel/{request_id}") | 
					
						
							|  |  |  | async def cancel_ollama_request(request_id: str, user=Depends(get_current_user)): | 
					
						
							|  |  |  |     if user: | 
					
						
							|  |  |  |         if request_id in REQUEST_POOL: | 
					
						
							|  |  |  |             REQUEST_POOL.remove(request_id) | 
					
						
							|  |  |  |         return True | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  | async def fetch_url(url): | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         async with aiohttp.ClientSession() as session: | 
					
						
							|  |  |  |             async with session.get(url) as response: | 
					
						
							|  |  |  |                 return await response.json() | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         # Handle connection error here | 
					
						
							| 
									
										
										
										
											2024-03-21 07:11:36 +08:00
										 |  |  |         log.error(f"Connection error: {e}") | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |         return None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def merge_models_lists(model_lists): | 
					
						
							|  |  |  |     merged_models = {} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     for idx, model_list in enumerate(model_lists): | 
					
						
							| 
									
										
										
										
											2024-03-18 16:11:48 +08:00
										 |  |  |         if model_list is not None: | 
					
						
							|  |  |  |             for model in model_list: | 
					
						
							|  |  |  |                 digest = model["digest"] | 
					
						
							|  |  |  |                 if digest not in merged_models: | 
					
						
							|  |  |  |                     model["urls"] = [idx] | 
					
						
							|  |  |  |                     merged_models[digest] = model | 
					
						
							|  |  |  |                 else: | 
					
						
							|  |  |  |                     merged_models[digest]["urls"].append(idx) | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     return list(merged_models.values()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # user=Depends(get_current_user) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | async def get_all_models(): | 
					
						
							| 
									
										
										
										
											2024-03-21 07:11:36 +08:00
										 |  |  |     log.info("get_all_models()") | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |     tasks = [fetch_url(f"{url}/api/tags") for url in app.state.OLLAMA_BASE_URLS] | 
					
						
							|  |  |  |     responses = await asyncio.gather(*tasks) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     models = { | 
					
						
							|  |  |  |         "models": merge_models_lists( | 
					
						
							| 
									
										
										
										
											2024-03-18 16:11:48 +08:00
										 |  |  |             map(lambda response: response["models"] if response else None, responses) | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2024-03-12 15:26:14 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |     app.state.MODELS = {model["model"]: model for model in models["models"]} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return models | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @app.get("/api/tags") | 
					
						
							|  |  |  | @app.get("/api/tags/{url_idx}") | 
					
						
							|  |  |  | async def get_ollama_tags( | 
					
						
							|  |  |  |     url_idx: Optional[int] = None, user=Depends(get_current_user) | 
					
						
							|  |  |  | ): | 
					
						
							|  |  |  |     if url_idx == None: | 
					
						
							| 
									
										
										
										
											2024-03-10 13:19:20 +08:00
										 |  |  |         models = await get_all_models() | 
					
						
							| 
									
										
										
										
											2024-03-10 13:29:04 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-10 13:19:20 +08:00
										 |  |  |         if app.state.MODEL_FILTER_ENABLED: | 
					
						
							|  |  |  |             if user.role == "user": | 
					
						
							| 
									
										
										
										
											2024-03-10 13:29:04 +08:00
										 |  |  |                 models["models"] = list( | 
					
						
							|  |  |  |                     filter( | 
					
						
							| 
									
										
										
										
											2024-03-10 13:47:01 +08:00
										 |  |  |                         lambda model: model["name"] in app.state.MODEL_FILTER_LIST, | 
					
						
							| 
									
										
										
										
											2024-03-10 13:29:04 +08:00
										 |  |  |                         models["models"], | 
					
						
							|  |  |  |                     ) | 
					
						
							| 
									
										
										
										
											2024-03-10 13:19:20 +08:00
										 |  |  |                 ) | 
					
						
							|  |  |  |                 return models | 
					
						
							|  |  |  |         return models | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |     else: | 
					
						
							|  |  |  |         url = app.state.OLLAMA_BASE_URLS[url_idx] | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             r = requests.request(method="GET", url=f"{url}/api/tags") | 
					
						
							|  |  |  |             r.raise_for_status() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             return r.json() | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							| 
									
										
										
										
											2024-03-21 07:11:36 +08:00
										 |  |  |             log.exception(e) | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |             error_detail = "Open WebUI: Server Connection Error" | 
					
						
							|  |  |  |             if r is not None: | 
					
						
							|  |  |  |                 try: | 
					
						
							|  |  |  |                     res = r.json() | 
					
						
							|  |  |  |                     if "error" in res: | 
					
						
							|  |  |  |                         error_detail = f"Ollama: {res['error']}" | 
					
						
							|  |  |  |                 except: | 
					
						
							|  |  |  |                     error_detail = f"Ollama: {e}" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             raise HTTPException( | 
					
						
							|  |  |  |                 status_code=r.status_code if r else 500, | 
					
						
							|  |  |  |                 detail=error_detail, | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @app.get("/api/version") | 
					
						
							|  |  |  | @app.get("/api/version/{url_idx}") | 
					
						
							|  |  |  | async def get_ollama_versions(url_idx: Optional[int] = None): | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if url_idx == None: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # returns lowest version | 
					
						
							|  |  |  |         tasks = [fetch_url(f"{url}/api/version") for url in app.state.OLLAMA_BASE_URLS] | 
					
						
							|  |  |  |         responses = await asyncio.gather(*tasks) | 
					
						
							|  |  |  |         responses = list(filter(lambda x: x is not None, responses)) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-12 15:26:14 +08:00
										 |  |  |         if len(responses) > 0: | 
					
						
							|  |  |  |             lowest_version = min( | 
					
						
							| 
									
										
										
										
											2024-04-07 09:55:51 +08:00
										 |  |  |                 responses, | 
					
						
							|  |  |  |                 key=lambda x: tuple(map(int, x["version"].split("-")[0].split("."))), | 
					
						
							| 
									
										
										
										
											2024-03-12 15:26:14 +08:00
										 |  |  |             ) | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-12 15:26:14 +08:00
										 |  |  |             return {"version": lowest_version["version"]} | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             raise HTTPException( | 
					
						
							|  |  |  |                 status_code=500, | 
					
						
							|  |  |  |                 detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND, | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |     else: | 
					
						
							|  |  |  |         url = app.state.OLLAMA_BASE_URLS[url_idx] | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             r = requests.request(method="GET", url=f"{url}/api/version") | 
					
						
							|  |  |  |             r.raise_for_status() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             return r.json() | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							| 
									
										
										
										
											2024-03-21 07:11:36 +08:00
										 |  |  |             log.exception(e) | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |             error_detail = "Open WebUI: Server Connection Error" | 
					
						
							|  |  |  |             if r is not None: | 
					
						
							|  |  |  |                 try: | 
					
						
							|  |  |  |                     res = r.json() | 
					
						
							|  |  |  |                     if "error" in res: | 
					
						
							|  |  |  |                         error_detail = f"Ollama: {res['error']}" | 
					
						
							|  |  |  |                 except: | 
					
						
							|  |  |  |                     error_detail = f"Ollama: {e}" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             raise HTTPException( | 
					
						
							|  |  |  |                 status_code=r.status_code if r else 500, | 
					
						
							|  |  |  |                 detail=error_detail, | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class ModelNameForm(BaseModel): | 
					
						
							|  |  |  |     name: str | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @app.post("/api/pull") | 
					
						
							|  |  |  | @app.post("/api/pull/{url_idx}") | 
					
						
							|  |  |  | async def pull_model( | 
					
						
							|  |  |  |     form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user) | 
					
						
							|  |  |  | ): | 
					
						
							|  |  |  |     url = app.state.OLLAMA_BASE_URLS[url_idx] | 
					
						
							| 
									
										
										
										
											2024-03-21 07:11:36 +08:00
										 |  |  |     log.info(f"url: {url}") | 
					
						
							| 
									
										
										
										
											2024-03-05 18:12:55 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |     r = None | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-05 18:12:55 +08:00
										 |  |  |     def get_request(): | 
					
						
							|  |  |  |         nonlocal url | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |         nonlocal r | 
					
						
							| 
									
										
										
										
											2024-03-24 04:12:54 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         request_id = str(uuid.uuid4()) | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |         try: | 
					
						
							| 
									
										
										
										
											2024-03-24 04:12:54 +08:00
										 |  |  |             REQUEST_POOL.append(request_id) | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             def stream_content(): | 
					
						
							| 
									
										
										
										
											2024-03-24 04:12:54 +08:00
										 |  |  |                 try: | 
					
						
							|  |  |  |                     yield json.dumps({"id": request_id, "done": False}) + "\n" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     for chunk in r.iter_content(chunk_size=8192): | 
					
						
							|  |  |  |                         if request_id in REQUEST_POOL: | 
					
						
							|  |  |  |                             yield chunk | 
					
						
							|  |  |  |                         else: | 
					
						
							| 
									
										
										
										
											2024-04-01 03:17:29 +08:00
										 |  |  |                             log.warning("User: canceled request") | 
					
						
							| 
									
										
										
										
											2024-03-24 04:12:54 +08:00
										 |  |  |                             break | 
					
						
							|  |  |  |                 finally: | 
					
						
							|  |  |  |                     if hasattr(r, "close"): | 
					
						
							|  |  |  |                         r.close() | 
					
						
							|  |  |  |                         if request_id in REQUEST_POOL: | 
					
						
							|  |  |  |                             REQUEST_POOL.remove(request_id) | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             r = requests.request( | 
					
						
							|  |  |  |                 method="POST", | 
					
						
							|  |  |  |                 url=f"{url}/api/pull", | 
					
						
							| 
									
										
										
										
											2024-03-07 10:37:40 +08:00
										 |  |  |                 data=form_data.model_dump_json(exclude_none=True).encode(), | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |                 stream=True, | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             r.raise_for_status() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             return StreamingResponse( | 
					
						
							|  |  |  |                 stream_content(), | 
					
						
							|  |  |  |                 status_code=r.status_code, | 
					
						
							|  |  |  |                 headers=dict(r.headers), | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             raise e | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     try: | 
					
						
							| 
									
										
										
										
											2024-03-05 18:12:55 +08:00
										 |  |  |         return await run_in_threadpool(get_request) | 
					
						
							| 
									
										
										
										
											2024-03-24 04:12:54 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |     except Exception as e: | 
					
						
							| 
									
										
										
										
											2024-03-21 07:11:36 +08:00
										 |  |  |         log.exception(e) | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |         error_detail = "Open WebUI: Server Connection Error" | 
					
						
							|  |  |  |         if r is not None: | 
					
						
							|  |  |  |             try: | 
					
						
							|  |  |  |                 res = r.json() | 
					
						
							|  |  |  |                 if "error" in res: | 
					
						
							|  |  |  |                     error_detail = f"Ollama: {res['error']}" | 
					
						
							|  |  |  |             except: | 
					
						
							|  |  |  |                 error_detail = f"Ollama: {e}" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         raise HTTPException( | 
					
						
							|  |  |  |             status_code=r.status_code if r else 500, | 
					
						
							|  |  |  |             detail=error_detail, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class PushModelForm(BaseModel): | 
					
						
							|  |  |  |     name: str | 
					
						
							|  |  |  |     insecure: Optional[bool] = None | 
					
						
							|  |  |  |     stream: Optional[bool] = None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @app.delete("/api/push") | 
					
						
							|  |  |  | @app.delete("/api/push/{url_idx}") | 
					
						
							|  |  |  | async def push_model( | 
					
						
							|  |  |  |     form_data: PushModelForm, | 
					
						
							|  |  |  |     url_idx: Optional[int] = None, | 
					
						
							|  |  |  |     user=Depends(get_admin_user), | 
					
						
							|  |  |  | ): | 
					
						
							|  |  |  |     if url_idx == None: | 
					
						
							|  |  |  |         if form_data.name in app.state.MODELS: | 
					
						
							|  |  |  |             url_idx = app.state.MODELS[form_data.name]["urls"][0] | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             raise HTTPException( | 
					
						
							|  |  |  |                 status_code=400, | 
					
						
							| 
									
										
										
										
											2024-03-05 17:07:59 +08:00
										 |  |  |                 detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     url = app.state.OLLAMA_BASE_URLS[url_idx] | 
					
						
							| 
									
										
										
										
											2024-03-21 07:11:36 +08:00
										 |  |  |     log.debug(f"url: {url}") | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     r = None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_request(): | 
					
						
							|  |  |  |         nonlocal url | 
					
						
							|  |  |  |         nonlocal r | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             def stream_content(): | 
					
						
							|  |  |  |                 for chunk in r.iter_content(chunk_size=8192): | 
					
						
							|  |  |  |                     yield chunk | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             r = requests.request( | 
					
						
							|  |  |  |                 method="POST", | 
					
						
							|  |  |  |                 url=f"{url}/api/push", | 
					
						
							| 
									
										
										
										
											2024-03-07 10:37:40 +08:00
										 |  |  |                 data=form_data.model_dump_json(exclude_none=True).encode(), | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             r.raise_for_status() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             return StreamingResponse( | 
					
						
							|  |  |  |                 stream_content(), | 
					
						
							|  |  |  |                 status_code=r.status_code, | 
					
						
							|  |  |  |                 headers=dict(r.headers), | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             raise e | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         return await run_in_threadpool(get_request) | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							| 
									
										
										
										
											2024-03-21 07:11:36 +08:00
										 |  |  |         log.exception(e) | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |         error_detail = "Open WebUI: Server Connection Error" | 
					
						
							|  |  |  |         if r is not None: | 
					
						
							|  |  |  |             try: | 
					
						
							|  |  |  |                 res = r.json() | 
					
						
							|  |  |  |                 if "error" in res: | 
					
						
							|  |  |  |                     error_detail = f"Ollama: {res['error']}" | 
					
						
							|  |  |  |             except: | 
					
						
							|  |  |  |                 error_detail = f"Ollama: {e}" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         raise HTTPException( | 
					
						
							|  |  |  |             status_code=r.status_code if r else 500, | 
					
						
							|  |  |  |             detail=error_detail, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class CreateModelForm(BaseModel): | 
					
						
							|  |  |  |     name: str | 
					
						
							|  |  |  |     modelfile: Optional[str] = None | 
					
						
							|  |  |  |     stream: Optional[bool] = None | 
					
						
							|  |  |  |     path: Optional[str] = None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @app.post("/api/create") | 
					
						
							|  |  |  | @app.post("/api/create/{url_idx}") | 
					
						
							|  |  |  | async def create_model( | 
					
						
							|  |  |  |     form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user) | 
					
						
							|  |  |  | ): | 
					
						
							| 
									
										
										
										
											2024-03-21 07:11:36 +08:00
										 |  |  |     log.debug(f"form_data: {form_data}") | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |     url = app.state.OLLAMA_BASE_URLS[url_idx] | 
					
						
							| 
									
										
										
										
											2024-03-21 07:11:36 +08:00
										 |  |  |     log.info(f"url: {url}") | 
					
						
							| 
									
										
										
										
											2024-03-05 18:19:25 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |     r = None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_request(): | 
					
						
							|  |  |  |         nonlocal url | 
					
						
							|  |  |  |         nonlocal r | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             def stream_content(): | 
					
						
							|  |  |  |                 for chunk in r.iter_content(chunk_size=8192): | 
					
						
							|  |  |  |                     yield chunk | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             r = requests.request( | 
					
						
							|  |  |  |                 method="POST", | 
					
						
							|  |  |  |                 url=f"{url}/api/create", | 
					
						
							| 
									
										
										
										
											2024-03-07 10:37:40 +08:00
										 |  |  |                 data=form_data.model_dump_json(exclude_none=True).encode(), | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |                 stream=True, | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             r.raise_for_status() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-21 07:11:36 +08:00
										 |  |  |             log.debug(f"r: {r}") | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             return StreamingResponse( | 
					
						
							|  |  |  |                 stream_content(), | 
					
						
							|  |  |  |                 status_code=r.status_code, | 
					
						
							|  |  |  |                 headers=dict(r.headers), | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             raise e | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         return await run_in_threadpool(get_request) | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							| 
									
										
										
										
											2024-03-21 07:11:36 +08:00
										 |  |  |         log.exception(e) | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |         error_detail = "Open WebUI: Server Connection Error" | 
					
						
							|  |  |  |         if r is not None: | 
					
						
							|  |  |  |             try: | 
					
						
							|  |  |  |                 res = r.json() | 
					
						
							|  |  |  |                 if "error" in res: | 
					
						
							|  |  |  |                     error_detail = f"Ollama: {res['error']}" | 
					
						
							|  |  |  |             except: | 
					
						
							|  |  |  |                 error_detail = f"Ollama: {e}" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         raise HTTPException( | 
					
						
							|  |  |  |             status_code=r.status_code if r else 500, | 
					
						
							|  |  |  |             detail=error_detail, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class CopyModelForm(BaseModel): | 
					
						
							|  |  |  |     source: str | 
					
						
							|  |  |  |     destination: str | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @app.post("/api/copy") | 
					
						
							|  |  |  | @app.post("/api/copy/{url_idx}") | 
					
						
							|  |  |  | async def copy_model( | 
					
						
							|  |  |  |     form_data: CopyModelForm, | 
					
						
							|  |  |  |     url_idx: Optional[int] = None, | 
					
						
							|  |  |  |     user=Depends(get_admin_user), | 
					
						
							|  |  |  | ): | 
					
						
							|  |  |  |     if url_idx == None: | 
					
						
							|  |  |  |         if form_data.source in app.state.MODELS: | 
					
						
							|  |  |  |             url_idx = app.state.MODELS[form_data.source]["urls"][0] | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             raise HTTPException( | 
					
						
							|  |  |  |                 status_code=400, | 
					
						
							| 
									
										
										
										
											2024-03-05 17:07:59 +08:00
										 |  |  |                 detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source), | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     url = app.state.OLLAMA_BASE_URLS[url_idx] | 
					
						
							| 
									
										
										
										
											2024-03-21 07:11:36 +08:00
										 |  |  |     log.info(f"url: {url}") | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         r = requests.request( | 
					
						
							|  |  |  |             method="POST", | 
					
						
							|  |  |  |             url=f"{url}/api/copy", | 
					
						
							| 
									
										
										
										
											2024-03-07 10:37:40 +08:00
										 |  |  |             data=form_data.model_dump_json(exclude_none=True).encode(), | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  |         r.raise_for_status() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-21 07:11:36 +08:00
										 |  |  |         log.debug(f"r.text: {r.text}") | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         return True | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							| 
									
										
										
										
											2024-03-21 07:11:36 +08:00
										 |  |  |         log.exception(e) | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |         error_detail = "Open WebUI: Server Connection Error" | 
					
						
							|  |  |  |         if r is not None: | 
					
						
							|  |  |  |             try: | 
					
						
							|  |  |  |                 res = r.json() | 
					
						
							|  |  |  |                 if "error" in res: | 
					
						
							|  |  |  |                     error_detail = f"Ollama: {res['error']}" | 
					
						
							|  |  |  |             except: | 
					
						
							|  |  |  |                 error_detail = f"Ollama: {e}" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         raise HTTPException( | 
					
						
							|  |  |  |             status_code=r.status_code if r else 500, | 
					
						
							|  |  |  |             detail=error_detail, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @app.delete("/api/delete") | 
					
						
							|  |  |  | @app.delete("/api/delete/{url_idx}") | 
					
						
							|  |  |  | async def delete_model( | 
					
						
							|  |  |  |     form_data: ModelNameForm, | 
					
						
							|  |  |  |     url_idx: Optional[int] = None, | 
					
						
							|  |  |  |     user=Depends(get_admin_user), | 
					
						
							|  |  |  | ): | 
					
						
							|  |  |  |     if url_idx == None: | 
					
						
							|  |  |  |         if form_data.name in app.state.MODELS: | 
					
						
							|  |  |  |             url_idx = app.state.MODELS[form_data.name]["urls"][0] | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             raise HTTPException( | 
					
						
							|  |  |  |                 status_code=400, | 
					
						
							| 
									
										
										
										
											2024-03-05 17:07:59 +08:00
										 |  |  |                 detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     url = app.state.OLLAMA_BASE_URLS[url_idx] | 
					
						
							| 
									
										
										
										
											2024-03-21 07:11:36 +08:00
										 |  |  |     log.info(f"url: {url}") | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         r = requests.request( | 
					
						
							|  |  |  |             method="DELETE", | 
					
						
							|  |  |  |             url=f"{url}/api/delete", | 
					
						
							| 
									
										
										
										
											2024-03-07 10:37:40 +08:00
										 |  |  |             data=form_data.model_dump_json(exclude_none=True).encode(), | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  |         r.raise_for_status() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-21 07:11:36 +08:00
										 |  |  |         log.debug(f"r.text: {r.text}") | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         return True | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							| 
									
										
										
										
											2024-03-21 07:11:36 +08:00
										 |  |  |         log.exception(e) | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |         error_detail = "Open WebUI: Server Connection Error" | 
					
						
							|  |  |  |         if r is not None: | 
					
						
							|  |  |  |             try: | 
					
						
							|  |  |  |                 res = r.json() | 
					
						
							|  |  |  |                 if "error" in res: | 
					
						
							|  |  |  |                     error_detail = f"Ollama: {res['error']}" | 
					
						
							|  |  |  |             except: | 
					
						
							|  |  |  |                 error_detail = f"Ollama: {e}" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         raise HTTPException( | 
					
						
							|  |  |  |             status_code=r.status_code if r else 500, | 
					
						
							|  |  |  |             detail=error_detail, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @app.post("/api/show") | 
					
						
							|  |  |  | async def show_model_info(form_data: ModelNameForm, user=Depends(get_current_user)): | 
					
						
							|  |  |  |     if form_data.name not in app.state.MODELS: | 
					
						
							|  |  |  |         raise HTTPException( | 
					
						
							|  |  |  |             status_code=400, | 
					
						
							| 
									
										
										
										
											2024-03-05 17:07:59 +08:00
										 |  |  |             detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     url_idx = random.choice(app.state.MODELS[form_data.name]["urls"]) | 
					
						
							|  |  |  |     url = app.state.OLLAMA_BASE_URLS[url_idx] | 
					
						
							| 
									
										
										
										
											2024-03-21 07:11:36 +08:00
										 |  |  |     log.info(f"url: {url}") | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         r = requests.request( | 
					
						
							|  |  |  |             method="POST", | 
					
						
							|  |  |  |             url=f"{url}/api/show", | 
					
						
							| 
									
										
										
										
											2024-03-07 10:37:40 +08:00
										 |  |  |             data=form_data.model_dump_json(exclude_none=True).encode(), | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  |         r.raise_for_status() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return r.json() | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							| 
									
										
										
										
											2024-03-21 07:11:36 +08:00
										 |  |  |         log.exception(e) | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |         error_detail = "Open WebUI: Server Connection Error" | 
					
						
							|  |  |  |         if r is not None: | 
					
						
							|  |  |  |             try: | 
					
						
							|  |  |  |                 res = r.json() | 
					
						
							|  |  |  |                 if "error" in res: | 
					
						
							|  |  |  |                     error_detail = f"Ollama: {res['error']}" | 
					
						
							|  |  |  |             except: | 
					
						
							|  |  |  |                 error_detail = f"Ollama: {e}" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         raise HTTPException( | 
					
						
							|  |  |  |             status_code=r.status_code if r else 500, | 
					
						
							|  |  |  |             detail=error_detail, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class GenerateEmbeddingsForm(BaseModel): | 
					
						
							|  |  |  |     model: str | 
					
						
							|  |  |  |     prompt: str | 
					
						
							|  |  |  |     options: Optional[dict] = None | 
					
						
							|  |  |  |     keep_alive: Optional[Union[int, str]] = None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @app.post("/api/embeddings") | 
					
						
							|  |  |  | @app.post("/api/embeddings/{url_idx}") | 
					
						
							|  |  |  | async def generate_embeddings( | 
					
						
							|  |  |  |     form_data: GenerateEmbeddingsForm, | 
					
						
							|  |  |  |     url_idx: Optional[int] = None, | 
					
						
							|  |  |  |     user=Depends(get_current_user), | 
					
						
							|  |  |  | ): | 
					
						
							|  |  |  |     if url_idx == None: | 
					
						
							| 
									
										
										
										
											2024-04-15 04:51:13 +08:00
										 |  |  |         model = form_data.model | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if ":" not in model: | 
					
						
							|  |  |  |             model = f"{model}:latest" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if model in app.state.MODELS: | 
					
						
							| 
									
										
										
										
											2024-04-15 04:52:59 +08:00
										 |  |  |             url_idx = random.choice(app.state.MODELS[model]["urls"]) | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |         else: | 
					
						
							|  |  |  |             raise HTTPException( | 
					
						
							|  |  |  |                 status_code=400, | 
					
						
							| 
									
										
										
										
											2024-03-05 17:07:59 +08:00
										 |  |  |                 detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     url = app.state.OLLAMA_BASE_URLS[url_idx] | 
					
						
							| 
									
										
										
										
											2024-03-21 07:11:36 +08:00
										 |  |  |     log.info(f"url: {url}") | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         r = requests.request( | 
					
						
							|  |  |  |             method="POST", | 
					
						
							|  |  |  |             url=f"{url}/api/embeddings", | 
					
						
							| 
									
										
										
										
											2024-03-07 10:37:40 +08:00
										 |  |  |             data=form_data.model_dump_json(exclude_none=True).encode(), | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  |         r.raise_for_status() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return r.json() | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							| 
									
										
										
										
											2024-03-21 07:11:36 +08:00
										 |  |  |         log.exception(e) | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |         error_detail = "Open WebUI: Server Connection Error" | 
					
						
							|  |  |  |         if r is not None: | 
					
						
							|  |  |  |             try: | 
					
						
							|  |  |  |                 res = r.json() | 
					
						
							|  |  |  |                 if "error" in res: | 
					
						
							|  |  |  |                     error_detail = f"Ollama: {res['error']}" | 
					
						
							|  |  |  |             except: | 
					
						
							|  |  |  |                 error_detail = f"Ollama: {e}" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         raise HTTPException( | 
					
						
							|  |  |  |             status_code=r.status_code if r else 500, | 
					
						
							|  |  |  |             detail=error_detail, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-15 05:55:00 +08:00
										 |  |  | def generate_ollama_embeddings( | 
					
						
							|  |  |  |     form_data: GenerateEmbeddingsForm, | 
					
						
							|  |  |  |     url_idx: Optional[int] = None, | 
					
						
							|  |  |  | ): | 
					
						
							| 
									
										
										
										
											2024-04-15 06:47:45 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-15 07:48:15 +08:00
										 |  |  |     log.info(f"generate_ollama_embeddings {form_data}") | 
					
						
							| 
									
										
										
										
											2024-04-15 06:47:45 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-15 05:55:00 +08:00
										 |  |  |     if url_idx == None: | 
					
						
							|  |  |  |         model = form_data.model | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if ":" not in model: | 
					
						
							|  |  |  |             model = f"{model}:latest" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if model in app.state.MODELS: | 
					
						
							|  |  |  |             url_idx = random.choice(app.state.MODELS[model]["urls"]) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             raise HTTPException( | 
					
						
							|  |  |  |                 status_code=400, | 
					
						
							|  |  |  |                 detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     url = app.state.OLLAMA_BASE_URLS[url_idx] | 
					
						
							|  |  |  |     log.info(f"url: {url}") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         r = requests.request( | 
					
						
							|  |  |  |             method="POST", | 
					
						
							|  |  |  |             url=f"{url}/api/embeddings", | 
					
						
							|  |  |  |             data=form_data.model_dump_json(exclude_none=True).encode(), | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         r.raise_for_status() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         data = r.json() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-15 07:48:15 +08:00
										 |  |  |         log.info(f"generate_ollama_embeddings {data}") | 
					
						
							| 
									
										
										
										
											2024-04-15 06:47:45 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-15 05:55:00 +08:00
										 |  |  |         if "embedding" in data: | 
					
						
							|  |  |  |             return data["embedding"] | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             raise "Something went wrong :/" | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         log.exception(e) | 
					
						
							|  |  |  |         error_detail = "Open WebUI: Server Connection Error" | 
					
						
							|  |  |  |         if r is not None: | 
					
						
							|  |  |  |             try: | 
					
						
							|  |  |  |                 res = r.json() | 
					
						
							|  |  |  |                 if "error" in res: | 
					
						
							|  |  |  |                     error_detail = f"Ollama: {res['error']}" | 
					
						
							|  |  |  |             except: | 
					
						
							|  |  |  |                 error_detail = f"Ollama: {e}" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         raise error_detail | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  | class GenerateCompletionForm(BaseModel): | 
					
						
							|  |  |  |     model: str | 
					
						
							|  |  |  |     prompt: str | 
					
						
							|  |  |  |     images: Optional[List[str]] = None | 
					
						
							|  |  |  |     format: Optional[str] = None | 
					
						
							|  |  |  |     options: Optional[dict] = None | 
					
						
							|  |  |  |     system: Optional[str] = None | 
					
						
							|  |  |  |     template: Optional[str] = None | 
					
						
							|  |  |  |     context: Optional[str] = None | 
					
						
							|  |  |  |     stream: Optional[bool] = True | 
					
						
							|  |  |  |     raw: Optional[bool] = None | 
					
						
							|  |  |  |     keep_alive: Optional[Union[int, str]] = None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @app.post("/api/generate") | 
					
						
							|  |  |  | @app.post("/api/generate/{url_idx}") | 
					
						
							|  |  |  | async def generate_completion( | 
					
						
							|  |  |  |     form_data: GenerateCompletionForm, | 
					
						
							|  |  |  |     url_idx: Optional[int] = None, | 
					
						
							|  |  |  |     user=Depends(get_current_user), | 
					
						
							|  |  |  | ): | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if url_idx == None: | 
					
						
							| 
									
										
										
										
											2024-04-15 04:51:13 +08:00
										 |  |  |         model = form_data.model | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if ":" not in model: | 
					
						
							|  |  |  |             model = f"{model}:latest" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if model in app.state.MODELS: | 
					
						
							| 
									
										
										
										
											2024-04-15 04:52:59 +08:00
										 |  |  |             url_idx = random.choice(app.state.MODELS[model]["urls"]) | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |         else: | 
					
						
							|  |  |  |             raise HTTPException( | 
					
						
							|  |  |  |                 status_code=400, | 
					
						
							| 
									
										
										
										
											2024-04-01 04:40:57 +08:00
										 |  |  |                 detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     url = app.state.OLLAMA_BASE_URLS[url_idx] | 
					
						
							| 
									
										
										
										
											2024-03-21 07:11:36 +08:00
										 |  |  |     log.info(f"url: {url}") | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     r = None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_request(): | 
					
						
							|  |  |  |         nonlocal form_data | 
					
						
							|  |  |  |         nonlocal r | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         request_id = str(uuid.uuid4()) | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             REQUEST_POOL.append(request_id) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             def stream_content(): | 
					
						
							|  |  |  |                 try: | 
					
						
							|  |  |  |                     if form_data.stream: | 
					
						
							|  |  |  |                         yield json.dumps({"id": request_id, "done": False}) + "\n" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     for chunk in r.iter_content(chunk_size=8192): | 
					
						
							|  |  |  |                         if request_id in REQUEST_POOL: | 
					
						
							|  |  |  |                             yield chunk | 
					
						
							|  |  |  |                         else: | 
					
						
							| 
									
										
										
										
											2024-03-21 07:11:36 +08:00
										 |  |  |                             log.warning("User: canceled request") | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |                             break | 
					
						
							|  |  |  |                 finally: | 
					
						
							|  |  |  |                     if hasattr(r, "close"): | 
					
						
							|  |  |  |                         r.close() | 
					
						
							|  |  |  |                         if request_id in REQUEST_POOL: | 
					
						
							|  |  |  |                             REQUEST_POOL.remove(request_id) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             r = requests.request( | 
					
						
							|  |  |  |                 method="POST", | 
					
						
							|  |  |  |                 url=f"{url}/api/generate", | 
					
						
							| 
									
										
										
										
											2024-03-07 10:37:40 +08:00
										 |  |  |                 data=form_data.model_dump_json(exclude_none=True).encode(), | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |                 stream=True, | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             r.raise_for_status() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             return StreamingResponse( | 
					
						
							|  |  |  |                 stream_content(), | 
					
						
							|  |  |  |                 status_code=r.status_code, | 
					
						
							|  |  |  |                 headers=dict(r.headers), | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             raise e | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         return await run_in_threadpool(get_request) | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         error_detail = "Open WebUI: Server Connection Error" | 
					
						
							|  |  |  |         if r is not None: | 
					
						
							|  |  |  |             try: | 
					
						
							|  |  |  |                 res = r.json() | 
					
						
							|  |  |  |                 if "error" in res: | 
					
						
							|  |  |  |                     error_detail = f"Ollama: {res['error']}" | 
					
						
							|  |  |  |             except: | 
					
						
							|  |  |  |                 error_detail = f"Ollama: {e}" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         raise HTTPException( | 
					
						
							|  |  |  |             status_code=r.status_code if r else 500, | 
					
						
							|  |  |  |             detail=error_detail, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class ChatMessage(BaseModel): | 
					
						
							|  |  |  |     role: str | 
					
						
							|  |  |  |     content: str | 
					
						
							|  |  |  |     images: Optional[List[str]] = None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class GenerateChatCompletionForm(BaseModel): | 
					
						
							|  |  |  |     model: str | 
					
						
							|  |  |  |     messages: List[ChatMessage] | 
					
						
							|  |  |  |     format: Optional[str] = None | 
					
						
							|  |  |  |     options: Optional[dict] = None | 
					
						
							|  |  |  |     template: Optional[str] = None | 
					
						
							| 
									
										
										
										
											2024-03-19 09:14:05 +08:00
										 |  |  |     stream: Optional[bool] = None | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |     keep_alive: Optional[Union[int, str]] = None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @app.post("/api/chat") | 
					
						
							|  |  |  | @app.post("/api/chat/{url_idx}") | 
					
						
							| 
									
										
										
										
											2024-03-05 17:41:22 +08:00
										 |  |  | async def generate_chat_completion( | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |     form_data: GenerateChatCompletionForm, | 
					
						
							|  |  |  |     url_idx: Optional[int] = None, | 
					
						
							|  |  |  |     user=Depends(get_current_user), | 
					
						
							|  |  |  | ): | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if url_idx == None: | 
					
						
							| 
									
										
										
										
											2024-04-15 04:51:13 +08:00
										 |  |  |         model = form_data.model | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if ":" not in model: | 
					
						
							|  |  |  |             model = f"{model}:latest" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if model in app.state.MODELS: | 
					
						
							| 
									
										
										
										
											2024-04-15 04:52:59 +08:00
										 |  |  |             url_idx = random.choice(app.state.MODELS[model]["urls"]) | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |         else: | 
					
						
							|  |  |  |             raise HTTPException( | 
					
						
							|  |  |  |                 status_code=400, | 
					
						
							| 
									
										
										
										
											2024-03-05 17:07:59 +08:00
										 |  |  |                 detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     url = app.state.OLLAMA_BASE_URLS[url_idx] | 
					
						
							| 
									
										
										
										
											2024-03-21 07:11:36 +08:00
										 |  |  |     log.info(f"url: {url}") | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     r = None | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-31 16:13:39 +08:00
										 |  |  |     log.debug( | 
					
						
							|  |  |  |         "form_data.model_dump_json(exclude_none=True).encode(): {0} ".format( | 
					
						
							|  |  |  |             form_data.model_dump_json(exclude_none=True).encode() | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def get_request(): | 
					
						
							|  |  |  |         nonlocal form_data | 
					
						
							|  |  |  |         nonlocal r | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         request_id = str(uuid.uuid4()) | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             REQUEST_POOL.append(request_id) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             def stream_content(): | 
					
						
							|  |  |  |                 try: | 
					
						
							|  |  |  |                     if form_data.stream: | 
					
						
							|  |  |  |                         yield json.dumps({"id": request_id, "done": False}) + "\n" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     for chunk in r.iter_content(chunk_size=8192): | 
					
						
							|  |  |  |                         if request_id in REQUEST_POOL: | 
					
						
							|  |  |  |                             yield chunk | 
					
						
							|  |  |  |                         else: | 
					
						
							| 
									
										
										
										
											2024-03-21 07:11:36 +08:00
										 |  |  |                             log.warning("User: canceled request") | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |                             break | 
					
						
							|  |  |  |                 finally: | 
					
						
							|  |  |  |                     if hasattr(r, "close"): | 
					
						
							|  |  |  |                         r.close() | 
					
						
							|  |  |  |                         if request_id in REQUEST_POOL: | 
					
						
							|  |  |  |                             REQUEST_POOL.remove(request_id) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             r = requests.request( | 
					
						
							|  |  |  |                 method="POST", | 
					
						
							|  |  |  |                 url=f"{url}/api/chat", | 
					
						
							| 
									
										
										
										
											2024-03-07 10:37:40 +08:00
										 |  |  |                 data=form_data.model_dump_json(exclude_none=True).encode(), | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |                 stream=True, | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             r.raise_for_status() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             return StreamingResponse( | 
					
						
							|  |  |  |                 stream_content(), | 
					
						
							|  |  |  |                 status_code=r.status_code, | 
					
						
							|  |  |  |                 headers=dict(r.headers), | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							| 
									
										
										
										
											2024-03-21 07:11:36 +08:00
										 |  |  |             log.exception(e) | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |             raise e | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         return await run_in_threadpool(get_request) | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         error_detail = "Open WebUI: Server Connection Error" | 
					
						
							|  |  |  |         if r is not None: | 
					
						
							|  |  |  |             try: | 
					
						
							|  |  |  |                 res = r.json() | 
					
						
							|  |  |  |                 if "error" in res: | 
					
						
							|  |  |  |                     error_detail = f"Ollama: {res['error']}" | 
					
						
							|  |  |  |             except: | 
					
						
							|  |  |  |                 error_detail = f"Ollama: {e}" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         raise HTTPException( | 
					
						
							|  |  |  |             status_code=r.status_code if r else 500, | 
					
						
							|  |  |  |             detail=error_detail, | 
					
						
							| 
									
										
										
										
											2024-03-05 17:41:22 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # TODO: we should update this part once Ollama supports other types | 
					
						
							|  |  |  | class OpenAIChatMessage(BaseModel): | 
					
						
							|  |  |  |     role: str | 
					
						
							|  |  |  |     content: str | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     model_config = ConfigDict(extra="allow") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class OpenAIChatCompletionForm(BaseModel): | 
					
						
							|  |  |  |     model: str | 
					
						
							|  |  |  |     messages: List[OpenAIChatMessage] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     model_config = ConfigDict(extra="allow") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @app.post("/v1/chat/completions") | 
					
						
							|  |  |  | @app.post("/v1/chat/completions/{url_idx}") | 
					
						
							|  |  |  | async def generate_openai_chat_completion( | 
					
						
							|  |  |  |     form_data: OpenAIChatCompletionForm, | 
					
						
							|  |  |  |     url_idx: Optional[int] = None, | 
					
						
							|  |  |  |     user=Depends(get_current_user), | 
					
						
							|  |  |  | ): | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if url_idx == None: | 
					
						
							| 
									
										
										
										
											2024-04-15 04:51:13 +08:00
										 |  |  |         model = form_data.model | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if ":" not in model: | 
					
						
							|  |  |  |             model = f"{model}:latest" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if model in app.state.MODELS: | 
					
						
							| 
									
										
										
										
											2024-04-15 04:52:59 +08:00
										 |  |  |             url_idx = random.choice(app.state.MODELS[model]["urls"]) | 
					
						
							| 
									
										
										
										
											2024-03-05 17:41:22 +08:00
										 |  |  |         else: | 
					
						
							|  |  |  |             raise HTTPException( | 
					
						
							|  |  |  |                 status_code=400, | 
					
						
							|  |  |  |                 detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     url = app.state.OLLAMA_BASE_URLS[url_idx] | 
					
						
							| 
									
										
										
										
											2024-03-21 07:11:36 +08:00
										 |  |  |     log.info(f"url: {url}") | 
					
						
							| 
									
										
										
										
											2024-03-05 17:41:22 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     r = None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_request(): | 
					
						
							|  |  |  |         nonlocal form_data | 
					
						
							|  |  |  |         nonlocal r | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         request_id = str(uuid.uuid4()) | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             REQUEST_POOL.append(request_id) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             def stream_content(): | 
					
						
							|  |  |  |                 try: | 
					
						
							|  |  |  |                     if form_data.stream: | 
					
						
							|  |  |  |                         yield json.dumps( | 
					
						
							|  |  |  |                             {"request_id": request_id, "done": False} | 
					
						
							|  |  |  |                         ) + "\n" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     for chunk in r.iter_content(chunk_size=8192): | 
					
						
							|  |  |  |                         if request_id in REQUEST_POOL: | 
					
						
							|  |  |  |                             yield chunk | 
					
						
							|  |  |  |                         else: | 
					
						
							| 
									
										
										
										
											2024-03-21 07:11:36 +08:00
										 |  |  |                             log.warning("User: canceled request") | 
					
						
							| 
									
										
										
										
											2024-03-05 17:41:22 +08:00
										 |  |  |                             break | 
					
						
							|  |  |  |                 finally: | 
					
						
							|  |  |  |                     if hasattr(r, "close"): | 
					
						
							|  |  |  |                         r.close() | 
					
						
							|  |  |  |                         if request_id in REQUEST_POOL: | 
					
						
							|  |  |  |                             REQUEST_POOL.remove(request_id) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             r = requests.request( | 
					
						
							|  |  |  |                 method="POST", | 
					
						
							|  |  |  |                 url=f"{url}/v1/chat/completions", | 
					
						
							| 
									
										
										
										
											2024-03-07 10:37:40 +08:00
										 |  |  |                 data=form_data.model_dump_json(exclude_none=True).encode(), | 
					
						
							| 
									
										
										
										
											2024-03-05 17:41:22 +08:00
										 |  |  |                 stream=True, | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             r.raise_for_status() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             return StreamingResponse( | 
					
						
							|  |  |  |                 stream_content(), | 
					
						
							|  |  |  |                 status_code=r.status_code, | 
					
						
							|  |  |  |                 headers=dict(r.headers), | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             raise e | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         return await run_in_threadpool(get_request) | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         error_detail = "Open WebUI: Server Connection Error" | 
					
						
							|  |  |  |         if r is not None: | 
					
						
							|  |  |  |             try: | 
					
						
							|  |  |  |                 res = r.json() | 
					
						
							|  |  |  |                 if "error" in res: | 
					
						
							|  |  |  |                     error_detail = f"Ollama: {res['error']}" | 
					
						
							|  |  |  |             except: | 
					
						
							|  |  |  |                 error_detail = f"Ollama: {e}" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         raise HTTPException( | 
					
						
							|  |  |  |             status_code=r.status_code if r else 500, | 
					
						
							|  |  |  |             detail=error_detail, | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-22 14:45:00 +08:00
										 |  |  | class UrlForm(BaseModel): | 
					
						
							|  |  |  |     url: str | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class UploadBlobForm(BaseModel): | 
					
						
							|  |  |  |     filename: str | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def parse_huggingface_url(hf_url): | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         # Parse the URL | 
					
						
							|  |  |  |         parsed_url = urlparse(hf_url) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Get the path and split it into components | 
					
						
							|  |  |  |         path_components = parsed_url.path.split("/") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Extract the desired output | 
					
						
							|  |  |  |         user_repo = "/".join(path_components[1:3]) | 
					
						
							|  |  |  |         model_file = path_components[-1] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return model_file | 
					
						
							|  |  |  |     except ValueError: | 
					
						
							|  |  |  |         return None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | async def download_file_stream( | 
					
						
							|  |  |  |     ollama_url, file_url, file_path, file_name, chunk_size=1024 * 1024 | 
					
						
							|  |  |  | ): | 
					
						
							|  |  |  |     done = False | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if os.path.exists(file_path): | 
					
						
							|  |  |  |         current_size = os.path.getsize(file_path) | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         current_size = 0 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     headers = {"Range": f"bytes={current_size}-"} if current_size > 0 else {} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     timeout = aiohttp.ClientTimeout(total=600)  # Set the timeout | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     async with aiohttp.ClientSession(timeout=timeout) as session: | 
					
						
							|  |  |  |         async with session.get(file_url, headers=headers) as response: | 
					
						
							|  |  |  |             total_size = int(response.headers.get("content-length", 0)) + current_size | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             with open(file_path, "ab+") as file: | 
					
						
							|  |  |  |                 async for data in response.content.iter_chunked(chunk_size): | 
					
						
							|  |  |  |                     current_size += len(data) | 
					
						
							|  |  |  |                     file.write(data) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     done = current_size == total_size | 
					
						
							|  |  |  |                     progress = round((current_size / total_size) * 100, 2) | 
					
						
							| 
									
										
										
										
											2024-03-22 15:10:55 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-22 14:45:00 +08:00
										 |  |  |                     yield f'data: {{"progress": {progress}, "completed": {current_size}, "total": {total_size}}}\n\n' | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 if done: | 
					
						
							|  |  |  |                     file.seek(0) | 
					
						
							|  |  |  |                     hashed = calculate_sha256(file) | 
					
						
							|  |  |  |                     file.seek(0) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     url = f"{ollama_url}/api/blobs/sha256:{hashed}" | 
					
						
							|  |  |  |                     response = requests.post(url, data=file) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     if response.ok: | 
					
						
							|  |  |  |                         res = { | 
					
						
							|  |  |  |                             "done": done, | 
					
						
							|  |  |  |                             "blob": f"sha256:{hashed}", | 
					
						
							|  |  |  |                             "name": file_name, | 
					
						
							|  |  |  |                         } | 
					
						
							|  |  |  |                         os.remove(file_path) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                         yield f"data: {json.dumps(res)}\n\n" | 
					
						
							|  |  |  |                     else: | 
					
						
							|  |  |  |                         raise "Ollama: Could not create blob, Please try again." | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-22 15:10:55 +08:00
										 |  |  | # def number_generator(): | 
					
						
							|  |  |  | #     for i in range(1, 101): | 
					
						
							|  |  |  | #         yield f"data: {i}\n" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-22 14:45:00 +08:00
										 |  |  | # url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf" | 
					
						
							|  |  |  | @app.post("/models/download") | 
					
						
							|  |  |  | @app.post("/models/download/{url_idx}") | 
					
						
							|  |  |  | async def download_model( | 
					
						
							|  |  |  |     form_data: UrlForm, | 
					
						
							|  |  |  |     url_idx: Optional[int] = None, | 
					
						
							|  |  |  | ): | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-04-02 05:01:05 +08:00
										 |  |  |     allowed_hosts = ["https://huggingface.co/", "https://github.com/"] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if not any(form_data.url.startswith(host) for host in allowed_hosts): | 
					
						
							|  |  |  |         raise HTTPException( | 
					
						
							|  |  |  |             status_code=400, | 
					
						
							|  |  |  |             detail="Invalid file_url. Only URLs from allowed hosts are permitted.", | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-22 14:45:00 +08:00
										 |  |  |     if url_idx == None: | 
					
						
							|  |  |  |         url_idx = 0 | 
					
						
							|  |  |  |     url = app.state.OLLAMA_BASE_URLS[url_idx] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     file_name = parse_huggingface_url(form_data.url) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if file_name: | 
					
						
							|  |  |  |         file_path = f"{UPLOAD_DIR}/{file_name}" | 
					
						
							| 
									
										
										
										
											2024-04-02 05:01:05 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-22 14:45:00 +08:00
										 |  |  |         return StreamingResponse( | 
					
						
							| 
									
										
										
										
											2024-03-22 15:10:55 +08:00
										 |  |  |             download_file_stream(url, form_data.url, file_path, file_name), | 
					
						
							| 
									
										
										
										
											2024-03-22 14:45:00 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         return None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @app.post("/models/upload") | 
					
						
							|  |  |  | @app.post("/models/upload/{url_idx}") | 
					
						
							|  |  |  | def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None): | 
					
						
							|  |  |  |     if url_idx == None: | 
					
						
							|  |  |  |         url_idx = 0 | 
					
						
							|  |  |  |     ollama_url = app.state.OLLAMA_BASE_URLS[url_idx] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     file_path = f"{UPLOAD_DIR}/{file.filename}" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Save file in chunks | 
					
						
							|  |  |  |     with open(file_path, "wb+") as f: | 
					
						
							|  |  |  |         for chunk in file.file: | 
					
						
							|  |  |  |             f.write(chunk) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def file_process_stream(): | 
					
						
							|  |  |  |         nonlocal ollama_url | 
					
						
							|  |  |  |         total_size = os.path.getsize(file_path) | 
					
						
							|  |  |  |         chunk_size = 1024 * 1024 | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             with open(file_path, "rb") as f: | 
					
						
							|  |  |  |                 total = 0 | 
					
						
							|  |  |  |                 done = False | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 while not done: | 
					
						
							|  |  |  |                     chunk = f.read(chunk_size) | 
					
						
							|  |  |  |                     if not chunk: | 
					
						
							|  |  |  |                         done = True | 
					
						
							|  |  |  |                         continue | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     total += len(chunk) | 
					
						
							|  |  |  |                     progress = round((total / total_size) * 100, 2) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     res = { | 
					
						
							|  |  |  |                         "progress": progress, | 
					
						
							|  |  |  |                         "total": total_size, | 
					
						
							|  |  |  |                         "completed": total, | 
					
						
							|  |  |  |                     } | 
					
						
							|  |  |  |                     yield f"data: {json.dumps(res)}\n\n" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 if done: | 
					
						
							|  |  |  |                     f.seek(0) | 
					
						
							|  |  |  |                     hashed = calculate_sha256(f) | 
					
						
							|  |  |  |                     f.seek(0) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-22 15:58:38 +08:00
										 |  |  |                     url = f"{ollama_url}/api/blobs/sha256:{hashed}" | 
					
						
							| 
									
										
										
										
											2024-03-22 14:45:00 +08:00
										 |  |  |                     response = requests.post(url, data=f) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     if response.ok: | 
					
						
							|  |  |  |                         res = { | 
					
						
							|  |  |  |                             "done": done, | 
					
						
							|  |  |  |                             "blob": f"sha256:{hashed}", | 
					
						
							|  |  |  |                             "name": file.filename, | 
					
						
							|  |  |  |                         } | 
					
						
							|  |  |  |                         os.remove(file_path) | 
					
						
							|  |  |  |                         yield f"data: {json.dumps(res)}\n\n" | 
					
						
							|  |  |  |                     else: | 
					
						
							|  |  |  |                         raise Exception( | 
					
						
							|  |  |  |                             "Ollama: Could not create blob, Please try again." | 
					
						
							|  |  |  |                         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             res = {"error": str(e)} | 
					
						
							|  |  |  |             yield f"data: {json.dumps(res)}\n\n" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return StreamingResponse(file_process_stream(), media_type="text/event-stream") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-22 15:55:59 +08:00
										 |  |  | # async def upload_model(file: UploadFile = File(), url_idx: Optional[int] = None): | 
					
						
							|  |  |  | #     if url_idx == None: | 
					
						
							|  |  |  | #         url_idx = 0 | 
					
						
							|  |  |  | #     url = app.state.OLLAMA_BASE_URLS[url_idx] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #     file_location = os.path.join(UPLOAD_DIR, file.filename) | 
					
						
							|  |  |  | #     total_size = file.size | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #     async def file_upload_generator(file): | 
					
						
							|  |  |  | #         print(file) | 
					
						
							|  |  |  | #         try: | 
					
						
							|  |  |  | #             async with aiofiles.open(file_location, "wb") as f: | 
					
						
							|  |  |  | #                 completed_size = 0 | 
					
						
							|  |  |  | #                 while True: | 
					
						
							|  |  |  | #                     chunk = await file.read(1024*1024) | 
					
						
							|  |  |  | #                     if not chunk: | 
					
						
							|  |  |  | #                         break | 
					
						
							|  |  |  | #                     await f.write(chunk) | 
					
						
							|  |  |  | #                     completed_size += len(chunk) | 
					
						
							|  |  |  | #                     progress = (completed_size / total_size) * 100 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #                     print(progress) | 
					
						
							|  |  |  | #                     yield f'data: {json.dumps({"status": "uploading", "percentage": progress, "total": total_size, "completed": completed_size, "done": False})}\n' | 
					
						
							|  |  |  | #         except Exception as e: | 
					
						
							|  |  |  | #             print(e) | 
					
						
							|  |  |  | #             yield f"data: {json.dumps({'status': 'error', 'message': str(e)})}\n" | 
					
						
							|  |  |  | #         finally: | 
					
						
							|  |  |  | #             await file.close() | 
					
						
							|  |  |  | #             print("done") | 
					
						
							|  |  |  | #             yield f'data: {json.dumps({"status": "completed", "percentage": 100, "total": total_size, "completed": completed_size, "done": True})}\n' | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #     return StreamingResponse( | 
					
						
							|  |  |  | #         file_upload_generator(copy.deepcopy(file)), media_type="text/event-stream" | 
					
						
							|  |  |  | #     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-05 05:06:31 +08:00
										 |  |  | @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) | 
					
						
							| 
									
										
										
										
											2024-03-05 17:21:50 +08:00
										 |  |  | async def deprecated_proxy(path: str, request: Request, user=Depends(get_current_user)): | 
					
						
							| 
									
										
										
										
											2024-03-05 16:59:35 +08:00
										 |  |  |     url = app.state.OLLAMA_BASE_URLS[0] | 
					
						
							|  |  |  |     target_url = f"{url}/{path}" | 
					
						
							| 
									
										
										
										
											2024-01-05 05:06:31 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     body = await request.body() | 
					
						
							|  |  |  |     headers = dict(request.headers) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if user.role in ["user", "admin"]: | 
					
						
							|  |  |  |         if path in ["pull", "delete", "push", "copy", "create"]: | 
					
						
							|  |  |  |             if user.role != "admin": | 
					
						
							| 
									
										
										
										
											2024-01-05 17:25:34 +08:00
										 |  |  |                 raise HTTPException( | 
					
						
							| 
									
										
										
										
											2024-02-17 15:30:38 +08:00
										 |  |  |                     status_code=status.HTTP_401_UNAUTHORIZED, | 
					
						
							|  |  |  |                     detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | 
					
						
							| 
									
										
										
										
											2024-01-05 17:25:34 +08:00
										 |  |  |                 ) | 
					
						
							| 
									
										
										
										
											2024-01-05 05:06:31 +08:00
										 |  |  |     else: | 
					
						
							| 
									
										
										
										
											2024-02-17 15:30:38 +08:00
										 |  |  |         raise HTTPException( | 
					
						
							|  |  |  |             status_code=status.HTTP_401_UNAUTHORIZED, | 
					
						
							|  |  |  |             detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2023-12-15 09:05:46 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
											  
											
												Fix bug: Header attributes (Host, Authorization, Origin, Referer) not sanitized
- Resolved an issue where header attributes Host, Authorization, Origin, and Referer were not being sanitized, resulting in two major issues:
  1. Ollama requests inadvertently exposed user information, leading to data leakage.
  2. When Ollama is deployed on different servers, and the intermediary proxy layer uses the host header to locate downstream services, it fails to find them.
Root Cause:
- In FastAPI, when accessing request.headers, all header names are converted to lowercase. This is because FastAPI, and its underlying framework Starlette, adhere to the HTTP/2 standard, which mandates lowercase header field names for performance and consistency.
- In HTTP/2, enforcing lowercase header field names reduces complexity in header processing as case sensitivity is no longer a concern. Thus, regardless of the case used in client-sent header fields, the server processes them uniformly in lowercase.
- This practice is adopted in FastAPI and other modern HTTP frameworks, even in an HTTP/1.1 context, to maintain consistency with HTTP/2 and improve overall performance. As a result, header field names are always presented in lowercase in FastAPI, even if the original request used capitalization or mixed case.
											
										 
											2024-01-11 14:36:34 +08:00
										 |  |  |     headers.pop("host", None) | 
					
						
							|  |  |  |     headers.pop("authorization", None) | 
					
						
							|  |  |  |     headers.pop("origin", None) | 
					
						
							|  |  |  |     headers.pop("referer", None) | 
					
						
							| 
									
										
										
										
											2023-12-27 05:40:03 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-06 09:16:35 +08:00
										 |  |  |     r = None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_request(): | 
					
						
							|  |  |  |         nonlocal r | 
					
						
							| 
									
										
										
										
											2024-01-18 11:19:44 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         request_id = str(uuid.uuid4()) | 
					
						
							| 
									
										
										
										
											2024-01-06 09:16:35 +08:00
										 |  |  |         try: | 
					
						
							| 
									
										
										
										
											2024-01-18 11:19:44 +08:00
										 |  |  |             REQUEST_POOL.append(request_id) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             def stream_content(): | 
					
						
							|  |  |  |                 try: | 
					
						
							| 
									
										
										
										
											2024-03-02 19:01:44 +08:00
										 |  |  |                     if path == "generate": | 
					
						
							|  |  |  |                         data = json.loads(body.decode("utf-8")) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                         if not ("stream" in data and data["stream"] == False): | 
					
						
							|  |  |  |                             yield json.dumps({"id": request_id, "done": False}) + "\n" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     elif path == "chat": | 
					
						
							| 
									
										
										
										
											2024-01-18 11:19:44 +08:00
										 |  |  |                         yield json.dumps({"id": request_id, "done": False}) + "\n" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     for chunk in r.iter_content(chunk_size=8192): | 
					
						
							|  |  |  |                         if request_id in REQUEST_POOL: | 
					
						
							|  |  |  |                             yield chunk | 
					
						
							|  |  |  |                         else: | 
					
						
							| 
									
										
										
										
											2024-03-21 07:11:36 +08:00
										 |  |  |                             log.warning("User: canceled request") | 
					
						
							| 
									
										
										
										
											2024-01-18 11:19:44 +08:00
										 |  |  |                             break | 
					
						
							|  |  |  |                 finally: | 
					
						
							|  |  |  |                     if hasattr(r, "close"): | 
					
						
							|  |  |  |                         r.close() | 
					
						
							| 
									
										
										
										
											2024-03-02 19:01:44 +08:00
										 |  |  |                         if request_id in REQUEST_POOL: | 
					
						
							|  |  |  |                             REQUEST_POOL.remove(request_id) | 
					
						
							| 
									
										
										
										
											2024-01-18 11:19:44 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-06 09:16:35 +08:00
										 |  |  |             r = requests.request( | 
					
						
							|  |  |  |                 method=request.method, | 
					
						
							|  |  |  |                 url=target_url, | 
					
						
							|  |  |  |                 data=body, | 
					
						
							|  |  |  |                 headers=headers, | 
					
						
							|  |  |  |                 stream=True, | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             r.raise_for_status() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-18 11:19:44 +08:00
										 |  |  |             # r.close() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-06 09:16:35 +08:00
										 |  |  |             return StreamingResponse( | 
					
						
							| 
									
										
										
										
											2024-01-18 11:19:44 +08:00
										 |  |  |                 stream_content(), | 
					
						
							| 
									
										
										
										
											2024-01-06 09:16:35 +08:00
										 |  |  |                 status_code=r.status_code, | 
					
						
							|  |  |  |                 headers=dict(r.headers), | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             raise e | 
					
						
							| 
									
										
										
										
											2023-12-14 09:37:29 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-06 09:16:35 +08:00
										 |  |  |     try: | 
					
						
							|  |  |  |         return await run_in_threadpool(get_request) | 
					
						
							| 
									
										
										
										
											2023-12-14 09:37:29 +08:00
										 |  |  |     except Exception as e: | 
					
						
							| 
									
										
										
										
											2024-02-17 15:30:38 +08:00
										 |  |  |         error_detail = "Open WebUI: Server Connection Error" | 
					
						
							| 
									
										
										
										
											2024-01-06 09:16:35 +08:00
										 |  |  |         if r is not None: | 
					
						
							| 
									
										
										
										
											2024-01-05 05:06:31 +08:00
										 |  |  |             try: | 
					
						
							| 
									
										
										
										
											2024-01-06 09:16:35 +08:00
										 |  |  |                 res = r.json() | 
					
						
							| 
									
										
										
										
											2024-01-05 05:06:31 +08:00
										 |  |  |                 if "error" in res: | 
					
						
							|  |  |  |                     error_detail = f"Ollama: {res['error']}" | 
					
						
							|  |  |  |             except: | 
					
						
							|  |  |  |                 error_detail = f"Ollama: {e}" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-05 17:25:34 +08:00
										 |  |  |         raise HTTPException( | 
					
						
							| 
									
										
										
										
											2024-01-06 09:16:35 +08:00
										 |  |  |             status_code=r.status_code if r else 500, | 
					
						
							| 
									
										
										
										
											2024-01-05 17:25:34 +08:00
										 |  |  |             detail=error_detail, | 
					
						
							|  |  |  |         ) |