2024-10-16 22:32:57 +08:00
import asyncio
2024-08-28 06:10:27 +08:00
import inspect
2024-02-23 16:30:26 +08:00
import json
2024-03-21 07:11:36 +08:00
import logging
2024-05-22 06:04:00 +08:00
import mimetypes
2024-08-28 06:10:27 +08:00
import os
2024-06-06 04:57:48 +08:00
import shutil
2024-08-28 06:10:27 +08:00
import sys
import time
2024-10-22 18:16:48 +08:00
import random
2024-12-12 10:36:59 +08:00
2024-08-28 06:10:27 +08:00
from contextlib import asynccontextmanager
2024-12-10 16:54:13 +08:00
from urllib . parse import urlencode , parse_qs , urlparse
from pydantic import BaseModel
from sqlalchemy import text
2024-02-23 16:30:26 +08:00
2024-12-10 16:54:13 +08:00
from typing import Optional
2024-11-16 20:41:07 +08:00
from aiocache import cached
2024-08-28 06:10:27 +08:00
import aiohttp
import requests
2024-10-16 22:58:03 +08:00
from fastapi import (
Depends ,
FastAPI ,
File ,
Form ,
HTTPException ,
Request ,
UploadFile ,
status ,
)
from fastapi . middleware . cors import CORSMiddleware
2024-10-22 18:16:48 +08:00
from fastapi . responses import JSONResponse , RedirectResponse
2024-10-16 22:58:03 +08:00
from fastapi . staticfiles import StaticFiles
2024-12-10 16:54:13 +08:00
2024-10-16 22:58:03 +08:00
from starlette . exceptions import HTTPException as StarletteHTTPException
from starlette . middleware . base import BaseHTTPMiddleware
from starlette . middleware . sessions import SessionMiddleware
from starlette . responses import Response , StreamingResponse
2024-09-04 22:54:48 +08:00
2024-12-10 16:54:13 +08:00
2024-12-12 10:46:29 +08:00
from open_webui . socket . main import (
app as socket_app ,
periodic_usage_pool_cleanup ,
get_event_call ,
get_event_emitter ,
)
2024-12-10 16:54:13 +08:00
from open_webui . routers import (
audio ,
images ,
ollama ,
openai ,
retrieval ,
pipelines ,
tasks ,
2024-12-11 18:41:25 +08:00
auths ,
chats ,
folders ,
configs ,
groups ,
files ,
functions ,
memories ,
models ,
knowledge ,
prompts ,
evaluations ,
tools ,
users ,
utils ,
2024-09-04 22:54:48 +08:00
)
2024-12-12 12:39:55 +08:00
2024-12-12 10:08:55 +08:00
from open_webui . routers . retrieval import (
get_embedding_function ,
2024-12-12 10:46:29 +08:00
get_ef ,
get_rf ,
2024-12-12 10:08:55 +08:00
)
2024-12-12 11:52:46 +08:00
from open_webui . routers . pipelines import (
process_pipeline_inlet_filter ,
)
2024-12-12 10:46:29 +08:00
from open_webui . retrieval . utils import get_sources_from_files
2024-11-25 10:49:56 +08:00
2024-12-10 16:54:13 +08:00
from open_webui . internal . db import Session
from open_webui . models . functions import Functions
from open_webui . models . models import Models
from open_webui . models . users import UserModel , Users
from open_webui . constants import TASKS
2024-09-04 22:54:48 +08:00
from open_webui . config import (
2024-12-10 16:54:13 +08:00
# Ollama
2024-08-28 06:10:27 +08:00
ENABLE_OLLAMA_API ,
2024-12-10 16:54:13 +08:00
OLLAMA_BASE_URLS ,
OLLAMA_API_CONFIGS ,
# OpenAI
2024-08-28 06:10:27 +08:00
ENABLE_OPENAI_API ,
2024-12-10 16:54:13 +08:00
OPENAI_API_BASE_URLS ,
OPENAI_API_KEYS ,
OPENAI_API_CONFIGS ,
# Image
AUTOMATIC1111_API_AUTH ,
AUTOMATIC1111_BASE_URL ,
AUTOMATIC1111_CFG_SCALE ,
AUTOMATIC1111_SAMPLER ,
AUTOMATIC1111_SCHEDULER ,
COMFYUI_BASE_URL ,
COMFYUI_WORKFLOW ,
COMFYUI_WORKFLOW_NODES ,
ENABLE_IMAGE_GENERATION ,
IMAGE_GENERATION_ENGINE ,
IMAGE_GENERATION_MODEL ,
IMAGE_SIZE ,
IMAGE_STEPS ,
IMAGES_OPENAI_API_BASE_URL ,
IMAGES_OPENAI_API_KEY ,
# Audio
AUDIO_STT_ENGINE ,
AUDIO_STT_MODEL ,
AUDIO_STT_OPENAI_API_BASE_URL ,
AUDIO_STT_OPENAI_API_KEY ,
AUDIO_TTS_API_KEY ,
AUDIO_TTS_ENGINE ,
AUDIO_TTS_MODEL ,
AUDIO_TTS_OPENAI_API_BASE_URL ,
AUDIO_TTS_OPENAI_API_KEY ,
AUDIO_TTS_SPLIT_ON ,
AUDIO_TTS_VOICE ,
AUDIO_TTS_AZURE_SPEECH_REGION ,
AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT ,
WHISPER_MODEL ,
WHISPER_MODEL_AUTO_UPDATE ,
WHISPER_MODEL_DIR ,
2024-12-11 18:41:25 +08:00
# Retrieval
RAG_TEMPLATE ,
DEFAULT_RAG_TEMPLATE ,
RAG_EMBEDDING_MODEL ,
RAG_EMBEDDING_MODEL_AUTO_UPDATE ,
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE ,
RAG_RERANKING_MODEL ,
RAG_RERANKING_MODEL_AUTO_UPDATE ,
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE ,
RAG_EMBEDDING_ENGINE ,
RAG_EMBEDDING_BATCH_SIZE ,
RAG_RELEVANCE_THRESHOLD ,
RAG_FILE_MAX_COUNT ,
RAG_FILE_MAX_SIZE ,
RAG_OPENAI_API_BASE_URL ,
RAG_OPENAI_API_KEY ,
RAG_OLLAMA_BASE_URL ,
RAG_OLLAMA_API_KEY ,
CHUNK_OVERLAP ,
CHUNK_SIZE ,
CONTENT_EXTRACTION_ENGINE ,
TIKA_SERVER_URL ,
RAG_TOP_K ,
RAG_TEXT_SPLITTER ,
TIKTOKEN_ENCODING_NAME ,
PDF_EXTRACT_IMAGES ,
YOUTUBE_LOADER_LANGUAGE ,
YOUTUBE_LOADER_PROXY_URL ,
# Retrieval (Web Search)
RAG_WEB_SEARCH_ENGINE ,
RAG_WEB_SEARCH_RESULT_COUNT ,
RAG_WEB_SEARCH_CONCURRENT_REQUESTS ,
RAG_WEB_SEARCH_DOMAIN_FILTER_LIST ,
JINA_API_KEY ,
SEARCHAPI_API_KEY ,
SEARCHAPI_ENGINE ,
SEARXNG_QUERY_URL ,
SERPER_API_KEY ,
SERPLY_API_KEY ,
SERPSTACK_API_KEY ,
SERPSTACK_HTTPS ,
TAVILY_API_KEY ,
BING_SEARCH_V7_ENDPOINT ,
BING_SEARCH_V7_SUBSCRIPTION_KEY ,
BRAVE_SEARCH_API_KEY ,
KAGI_SEARCH_API_KEY ,
MOJEEK_SEARCH_API_KEY ,
GOOGLE_PSE_API_KEY ,
GOOGLE_PSE_ENGINE_ID ,
ENABLE_RAG_HYBRID_SEARCH ,
ENABLE_RAG_LOCAL_WEB_FETCH ,
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION ,
ENABLE_RAG_WEB_SEARCH ,
UPLOAD_DIR ,
2024-12-10 16:54:13 +08:00
# WebUI
WEBUI_AUTH ,
WEBUI_NAME ,
WEBUI_BANNERS ,
WEBHOOK_URL ,
ADMIN_EMAIL ,
SHOW_ADMIN_DETAILS ,
JWT_EXPIRES_IN ,
ENABLE_SIGNUP ,
ENABLE_LOGIN_FORM ,
ENABLE_API_KEY ,
ENABLE_COMMUNITY_SHARING ,
ENABLE_MESSAGE_RATING ,
ENABLE_EVALUATION_ARENA_MODELS ,
USER_PERMISSIONS ,
DEFAULT_USER_ROLE ,
DEFAULT_PROMPT_SUGGESTIONS ,
DEFAULT_MODELS ,
DEFAULT_ARENA_MODEL ,
MODEL_ORDER_LIST ,
EVALUATION_ARENA_MODELS ,
# WebUI (OAuth)
ENABLE_OAUTH_ROLE_MANAGEMENT ,
OAUTH_ROLES_CLAIM ,
OAUTH_EMAIL_CLAIM ,
OAUTH_PICTURE_CLAIM ,
OAUTH_USERNAME_CLAIM ,
OAUTH_ALLOWED_ROLES ,
OAUTH_ADMIN_ROLES ,
# WebUI (LDAP)
ENABLE_LDAP ,
LDAP_SERVER_LABEL ,
LDAP_SERVER_HOST ,
LDAP_SERVER_PORT ,
LDAP_ATTRIBUTE_FOR_USERNAME ,
LDAP_SEARCH_FILTERS ,
LDAP_SEARCH_BASE ,
LDAP_APP_DN ,
LDAP_APP_PASSWORD ,
LDAP_USE_TLS ,
LDAP_CA_CERT_FILE ,
LDAP_CIPHERS ,
# Misc
2024-08-28 06:10:27 +08:00
ENV ,
2024-12-10 16:54:13 +08:00
CACHE_DIR ,
STATIC_DIR ,
2024-08-28 06:10:27 +08:00
FRONTEND_BUILD_DIR ,
2024-12-10 16:54:13 +08:00
CORS_ALLOW_ORIGIN ,
DEFAULT_LOCALE ,
2024-08-28 06:10:27 +08:00
OAUTH_PROVIDERS ,
2024-12-10 16:54:13 +08:00
# Admin
ENABLE_ADMIN_CHAT_ACCESS ,
ENABLE_ADMIN_EXPORT ,
# Tasks
2024-06-10 05:53:10 +08:00
TASK_MODEL ,
TASK_MODEL_EXTERNAL ,
2024-12-10 16:54:13 +08:00
ENABLE_TAGS_GENERATION ,
2024-11-19 18:24:32 +08:00
ENABLE_SEARCH_QUERY_GENERATION ,
ENABLE_RETRIEVAL_QUERY_GENERATION ,
2024-12-10 16:54:13 +08:00
ENABLE_AUTOCOMPLETE_GENERATION ,
2024-06-10 05:25:31 +08:00
TITLE_GENERATION_PROMPT_TEMPLATE ,
2024-10-20 12:27:10 +08:00
TAGS_GENERATION_PROMPT_TEMPLATE ,
2024-06-11 14:40:27 +08:00
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE ,
2024-12-10 16:54:13 +08:00
QUERY_GENERATION_PROMPT_TEMPLATE ,
AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE ,
AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH ,
2024-08-28 06:10:27 +08:00
AppConfig ,
2024-09-25 07:06:11 +08:00
reset_config ,
2024-08-28 06:10:27 +08:00
)
2024-09-04 22:54:48 +08:00
from open_webui . env import (
2024-08-28 06:10:27 +08:00
CHANGELOG ,
GLOBAL_LOG_LEVEL ,
2024-06-24 10:28:33 +08:00
SAFE_MODE ,
2024-08-28 06:10:27 +08:00
SRC_LOG_LEVELS ,
VERSION ,
2024-12-10 16:54:13 +08:00
WEBUI_URL ,
2024-08-28 06:10:27 +08:00
WEBUI_BUILD_HASH ,
2024-05-28 01:07:38 +08:00
WEBUI_SECRET_KEY ,
2024-06-06 02:21:42 +08:00
WEBUI_SESSION_COOKIE_SAME_SITE ,
2024-06-07 16:13:42 +08:00
WEBUI_SESSION_COOKIE_SECURE ,
2024-12-10 16:54:13 +08:00
WEBUI_AUTH_TRUSTED_EMAIL_HEADER ,
WEBUI_AUTH_TRUSTED_NAME_HEADER ,
2024-12-02 10:25:44 +08:00
BYPASS_MODEL_ACCESS_CONTROL ,
2024-09-25 07:06:11 +08:00
RESET_CONFIG_ON_START ,
2024-10-08 13:13:49 +08:00
OFFLINE_MODE ,
2024-08-28 06:10:27 +08:00
)
2024-12-10 16:54:13 +08:00
2024-12-13 12:22:17 +08:00
from open_webui . utils . models import get_all_models , get_all_base_models
from open_webui . utils . chat import (
generate_chat_completion as chat_completion_handler ,
chat_completed as chat_completed_handler ,
chat_action as chat_action_handler ,
)
2024-12-12 10:46:29 +08:00
from open_webui . utils . plugin import load_function_module_by_id
2024-09-04 22:54:48 +08:00
from open_webui . utils . misc import (
2024-08-28 06:10:27 +08:00
add_or_update_system_message ,
get_last_user_message ,
prepend_to_first_user_message_content ,
2024-12-12 10:36:59 +08:00
openai_chat_chunk_message_template ,
openai_chat_completion_message_template ,
)
2024-12-13 12:22:17 +08:00
2024-12-10 16:54:13 +08:00
2024-10-16 22:32:57 +08:00
from open_webui . utils . payload import convert_payload_openai_to_ollama
from open_webui . utils . response import (
convert_response_ollama_to_openai ,
convert_streaming_response_ollama_to_openai ,
)
2024-12-10 16:54:13 +08:00
2024-09-04 22:54:48 +08:00
from open_webui . utils . task import (
2024-12-12 11:52:46 +08:00
get_task_model_id ,
2024-11-25 10:49:56 +08:00
rag_template ,
2024-08-28 06:10:27 +08:00
tools_function_calling_generation_template ,
)
2024-09-04 22:54:48 +08:00
from open_webui . utils . tools import get_tools
2024-12-10 16:54:13 +08:00
from open_webui . utils . access_control import has_access
2024-12-09 08:01:56 +08:00
from open_webui . utils . auth import (
2024-08-28 06:10:27 +08:00
decode_token ,
get_admin_user ,
get_current_user ,
get_http_authorization_cred ,
get_verified_user ,
2024-03-10 13:47:01 +08:00
)
2024-12-10 16:54:13 +08:00
from open_webui . utils . oauth import oauth_manager
from open_webui . utils . security_headers import SecurityHeadersMiddleware
2024-09-21 06:30:13 +08:00
2024-06-24 10:28:33 +08:00
if SAFE_MODE :
print ( " SAFE MODE ENABLED " )
Functions . deactivate_all_functions ( )
2024-03-21 07:11:36 +08:00
logging . basicConfig ( stream = sys . stdout , level = GLOBAL_LOG_LEVEL )
log = logging . getLogger ( __name__ )
log . setLevel ( SRC_LOG_LEVELS [ " MAIN " ] )
2023-11-15 08:28:51 +08:00
2024-03-28 17:45:56 +08:00
2023-11-15 08:28:51 +08:00
class SPAStaticFiles ( StaticFiles ) :
async def get_response ( self , path : str , scope ) :
try :
return await super ( ) . get_response ( path , scope )
except ( HTTPException , StarletteHTTPException ) as ex :
if ex . status_code == 404 :
return await super ( ) . get_response ( " index.html " , scope )
else :
raise ex
2024-04-02 18:03:55 +08:00
print (
2024-05-04 05:23:38 +08:00
rf """
2024-10-08 13:13:49 +08:00
___ __ __ _ _ _ ___
2024-04-02 18:03:55 +08:00
/ _ \ _ __ ___ _ __ \ \ / / __ | | __ | | | | _ _ |
2024-10-08 13:13:49 +08:00
| | | | ' _ \ / _ \ ' _ \ \ \ / \ / / _ \ ' _ \ | | | || |
| | _ | | | _ ) | __ / | | | \ V V / __ / | _ ) | | _ | | | |
2024-04-02 18:03:55 +08:00
\___ / | . __ / \___ | _ | | _ | \_ / \_ / \___ | _ . __ / \___ / | ___ |
2024-10-08 13:13:49 +08:00
| _ |
2024-04-02 18:03:55 +08:00
2024-05-23 03:22:38 +08:00
v { VERSION } - building the best open - source AI user interface .
2024-05-26 15:49:30 +08:00
{ f " Commit: { WEBUI_BUILD_HASH } " if WEBUI_BUILD_HASH != " dev-build " else " " }
2024-04-02 18:03:55 +08:00
https : / / github . com / open - webui / open - webui
"""
)
2023-11-15 08:28:51 +08:00
2024-05-09 12:00:03 +08:00
@asynccontextmanager
async def lifespan ( app : FastAPI ) :
2024-09-25 07:06:11 +08:00
if RESET_CONFIG_ON_START :
reset_config ( )
2024-09-24 23:43:43 +08:00
asyncio . create_task ( periodic_usage_pool_cleanup ( ) )
2024-05-09 12:00:03 +08:00
yield
app = FastAPI (
2024-11-13 19:09:46 +08:00
docs_url = " /docs " if ENV == " dev " else None ,
openapi_url = " /openapi.json " if ENV == " dev " else None ,
redoc_url = None ,
lifespan = lifespan ,
2024-05-09 12:00:03 +08:00
)
2023-11-15 08:28:51 +08:00
2024-05-10 15:03:24 +08:00
app . state . config = AppConfig ( )
2024-05-24 16:40:48 +08:00
2024-12-10 16:54:13 +08:00
########################################
#
# OLLAMA
#
########################################
2024-05-24 16:40:48 +08:00
app . state . config . ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
2024-12-10 16:54:13 +08:00
app . state . config . OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
app . state . config . OLLAMA_API_CONFIGS = OLLAMA_API_CONFIGS
2024-12-11 19:38:45 +08:00
app . state . OLLAMA_MODELS = { }
2024-12-10 16:54:13 +08:00
########################################
#
# OPENAI
#
########################################
app . state . config . ENABLE_OPENAI_API = ENABLE_OPENAI_API
app . state . config . OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
app . state . config . OPENAI_API_KEYS = OPENAI_API_KEYS
app . state . config . OPENAI_API_CONFIGS = OPENAI_API_CONFIGS
2024-12-11 19:38:45 +08:00
app . state . OPENAI_MODELS = { }
2024-12-10 16:54:13 +08:00
########################################
#
# WEBUI
#
########################################
app . state . config . ENABLE_SIGNUP = ENABLE_SIGNUP
app . state . config . ENABLE_LOGIN_FORM = ENABLE_LOGIN_FORM
app . state . config . ENABLE_API_KEY = ENABLE_API_KEY
2024-05-24 16:40:48 +08:00
2024-12-10 16:54:13 +08:00
app . state . config . JWT_EXPIRES_IN = JWT_EXPIRES_IN
app . state . config . SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS
app . state . config . ADMIN_EMAIL = ADMIN_EMAIL
app . state . config . DEFAULT_MODELS = DEFAULT_MODELS
app . state . config . DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
app . state . config . DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
app . state . config . USER_PERMISSIONS = USER_PERMISSIONS
2024-05-10 15:03:24 +08:00
app . state . config . WEBHOOK_URL = WEBHOOK_URL
2024-12-10 16:54:13 +08:00
app . state . config . BANNERS = WEBUI_BANNERS
app . state . config . MODEL_ORDER_LIST = MODEL_ORDER_LIST
app . state . config . ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
app . state . config . ENABLE_MESSAGE_RATING = ENABLE_MESSAGE_RATING
app . state . config . ENABLE_EVALUATION_ARENA_MODELS = ENABLE_EVALUATION_ARENA_MODELS
app . state . config . EVALUATION_ARENA_MODELS = EVALUATION_ARENA_MODELS
app . state . config . OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
app . state . config . OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
app . state . config . OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM
app . state . config . ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT
app . state . config . OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM
app . state . config . OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES
app . state . config . OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES
app . state . config . ENABLE_LDAP = ENABLE_LDAP
app . state . config . LDAP_SERVER_LABEL = LDAP_SERVER_LABEL
app . state . config . LDAP_SERVER_HOST = LDAP_SERVER_HOST
app . state . config . LDAP_SERVER_PORT = LDAP_SERVER_PORT
app . state . config . LDAP_ATTRIBUTE_FOR_USERNAME = LDAP_ATTRIBUTE_FOR_USERNAME
app . state . config . LDAP_APP_DN = LDAP_APP_DN
app . state . config . LDAP_APP_PASSWORD = LDAP_APP_PASSWORD
app . state . config . LDAP_SEARCH_BASE = LDAP_SEARCH_BASE
app . state . config . LDAP_SEARCH_FILTERS = LDAP_SEARCH_FILTERS
app . state . config . LDAP_USE_TLS = LDAP_USE_TLS
app . state . config . LDAP_CA_CERT_FILE = LDAP_CA_CERT_FILE
app . state . config . LDAP_CIPHERS = LDAP_CIPHERS
app . state . AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
app . state . AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER
app . state . TOOLS = { }
app . state . FUNCTIONS = { }
########################################
#
# RETRIEVAL
#
########################################
2024-12-11 18:41:25 +08:00
app . state . config . TOP_K = RAG_TOP_K
app . state . config . RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
app . state . config . FILE_MAX_SIZE = RAG_FILE_MAX_SIZE
app . state . config . FILE_MAX_COUNT = RAG_FILE_MAX_COUNT
app . state . config . ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
app . state . config . ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
app . state . config . CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
app . state . config . TIKA_SERVER_URL = TIKA_SERVER_URL
app . state . config . TEXT_SPLITTER = RAG_TEXT_SPLITTER
app . state . config . TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME
app . state . config . CHUNK_SIZE = CHUNK_SIZE
app . state . config . CHUNK_OVERLAP = CHUNK_OVERLAP
app . state . config . RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
app . state . config . RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app . state . config . RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE
app . state . config . RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
app . state . config . RAG_TEMPLATE = RAG_TEMPLATE
app . state . config . RAG_OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
app . state . config . RAG_OPENAI_API_KEY = RAG_OPENAI_API_KEY
app . state . config . RAG_OLLAMA_BASE_URL = RAG_OLLAMA_BASE_URL
app . state . config . RAG_OLLAMA_API_KEY = RAG_OLLAMA_API_KEY
app . state . config . PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
app . state . config . YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
app . state . config . YOUTUBE_LOADER_PROXY_URL = YOUTUBE_LOADER_PROXY_URL
app . state . config . ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
app . state . config . RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE
app . state . config . RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = RAG_WEB_SEARCH_DOMAIN_FILTER_LIST
app . state . config . SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
app . state . config . GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
app . state . config . GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID
app . state . config . BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY
app . state . config . KAGI_SEARCH_API_KEY = KAGI_SEARCH_API_KEY
app . state . config . MOJEEK_SEARCH_API_KEY = MOJEEK_SEARCH_API_KEY
app . state . config . SERPSTACK_API_KEY = SERPSTACK_API_KEY
app . state . config . SERPSTACK_HTTPS = SERPSTACK_HTTPS
app . state . config . SERPER_API_KEY = SERPER_API_KEY
app . state . config . SERPLY_API_KEY = SERPLY_API_KEY
app . state . config . TAVILY_API_KEY = TAVILY_API_KEY
app . state . config . SEARCHAPI_API_KEY = SEARCHAPI_API_KEY
app . state . config . SEARCHAPI_ENGINE = SEARCHAPI_ENGINE
app . state . config . JINA_API_KEY = JINA_API_KEY
app . state . config . BING_SEARCH_V7_ENDPOINT = BING_SEARCH_V7_ENDPOINT
app . state . config . BING_SEARCH_V7_SUBSCRIPTION_KEY = BING_SEARCH_V7_SUBSCRIPTION_KEY
app . state . config . RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
app . state . config . RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
2024-12-12 10:05:42 +08:00
app . state . EMBEDDING_FUNCTION = None
2024-12-12 10:46:29 +08:00
app . state . ef = None
app . state . rf = None
2024-12-11 18:41:25 +08:00
app . state . YOUTUBE_LOADER_TRANSLATION = None
2024-12-12 10:05:42 +08:00
2024-12-11 18:41:25 +08:00
2024-12-12 10:08:55 +08:00
app . state . EMBEDDING_FUNCTION = get_embedding_function (
app . state . config . RAG_EMBEDDING_ENGINE ,
app . state . config . RAG_EMBEDDING_MODEL ,
2024-12-12 10:46:29 +08:00
app . state . ef ,
2024-12-12 10:08:55 +08:00
(
2024-12-12 10:46:29 +08:00
app . state . config . RAG_OPENAI_API_BASE_URL
2024-12-12 10:08:55 +08:00
if app . state . config . RAG_EMBEDDING_ENGINE == " openai "
2024-12-12 10:46:29 +08:00
else app . state . config . RAG_OLLAMA_BASE_URL
2024-12-12 10:08:55 +08:00
) ,
(
2024-12-12 10:46:29 +08:00
app . state . config . RAG_OPENAI_API_KEY
2024-12-12 10:08:55 +08:00
if app . state . config . RAG_EMBEDDING_ENGINE == " openai "
2024-12-12 10:46:29 +08:00
else app . state . config . RAG_OLLAMA_API_KEY
2024-12-12 10:08:55 +08:00
) ,
app . state . config . RAG_EMBEDDING_BATCH_SIZE ,
)
2024-12-12 10:46:29 +08:00
try :
app . state . ef = get_ef (
app . state . config . RAG_EMBEDDING_ENGINE ,
app . state . config . RAG_EMBEDDING_MODEL ,
RAG_EMBEDDING_MODEL_AUTO_UPDATE ,
)
2024-12-12 10:08:55 +08:00
2024-12-12 10:46:29 +08:00
app . state . rf = get_rf (
app . state . config . RAG_RERANKING_MODEL ,
RAG_RERANKING_MODEL_AUTO_UPDATE ,
)
except Exception as e :
log . error ( f " Error updating models: { e } " )
pass
2024-12-12 10:08:55 +08:00
2024-12-10 16:54:13 +08:00
########################################
#
# IMAGES
#
########################################
app . state . config . IMAGE_GENERATION_ENGINE = IMAGE_GENERATION_ENGINE
app . state . config . ENABLE_IMAGE_GENERATION = ENABLE_IMAGE_GENERATION
app . state . config . IMAGES_OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
app . state . config . IMAGES_OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
app . state . config . IMAGE_GENERATION_MODEL = IMAGE_GENERATION_MODEL
app . state . config . AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
app . state . config . AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
app . state . config . AUTOMATIC1111_CFG_SCALE = AUTOMATIC1111_CFG_SCALE
app . state . config . AUTOMATIC1111_SAMPLER = AUTOMATIC1111_SAMPLER
app . state . config . AUTOMATIC1111_SCHEDULER = AUTOMATIC1111_SCHEDULER
app . state . config . COMFYUI_BASE_URL = COMFYUI_BASE_URL
app . state . config . COMFYUI_WORKFLOW = COMFYUI_WORKFLOW
app . state . config . COMFYUI_WORKFLOW_NODES = COMFYUI_WORKFLOW_NODES
app . state . config . IMAGE_SIZE = IMAGE_SIZE
app . state . config . IMAGE_STEPS = IMAGE_STEPS
########################################
#
# AUDIO
#
########################################
app . state . config . STT_OPENAI_API_BASE_URL = AUDIO_STT_OPENAI_API_BASE_URL
app . state . config . STT_OPENAI_API_KEY = AUDIO_STT_OPENAI_API_KEY
app . state . config . STT_ENGINE = AUDIO_STT_ENGINE
app . state . config . STT_MODEL = AUDIO_STT_MODEL
app . state . config . WHISPER_MODEL = WHISPER_MODEL
app . state . config . TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL
app . state . config . TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY
app . state . config . TTS_ENGINE = AUDIO_TTS_ENGINE
app . state . config . TTS_MODEL = AUDIO_TTS_MODEL
app . state . config . TTS_VOICE = AUDIO_TTS_VOICE
app . state . config . TTS_API_KEY = AUDIO_TTS_API_KEY
app . state . config . TTS_SPLIT_ON = AUDIO_TTS_SPLIT_ON
app . state . config . TTS_AZURE_SPEECH_REGION = AUDIO_TTS_AZURE_SPEECH_REGION
app . state . config . TTS_AZURE_SPEECH_OUTPUT_FORMAT = AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT
app . state . faster_whisper_model = None
app . state . speech_synthesiser = None
app . state . speech_speaker_embeddings_dataset = None
########################################
#
# TASKS
#
########################################
2024-06-10 05:53:10 +08:00
app . state . config . TASK_MODEL = TASK_MODEL
app . state . config . TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL
2024-11-16 20:41:07 +08:00
2024-12-10 16:54:13 +08:00
app . state . config . ENABLE_SEARCH_QUERY_GENERATION = ENABLE_SEARCH_QUERY_GENERATION
app . state . config . ENABLE_RETRIEVAL_QUERY_GENERATION = ENABLE_RETRIEVAL_QUERY_GENERATION
2024-12-01 10:30:59 +08:00
app . state . config . ENABLE_AUTOCOMPLETE_GENERATION = ENABLE_AUTOCOMPLETE_GENERATION
2024-11-06 10:32:08 +08:00
app . state . config . ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION
2024-11-16 20:41:07 +08:00
2024-12-10 16:54:13 +08:00
app . state . config . TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
app . state . config . TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE
app . state . config . TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
)
2024-11-19 18:24:32 +08:00
app . state . config . QUERY_GENERATION_PROMPT_TEMPLATE = QUERY_GENERATION_PROMPT_TEMPLATE
2024-11-29 15:53:52 +08:00
app . state . config . AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE = (
AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
)
2024-12-10 16:54:13 +08:00
app . state . config . AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = (
AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH
2024-06-11 14:40:27 +08:00
)
2024-05-25 09:26:36 +08:00
2024-12-10 16:54:13 +08:00
2024-12-11 19:38:45 +08:00
########################################
#
# WEBUI
#
########################################
app . state . MODELS = { }
2024-06-20 16:51:39 +08:00
##################################
#
# ChatCompletion Middleware
#
##################################
2024-12-12 11:52:46 +08:00
async def chat_completion_filter_functions_handler ( body , model , extra_params ) :
skip_files = None
2024-07-02 10:37:54 +08:00
2024-12-12 11:52:46 +08:00
def get_filter_function_ids ( model ) :
def get_priority ( function_id ) :
function = Functions . get_function_by_id ( function_id )
if function is not None and hasattr ( function , " valves " ) :
# TODO: Fix FunctionModel
return ( function . valves if function . valves else { } ) . get ( " priority " , 0 )
return 0
2024-07-02 10:37:54 +08:00
2024-12-12 11:52:46 +08:00
filter_ids = [
function . id for function in Functions . get_global_filter_functions ( )
]
if " info " in model and " meta " in model [ " info " ] :
filter_ids . extend ( model [ " info " ] [ " meta " ] . get ( " filterIds " , [ ] ) )
filter_ids = list ( set ( filter_ids ) )
2024-07-02 10:37:54 +08:00
2024-12-12 11:52:46 +08:00
enabled_filter_ids = [
function . id
for function in Functions . get_functions_by_type ( " filter " , active_only = True )
]
2024-07-02 10:37:54 +08:00
2024-12-12 11:52:46 +08:00
filter_ids = [
filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
]
2024-07-02 10:37:54 +08:00
2024-12-12 11:52:46 +08:00
filter_ids . sort ( key = get_priority )
return filter_ids
2024-07-02 10:33:58 +08:00
filter_ids = get_filter_function_ids ( model )
for filter_id in filter_ids :
filter = Functions . get_function_by_id ( filter_id )
2024-07-09 18:51:43 +08:00
if not filter :
continue
2024-06-28 04:04:12 +08:00
2024-12-12 10:53:38 +08:00
if filter_id in app . state . FUNCTIONS :
function_module = app . state . FUNCTIONS [ filter_id ]
2024-07-09 18:51:43 +08:00
else :
function_module , _ , _ = load_function_module_by_id ( filter_id )
2024-12-12 10:53:38 +08:00
app . state . FUNCTIONS [ filter_id ] = function_module
2024-07-02 10:33:58 +08:00
2024-07-09 18:51:43 +08:00
# Check if the function has a file_handler variable
if hasattr ( function_module , " file_handler " ) :
skip_files = function_module . file_handler
2024-07-02 10:33:58 +08:00
2024-07-09 18:51:43 +08:00
if hasattr ( function_module , " valves " ) and hasattr ( function_module , " Valves " ) :
valves = Functions . get_function_valves_by_id ( filter_id )
function_module . valves = function_module . Valves (
* * ( valves if valves else { } )
)
2024-07-09 19:15:09 +08:00
if not hasattr ( function_module , " inlet " ) :
continue
2024-07-09 18:51:43 +08:00
try :
2024-07-09 19:15:09 +08:00
inlet = function_module . inlet
# Get the signature of the function
sig = inspect . signature ( inlet )
2024-08-22 22:02:29 +08:00
params = { " body " : body } | {
k : v
for k , v in {
* * extra_params ,
" __model__ " : model ,
" __id__ " : filter_id ,
} . items ( )
if k in sig . parameters
}
if " __user__ " in params and hasattr ( function_module , " UserValves " ) :
2024-08-11 16:07:12 +08:00
try :
2024-08-22 22:02:29 +08:00
params [ " __user__ " ] [ " valves " ] = function_module . UserValves (
* * Functions . get_user_valves_by_id_and_user_id (
filter_id , params [ " __user__ " ] [ " id " ]
)
2024-08-11 16:07:12 +08:00
)
except Exception as e :
print ( e )
2024-07-09 19:15:09 +08:00
if inspect . iscoroutinefunction ( inlet ) :
body = await inlet ( * * params )
else :
body = inlet ( * * params )
2024-07-09 18:51:43 +08:00
except Exception as e :
print ( f " Error: { e } " )
raise e
2024-06-20 17:30:00 +08:00
2024-08-20 22:41:49 +08:00
if skip_files and " files " in body . get ( " metadata " , { } ) :
del body [ " metadata " ] [ " files " ]
2024-06-11 14:40:27 +08:00
2024-07-02 10:33:58 +08:00
return body , { }
2024-06-11 16:10:24 +08:00
2024-06-20 17:06:10 +08:00
2024-12-13 12:22:17 +08:00
async def chat_completion_tools_handler (
request : Request , body : dict , user : UserModel , models , extra_params : dict
) - > tuple [ dict , dict ] :
async def get_content_from_response ( response ) - > Optional [ str ] :
content = None
if hasattr ( response , " body_iterator " ) :
async for chunk in response . body_iterator :
data = json . loads ( chunk . decode ( " utf-8 " ) )
content = data [ " choices " ] [ 0 ] [ " message " ] [ " content " ]
# Cleanup any remaining background tasks if necessary
if response . background is not None :
await response . background ( )
else :
content = response [ " choices " ] [ 0 ] [ " message " ] [ " content " ]
return content
def get_tools_function_calling_payload ( messages , task_model_id , content ) :
user_message = get_last_user_message ( messages )
history = " \n " . join (
f " { message [ ' role ' ] . upper ( ) } : \" \" \" { message [ ' content ' ] } \" \" \" "
for message in messages [ : : - 1 ] [ : 4 ]
)
2024-08-17 22:41:34 +08:00
2024-12-13 12:22:17 +08:00
prompt = f " History: \n { history } \n Query: { user_message } "
2024-08-17 22:24:11 +08:00
2024-12-13 12:22:17 +08:00
return {
" model " : task_model_id ,
" messages " : [
{ " role " : " system " , " content " : content } ,
{ " role " : " user " , " content " : f " Query: { prompt } " } ,
] ,
" stream " : False ,
" metadata " : { " task " : str ( TASKS . FUNCTION_CALLING ) } ,
}
2024-08-17 22:24:11 +08:00
2024-08-19 18:11:00 +08:00
# If tool_ids field is present, call the functions
2024-08-20 22:41:49 +08:00
metadata = body . get ( " metadata " , { } )
2024-08-22 22:02:29 +08:00
2024-08-20 22:41:49 +08:00
tool_ids = metadata . get ( " tool_ids " , None )
2024-08-22 22:02:29 +08:00
log . debug ( f " { tool_ids =} " )
2024-08-19 18:11:00 +08:00
if not tool_ids :
return body , { }
2024-08-12 21:48:57 +08:00
skip_files = False
2024-11-22 11:46:09 +08:00
sources = [ ]
2024-07-02 10:33:58 +08:00
2024-11-16 20:41:07 +08:00
task_model_id = get_task_model_id (
body [ " model " ] ,
2024-12-13 12:22:17 +08:00
request . app . state . config . TASK_MODEL ,
request . app . state . config . TASK_MODEL_EXTERNAL ,
2024-11-16 20:41:07 +08:00
models ,
)
2024-08-22 22:02:29 +08:00
tools = get_tools (
2024-12-13 12:22:17 +08:00
request ,
2024-08-22 22:02:29 +08:00
tool_ids ,
user ,
{
* * extra_params ,
2024-11-16 20:41:07 +08:00
" __model__ " : models [ task_model_id ] ,
2024-08-22 22:02:29 +08:00
" __messages__ " : body [ " messages " ] ,
" __files__ " : metadata . get ( " files " , [ ] ) ,
} ,
)
2024-08-17 22:24:11 +08:00
log . info ( f " { tools =} " )
2024-08-12 21:48:57 +08:00
2024-08-17 22:24:11 +08:00
specs = [ tool [ " spec " ] for tool in tools . values ( ) ]
2024-08-12 22:53:47 +08:00
tools_specs = json . dumps ( specs )
2024-08-17 22:27:11 +08:00
2024-09-07 11:50:29 +08:00
if app . state . config . TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE != " " :
template = app . state . config . TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
else :
template = """ Available Tools: {{ TOOLS}} \n Return an empty string if no tools match the query. If a function tool matches, construct and return a JSON object in the format { \" name \" : \" functionName \" , \" parameters \" : { \" requiredFunctionParamKey \" : \" requiredFunctionParamValue \" }} using the appropriate tool and its parameters. Only return the object and limit the response to the JSON object without additional text. """
2024-08-17 23:01:35 +08:00
tools_function_calling_prompt = tools_function_calling_generation_template (
2024-09-07 11:50:29 +08:00
template , tools_specs
2024-08-17 22:32:39 +08:00
)
2024-08-17 23:01:35 +08:00
log . info ( f " { tools_function_calling_prompt =} " )
2024-08-17 22:32:39 +08:00
payload = get_tools_function_calling_payload (
2024-08-17 23:01:35 +08:00
body [ " messages " ] , task_model_id , tools_function_calling_prompt
2024-08-17 22:32:39 +08:00
)
2024-08-17 22:24:11 +08:00
2024-08-12 22:53:47 +08:00
try :
2024-12-12 11:52:46 +08:00
payload = process_pipeline_inlet_filter ( request , payload , user , models )
2024-08-12 22:53:47 +08:00
except Exception as e :
raise e
2024-07-02 10:33:58 +08:00
2024-08-12 22:53:47 +08:00
try :
response = await generate_chat_completions ( form_data = payload , user = user )
log . debug ( f " { response =} " )
content = await get_content_from_response ( response )
log . debug ( f " { content =} " )
2024-08-17 22:27:11 +08:00
2024-08-20 00:04:57 +08:00
if not content :
2024-08-12 22:53:47 +08:00
return body , { }
2024-07-02 10:33:58 +08:00
2024-08-12 22:53:47 +08:00
try :
2024-09-29 01:51:28 +08:00
content = content [ content . find ( " { " ) : content . rfind ( " } " ) + 1 ]
if not content :
raise Exception ( " No JSON object found in the response " )
result = json . loads ( content )
tool_function_name = result . get ( " name " , None )
if tool_function_name not in tools :
return body , { }
tool_function_params = result . get ( " parameters " , { } )
try :
2024-10-27 03:21:05 +08:00
required_params = (
tools [ tool_function_name ]
. get ( " spec " , { } )
. get ( " parameters " , { } )
. get ( " required " , [ ] )
)
2024-10-26 05:36:44 +08:00
tool_function = tools [ tool_function_name ] [ " callable " ]
tool_function_params = {
2024-10-27 03:21:05 +08:00
k : v
for k , v in tool_function_params . items ( )
if k in required_params
2024-10-26 05:36:44 +08:00
}
tool_output = await tool_function ( * * tool_function_params )
2024-10-26 13:18:48 +08:00
2024-09-29 01:51:28 +08:00
except Exception as e :
tool_output = str ( e )
2024-11-22 11:46:09 +08:00
if isinstance ( tool_output , str ) :
if tools [ tool_function_name ] [ " citation " ] :
sources . append (
{
" source " : {
" name " : f " TOOL: { tools [ tool_function_name ] [ ' toolkit_id ' ] } / { tool_function_name } "
} ,
" document " : [ tool_output ] ,
" metadata " : [
{
" source " : f " TOOL: { tools [ tool_function_name ] [ ' toolkit_id ' ] } / { tool_function_name } "
}
] ,
}
)
else :
sources . append (
{
" source " : { } ,
" document " : [ tool_output ] ,
" metadata " : [
{
" source " : f " TOOL: { tools [ tool_function_name ] [ ' toolkit_id ' ] } / { tool_function_name } "
}
] ,
}
)
2024-11-22 10:26:38 +08:00
2024-11-22 11:46:09 +08:00
if tools [ tool_function_name ] [ " file_handler " ] :
skip_files = True
2024-09-29 01:51:28 +08:00
2024-08-10 18:58:18 +08:00
except Exception as e :
2024-09-29 01:51:28 +08:00
log . exception ( f " Error: { e } " )
content = None
2024-08-12 22:53:47 +08:00
except Exception as e :
2024-08-19 18:03:55 +08:00
log . exception ( f " Error: { e } " )
2024-08-12 22:53:47 +08:00
content = None
2024-08-10 18:58:18 +08:00
2024-11-22 11:46:09 +08:00
log . debug ( f " tool_contexts: { sources } " )
2024-07-02 10:33:58 +08:00
2024-08-20 22:41:49 +08:00
if skip_files and " files " in body . get ( " metadata " , { } ) :
del body [ " metadata " ] [ " files " ]
2024-07-02 10:33:58 +08:00
2024-11-22 11:46:09 +08:00
return body , { " sources " : sources }
2024-07-02 10:33:58 +08:00
2024-11-19 18:24:32 +08:00
async def chat_completion_files_handler (
2024-12-13 12:22:17 +08:00
request : Request , body : dict , user : UserModel
2024-11-19 18:24:32 +08:00
) - > tuple [ dict , dict [ str , list ] ] :
2024-11-22 11:46:09 +08:00
sources = [ ]
2024-07-02 10:33:58 +08:00
2024-11-23 04:31:06 +08:00
if files := body . get ( " metadata " , { } ) . get ( " files " , None ) :
2024-11-19 18:24:32 +08:00
try :
2024-11-23 04:31:06 +08:00
queries_response = await generate_queries (
{
" model " : body [ " model " ] ,
" messages " : body [ " messages " ] ,
" type " : " retrieval " ,
} ,
user ,
)
queries_response = queries_response [ " choices " ] [ 0 ] [ " message " ] [ " content " ]
2024-11-19 18:24:32 +08:00
2024-11-23 04:31:06 +08:00
try :
2024-11-27 06:00:49 +08:00
bracket_start = queries_response . find ( " { " )
bracket_end = queries_response . rfind ( " } " ) + 1
if bracket_start == - 1 or bracket_end == - 1 :
raise Exception ( " No JSON object found in the response " )
queries_response = queries_response [ bracket_start : bracket_end ]
2024-11-23 04:31:06 +08:00
queries_response = json . loads ( queries_response )
except Exception as e :
2024-11-27 06:00:49 +08:00
queries_response = { " queries " : [ queries_response ] }
2024-11-19 18:24:32 +08:00
2024-11-23 04:31:06 +08:00
queries = queries_response . get ( " queries " , [ ] )
except Exception as e :
queries = [ ]
2024-11-19 18:24:32 +08:00
2024-11-23 04:31:06 +08:00
if len ( queries ) == 0 :
queries = [ get_last_user_message ( body [ " messages " ] ) ]
2024-11-22 11:46:09 +08:00
sources = get_sources_from_files (
2024-07-02 10:33:58 +08:00
files = files ,
2024-11-19 18:24:32 +08:00
queries = queries ,
2024-12-13 12:22:17 +08:00
embedding_function = request . app . state . EMBEDDING_FUNCTION ,
k = request . app . state . config . TOP_K ,
reranking_function = request . app . state . rf ,
r = request . app . state . config . RELEVANCE_THRESHOLD ,
hybrid_search = request . app . state . config . ENABLE_RAG_HYBRID_SEARCH ,
2024-07-02 10:33:58 +08:00
)
2024-11-22 11:46:09 +08:00
log . debug ( f " rag_contexts:sources: { sources } " )
return body , { " sources " : sources }
2024-07-02 10:33:58 +08:00
class ChatCompletionMiddleware ( BaseHTTPMiddleware ) :
async def dispatch ( self , request : Request , call_next ) :
2024-12-12 12:15:23 +08:00
if not (
request . method == " POST "
and any (
endpoint in request . url . path
for endpoint in [ " /ollama/api/chat " , " /chat/completions " ]
)
2024-12-10 16:54:13 +08:00
) :
2024-08-10 19:03:47 +08:00
return await call_next ( request )
log . debug ( f " request.url.path: { request . url . path } " )
2024-07-02 10:33:58 +08:00
2024-12-12 12:39:55 +08:00
await get_all_models ( request )
models = app . state . MODELS
2024-11-16 20:41:07 +08:00
2024-12-13 12:22:17 +08:00
async def get_body_and_model_and_user ( request , models ) :
# Read the original request body
body = await request . body ( )
body_str = body . decode ( " utf-8 " )
body = json . loads ( body_str ) if body_str else { }
model_id = body [ " model " ]
if model_id not in models :
raise Exception ( " Model not found " )
model = models [ model_id ]
user = get_current_user (
request ,
get_http_authorization_cred ( request . headers . get ( " Authorization " ) ) ,
)
return body , model , user
2024-08-10 19:03:47 +08:00
try :
2024-11-16 20:41:07 +08:00
body , model , user = await get_body_and_model_and_user ( request , models )
2024-08-10 19:03:47 +08:00
except Exception as e :
return JSONResponse (
status_code = status . HTTP_400_BAD_REQUEST ,
content = { " detail " : str ( e ) } ,
)
2024-07-31 21:01:40 +08:00
2024-11-16 20:41:07 +08:00
model_info = Models . get_model_by_id ( model [ " id " ] )
2024-12-02 10:25:44 +08:00
if user . role == " user " and not BYPASS_MODEL_ACCESS_CONTROL :
2024-11-18 23:40:37 +08:00
if model . get ( " arena " ) :
if not has_access (
user . id ,
type = " read " ,
access_control = model . get ( " info " , { } )
. get ( " meta " , { } )
. get ( " access_control " , { } ) ,
) :
raise HTTPException (
status_code = 403 ,
detail = " Model not found " ,
)
else :
if not model_info :
return JSONResponse (
status_code = status . HTTP_404_NOT_FOUND ,
content = { " detail " : " Model not found " } ,
)
elif not (
user . id == model_info . user_id
or has_access (
user . id , type = " read " , access_control = model_info . access_control
)
) :
return JSONResponse (
status_code = status . HTTP_403_FORBIDDEN ,
content = { " detail " : " User does not have access to the model " } ,
)
2024-11-16 20:41:07 +08:00
2024-08-10 19:03:47 +08:00
metadata = {
" chat_id " : body . pop ( " chat_id " , None ) ,
" message_id " : body . pop ( " id " , None ) ,
" session_id " : body . pop ( " session_id " , None ) ,
2024-08-22 07:08:59 +08:00
" tool_ids " : body . get ( " tool_ids " , None ) ,
" files " : body . get ( " files " , None ) ,
2024-08-10 19:03:47 +08:00
}
2024-08-21 00:41:51 +08:00
body [ " metadata " ] = metadata
2024-07-09 12:39:06 +08:00
2024-08-11 15:31:40 +08:00
extra_params = {
" __event_emitter__ " : get_event_emitter ( metadata ) ,
" __event_call__ " : get_event_call ( metadata ) ,
2024-08-22 22:02:29 +08:00
" __user__ " : {
" id " : user . id ,
" email " : user . email ,
" name " : user . name ,
" role " : user . role ,
} ,
2024-11-19 02:12:54 +08:00
" __metadata__ " : metadata ,
2024-08-11 15:31:40 +08:00
}
2024-07-02 10:33:58 +08:00
2024-08-10 19:03:47 +08:00
# Initialize data_items to store additional data to be sent to the client
2024-10-14 15:13:26 +08:00
# Initialize contexts and citation
2024-08-10 19:03:47 +08:00
data_items = [ ]
2024-11-22 11:46:09 +08:00
sources = [ ]
2024-07-02 10:33:58 +08:00
2024-08-10 19:03:47 +08:00
try :
2024-08-17 22:00:18 +08:00
body , flags = await chat_completion_filter_functions_handler (
2024-08-11 15:31:40 +08:00
body , model , extra_params
2024-08-10 19:03:47 +08:00
)
except Exception as e :
return JSONResponse (
status_code = status . HTTP_400_BAD_REQUEST ,
content = { " detail " : str ( e ) } ,
)
2024-07-02 10:33:58 +08:00
2024-11-16 20:41:07 +08:00
tool_ids = body . pop ( " tool_ids " , None )
files = body . pop ( " files " , None )
2024-08-22 07:08:59 +08:00
metadata = {
* * metadata ,
2024-11-16 20:41:07 +08:00
" tool_ids " : tool_ids ,
" files " : files ,
2024-08-22 07:08:59 +08:00
}
body [ " metadata " ] = metadata
2024-08-10 19:03:47 +08:00
try :
2024-11-16 20:41:07 +08:00
body , flags = await chat_completion_tools_handler (
2024-12-13 12:22:17 +08:00
request , body , user , models , extra_params
2024-11-16 20:41:07 +08:00
)
2024-11-22 11:46:09 +08:00
sources . extend ( flags . get ( " sources " , [ ] ) )
2024-08-10 19:03:47 +08:00
except Exception as e :
2024-08-15 03:40:10 +08:00
log . exception ( e )
2024-03-09 14:34:47 +08:00
2024-08-10 19:03:47 +08:00
try :
2024-12-13 12:22:17 +08:00
body , flags = await chat_completion_files_handler ( request , body , user )
2024-11-22 11:46:09 +08:00
sources . extend ( flags . get ( " sources " , [ ] ) )
2024-08-10 19:03:47 +08:00
except Exception as e :
2024-08-15 03:40:10 +08:00
log . exception ( e )
2024-08-10 19:03:47 +08:00
# If context is not empty, insert it into the messages
2024-11-22 11:46:09 +08:00
if len ( sources ) > 0 :
2024-11-22 09:58:29 +08:00
context_string = " "
2024-11-22 11:46:09 +08:00
for source_idx , source in enumerate ( sources ) :
source_id = source . get ( " source " , { } ) . get ( " name " , " " )
2024-11-22 10:26:38 +08:00
2024-11-22 11:46:09 +08:00
if " document " in source :
for doc_idx , doc_context in enumerate ( source [ " document " ] ) :
metadata = source . get ( " metadata " )
2024-11-23 02:35:59 +08:00
doc_source_id = None
2024-11-22 11:46:09 +08:00
if metadata :
doc_source_id = metadata [ doc_idx ] . get ( " source " , source_id )
if source_id :
2024-11-23 02:35:59 +08:00
context_string + = f " <source><source_id> { doc_source_id if doc_source_id is not None else source_id } </source_id><source_context> { doc_context } </source_context></source> \n "
2024-11-22 11:46:09 +08:00
else :
# If there is no source_id, then do not include the source_id tag
context_string + = f " <source><source_context> { doc_context } </source_context></source> \n "
2024-11-22 09:58:29 +08:00
context_string = context_string . strip ( )
2024-08-10 19:03:47 +08:00
prompt = get_last_user_message ( body [ " messages " ] )
2024-09-13 13:07:03 +08:00
2024-08-10 19:11:41 +08:00
if prompt is None :
raise Exception ( " No user message found " )
2024-09-13 13:07:03 +08:00
if (
2024-12-12 11:52:46 +08:00
app . state . config . RELEVANCE_THRESHOLD == 0
2024-09-13 13:07:03 +08:00
and context_string . strip ( ) == " "
) :
log . debug (
f " With a 0 relevancy threshold for RAG, the context cannot be empty "
)
2024-09-13 12:56:50 +08:00
2024-08-10 19:03:47 +08:00
# Workaround for Ollama 2.0+ system prompt issue
# TODO: replace with add_or_update_system_message
if model [ " owned_by " ] == " ollama " :
body [ " messages " ] = prepend_to_first_user_message_content (
2024-12-12 11:52:46 +08:00
rag_template ( app . state . config . RAG_TEMPLATE , context_string , prompt ) ,
2024-08-10 19:03:47 +08:00
body [ " messages " ] ,
)
2024-06-20 18:23:50 +08:00
else :
2024-08-10 19:03:47 +08:00
body [ " messages " ] = add_or_update_system_message (
2024-12-12 11:52:46 +08:00
rag_template ( app . state . config . RAG_TEMPLATE , context_string , prompt ) ,
2024-08-10 19:03:47 +08:00
body [ " messages " ] ,
)
# If there are citations, add them to the data_items
2024-11-22 11:46:09 +08:00
sources = [
source for source in sources if source . get ( " source " , { } ) . get ( " name " , " " )
]
if len ( sources ) > 0 :
data_items . append ( { " sources " : sources } )
2024-08-10 19:03:47 +08:00
modified_body_bytes = json . dumps ( body ) . encode ( " utf-8 " )
# Replace the request body with the modified one
request . _body = modified_body_bytes
# Set custom header to ensure content-length matches new body length
request . headers . __dict__ [ " _list " ] = [
( b " content-length " , str ( len ( modified_body_bytes ) ) . encode ( " utf-8 " ) ) ,
* [ ( k , v ) for k , v in request . headers . raw if k . lower ( ) != b " content-length " ] ,
]
2024-05-06 21:14:51 +08:00
2024-06-20 18:23:50 +08:00
response = await call_next ( request )
2024-08-19 17:34:44 +08:00
if not isinstance ( response , StreamingResponse ) :
return response
2024-08-10 19:03:47 +08:00
2024-08-19 17:34:44 +08:00
content_type = response . headers [ " Content-Type " ]
is_openai = " text/event-stream " in content_type
is_ollama = " application/x-ndjson " in content_type
if not is_openai and not is_ollama :
return response
2024-03-09 14:34:47 +08:00
2024-08-19 17:34:44 +08:00
def wrap_item ( item ) :
return f " data: { item } \n \n " if is_openai else f " { item } \n "
2024-03-09 14:34:47 +08:00
2024-08-19 17:34:44 +08:00
async def stream_wrapper ( original_generator , data_items ) :
for item in data_items :
yield wrap_item ( json . dumps ( item ) )
2024-06-20 17:06:10 +08:00
2024-08-19 17:34:44 +08:00
async for data in original_generator :
yield data
2024-05-06 21:14:51 +08:00
2024-08-31 22:15:21 +08:00
return StreamingResponse (
stream_wrapper ( response . body_iterator , data_items ) ,
headers = dict ( response . headers ) ,
)
2024-06-20 17:06:10 +08:00
2024-08-16 00:03:42 +08:00
async def _receive ( self , body : bytes ) :
return { " type " : " http.request " , " body " : body , " more_body " : False }
2024-05-06 21:14:51 +08:00
2024-03-09 14:34:47 +08:00
2024-06-11 14:40:27 +08:00
app . add_middleware ( ChatCompletionMiddleware )
2024-03-09 14:34:47 +08:00
2024-10-11 05:00:05 +08:00
2024-05-28 10:03:26 +08:00
class PipelineMiddleware ( BaseHTTPMiddleware ) :
async def dispatch ( self , request : Request , call_next ) :
2024-12-12 12:15:23 +08:00
if not (
request . method == " POST "
and any (
endpoint in request . url . path
for endpoint in [ " /ollama/api/chat " , " /chat/completions " ]
)
2024-12-10 16:54:13 +08:00
) :
2024-08-10 19:03:47 +08:00
return await call_next ( request )
2024-06-13 04:31:05 +08:00
2024-08-10 19:03:47 +08:00
log . debug ( f " request.url.path: { request . url . path } " )
2024-05-30 17:04:29 +08:00
2024-08-10 19:03:47 +08:00
# Read the original request body
body = await request . body ( )
# Decode body to string
body_str = body . decode ( " utf-8 " )
# Parse string to JSON
data = json . loads ( body_str ) if body_str else { }
2024-09-28 02:04:45 +08:00
try :
user = get_current_user (
request ,
get_http_authorization_cred ( request . headers [ " Authorization " ] ) ,
)
except KeyError as e :
if len ( e . args ) > 1 :
return JSONResponse (
status_code = e . args [ 0 ] ,
content = { " detail " : e . args [ 1 ] } ,
)
else :
return JSONResponse (
status_code = status . HTTP_401_UNAUTHORIZED ,
content = { " detail " : " Not authenticated " } ,
)
2024-11-20 15:25:50 +08:00
except HTTPException as e :
return JSONResponse (
status_code = e . status_code ,
content = { " detail " : e . detail } ,
)
2024-08-10 19:03:47 +08:00
2024-12-12 12:39:55 +08:00
await get_all_models ( request )
2024-12-12 11:52:46 +08:00
models = app . state . MODELS
2024-11-16 20:41:07 +08:00
2024-08-10 19:03:47 +08:00
try :
2024-12-12 11:52:46 +08:00
data = process_pipeline_inlet_filter ( request , data , user , models )
2024-08-10 19:03:47 +08:00
except Exception as e :
2024-09-04 23:52:59 +08:00
if len ( e . args ) > 1 :
return JSONResponse (
status_code = e . args [ 0 ] ,
content = { " detail " : e . args [ 1 ] } ,
)
else :
return JSONResponse (
status_code = status . HTTP_400_BAD_REQUEST ,
content = { " detail " : str ( e ) } ,
)
2024-08-10 19:03:47 +08:00
modified_body_bytes = json . dumps ( data ) . encode ( " utf-8 " )
# Replace the request body with the modified one
request . _body = modified_body_bytes
# Set custom header to ensure content-length matches new body length
request . headers . __dict__ [ " _list " ] = [
( b " content-length " , str ( len ( modified_body_bytes ) ) . encode ( " utf-8 " ) ) ,
* [ ( k , v ) for k , v in request . headers . raw if k . lower ( ) != b " content-length " ] ,
]
2024-05-28 10:03:26 +08:00
response = await call_next ( request )
return response
async def _receive ( self , body : bytes ) :
return { " type " : " http.request " , " body " : body , " more_body " : False }
app . add_middleware ( PipelineMiddleware )
2024-10-08 09:19:13 +08:00
class RedirectMiddleware ( BaseHTTPMiddleware ) :
async def dispatch ( self , request : Request , call_next ) :
# Check if the request is a GET request
if request . method == " GET " :
path = request . url . path
query_params = dict ( parse_qs ( urlparse ( str ( request . url ) ) . query ) )
# Check for the specific watch path and the presence of 'v' parameter
if path . endswith ( " /watch " ) and " v " in query_params :
video_id = query_params [ " v " ] [ 0 ] # Extract the first 'v' parameter
encoded_video_id = urlencode ( { " youtube " : video_id } )
redirect_url = f " /? { encoded_video_id } "
return RedirectResponse ( url = redirect_url )
# Proceed with the normal flow of other requests
response = await call_next ( request )
return response
# Add the middleware to the app
app . add_middleware ( RedirectMiddleware )
2024-09-17 08:53:30 +08:00
app . add_middleware ( SecurityHeadersMiddleware )
2024-05-29 00:50:17 +08:00
2024-06-24 19:06:15 +08:00
@app.middleware ( " http " )
2024-06-24 19:45:33 +08:00
async def commit_session_after_request ( request : Request , call_next ) :
2024-06-24 19:06:15 +08:00
response = await call_next ( request )
2024-11-23 12:11:46 +08:00
# log.debug("Commit session after request")
2024-06-24 19:06:15 +08:00
Session . commit ( )
return response
2024-05-29 00:50:17 +08:00
2023-11-15 08:28:51 +08:00
@app.middleware ( " http " )
async def check_url ( request : Request , call_next ) :
start_time = int ( time . time ( ) )
2024-12-12 10:53:38 +08:00
request . state . enable_api_key = app . state . config . ENABLE_API_KEY
2023-11-15 08:28:51 +08:00
response = await call_next ( request )
process_time = int ( time . time ( ) ) - start_time
response . headers [ " X-Process-Time " ] = str ( process_time )
return response
2024-09-08 18:54:56 +08:00
@app.middleware ( " http " )
async def inspect_websocket ( request : Request , call_next ) :
if (
2024-10-21 09:38:06 +08:00
" /ws/socket.io " in request . url . path
and request . query_params . get ( " transport " ) == " websocket "
2024-09-08 18:54:56 +08:00
) :
upgrade = ( request . headers . get ( " Upgrade " ) or " " ) . lower ( )
connection = ( request . headers . get ( " Connection " ) or " " ) . lower ( ) . split ( " , " )
# Check that there's the correct headers for an upgrade, else reject the connection
# This is to work around this upstream issue: https://github.com/miguelgrinberg/python-engineio/issues/367
if upgrade != " websocket " or " upgrade " not in connection :
return JSONResponse (
status_code = status . HTTP_400_BAD_REQUEST ,
content = { " detail " : " Invalid WebSocket upgrade request " } ,
)
return await call_next ( request )
2024-12-10 16:54:13 +08:00
app . add_middleware (
CORSMiddleware ,
allow_origins = CORS_ALLOW_ORIGIN ,
allow_credentials = True ,
allow_methods = [ " * " ] ,
allow_headers = [ " * " ] ,
)
2024-06-04 14:39:52 +08:00
app . mount ( " /ws " , socket_app )
2024-12-10 16:54:13 +08:00
2024-12-12 10:36:59 +08:00
app . include_router ( ollama . router , prefix = " /ollama " , tags = [ " ollama " ] )
app . include_router ( openai . router , prefix = " /openai " , tags = [ " openai " ] )
2024-12-13 12:22:17 +08:00
app . include_router ( pipelines . router , prefix = " /api/v1/pipelines " , tags = [ " pipelines " ] )
app . include_router ( tasks . router , prefix = " /api/v1/tasks " , tags = [ " tasks " ] )
2024-12-12 09:50:48 +08:00
app . include_router ( images . router , prefix = " /api/v1/images " , tags = [ " images " ] )
app . include_router ( audio . router , prefix = " /api/v1/audio " , tags = [ " audio " ] )
app . include_router ( retrieval . router , prefix = " /api/v1/retrieval " , tags = [ " retrieval " ] )
2024-12-10 16:54:13 +08:00
2024-12-11 18:41:25 +08:00
app . include_router ( configs . router , prefix = " /api/v1/configs " , tags = [ " configs " ] )
2024-01-07 14:07:20 +08:00
2024-12-11 18:41:25 +08:00
app . include_router ( auths . router , prefix = " /api/v1/auths " , tags = [ " auths " ] )
app . include_router ( users . router , prefix = " /api/v1/users " , tags = [ " users " ] )
2024-05-19 23:00:07 +08:00
2024-12-11 18:41:25 +08:00
app . include_router ( chats . router , prefix = " /api/v1/chats " , tags = [ " chats " ] )
app . include_router ( models . router , prefix = " /api/v1/models " , tags = [ " models " ] )
app . include_router ( knowledge . router , prefix = " /api/v1/knowledge " , tags = [ " knowledge " ] )
app . include_router ( prompts . router , prefix = " /api/v1/prompts " , tags = [ " prompts " ] )
app . include_router ( tools . router , prefix = " /api/v1/tools " , tags = [ " tools " ] )
app . include_router ( memories . router , prefix = " /api/v1/memories " , tags = [ " memories " ] )
app . include_router ( folders . router , prefix = " /api/v1/folders " , tags = [ " folders " ] )
app . include_router ( groups . router , prefix = " /api/v1/groups " , tags = [ " groups " ] )
app . include_router ( files . router , prefix = " /api/v1/files " , tags = [ " files " ] )
app . include_router ( functions . router , prefix = " /api/v1/functions " , tags = [ " functions " ] )
app . include_router (
evaluations . router , prefix = " /api/v1/evaluations " , tags = [ " evaluations " ]
)
app . include_router ( utils . router , prefix = " /api/v1/utils " , tags = [ " utils " ] )
2024-05-19 23:00:07 +08:00
2024-04-01 04:59:39 +08:00
2024-12-12 10:36:59 +08:00
##################################
#
# Chat Endpoints
#
##################################
2024-12-13 12:22:17 +08:00
@app.get ( " /api/models " )
async def get_models ( request : Request , user = Depends ( get_verified_user ) ) :
def get_filtered_models ( models , user ) :
filtered_models = [ ]
for model in models :
if model . get ( " arena " ) :
if has_access (
user . id ,
type = " read " ,
access_control = model . get ( " info " , { } )
. get ( " meta " , { } )
. get ( " access_control " , { } ) ,
2024-05-25 11:29:13 +08:00
) :
2024-12-13 12:22:17 +08:00
filtered_models . append ( model )
continue
2024-11-22 12:14:05 +08:00
2024-12-13 12:22:17 +08:00
model_info = Models . get_model_by_id ( model [ " id " ] )
if model_info :
if user . id == model_info . user_id or has_access (
user . id , type = " read " , access_control = model_info . access_control
) :
filtered_models . append ( model )
2024-05-25 09:26:36 +08:00
2024-12-13 12:22:17 +08:00
return filtered_models
2024-05-25 09:26:36 +08:00
2024-12-12 12:15:23 +08:00
models = await get_all_models ( request )
2024-05-28 10:03:26 +08:00
2024-05-28 10:34:05 +08:00
# Filter out filter pipelines
2024-05-28 10:03:26 +08:00
models = [
model
for model in models
2024-05-29 02:43:48 +08:00
if " pipeline " not in model or model [ " pipeline " ] . get ( " type " , None ) != " filter "
2024-05-28 10:03:26 +08:00
]
2024-12-13 12:22:17 +08:00
model_order_list = request . app . state . config . MODEL_ORDER_LIST
2024-11-26 16:55:58 +08:00
if model_order_list :
model_order_dict = { model_id : i for i , model_id in enumerate ( model_order_list ) }
# Sort models by order list priority, with fallback for those not in the list
models . sort (
key = lambda x : ( model_order_dict . get ( x [ " id " ] , float ( " inf " ) ) , x [ " name " ] )
)
2024-11-16 20:41:07 +08:00
# Filter out models that the user does not have access to
2024-12-02 10:25:44 +08:00
if user . role == " user " and not BYPASS_MODEL_ACCESS_CONTROL :
2024-12-13 12:22:17 +08:00
models = get_filtered_models ( models , user )
2024-05-24 16:40:48 +08:00
2024-11-23 12:11:46 +08:00
log . debug (
f " /api/models returned filtered models accessible to the user: { json . dumps ( [ model [ ' id ' ] for model in models ] ) } "
)
2024-05-24 16:40:48 +08:00
return { " data " : models }
2024-11-16 11:14:24 +08:00
@app.get ( " /api/models/base " )
2024-12-12 12:26:24 +08:00
async def get_base_models ( request : Request , user = Depends ( get_admin_user ) ) :
models = await get_all_base_models ( request )
2024-11-16 11:14:24 +08:00
return { " data " : models }
2024-12-12 10:05:42 +08:00
@app.post ( " /api/chat/completions " )
2024-12-13 12:22:17 +08:00
async def chat_completion (
2024-12-12 12:39:55 +08:00
request : Request ,
2024-12-12 10:05:42 +08:00
form_data : dict ,
user = Depends ( get_verified_user ) ,
bypass_filter : bool = False ,
) :
2024-12-13 12:22:17 +08:00
try :
return await chat_completion_handler ( request , form_data , user , bypass_filter )
except Exception as e :
2024-12-12 10:05:42 +08:00
raise HTTPException (
2024-12-13 12:22:17 +08:00
status_code = status . HTTP_400_BAD_REQUEST ,
detail = str ( e ) ,
2024-12-12 10:05:42 +08:00
)
2024-12-13 12:22:17 +08:00
generate_chat_completions = chat_completion
generate_chat_completion = chat_completion
2024-12-12 10:05:42 +08:00
@app.post ( " /api/chat/completed " )
2024-12-12 11:52:46 +08:00
async def chat_completed (
request : Request , form_data : dict , user = Depends ( get_verified_user )
) :
try :
2024-12-13 12:22:17 +08:00
return await chat_completed_handler ( request , form_data , user )
2024-12-12 11:52:46 +08:00
except Exception as e :
2024-12-13 12:22:17 +08:00
raise HTTPException (
2024-12-12 11:52:46 +08:00
status_code = status . HTTP_400_BAD_REQUEST ,
detail = str ( e ) ,
)
2024-12-12 10:05:42 +08:00
@app.post ( " /api/chat/actions/ {action_id} " )
2024-12-12 12:39:55 +08:00
async def chat_action (
request : Request , action_id : str , form_data : dict , user = Depends ( get_verified_user )
) :
2024-12-13 12:22:17 +08:00
try :
return await chat_action_handler ( request , action_id , form_data , user )
except Exception as e :
2024-12-12 10:05:42 +08:00
raise HTTPException (
2024-12-13 12:22:17 +08:00
status_code = status . HTTP_400_BAD_REQUEST ,
detail = str ( e ) ,
2024-12-12 10:05:42 +08:00
)
2024-06-20 16:51:39 +08:00
##################################
#
# Config Endpoints
#
##################################
2024-02-22 10:12:01 +08:00
@app.get ( " /api/config " )
2024-08-19 22:49:40 +08:00
async def get_app_config ( request : Request ) :
user = None
if " token " in request . cookies :
token = request . cookies . get ( " token " )
2024-11-06 13:14:02 +08:00
try :
data = decode_token ( token )
except Exception as e :
log . debug ( e )
raise HTTPException (
status_code = status . HTTP_401_UNAUTHORIZED ,
detail = " Invalid token " ,
)
2024-08-19 22:49:40 +08:00
if data is not None and " id " in data :
user = Users . get_user_by_id ( data [ " id " ] )
2024-11-06 12:47:23 +08:00
onboarding = False
2024-11-03 19:00:28 +08:00
if user is None :
user_count = Users . get_num_users ( )
2024-11-06 12:47:23 +08:00
onboarding = user_count == 0
2024-11-03 19:00:28 +08:00
2024-02-22 10:12:01 +08:00
return {
2024-11-06 12:47:23 +08:00
* * ( { " onboarding " : True } if onboarding else { } ) ,
2024-02-22 10:12:01 +08:00
" status " : True ,
2024-02-24 09:12:19 +08:00
" name " : WEBUI_NAME ,
2024-02-23 16:30:26 +08:00
" version " : VERSION ,
2024-07-01 05:48:05 +08:00
" default_locale " : str ( DEFAULT_LOCALE ) ,
2024-08-19 22:49:40 +08:00
" oauth " : {
" providers " : {
name : config . get ( " name " , name )
for name , config in OAUTH_PROVIDERS . items ( )
}
} ,
2024-05-27 04:02:40 +08:00
" features " : {
2024-05-27 00:05:26 +08:00
" auth " : WEBUI_AUTH ,
2024-12-12 10:53:38 +08:00
" auth_trusted_header " : bool ( app . state . AUTH_TRUSTED_EMAIL_HEADER ) ,
" enable_ldap " : app . state . config . ENABLE_LDAP ,
" enable_api_key " : app . state . config . ENABLE_API_KEY ,
" enable_signup " : app . state . config . ENABLE_SIGNUP ,
" enable_login_form " : app . state . config . ENABLE_LOGIN_FORM ,
2024-08-19 22:49:40 +08:00
* * (
{
2024-12-12 11:52:46 +08:00
" enable_web_search " : app . state . config . ENABLE_RAG_WEB_SEARCH ,
" enable_image_generation " : app . state . config . ENABLE_IMAGE_GENERATION ,
2024-12-12 10:53:38 +08:00
" enable_community_sharing " : app . state . config . ENABLE_COMMUNITY_SHARING ,
" enable_message_rating " : app . state . config . ENABLE_MESSAGE_RATING ,
2024-08-19 22:49:40 +08:00
" enable_admin_export " : ENABLE_ADMIN_EXPORT ,
" enable_admin_chat_access " : ENABLE_ADMIN_CHAT_ACCESS ,
}
if user is not None
else { }
) ,
2024-06-08 11:18:48 +08:00
} ,
2024-08-19 22:49:40 +08:00
* * (
{
2024-12-12 10:53:38 +08:00
" default_models " : app . state . config . DEFAULT_MODELS ,
" default_prompt_suggestions " : app . state . config . DEFAULT_PROMPT_SUGGESTIONS ,
2024-08-19 22:49:40 +08:00
" audio " : {
" tts " : {
2024-12-12 11:52:46 +08:00
" engine " : app . state . config . TTS_ENGINE ,
" voice " : app . state . config . TTS_VOICE ,
" split_on " : app . state . config . TTS_SPLIT_ON ,
2024-08-19 22:49:40 +08:00
} ,
" stt " : {
2024-12-12 11:52:46 +08:00
" engine " : app . state . config . STT_ENGINE ,
2024-08-19 22:49:40 +08:00
} ,
} ,
2024-08-27 23:05:24 +08:00
" file " : {
2024-12-12 11:52:46 +08:00
" max_size " : app . state . config . FILE_MAX_SIZE ,
" max_count " : app . state . config . FILE_MAX_COUNT ,
2024-08-27 23:05:24 +08:00
} ,
2024-12-12 10:53:38 +08:00
" permissions " : { * * app . state . config . USER_PERMISSIONS } ,
2024-05-26 15:37:09 +08:00
}
2024-08-19 22:49:40 +08:00
if user is not None
else { }
) ,
2024-02-22 10:12:01 +08:00
}
2024-12-10 16:54:13 +08:00
class UrlForm ( BaseModel ) :
url : str
2024-06-20 16:51:39 +08:00
2024-03-21 09:35:02 +08:00
@app.get ( " /api/webhook " )
async def get_webhook_url ( user = Depends ( get_admin_user ) ) :
return {
2024-05-10 15:03:24 +08:00
" url " : app . state . config . WEBHOOK_URL ,
2024-03-21 09:35:02 +08:00
}
@app.post ( " /api/webhook " )
async def update_webhook_url ( form_data : UrlForm , user = Depends ( get_admin_user ) ) :
2024-05-10 15:03:24 +08:00
app . state . config . WEBHOOK_URL = form_data . url
2024-12-12 10:53:38 +08:00
app . state . WEBHOOK_URL = app . state . config . WEBHOOK_URL
2024-06-04 12:17:43 +08:00
return { " url " : app . state . config . WEBHOOK_URL }
2024-05-27 00:23:24 +08:00
2024-03-05 16:59:35 +08:00
@app.get ( " /api/version " )
2024-08-03 21:24:26 +08:00
async def get_app_version ( ) :
2024-03-05 16:59:35 +08:00
return {
" version " : VERSION ,
}
2024-02-26 03:26:58 +08:00
@app.get ( " /api/version/updates " )
async def get_app_latest_release_version ( ) :
2024-10-08 13:13:49 +08:00
if OFFLINE_MODE :
log . debug (
f " Offline mode is enabled, returning current version as latest version "
)
return { " current " : VERSION , " latest " : VERSION }
2024-02-26 03:26:58 +08:00
try :
2024-09-27 20:38:56 +08:00
timeout = aiohttp . ClientTimeout ( total = 1 )
async with aiohttp . ClientSession ( timeout = timeout , trust_env = True ) as session :
2024-04-10 14:03:05 +08:00
async with session . get (
2024-10-21 09:38:06 +08:00
" https://api.github.com/repos/open-webui/open-webui/releases/latest "
2024-04-10 14:03:05 +08:00
) as response :
response . raise_for_status ( )
data = await response . json ( )
latest_version = data [ " tag_name " ]
return { " current " : VERSION , " latest " : latest_version [ 1 : ] }
2024-09-30 22:32:38 +08:00
except Exception as e :
log . debug ( e )
2024-09-27 20:38:56 +08:00
return { " current " : VERSION , " latest " : VERSION }
2024-02-26 03:26:58 +08:00
2024-04-10 16:27:19 +08:00
2024-12-10 16:54:13 +08:00
@app.get ( " /api/changelog " )
async def get_app_changelog ( ) :
return { key : CHANGELOG [ key ] for idx , key in enumerate ( CHANGELOG ) if idx < 5 }
2024-05-28 01:07:38 +08:00
############################
# OAuth Login & Callback
############################
# SessionMiddleware is used by authlib for oauth
if len ( OAUTH_PROVIDERS ) > 0 :
app . add_middleware (
2024-06-06 02:21:42 +08:00
SessionMiddleware ,
secret_key = WEBUI_SECRET_KEY ,
session_cookie = " oui-session " ,
same_site = WEBUI_SESSION_COOKIE_SAME_SITE ,
2024-06-07 16:13:42 +08:00
https_only = WEBUI_SESSION_COOKIE_SECURE ,
2024-05-28 01:07:38 +08:00
)
@app.get ( " /oauth/ {provider} /login " )
async def oauth_login ( provider : str , request : Request ) :
2024-10-16 22:32:57 +08:00
return await oauth_manager . handle_login ( provider , request )
2024-05-28 01:07:38 +08:00
2024-06-22 01:25:19 +08:00
# OAuth login logic is as follows:
# 1. Attempt to find a user with matching subject ID, tied to the provider
# 2. If OAUTH_MERGE_ACCOUNTS_BY_EMAIL is true, find a user with the email address provided via OAuth
# - This is considered insecure in general, as OAuth providers do not always verify email addresses
# 3. If there is no user, and ENABLE_OAUTH_SIGNUP is true, create a user
2024-10-14 15:13:26 +08:00
# - Email addresses are considered unique, so we fail registration if the email address is already taken
2024-05-28 01:07:38 +08:00
@app.get ( " /oauth/ {provider} /callback " )
2024-06-21 21:35:11 +08:00
async def oauth_callback ( provider : str , request : Request , response : Response ) :
2024-10-16 22:32:57 +08:00
return await oauth_manager . handle_callback ( provider , request , response )
2024-05-28 01:07:38 +08:00
2024-04-03 02:55:00 +08:00
@app.get ( " /manifest.json " )
async def get_manifest_json ( ) :
return {
2024-04-04 11:43:55 +08:00
" name " : WEBUI_NAME ,
" short_name " : WEBUI_NAME ,
2024-09-24 19:28:00 +08:00
" description " : " Open WebUI is an open, extensible, user-friendly interface for AI that adapts to your workflow. " ,
2024-04-03 02:55:00 +08:00
" start_url " : " / " ,
" display " : " standalone " ,
" background_color " : " #343541 " ,
2024-11-03 16:59:53 +08:00
" orientation " : " natural " ,
2024-08-15 19:24:47 +08:00
" icons " : [
{
" src " : " /static/logo.png " ,
" type " : " image/png " ,
" sizes " : " 500x500 " ,
" purpose " : " any " ,
} ,
{
" src " : " /static/logo.png " ,
" type " : " image/png " ,
" sizes " : " 500x500 " ,
" purpose " : " maskable " ,
} ,
] ,
2024-04-03 02:55:00 +08:00
}
2024-04-10 16:27:19 +08:00
2024-05-07 08:29:16 +08:00
@app.get ( " /opensearch.xml " )
async def get_opensearch_xml ( ) :
xml_content = rf """
< OpenSearchDescription xmlns = " http://a9.com/-/spec/opensearch/1.1/ " xmlns : moz = " http://www.mozilla.org/2006/browser/search/ " >
< ShortName > { WEBUI_NAME } < / ShortName >
< Description > Search { WEBUI_NAME } < / Description >
< InputEncoding > UTF - 8 < / InputEncoding >
2024-07-09 14:07:23 +08:00
< Image width = " 16 " height = " 16 " type = " image/x-icon " > { WEBUI_URL } / static / favicon . png < / Image >
2024-05-07 08:29:16 +08:00
< Url type = " text/html " method = " get " template = " {WEBUI_URL} /?q= { " { searchTerms } " } " / >
< moz : SearchForm > { WEBUI_URL } < / moz : SearchForm >
< / OpenSearchDescription >
"""
return Response ( content = xml_content , media_type = " application/xml " )
2024-05-16 02:17:18 +08:00
@app.get ( " /health " )
async def healthcheck ( ) :
return { " status " : True }
2024-06-18 21:03:31 +08:00
@app.get ( " /health/db " )
2024-06-21 20:58:57 +08:00
async def healthcheck_with_db ( ) :
2024-06-24 19:06:15 +08:00
Session . execute ( text ( " SELECT 1; " ) ) . all ( )
2024-06-18 21:03:31 +08:00
return { " status " : True }
2024-04-09 18:32:28 +08:00
app . mount ( " /static " , StaticFiles ( directory = STATIC_DIR ) , name = " static " )
app . mount ( " /cache " , StaticFiles ( directory = CACHE_DIR ) , name = " cache " )
2024-02-24 09:12:19 +08:00
2024-04-28 23:03:30 +08:00
if os . path . exists ( FRONTEND_BUILD_DIR ) :
2024-05-22 12:38:58 +08:00
mimetypes . add_type ( " text/javascript " , " .js " )
2024-04-28 23:03:30 +08:00
app . mount (
" / " ,
SPAStaticFiles ( directory = FRONTEND_BUILD_DIR , html = True ) ,
name = " spa-static-files " ,
)
else :
log . warning (
f " Frontend build directory not found at ' { FRONTEND_BUILD_DIR } ' . Serving API only. "
)