1422 lines
		
	
	
		
			40 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			1422 lines
		
	
	
		
			40 KiB
		
	
	
	
		
			Python
		
	
	
	
| import os
 | |
| import sys
 | |
| import logging
 | |
| import importlib.metadata
 | |
| import pkgutil
 | |
| import chromadb
 | |
| from chromadb import Settings
 | |
| from bs4 import BeautifulSoup
 | |
| from typing import TypeVar, Generic
 | |
| from pydantic import BaseModel
 | |
| from typing import Optional
 | |
| 
 | |
| from pathlib import Path
 | |
| import json
 | |
| import yaml
 | |
| 
 | |
| import markdown
 | |
| import requests
 | |
| import shutil
 | |
| 
 | |
| from constants import ERROR_MESSAGES
 | |
| 
 | |
| ####################################
 | |
| # Load .env file
 | |
| ####################################
 | |
| 
 | |
| BACKEND_DIR = Path(__file__).parent  # the path containing this file
 | |
| BASE_DIR = BACKEND_DIR.parent  # the path containing the backend/
 | |
| 
 | |
| print(BASE_DIR)
 | |
| 
 | |
| try:
 | |
|     from dotenv import load_dotenv, find_dotenv
 | |
| 
 | |
|     load_dotenv(find_dotenv(str(BASE_DIR / ".env")))
 | |
| except ImportError:
 | |
|     print("dotenv not installed, skipping...")
 | |
| 
 | |
| 
 | |
| ####################################
 | |
| # LOGGING
 | |
| ####################################
 | |
| 
 | |
| log_levels = ["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"]
 | |
| 
 | |
| GLOBAL_LOG_LEVEL = os.environ.get("GLOBAL_LOG_LEVEL", "").upper()
 | |
| if GLOBAL_LOG_LEVEL in log_levels:
 | |
|     logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL, force=True)
 | |
| else:
 | |
|     GLOBAL_LOG_LEVEL = "INFO"
 | |
| 
 | |
| log = logging.getLogger(__name__)
 | |
| log.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}")
 | |
| 
 | |
| log_sources = [
 | |
|     "AUDIO",
 | |
|     "COMFYUI",
 | |
|     "CONFIG",
 | |
|     "DB",
 | |
|     "IMAGES",
 | |
|     "MAIN",
 | |
|     "MODELS",
 | |
|     "OLLAMA",
 | |
|     "OPENAI",
 | |
|     "RAG",
 | |
|     "WEBHOOK",
 | |
| ]
 | |
| 
 | |
| SRC_LOG_LEVELS = {}
 | |
| 
 | |
| for source in log_sources:
 | |
|     log_env_var = source + "_LOG_LEVEL"
 | |
|     SRC_LOG_LEVELS[source] = os.environ.get(log_env_var, "").upper()
 | |
|     if SRC_LOG_LEVELS[source] not in log_levels:
 | |
|         SRC_LOG_LEVELS[source] = GLOBAL_LOG_LEVEL
 | |
|     log.info(f"{log_env_var}: {SRC_LOG_LEVELS[source]}")
 | |
| 
 | |
| log.setLevel(SRC_LOG_LEVELS["CONFIG"])
 | |
| 
 | |
| 
 | |
| class EndpointFilter(logging.Filter):
 | |
|     def filter(self, record: logging.LogRecord) -> bool:
 | |
|         return record.getMessage().find("/health") == -1
 | |
| 
 | |
| 
 | |
| # Filter out /endpoint
 | |
| logging.getLogger("uvicorn.access").addFilter(EndpointFilter())
 | |
| 
 | |
| 
 | |
| WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI")
 | |
| if WEBUI_NAME != "Open WebUI":
 | |
|     WEBUI_NAME += " (Open WebUI)"
 | |
| 
 | |
| WEBUI_URL = os.environ.get("WEBUI_URL", "http://localhost:3000")
 | |
| 
 | |
| WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png"
 | |
| 
 | |
| 
 | |
| ####################################
 | |
| # ENV (dev,test,prod)
 | |
| ####################################
 | |
| 
 | |
| ENV = os.environ.get("ENV", "dev")
 | |
| 
 | |
| try:
 | |
|     PACKAGE_DATA = json.loads((BASE_DIR / "package.json").read_text())
 | |
| except:
 | |
|     try:
 | |
|         PACKAGE_DATA = {"version": importlib.metadata.version("open-webui")}
 | |
|     except importlib.metadata.PackageNotFoundError:
 | |
|         PACKAGE_DATA = {"version": "0.0.0"}
 | |
| 
 | |
| VERSION = PACKAGE_DATA["version"]
 | |
| 
 | |
| 
 | |
| # Function to parse each section
 | |
| def parse_section(section):
 | |
|     items = []
 | |
|     for li in section.find_all("li"):
 | |
|         # Extract raw HTML string
 | |
|         raw_html = str(li)
 | |
| 
 | |
|         # Extract text without HTML tags
 | |
|         text = li.get_text(separator=" ", strip=True)
 | |
| 
 | |
|         # Split into title and content
 | |
|         parts = text.split(": ", 1)
 | |
|         title = parts[0].strip() if len(parts) > 1 else ""
 | |
|         content = parts[1].strip() if len(parts) > 1 else text
 | |
| 
 | |
|         items.append({"title": title, "content": content, "raw": raw_html})
 | |
|     return items
 | |
| 
 | |
| 
 | |
| try:
 | |
|     changelog_path = BASE_DIR / "CHANGELOG.md"
 | |
|     with open(str(changelog_path.absolute()), "r", encoding="utf8") as file:
 | |
|         changelog_content = file.read()
 | |
| 
 | |
| except:
 | |
|     changelog_content = (pkgutil.get_data("open_webui", "CHANGELOG.md") or b"").decode()
 | |
| 
 | |
| 
 | |
| # Convert markdown content to HTML
 | |
| html_content = markdown.markdown(changelog_content)
 | |
| 
 | |
| # Parse the HTML content
 | |
| soup = BeautifulSoup(html_content, "html.parser")
 | |
| 
 | |
| # Initialize JSON structure
 | |
| changelog_json = {}
 | |
| 
 | |
| # Iterate over each version
 | |
| for version in soup.find_all("h2"):
 | |
|     version_number = version.get_text().strip().split(" - ")[0][1:-1]  # Remove brackets
 | |
|     date = version.get_text().strip().split(" - ")[1]
 | |
| 
 | |
|     version_data = {"date": date}
 | |
| 
 | |
|     # Find the next sibling that is a h3 tag (section title)
 | |
|     current = version.find_next_sibling()
 | |
| 
 | |
|     while current and current.name != "h2":
 | |
|         if current.name == "h3":
 | |
|             section_title = current.get_text().lower()  # e.g., "added", "fixed"
 | |
|             section_items = parse_section(current.find_next_sibling("ul"))
 | |
|             version_data[section_title] = section_items
 | |
| 
 | |
|         # Move to the next element
 | |
|         current = current.find_next_sibling()
 | |
| 
 | |
|     changelog_json[version_number] = version_data
 | |
| 
 | |
| 
 | |
| CHANGELOG = changelog_json
 | |
| 
 | |
| 
 | |
| ####################################
 | |
| # SAFE_MODE
 | |
| ####################################
 | |
| 
 | |
| SAFE_MODE = os.environ.get("SAFE_MODE", "false").lower() == "true"
 | |
| 
 | |
| ####################################
 | |
| # WEBUI_BUILD_HASH
 | |
| ####################################
 | |
| 
 | |
| WEBUI_BUILD_HASH = os.environ.get("WEBUI_BUILD_HASH", "dev-build")
 | |
| 
 | |
| ####################################
 | |
| # DATA/FRONTEND BUILD DIR
 | |
| ####################################
 | |
| 
 | |
| DATA_DIR = Path(os.getenv("DATA_DIR", BACKEND_DIR / "data")).resolve()
 | |
| FRONTEND_BUILD_DIR = Path(os.getenv("FRONTEND_BUILD_DIR", BASE_DIR / "build")).resolve()
 | |
| 
 | |
| RESET_CONFIG_ON_START = (
 | |
|     os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true"
 | |
| )
 | |
| if RESET_CONFIG_ON_START:
 | |
|     try:
 | |
|         os.remove(f"{DATA_DIR}/config.json")
 | |
|         with open(f"{DATA_DIR}/config.json", "w") as f:
 | |
|             f.write("{}")
 | |
|     except:
 | |
|         pass
 | |
| 
 | |
| try:
 | |
|     CONFIG_DATA = json.loads((DATA_DIR / "config.json").read_text())
 | |
| except:
 | |
|     CONFIG_DATA = {}
 | |
| 
 | |
| 
 | |
| ####################################
 | |
| # Config helpers
 | |
| ####################################
 | |
| 
 | |
| 
 | |
| def save_config():
 | |
|     try:
 | |
|         with open(f"{DATA_DIR}/config.json", "w") as f:
 | |
|             json.dump(CONFIG_DATA, f, indent="\t")
 | |
|     except Exception as e:
 | |
|         log.exception(e)
 | |
| 
 | |
| 
 | |
| def get_config_value(config_path: str):
 | |
|     path_parts = config_path.split(".")
 | |
|     cur_config = CONFIG_DATA
 | |
|     for key in path_parts:
 | |
|         if key in cur_config:
 | |
|             cur_config = cur_config[key]
 | |
|         else:
 | |
|             return None
 | |
|     return cur_config
 | |
| 
 | |
| 
 | |
| T = TypeVar("T")
 | |
| 
 | |
| 
 | |
| class PersistentConfig(Generic[T]):
 | |
|     def __init__(self, env_name: str, config_path: str, env_value: T):
 | |
|         self.env_name = env_name
 | |
|         self.config_path = config_path
 | |
|         self.env_value = env_value
 | |
|         self.config_value = get_config_value(config_path)
 | |
|         if self.config_value is not None:
 | |
|             log.info(f"'{env_name}' loaded from config.json")
 | |
|             self.value = self.config_value
 | |
|         else:
 | |
|             self.value = env_value
 | |
| 
 | |
|     def __str__(self):
 | |
|         return str(self.value)
 | |
| 
 | |
|     @property
 | |
|     def __dict__(self):
 | |
|         raise TypeError(
 | |
|             "PersistentConfig object cannot be converted to dict, use config_get or .value instead."
 | |
|         )
 | |
| 
 | |
|     def __getattribute__(self, item):
 | |
|         if item == "__dict__":
 | |
|             raise TypeError(
 | |
|                 "PersistentConfig object cannot be converted to dict, use config_get or .value instead."
 | |
|             )
 | |
|         return super().__getattribute__(item)
 | |
| 
 | |
|     def save(self):
 | |
|         # Don't save if the value is the same as the env value and the config value
 | |
|         if self.env_value == self.value:
 | |
|             if self.config_value == self.value:
 | |
|                 return
 | |
|         log.info(f"Saving '{self.env_name}' to config.json")
 | |
|         path_parts = self.config_path.split(".")
 | |
|         config = CONFIG_DATA
 | |
|         for key in path_parts[:-1]:
 | |
|             if key not in config:
 | |
|                 config[key] = {}
 | |
|             config = config[key]
 | |
|         config[path_parts[-1]] = self.value
 | |
|         save_config()
 | |
|         self.config_value = self.value
 | |
| 
 | |
| 
 | |
| class AppConfig:
 | |
|     _state: dict[str, PersistentConfig]
 | |
| 
 | |
|     def __init__(self):
 | |
|         super().__setattr__("_state", {})
 | |
| 
 | |
|     def __setattr__(self, key, value):
 | |
|         if isinstance(value, PersistentConfig):
 | |
|             self._state[key] = value
 | |
|         else:
 | |
|             self._state[key].value = value
 | |
|             self._state[key].save()
 | |
| 
 | |
|     def __getattr__(self, key):
 | |
|         return self._state[key].value
 | |
| 
 | |
| 
 | |
| ####################################
 | |
| # WEBUI_AUTH (Required for security)
 | |
| ####################################
 | |
| 
 | |
| WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true"
 | |
| WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get(
 | |
|     "WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None
 | |
| )
 | |
| WEBUI_AUTH_TRUSTED_NAME_HEADER = os.environ.get("WEBUI_AUTH_TRUSTED_NAME_HEADER", None)
 | |
| JWT_EXPIRES_IN = PersistentConfig(
 | |
|     "JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1")
 | |
| )
 | |
| 
 | |
| ####################################
 | |
| # OAuth config
 | |
| ####################################
 | |
| 
 | |
| ENABLE_OAUTH_SIGNUP = PersistentConfig(
 | |
|     "ENABLE_OAUTH_SIGNUP",
 | |
|     "oauth.enable_signup",
 | |
|     os.environ.get("ENABLE_OAUTH_SIGNUP", "False").lower() == "true",
 | |
| )
 | |
| 
 | |
| OAUTH_MERGE_ACCOUNTS_BY_EMAIL = PersistentConfig(
 | |
|     "OAUTH_MERGE_ACCOUNTS_BY_EMAIL",
 | |
|     "oauth.merge_accounts_by_email",
 | |
|     os.environ.get("OAUTH_MERGE_ACCOUNTS_BY_EMAIL", "False").lower() == "true",
 | |
| )
 | |
| 
 | |
| OAUTH_PROVIDERS = {}
 | |
| 
 | |
| GOOGLE_CLIENT_ID = PersistentConfig(
 | |
|     "GOOGLE_CLIENT_ID",
 | |
|     "oauth.google.client_id",
 | |
|     os.environ.get("GOOGLE_CLIENT_ID", ""),
 | |
| )
 | |
| 
 | |
| GOOGLE_CLIENT_SECRET = PersistentConfig(
 | |
|     "GOOGLE_CLIENT_SECRET",
 | |
|     "oauth.google.client_secret",
 | |
|     os.environ.get("GOOGLE_CLIENT_SECRET", ""),
 | |
| )
 | |
| 
 | |
| GOOGLE_OAUTH_SCOPE = PersistentConfig(
 | |
|     "GOOGLE_OAUTH_SCOPE",
 | |
|     "oauth.google.scope",
 | |
|     os.environ.get("GOOGLE_OAUTH_SCOPE", "openid email profile"),
 | |
| )
 | |
| 
 | |
| GOOGLE_REDIRECT_URI = PersistentConfig(
 | |
|     "GOOGLE_REDIRECT_URI",
 | |
|     "oauth.google.redirect_uri",
 | |
|     os.environ.get("GOOGLE_REDIRECT_URI", ""),
 | |
| )
 | |
| 
 | |
| MICROSOFT_CLIENT_ID = PersistentConfig(
 | |
|     "MICROSOFT_CLIENT_ID",
 | |
|     "oauth.microsoft.client_id",
 | |
|     os.environ.get("MICROSOFT_CLIENT_ID", ""),
 | |
| )
 | |
| 
 | |
| MICROSOFT_CLIENT_SECRET = PersistentConfig(
 | |
|     "MICROSOFT_CLIENT_SECRET",
 | |
|     "oauth.microsoft.client_secret",
 | |
|     os.environ.get("MICROSOFT_CLIENT_SECRET", ""),
 | |
| )
 | |
| 
 | |
| MICROSOFT_CLIENT_TENANT_ID = PersistentConfig(
 | |
|     "MICROSOFT_CLIENT_TENANT_ID",
 | |
|     "oauth.microsoft.tenant_id",
 | |
|     os.environ.get("MICROSOFT_CLIENT_TENANT_ID", ""),
 | |
| )
 | |
| 
 | |
| MICROSOFT_OAUTH_SCOPE = PersistentConfig(
 | |
|     "MICROSOFT_OAUTH_SCOPE",
 | |
|     "oauth.microsoft.scope",
 | |
|     os.environ.get("MICROSOFT_OAUTH_SCOPE", "openid email profile"),
 | |
| )
 | |
| 
 | |
| MICROSOFT_REDIRECT_URI = PersistentConfig(
 | |
|     "MICROSOFT_REDIRECT_URI",
 | |
|     "oauth.microsoft.redirect_uri",
 | |
|     os.environ.get("MICROSOFT_REDIRECT_URI", ""),
 | |
| )
 | |
| 
 | |
| OAUTH_CLIENT_ID = PersistentConfig(
 | |
|     "OAUTH_CLIENT_ID",
 | |
|     "oauth.oidc.client_id",
 | |
|     os.environ.get("OAUTH_CLIENT_ID", ""),
 | |
| )
 | |
| 
 | |
| OAUTH_CLIENT_SECRET = PersistentConfig(
 | |
|     "OAUTH_CLIENT_SECRET",
 | |
|     "oauth.oidc.client_secret",
 | |
|     os.environ.get("OAUTH_CLIENT_SECRET", ""),
 | |
| )
 | |
| 
 | |
| OPENID_PROVIDER_URL = PersistentConfig(
 | |
|     "OPENID_PROVIDER_URL",
 | |
|     "oauth.oidc.provider_url",
 | |
|     os.environ.get("OPENID_PROVIDER_URL", ""),
 | |
| )
 | |
| 
 | |
| OPENID_REDIRECT_URI = PersistentConfig(
 | |
|     "OPENID_REDIRECT_URI",
 | |
|     "oauth.oidc.redirect_uri",
 | |
|     os.environ.get("OPENID_REDIRECT_URI", ""),
 | |
| )
 | |
| 
 | |
| OAUTH_SCOPES = PersistentConfig(
 | |
|     "OAUTH_SCOPES",
 | |
|     "oauth.oidc.scopes",
 | |
|     os.environ.get("OAUTH_SCOPES", "openid email profile"),
 | |
| )
 | |
| 
 | |
| OAUTH_PROVIDER_NAME = PersistentConfig(
 | |
|     "OAUTH_PROVIDER_NAME",
 | |
|     "oauth.oidc.provider_name",
 | |
|     os.environ.get("OAUTH_PROVIDER_NAME", "SSO"),
 | |
| )
 | |
| 
 | |
| OAUTH_USERNAME_CLAIM = PersistentConfig(
 | |
|     "OAUTH_USERNAME_CLAIM",
 | |
|     "oauth.oidc.username_claim",
 | |
|     os.environ.get("OAUTH_USERNAME_CLAIM", "name"),
 | |
| )
 | |
| 
 | |
| OAUTH_PICTURE_CLAIM = PersistentConfig(
 | |
|     "OAUTH_USERNAME_CLAIM",
 | |
|     "oauth.oidc.avatar_claim",
 | |
|     os.environ.get("OAUTH_PICTURE_CLAIM", "picture"),
 | |
| )
 | |
| 
 | |
| 
 | |
| def load_oauth_providers():
 | |
|     OAUTH_PROVIDERS.clear()
 | |
|     if GOOGLE_CLIENT_ID.value and GOOGLE_CLIENT_SECRET.value:
 | |
|         OAUTH_PROVIDERS["google"] = {
 | |
|             "client_id": GOOGLE_CLIENT_ID.value,
 | |
|             "client_secret": GOOGLE_CLIENT_SECRET.value,
 | |
|             "server_metadata_url": "https://accounts.google.com/.well-known/openid-configuration",
 | |
|             "scope": GOOGLE_OAUTH_SCOPE.value,
 | |
|             "redirect_uri": GOOGLE_REDIRECT_URI.value,
 | |
|         }
 | |
| 
 | |
|     if (
 | |
|         MICROSOFT_CLIENT_ID.value
 | |
|         and MICROSOFT_CLIENT_SECRET.value
 | |
|         and MICROSOFT_CLIENT_TENANT_ID.value
 | |
|     ):
 | |
|         OAUTH_PROVIDERS["microsoft"] = {
 | |
|             "client_id": MICROSOFT_CLIENT_ID.value,
 | |
|             "client_secret": MICROSOFT_CLIENT_SECRET.value,
 | |
|             "server_metadata_url": f"https://login.microsoftonline.com/{MICROSOFT_CLIENT_TENANT_ID.value}/v2.0/.well-known/openid-configuration",
 | |
|             "scope": MICROSOFT_OAUTH_SCOPE.value,
 | |
|             "redirect_uri": MICROSOFT_REDIRECT_URI.value,
 | |
|         }
 | |
| 
 | |
|     if (
 | |
|         OAUTH_CLIENT_ID.value
 | |
|         and OAUTH_CLIENT_SECRET.value
 | |
|         and OPENID_PROVIDER_URL.value
 | |
|     ):
 | |
|         OAUTH_PROVIDERS["oidc"] = {
 | |
|             "client_id": OAUTH_CLIENT_ID.value,
 | |
|             "client_secret": OAUTH_CLIENT_SECRET.value,
 | |
|             "server_metadata_url": OPENID_PROVIDER_URL.value,
 | |
|             "scope": OAUTH_SCOPES.value,
 | |
|             "name": OAUTH_PROVIDER_NAME.value,
 | |
|             "redirect_uri": OPENID_REDIRECT_URI.value,
 | |
|         }
 | |
| 
 | |
| 
 | |
| load_oauth_providers()
 | |
| 
 | |
| ####################################
 | |
| # Static DIR
 | |
| ####################################
 | |
| 
 | |
| STATIC_DIR = Path(os.getenv("STATIC_DIR", BACKEND_DIR / "static")).resolve()
 | |
| 
 | |
| frontend_favicon = FRONTEND_BUILD_DIR / "static" / "favicon.png"
 | |
| 
 | |
| if frontend_favicon.exists():
 | |
|     try:
 | |
|         shutil.copyfile(frontend_favicon, STATIC_DIR / "favicon.png")
 | |
|     except Exception as e:
 | |
|         logging.error(f"An error occurred: {e}")
 | |
| else:
 | |
|     logging.warning(f"Frontend favicon not found at {frontend_favicon}")
 | |
| 
 | |
| frontend_splash = FRONTEND_BUILD_DIR / "static" / "splash.png"
 | |
| 
 | |
| if frontend_splash.exists():
 | |
|     try:
 | |
|         shutil.copyfile(frontend_splash, STATIC_DIR / "splash.png")
 | |
|     except Exception as e:
 | |
|         logging.error(f"An error occurred: {e}")
 | |
| else:
 | |
|     logging.warning(f"Frontend splash not found at {frontend_splash}")
 | |
| 
 | |
| 
 | |
| ####################################
 | |
| # CUSTOM_NAME
 | |
| ####################################
 | |
| 
 | |
| CUSTOM_NAME = os.environ.get("CUSTOM_NAME", "")
 | |
| 
 | |
| if CUSTOM_NAME:
 | |
|     try:
 | |
|         r = requests.get(f"https://api.openwebui.com/api/v1/custom/{CUSTOM_NAME}")
 | |
|         data = r.json()
 | |
|         if r.ok:
 | |
|             if "logo" in data:
 | |
|                 WEBUI_FAVICON_URL = url = (
 | |
|                     f"https://api.openwebui.com{data['logo']}"
 | |
|                     if data["logo"][0] == "/"
 | |
|                     else data["logo"]
 | |
|                 )
 | |
| 
 | |
|                 r = requests.get(url, stream=True)
 | |
|                 if r.status_code == 200:
 | |
|                     with open(f"{STATIC_DIR}/favicon.png", "wb") as f:
 | |
|                         r.raw.decode_content = True
 | |
|                         shutil.copyfileobj(r.raw, f)
 | |
| 
 | |
|             if "splash" in data:
 | |
|                 url = (
 | |
|                     f"https://api.openwebui.com{data['splash']}"
 | |
|                     if data["splash"][0] == "/"
 | |
|                     else data["splash"]
 | |
|                 )
 | |
| 
 | |
|                 r = requests.get(url, stream=True)
 | |
|                 if r.status_code == 200:
 | |
|                     with open(f"{STATIC_DIR}/splash.png", "wb") as f:
 | |
|                         r.raw.decode_content = True
 | |
|                         shutil.copyfileobj(r.raw, f)
 | |
| 
 | |
|             WEBUI_NAME = data["name"]
 | |
|     except Exception as e:
 | |
|         log.exception(e)
 | |
|         pass
 | |
| 
 | |
| 
 | |
| ####################################
 | |
| # File Upload DIR
 | |
| ####################################
 | |
| 
 | |
| UPLOAD_DIR = f"{DATA_DIR}/uploads"
 | |
| Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True)
 | |
| 
 | |
| 
 | |
| ####################################
 | |
| # Cache DIR
 | |
| ####################################
 | |
| 
 | |
| CACHE_DIR = f"{DATA_DIR}/cache"
 | |
| Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
 | |
| 
 | |
| 
 | |
| ####################################
 | |
| # Docs DIR
 | |
| ####################################
 | |
| 
 | |
| DOCS_DIR = os.getenv("DOCS_DIR", f"{DATA_DIR}/docs")
 | |
| Path(DOCS_DIR).mkdir(parents=True, exist_ok=True)
 | |
| 
 | |
| 
 | |
| ####################################
 | |
| # Tools DIR
 | |
| ####################################
 | |
| 
 | |
| TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools")
 | |
| Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True)
 | |
| 
 | |
| 
 | |
| ####################################
 | |
| # Functions DIR
 | |
| ####################################
 | |
| 
 | |
| FUNCTIONS_DIR = os.getenv("FUNCTIONS_DIR", f"{DATA_DIR}/functions")
 | |
| Path(FUNCTIONS_DIR).mkdir(parents=True, exist_ok=True)
 | |
| 
 | |
| 
 | |
| ####################################
 | |
| # LITELLM_CONFIG
 | |
| ####################################
 | |
| 
 | |
| 
 | |
| def create_config_file(file_path):
 | |
|     directory = os.path.dirname(file_path)
 | |
| 
 | |
|     # Check if directory exists, if not, create it
 | |
|     if not os.path.exists(directory):
 | |
|         os.makedirs(directory)
 | |
| 
 | |
|     # Data to write into the YAML file
 | |
|     config_data = {
 | |
|         "general_settings": {},
 | |
|         "litellm_settings": {},
 | |
|         "model_list": [],
 | |
|         "router_settings": {},
 | |
|     }
 | |
| 
 | |
|     # Write data to YAML file
 | |
|     with open(file_path, "w") as file:
 | |
|         yaml.dump(config_data, file)
 | |
| 
 | |
| 
 | |
| LITELLM_CONFIG_PATH = f"{DATA_DIR}/litellm/config.yaml"
 | |
| 
 | |
| # if not os.path.exists(LITELLM_CONFIG_PATH):
 | |
| #     log.info("Config file doesn't exist. Creating...")
 | |
| #     create_config_file(LITELLM_CONFIG_PATH)
 | |
| #     log.info("Config file created successfully.")
 | |
| 
 | |
| 
 | |
| ####################################
 | |
| # OLLAMA_BASE_URL
 | |
| ####################################
 | |
| 
 | |
| 
 | |
| ENABLE_OLLAMA_API = PersistentConfig(
 | |
|     "ENABLE_OLLAMA_API",
 | |
|     "ollama.enable",
 | |
|     os.environ.get("ENABLE_OLLAMA_API", "True").lower() == "true",
 | |
| )
 | |
| 
 | |
| OLLAMA_API_BASE_URL = os.environ.get(
 | |
|     "OLLAMA_API_BASE_URL", "http://localhost:11434/api"
 | |
| )
 | |
| 
 | |
| OLLAMA_BASE_URL = os.environ.get("OLLAMA_BASE_URL", "")
 | |
| AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "")
 | |
| 
 | |
| if AIOHTTP_CLIENT_TIMEOUT == "":
 | |
|     AIOHTTP_CLIENT_TIMEOUT = None
 | |
| else:
 | |
|     try:
 | |
|         AIOHTTP_CLIENT_TIMEOUT = int(AIOHTTP_CLIENT_TIMEOUT)
 | |
|     except:
 | |
|         AIOHTTP_CLIENT_TIMEOUT = 300
 | |
| 
 | |
| 
 | |
| K8S_FLAG = os.environ.get("K8S_FLAG", "")
 | |
| USE_OLLAMA_DOCKER = os.environ.get("USE_OLLAMA_DOCKER", "false")
 | |
| 
 | |
| if OLLAMA_BASE_URL == "" and OLLAMA_API_BASE_URL != "":
 | |
|     OLLAMA_BASE_URL = (
 | |
|         OLLAMA_API_BASE_URL[:-4]
 | |
|         if OLLAMA_API_BASE_URL.endswith("/api")
 | |
|         else OLLAMA_API_BASE_URL
 | |
|     )
 | |
| 
 | |
| if ENV == "prod":
 | |
|     if OLLAMA_BASE_URL == "/ollama" and not K8S_FLAG:
 | |
|         if USE_OLLAMA_DOCKER.lower() == "true":
 | |
|             # if you use all-in-one docker container (Open WebUI + Ollama)
 | |
|             # with the docker build arg USE_OLLAMA=true (--build-arg="USE_OLLAMA=true") this only works with http://localhost:11434
 | |
|             OLLAMA_BASE_URL = "http://localhost:11434"
 | |
|         else:
 | |
|             OLLAMA_BASE_URL = "http://host.docker.internal:11434"
 | |
|     elif K8S_FLAG:
 | |
|         OLLAMA_BASE_URL = "http://ollama-service.open-webui.svc.cluster.local:11434"
 | |
| 
 | |
| 
 | |
| OLLAMA_BASE_URLS = os.environ.get("OLLAMA_BASE_URLS", "")
 | |
| OLLAMA_BASE_URLS = OLLAMA_BASE_URLS if OLLAMA_BASE_URLS != "" else OLLAMA_BASE_URL
 | |
| 
 | |
| OLLAMA_BASE_URLS = [url.strip() for url in OLLAMA_BASE_URLS.split(";")]
 | |
| OLLAMA_BASE_URLS = PersistentConfig(
 | |
|     "OLLAMA_BASE_URLS", "ollama.base_urls", OLLAMA_BASE_URLS
 | |
| )
 | |
| 
 | |
| ####################################
 | |
| # OPENAI_API
 | |
| ####################################
 | |
| 
 | |
| 
 | |
| ENABLE_OPENAI_API = PersistentConfig(
 | |
|     "ENABLE_OPENAI_API",
 | |
|     "openai.enable",
 | |
|     os.environ.get("ENABLE_OPENAI_API", "True").lower() == "true",
 | |
| )
 | |
| 
 | |
| 
 | |
| OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
 | |
| OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "")
 | |
| 
 | |
| 
 | |
| if OPENAI_API_BASE_URL == "":
 | |
|     OPENAI_API_BASE_URL = "https://api.openai.com/v1"
 | |
| 
 | |
| OPENAI_API_KEYS = os.environ.get("OPENAI_API_KEYS", "")
 | |
| OPENAI_API_KEYS = OPENAI_API_KEYS if OPENAI_API_KEYS != "" else OPENAI_API_KEY
 | |
| 
 | |
| OPENAI_API_KEYS = [url.strip() for url in OPENAI_API_KEYS.split(";")]
 | |
| OPENAI_API_KEYS = PersistentConfig(
 | |
|     "OPENAI_API_KEYS", "openai.api_keys", OPENAI_API_KEYS
 | |
| )
 | |
| 
 | |
| OPENAI_API_BASE_URLS = os.environ.get("OPENAI_API_BASE_URLS", "")
 | |
| OPENAI_API_BASE_URLS = (
 | |
|     OPENAI_API_BASE_URLS if OPENAI_API_BASE_URLS != "" else OPENAI_API_BASE_URL
 | |
| )
 | |
| 
 | |
| OPENAI_API_BASE_URLS = [
 | |
|     url.strip() if url != "" else "https://api.openai.com/v1"
 | |
|     for url in OPENAI_API_BASE_URLS.split(";")
 | |
| ]
 | |
| OPENAI_API_BASE_URLS = PersistentConfig(
 | |
|     "OPENAI_API_BASE_URLS", "openai.api_base_urls", OPENAI_API_BASE_URLS
 | |
| )
 | |
| 
 | |
| OPENAI_API_KEY = ""
 | |
| 
 | |
| try:
 | |
|     OPENAI_API_KEY = OPENAI_API_KEYS.value[
 | |
|         OPENAI_API_BASE_URLS.value.index("https://api.openai.com/v1")
 | |
|     ]
 | |
| except:
 | |
|     pass
 | |
| 
 | |
| OPENAI_API_BASE_URL = "https://api.openai.com/v1"
 | |
| 
 | |
| ####################################
 | |
| # WEBUI
 | |
| ####################################
 | |
| 
 | |
| ENABLE_SIGNUP = PersistentConfig(
 | |
|     "ENABLE_SIGNUP",
 | |
|     "ui.enable_signup",
 | |
|     (
 | |
|         False
 | |
|         if not WEBUI_AUTH
 | |
|         else os.environ.get("ENABLE_SIGNUP", "True").lower() == "true"
 | |
|     ),
 | |
| )
 | |
| 
 | |
| ENABLE_LOGIN_FORM = PersistentConfig(
 | |
|     "ENABLE_LOGIN_FORM",
 | |
|     "ui.ENABLE_LOGIN_FORM",
 | |
|     os.environ.get("ENABLE_LOGIN_FORM", "True").lower() == "true",
 | |
| )
 | |
| 
 | |
| DEFAULT_LOCALE = PersistentConfig(
 | |
|     "DEFAULT_LOCALE",
 | |
|     "ui.default_locale",
 | |
|     os.environ.get("DEFAULT_LOCALE", ""),
 | |
| )
 | |
| 
 | |
| DEFAULT_MODELS = PersistentConfig(
 | |
|     "DEFAULT_MODELS", "ui.default_models", os.environ.get("DEFAULT_MODELS", None)
 | |
| )
 | |
| 
 | |
| DEFAULT_PROMPT_SUGGESTIONS = PersistentConfig(
 | |
|     "DEFAULT_PROMPT_SUGGESTIONS",
 | |
|     "ui.prompt_suggestions",
 | |
|     [
 | |
|         {
 | |
|             "title": ["Help me study", "vocabulary for a college entrance exam"],
 | |
|             "content": "Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option.",
 | |
|         },
 | |
|         {
 | |
|             "title": ["Give me ideas", "for what to do with my kids' art"],
 | |
|             "content": "What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter.",
 | |
|         },
 | |
|         {
 | |
|             "title": ["Tell me a fun fact", "about the Roman Empire"],
 | |
|             "content": "Tell me a random fun fact about the Roman Empire",
 | |
|         },
 | |
|         {
 | |
|             "title": ["Show me a code snippet", "of a website's sticky header"],
 | |
|             "content": "Show me a code snippet of a website's sticky header in CSS and JavaScript.",
 | |
|         },
 | |
|         {
 | |
|             "title": [
 | |
|                 "Explain options trading",
 | |
|                 "if I'm familiar with buying and selling stocks",
 | |
|             ],
 | |
|             "content": "Explain options trading in simple terms if I'm familiar with buying and selling stocks.",
 | |
|         },
 | |
|         {
 | |
|             "title": ["Overcome procrastination", "give me tips"],
 | |
|             "content": "Could you start by asking me about instances when I procrastinate the most and then give me some suggestions to overcome it?",
 | |
|         },
 | |
|     ],
 | |
| )
 | |
| 
 | |
| DEFAULT_USER_ROLE = PersistentConfig(
 | |
|     "DEFAULT_USER_ROLE",
 | |
|     "ui.default_user_role",
 | |
|     os.getenv("DEFAULT_USER_ROLE", "pending"),
 | |
| )
 | |
| 
 | |
| USER_PERMISSIONS_CHAT_DELETION = (
 | |
|     os.environ.get("USER_PERMISSIONS_CHAT_DELETION", "True").lower() == "true"
 | |
| )
 | |
| 
 | |
| USER_PERMISSIONS = PersistentConfig(
 | |
|     "USER_PERMISSIONS",
 | |
|     "ui.user_permissions",
 | |
|     {"chat": {"deletion": USER_PERMISSIONS_CHAT_DELETION}},
 | |
| )
 | |
| 
 | |
| ENABLE_MODEL_FILTER = PersistentConfig(
 | |
|     "ENABLE_MODEL_FILTER",
 | |
|     "model_filter.enable",
 | |
|     os.environ.get("ENABLE_MODEL_FILTER", "False").lower() == "true",
 | |
| )
 | |
| MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "")
 | |
| MODEL_FILTER_LIST = PersistentConfig(
 | |
|     "MODEL_FILTER_LIST",
 | |
|     "model_filter.list",
 | |
|     [model.strip() for model in MODEL_FILTER_LIST.split(";")],
 | |
| )
 | |
| 
 | |
| WEBHOOK_URL = PersistentConfig(
 | |
|     "WEBHOOK_URL", "webhook_url", os.environ.get("WEBHOOK_URL", "")
 | |
| )
 | |
| 
 | |
| ENABLE_ADMIN_EXPORT = os.environ.get("ENABLE_ADMIN_EXPORT", "True").lower() == "true"
 | |
| 
 | |
| ENABLE_COMMUNITY_SHARING = PersistentConfig(
 | |
|     "ENABLE_COMMUNITY_SHARING",
 | |
|     "ui.enable_community_sharing",
 | |
|     os.environ.get("ENABLE_COMMUNITY_SHARING", "True").lower() == "true",
 | |
| )
 | |
| 
 | |
| 
 | |
| class BannerModel(BaseModel):
 | |
|     id: str
 | |
|     type: str
 | |
|     title: Optional[str] = None
 | |
|     content: str
 | |
|     dismissible: bool
 | |
|     timestamp: int
 | |
| 
 | |
| 
 | |
| try:
 | |
|     banners = json.loads(os.environ.get("WEBUI_BANNERS", "[]"))
 | |
|     banners = [BannerModel(**banner) for banner in banners]
 | |
| except Exception as e:
 | |
|     print(f"Error loading WEBUI_BANNERS: {e}")
 | |
|     banners = []
 | |
| 
 | |
| WEBUI_BANNERS = PersistentConfig("WEBUI_BANNERS", "ui.banners", banners)
 | |
| 
 | |
| 
 | |
| SHOW_ADMIN_DETAILS = PersistentConfig(
 | |
|     "SHOW_ADMIN_DETAILS",
 | |
|     "auth.admin.show",
 | |
|     os.environ.get("SHOW_ADMIN_DETAILS", "true").lower() == "true",
 | |
| )
 | |
| 
 | |
| ADMIN_EMAIL = PersistentConfig(
 | |
|     "ADMIN_EMAIL",
 | |
|     "auth.admin.email",
 | |
|     os.environ.get("ADMIN_EMAIL", None),
 | |
| )
 | |
| 
 | |
| 
 | |
| ####################################
 | |
| # TASKS
 | |
| ####################################
 | |
| 
 | |
| 
 | |
| TASK_MODEL = PersistentConfig(
 | |
|     "TASK_MODEL",
 | |
|     "task.model.default",
 | |
|     os.environ.get("TASK_MODEL", ""),
 | |
| )
 | |
| 
 | |
| TASK_MODEL_EXTERNAL = PersistentConfig(
 | |
|     "TASK_MODEL_EXTERNAL",
 | |
|     "task.model.external",
 | |
|     os.environ.get("TASK_MODEL_EXTERNAL", ""),
 | |
| )
 | |
| 
 | |
| TITLE_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
 | |
|     "TITLE_GENERATION_PROMPT_TEMPLATE",
 | |
|     "task.title.prompt_template",
 | |
|     os.environ.get(
 | |
|         "TITLE_GENERATION_PROMPT_TEMPLATE",
 | |
|         """Here is the query:
 | |
| {{prompt:middletruncate:8000}}
 | |
| 
 | |
| Create a concise, 3-5 word phrase with an emoji as a title for the previous query. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT.
 | |
| 
 | |
| Examples of titles:
 | |
| 📉 Stock Market Trends
 | |
| 🍪 Perfect Chocolate Chip Recipe
 | |
| Evolution of Music Streaming
 | |
| Remote Work Productivity Tips
 | |
| Artificial Intelligence in Healthcare
 | |
| 🎮 Video Game Development Insights""",
 | |
|     ),
 | |
| )
 | |
| 
 | |
| 
 | |
| SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
 | |
|     "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE",
 | |
|     "task.search.prompt_template",
 | |
|     os.environ.get(
 | |
|         "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE",
 | |
|         """You are tasked with generating web search queries. Give me an appropriate query to answer my question for google search. Answer with only the query. Today is {{CURRENT_DATE}}.
 | |
|         
 | |
| Question:
 | |
| {{prompt:end:4000}}""",
 | |
|     ),
 | |
| )
 | |
| 
 | |
| SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = PersistentConfig(
 | |
|     "SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD",
 | |
|     "task.search.prompt_length_threshold",
 | |
|     int(
 | |
|         os.environ.get(
 | |
|             "SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD",
 | |
|             100,
 | |
|         )
 | |
|     ),
 | |
| )
 | |
| 
 | |
| TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
 | |
|     "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
 | |
|     "task.tools.prompt_template",
 | |
|     os.environ.get(
 | |
|         "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
 | |
|         """Tools: {{TOOLS}}
 | |
| If a function tool doesn't match the query, return an empty string. Else, pick a function tool, fill in the parameters from the function tool's schema, and return it in the format { "name": \"functionName\", "parameters": { "key": "value" } }. Only pick a function if the user asks.  Only return the object. Do not return any other text.""",
 | |
|     ),
 | |
| )
 | |
| 
 | |
| 
 | |
| ####################################
 | |
| # WEBUI_SECRET_KEY
 | |
| ####################################
 | |
| 
 | |
| WEBUI_SECRET_KEY = os.environ.get(
 | |
|     "WEBUI_SECRET_KEY",
 | |
|     os.environ.get(
 | |
|         "WEBUI_JWT_SECRET_KEY", "t0p-s3cr3t"
 | |
|     ),  # DEPRECATED: remove at next major version
 | |
| )
 | |
| 
 | |
| WEBUI_SESSION_COOKIE_SAME_SITE = os.environ.get(
 | |
|     "WEBUI_SESSION_COOKIE_SAME_SITE",
 | |
|     os.environ.get("WEBUI_SESSION_COOKIE_SAME_SITE", "lax"),
 | |
| )
 | |
| 
 | |
| WEBUI_SESSION_COOKIE_SECURE = os.environ.get(
 | |
|     "WEBUI_SESSION_COOKIE_SECURE",
 | |
|     os.environ.get("WEBUI_SESSION_COOKIE_SECURE", "false").lower() == "true",
 | |
| )
 | |
| 
 | |
| if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
 | |
|     raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND)
 | |
| 
 | |
| ####################################
 | |
| # RAG document content extraction
 | |
| ####################################
 | |
| 
 | |
| CONTENT_EXTRACTION_ENGINE = PersistentConfig(
 | |
|     "CONTENT_EXTRACTION_ENGINE",
 | |
|     "rag.CONTENT_EXTRACTION_ENGINE",
 | |
|     os.environ.get("CONTENT_EXTRACTION_ENGINE", "").lower(),
 | |
| )
 | |
| 
 | |
| TIKA_SERVER_URL = PersistentConfig(
 | |
|     "TIKA_SERVER_URL",
 | |
|     "rag.tika_server_url",
 | |
|     os.getenv("TIKA_SERVER_URL", "http://tika:9998"),  # Default for sidecar deployment
 | |
| )
 | |
| 
 | |
| ####################################
 | |
| # RAG
 | |
| ####################################
 | |
| 
 | |
| CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
 | |
| CHROMA_TENANT = os.environ.get("CHROMA_TENANT", chromadb.DEFAULT_TENANT)
 | |
| CHROMA_DATABASE = os.environ.get("CHROMA_DATABASE", chromadb.DEFAULT_DATABASE)
 | |
| CHROMA_HTTP_HOST = os.environ.get("CHROMA_HTTP_HOST", "")
 | |
| CHROMA_HTTP_PORT = int(os.environ.get("CHROMA_HTTP_PORT", "8000"))
 | |
| # Comma-separated list of header=value pairs
 | |
| CHROMA_HTTP_HEADERS = os.environ.get("CHROMA_HTTP_HEADERS", "")
 | |
| if CHROMA_HTTP_HEADERS:
 | |
|     CHROMA_HTTP_HEADERS = dict(
 | |
|         [pair.split("=") for pair in CHROMA_HTTP_HEADERS.split(",")]
 | |
|     )
 | |
| else:
 | |
|     CHROMA_HTTP_HEADERS = None
 | |
| CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true"
 | |
| # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2)
 | |
| 
 | |
| RAG_TOP_K = PersistentConfig(
 | |
|     "RAG_TOP_K", "rag.top_k", int(os.environ.get("RAG_TOP_K", "5"))
 | |
| )
 | |
| RAG_RELEVANCE_THRESHOLD = PersistentConfig(
 | |
|     "RAG_RELEVANCE_THRESHOLD",
 | |
|     "rag.relevance_threshold",
 | |
|     float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0")),
 | |
| )
 | |
| 
 | |
| ENABLE_RAG_HYBRID_SEARCH = PersistentConfig(
 | |
|     "ENABLE_RAG_HYBRID_SEARCH",
 | |
|     "rag.enable_hybrid_search",
 | |
|     os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true",
 | |
| )
 | |
| 
 | |
| ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = PersistentConfig(
 | |
|     "ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION",
 | |
|     "rag.enable_web_loader_ssl_verification",
 | |
|     os.environ.get("ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION", "True").lower() == "true",
 | |
| )
 | |
| 
 | |
| RAG_EMBEDDING_ENGINE = PersistentConfig(
 | |
|     "RAG_EMBEDDING_ENGINE",
 | |
|     "rag.embedding_engine",
 | |
|     os.environ.get("RAG_EMBEDDING_ENGINE", ""),
 | |
| )
 | |
| 
 | |
| PDF_EXTRACT_IMAGES = PersistentConfig(
 | |
|     "PDF_EXTRACT_IMAGES",
 | |
|     "rag.pdf_extract_images",
 | |
|     os.environ.get("PDF_EXTRACT_IMAGES", "False").lower() == "true",
 | |
| )
 | |
| 
 | |
| RAG_EMBEDDING_MODEL = PersistentConfig(
 | |
|     "RAG_EMBEDDING_MODEL",
 | |
|     "rag.embedding_model",
 | |
|     os.environ.get("RAG_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2"),
 | |
| )
 | |
| log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL.value}"),
 | |
| 
 | |
| RAG_EMBEDDING_MODEL_AUTO_UPDATE = (
 | |
|     os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true"
 | |
| )
 | |
| 
 | |
| RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
 | |
|     os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
 | |
| )
 | |
| 
 | |
| RAG_EMBEDDING_OPENAI_BATCH_SIZE = PersistentConfig(
 | |
|     "RAG_EMBEDDING_OPENAI_BATCH_SIZE",
 | |
|     "rag.embedding_openai_batch_size",
 | |
|     os.environ.get("RAG_EMBEDDING_OPENAI_BATCH_SIZE", 1),
 | |
| )
 | |
| 
 | |
| RAG_RERANKING_MODEL = PersistentConfig(
 | |
|     "RAG_RERANKING_MODEL",
 | |
|     "rag.reranking_model",
 | |
|     os.environ.get("RAG_RERANKING_MODEL", ""),
 | |
| )
 | |
| if RAG_RERANKING_MODEL.value != "":
 | |
|     log.info(f"Reranking model set: {RAG_RERANKING_MODEL.value}"),
 | |
| 
 | |
| RAG_RERANKING_MODEL_AUTO_UPDATE = (
 | |
|     os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "").lower() == "true"
 | |
| )
 | |
| 
 | |
| RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = (
 | |
|     os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
 | |
| )
 | |
| 
 | |
| 
 | |
| if CHROMA_HTTP_HOST != "":
 | |
|     CHROMA_CLIENT = chromadb.HttpClient(
 | |
|         host=CHROMA_HTTP_HOST,
 | |
|         port=CHROMA_HTTP_PORT,
 | |
|         headers=CHROMA_HTTP_HEADERS,
 | |
|         ssl=CHROMA_HTTP_SSL,
 | |
|         tenant=CHROMA_TENANT,
 | |
|         database=CHROMA_DATABASE,
 | |
|         settings=Settings(allow_reset=True, anonymized_telemetry=False),
 | |
|     )
 | |
| else:
 | |
|     CHROMA_CLIENT = chromadb.PersistentClient(
 | |
|         path=CHROMA_DATA_PATH,
 | |
|         settings=Settings(allow_reset=True, anonymized_telemetry=False),
 | |
|         tenant=CHROMA_TENANT,
 | |
|         database=CHROMA_DATABASE,
 | |
|     )
 | |
| 
 | |
| 
 | |
| # device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance
 | |
| USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false")
 | |
| 
 | |
| if USE_CUDA.lower() == "true":
 | |
|     DEVICE_TYPE = "cuda"
 | |
| else:
 | |
|     DEVICE_TYPE = "cpu"
 | |
| 
 | |
| CHUNK_SIZE = PersistentConfig(
 | |
|     "CHUNK_SIZE", "rag.chunk_size", int(os.environ.get("CHUNK_SIZE", "1500"))
 | |
| )
 | |
| CHUNK_OVERLAP = PersistentConfig(
 | |
|     "CHUNK_OVERLAP",
 | |
|     "rag.chunk_overlap",
 | |
|     int(os.environ.get("CHUNK_OVERLAP", "100")),
 | |
| )
 | |
| 
 | |
| DEFAULT_RAG_TEMPLATE = """Use the following context as your learned knowledge, inside <context></context> XML tags.
 | |
| <context>
 | |
|     [context]
 | |
| </context>
 | |
| 
 | |
| When answer to user:
 | |
| - If you don't know, just say that you don't know.
 | |
| - If you don't know when you are not sure, ask for clarification.
 | |
| Avoid mentioning that you obtained the information from the context.
 | |
| And answer according to the language of the user's question.
 | |
| 
 | |
| Given the context information, answer the query.
 | |
| Query: [query]"""
 | |
| 
 | |
| RAG_TEMPLATE = PersistentConfig(
 | |
|     "RAG_TEMPLATE",
 | |
|     "rag.template",
 | |
|     os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE),
 | |
| )
 | |
| 
 | |
| RAG_OPENAI_API_BASE_URL = PersistentConfig(
 | |
|     "RAG_OPENAI_API_BASE_URL",
 | |
|     "rag.openai_api_base_url",
 | |
|     os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
 | |
| )
 | |
| RAG_OPENAI_API_KEY = PersistentConfig(
 | |
|     "RAG_OPENAI_API_KEY",
 | |
|     "rag.openai_api_key",
 | |
|     os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY),
 | |
| )
 | |
| 
 | |
| ENABLE_RAG_LOCAL_WEB_FETCH = (
 | |
|     os.getenv("ENABLE_RAG_LOCAL_WEB_FETCH", "False").lower() == "true"
 | |
| )
 | |
| 
 | |
| YOUTUBE_LOADER_LANGUAGE = PersistentConfig(
 | |
|     "YOUTUBE_LOADER_LANGUAGE",
 | |
|     "rag.youtube_loader_language",
 | |
|     os.getenv("YOUTUBE_LOADER_LANGUAGE", "en").split(","),
 | |
| )
 | |
| 
 | |
| 
 | |
| ENABLE_RAG_WEB_SEARCH = PersistentConfig(
 | |
|     "ENABLE_RAG_WEB_SEARCH",
 | |
|     "rag.web.search.enable",
 | |
|     os.getenv("ENABLE_RAG_WEB_SEARCH", "False").lower() == "true",
 | |
| )
 | |
| 
 | |
| RAG_WEB_SEARCH_ENGINE = PersistentConfig(
 | |
|     "RAG_WEB_SEARCH_ENGINE",
 | |
|     "rag.web.search.engine",
 | |
|     os.getenv("RAG_WEB_SEARCH_ENGINE", ""),
 | |
| )
 | |
| 
 | |
| # You can provide a list of your own websites to filter after performing a web search.
 | |
| # This ensures the highest level of safety and reliability of the information sources.
 | |
| RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = PersistentConfig(
 | |
|     "RAG_WEB_SEARCH_DOMAIN_FILTER_LIST",
 | |
|     "rag.rag.web.search.domain.filter_list",
 | |
|     [
 | |
|         # "wikipedia.com",
 | |
|         # "wikimedia.org",
 | |
|         # "wikidata.org",
 | |
|     ],
 | |
| )
 | |
| 
 | |
| SEARXNG_QUERY_URL = PersistentConfig(
 | |
|     "SEARXNG_QUERY_URL",
 | |
|     "rag.web.search.searxng_query_url",
 | |
|     os.getenv("SEARXNG_QUERY_URL", ""),
 | |
| )
 | |
| 
 | |
| GOOGLE_PSE_API_KEY = PersistentConfig(
 | |
|     "GOOGLE_PSE_API_KEY",
 | |
|     "rag.web.search.google_pse_api_key",
 | |
|     os.getenv("GOOGLE_PSE_API_KEY", ""),
 | |
| )
 | |
| 
 | |
| GOOGLE_PSE_ENGINE_ID = PersistentConfig(
 | |
|     "GOOGLE_PSE_ENGINE_ID",
 | |
|     "rag.web.search.google_pse_engine_id",
 | |
|     os.getenv("GOOGLE_PSE_ENGINE_ID", ""),
 | |
| )
 | |
| 
 | |
| BRAVE_SEARCH_API_KEY = PersistentConfig(
 | |
|     "BRAVE_SEARCH_API_KEY",
 | |
|     "rag.web.search.brave_search_api_key",
 | |
|     os.getenv("BRAVE_SEARCH_API_KEY", ""),
 | |
| )
 | |
| 
 | |
| SERPSTACK_API_KEY = PersistentConfig(
 | |
|     "SERPSTACK_API_KEY",
 | |
|     "rag.web.search.serpstack_api_key",
 | |
|     os.getenv("SERPSTACK_API_KEY", ""),
 | |
| )
 | |
| 
 | |
| SERPSTACK_HTTPS = PersistentConfig(
 | |
|     "SERPSTACK_HTTPS",
 | |
|     "rag.web.search.serpstack_https",
 | |
|     os.getenv("SERPSTACK_HTTPS", "True").lower() == "true",
 | |
| )
 | |
| 
 | |
| SERPER_API_KEY = PersistentConfig(
 | |
|     "SERPER_API_KEY",
 | |
|     "rag.web.search.serper_api_key",
 | |
|     os.getenv("SERPER_API_KEY", ""),
 | |
| )
 | |
| 
 | |
| SERPLY_API_KEY = PersistentConfig(
 | |
|     "SERPLY_API_KEY",
 | |
|     "rag.web.search.serply_api_key",
 | |
|     os.getenv("SERPLY_API_KEY", ""),
 | |
| )
 | |
| 
 | |
| TAVILY_API_KEY = PersistentConfig(
 | |
|     "TAVILY_API_KEY",
 | |
|     "rag.web.search.tavily_api_key",
 | |
|     os.getenv("TAVILY_API_KEY", ""),
 | |
| )
 | |
| 
 | |
| RAG_WEB_SEARCH_RESULT_COUNT = PersistentConfig(
 | |
|     "RAG_WEB_SEARCH_RESULT_COUNT",
 | |
|     "rag.web.search.result_count",
 | |
|     int(os.getenv("RAG_WEB_SEARCH_RESULT_COUNT", "3")),
 | |
| )
 | |
| 
 | |
| RAG_WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig(
 | |
|     "RAG_WEB_SEARCH_CONCURRENT_REQUESTS",
 | |
|     "rag.web.search.concurrent_requests",
 | |
|     int(os.getenv("RAG_WEB_SEARCH_CONCURRENT_REQUESTS", "10")),
 | |
| )
 | |
| 
 | |
| 
 | |
| ####################################
 | |
| # Transcribe
 | |
| ####################################
 | |
| 
 | |
| WHISPER_MODEL = os.getenv("WHISPER_MODEL", "base")
 | |
| WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models")
 | |
| WHISPER_MODEL_AUTO_UPDATE = (
 | |
|     os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
 | |
| )
 | |
| 
 | |
| 
 | |
| ####################################
 | |
| # Images
 | |
| ####################################
 | |
| 
 | |
| IMAGE_GENERATION_ENGINE = PersistentConfig(
 | |
|     "IMAGE_GENERATION_ENGINE",
 | |
|     "image_generation.engine",
 | |
|     os.getenv("IMAGE_GENERATION_ENGINE", ""),
 | |
| )
 | |
| 
 | |
| ENABLE_IMAGE_GENERATION = PersistentConfig(
 | |
|     "ENABLE_IMAGE_GENERATION",
 | |
|     "image_generation.enable",
 | |
|     os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true",
 | |
| )
 | |
| AUTOMATIC1111_BASE_URL = PersistentConfig(
 | |
|     "AUTOMATIC1111_BASE_URL",
 | |
|     "image_generation.automatic1111.base_url",
 | |
|     os.getenv("AUTOMATIC1111_BASE_URL", ""),
 | |
| )
 | |
| AUTOMATIC1111_API_AUTH = PersistentConfig(
 | |
|     "AUTOMATIC1111_API_AUTH",
 | |
|     "image_generation.automatic1111.api_auth",
 | |
|     os.getenv("AUTOMATIC1111_API_AUTH", ""),
 | |
| )
 | |
| 
 | |
| COMFYUI_BASE_URL = PersistentConfig(
 | |
|     "COMFYUI_BASE_URL",
 | |
|     "image_generation.comfyui.base_url",
 | |
|     os.getenv("COMFYUI_BASE_URL", ""),
 | |
| )
 | |
| 
 | |
| COMFYUI_CFG_SCALE = PersistentConfig(
 | |
|     "COMFYUI_CFG_SCALE",
 | |
|     "image_generation.comfyui.cfg_scale",
 | |
|     os.getenv("COMFYUI_CFG_SCALE", ""),
 | |
| )
 | |
| 
 | |
| COMFYUI_SAMPLER = PersistentConfig(
 | |
|     "COMFYUI_SAMPLER",
 | |
|     "image_generation.comfyui.sampler",
 | |
|     os.getenv("COMFYUI_SAMPLER", ""),
 | |
| )
 | |
| 
 | |
| COMFYUI_SCHEDULER = PersistentConfig(
 | |
|     "COMFYUI_SCHEDULER",
 | |
|     "image_generation.comfyui.scheduler",
 | |
|     os.getenv("COMFYUI_SCHEDULER", ""),
 | |
| )
 | |
| 
 | |
| COMFYUI_SD3 = PersistentConfig(
 | |
|     "COMFYUI_SD3",
 | |
|     "image_generation.comfyui.sd3",
 | |
|     os.environ.get("COMFYUI_SD3", "").lower() == "true",
 | |
| )
 | |
| 
 | |
| COMFYUI_FLUX = PersistentConfig(
 | |
|     "COMFYUI_FLUX",
 | |
|     "image_generation.comfyui.flux",
 | |
|     os.environ.get("COMFYUI_FLUX", "").lower() == "true",
 | |
| )
 | |
| 
 | |
| COMFYUI_FLUX_WEIGHT_DTYPE = PersistentConfig(
 | |
|     "COMFYUI_FLUX_WEIGHT_DTYPE",
 | |
|     "image_generation.comfyui.flux_weight_dtype",
 | |
|     os.getenv("COMFYUI_FLUX_WEIGHT_DTYPE", ""),
 | |
| )
 | |
| 
 | |
| COMFYUI_FLUX_FP8_CLIP = PersistentConfig(
 | |
|     "COMFYUI_FLUX_FP8_CLIP",
 | |
|     "image_generation.comfyui.flux_fp8_clip",
 | |
|     os.getenv("COMFYUI_FLUX_FP8_CLIP", ""),
 | |
| )
 | |
| 
 | |
| IMAGES_OPENAI_API_BASE_URL = PersistentConfig(
 | |
|     "IMAGES_OPENAI_API_BASE_URL",
 | |
|     "image_generation.openai.api_base_url",
 | |
|     os.getenv("IMAGES_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
 | |
| )
 | |
| IMAGES_OPENAI_API_KEY = PersistentConfig(
 | |
|     "IMAGES_OPENAI_API_KEY",
 | |
|     "image_generation.openai.api_key",
 | |
|     os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY),
 | |
| )
 | |
| 
 | |
| IMAGE_SIZE = PersistentConfig(
 | |
|     "IMAGE_SIZE", "image_generation.size", os.getenv("IMAGE_SIZE", "512x512")
 | |
| )
 | |
| 
 | |
| IMAGE_STEPS = PersistentConfig(
 | |
|     "IMAGE_STEPS", "image_generation.steps", int(os.getenv("IMAGE_STEPS", 50))
 | |
| )
 | |
| 
 | |
| IMAGE_GENERATION_MODEL = PersistentConfig(
 | |
|     "IMAGE_GENERATION_MODEL",
 | |
|     "image_generation.model",
 | |
|     os.getenv("IMAGE_GENERATION_MODEL", ""),
 | |
| )
 | |
| 
 | |
| ####################################
 | |
| # Audio
 | |
| ####################################
 | |
| 
 | |
| AUDIO_STT_OPENAI_API_BASE_URL = PersistentConfig(
 | |
|     "AUDIO_STT_OPENAI_API_BASE_URL",
 | |
|     "audio.stt.openai.api_base_url",
 | |
|     os.getenv("AUDIO_STT_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
 | |
| )
 | |
| 
 | |
| AUDIO_STT_OPENAI_API_KEY = PersistentConfig(
 | |
|     "AUDIO_STT_OPENAI_API_KEY",
 | |
|     "audio.stt.openai.api_key",
 | |
|     os.getenv("AUDIO_STT_OPENAI_API_KEY", OPENAI_API_KEY),
 | |
| )
 | |
| 
 | |
| AUDIO_STT_ENGINE = PersistentConfig(
 | |
|     "AUDIO_STT_ENGINE",
 | |
|     "audio.stt.engine",
 | |
|     os.getenv("AUDIO_STT_ENGINE", ""),
 | |
| )
 | |
| 
 | |
| AUDIO_STT_MODEL = PersistentConfig(
 | |
|     "AUDIO_STT_MODEL",
 | |
|     "audio.stt.model",
 | |
|     os.getenv("AUDIO_STT_MODEL", "whisper-1"),
 | |
| )
 | |
| 
 | |
| AUDIO_TTS_OPENAI_API_BASE_URL = PersistentConfig(
 | |
|     "AUDIO_TTS_OPENAI_API_BASE_URL",
 | |
|     "audio.tts.openai.api_base_url",
 | |
|     os.getenv("AUDIO_TTS_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
 | |
| )
 | |
| AUDIO_TTS_OPENAI_API_KEY = PersistentConfig(
 | |
|     "AUDIO_TTS_OPENAI_API_KEY",
 | |
|     "audio.tts.openai.api_key",
 | |
|     os.getenv("AUDIO_TTS_OPENAI_API_KEY", OPENAI_API_KEY),
 | |
| )
 | |
| 
 | |
| AUDIO_TTS_API_KEY = PersistentConfig(
 | |
|     "AUDIO_TTS_API_KEY",
 | |
|     "audio.tts.api_key",
 | |
|     os.getenv("AUDIO_TTS_API_KEY", ""),
 | |
| )
 | |
| 
 | |
| AUDIO_TTS_ENGINE = PersistentConfig(
 | |
|     "AUDIO_TTS_ENGINE",
 | |
|     "audio.tts.engine",
 | |
|     os.getenv("AUDIO_TTS_ENGINE", ""),
 | |
| )
 | |
| 
 | |
| 
 | |
| AUDIO_TTS_MODEL = PersistentConfig(
 | |
|     "AUDIO_TTS_MODEL",
 | |
|     "audio.tts.model",
 | |
|     os.getenv("AUDIO_TTS_MODEL", "tts-1"),
 | |
| )
 | |
| 
 | |
| AUDIO_TTS_VOICE = PersistentConfig(
 | |
|     "AUDIO_TTS_VOICE",
 | |
|     "audio.tts.voice",
 | |
|     os.getenv("AUDIO_TTS_VOICE", "alloy"),
 | |
| )
 | |
| 
 | |
| 
 | |
| ####################################
 | |
| # Database
 | |
| ####################################
 | |
| 
 | |
| DATABASE_URL = os.environ.get("DATABASE_URL", f"sqlite:///{DATA_DIR}/webui.db")
 | |
| 
 | |
| # Replace the postgres:// with postgresql://
 | |
| if "postgres://" in DATABASE_URL:
 | |
|     DATABASE_URL = DATABASE_URL.replace("postgres://", "postgresql://")
 |