1518 lines
		
	
	
		
			42 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			1518 lines
		
	
	
		
			42 KiB
		
	
	
	
		
			Python
		
	
	
	
| import json
 | |
| import logging
 | |
| import os
 | |
| import shutil
 | |
| from datetime import datetime
 | |
| from pathlib import Path
 | |
| from typing import Generic, Optional, TypeVar
 | |
| from urllib.parse import urlparse
 | |
| 
 | |
| import chromadb
 | |
| import requests
 | |
| import yaml
 | |
| from open_webui.apps.webui.internal.db import Base, get_db
 | |
| from open_webui.env import (
 | |
|     OPEN_WEBUI_DIR,
 | |
|     DATA_DIR,
 | |
|     ENV,
 | |
|     FRONTEND_BUILD_DIR,
 | |
|     WEBUI_AUTH,
 | |
|     WEBUI_FAVICON_URL,
 | |
|     WEBUI_NAME,
 | |
|     log,
 | |
| )
 | |
| from pydantic import BaseModel
 | |
| from sqlalchemy import JSON, Column, DateTime, Integer, func
 | |
| 
 | |
| 
 | |
| class EndpointFilter(logging.Filter):
 | |
|     def filter(self, record: logging.LogRecord) -> bool:
 | |
|         return record.getMessage().find("/health") == -1
 | |
| 
 | |
| 
 | |
| # Filter out /endpoint
 | |
| logging.getLogger("uvicorn.access").addFilter(EndpointFilter())
 | |
| 
 | |
| ####################################
 | |
| # Config helpers
 | |
| ####################################
 | |
| 
 | |
| 
 | |
| # Function to run the alembic migrations
 | |
| def run_migrations():
 | |
|     print("Running migrations")
 | |
|     try:
 | |
|         from alembic import command
 | |
|         from alembic.config import Config
 | |
| 
 | |
|         alembic_cfg = Config(OPEN_WEBUI_DIR / "alembic.ini")
 | |
| 
 | |
|         # Set the script location dynamically
 | |
|         migrations_path = OPEN_WEBUI_DIR / "migrations"
 | |
|         alembic_cfg.set_main_option("script_location", str(migrations_path))
 | |
| 
 | |
|         command.upgrade(alembic_cfg, "head")
 | |
|     except Exception as e:
 | |
|         print(f"Error: {e}")
 | |
| 
 | |
| 
 | |
| run_migrations()
 | |
| 
 | |
| 
 | |
| class Config(Base):
 | |
|     __tablename__ = "config"
 | |
| 
 | |
|     id = Column(Integer, primary_key=True)
 | |
|     data = Column(JSON, nullable=False)
 | |
|     version = Column(Integer, nullable=False, default=0)
 | |
|     created_at = Column(DateTime, nullable=False, server_default=func.now())
 | |
|     updated_at = Column(DateTime, nullable=True, onupdate=func.now())
 | |
| 
 | |
| 
 | |
| def load_json_config():
 | |
|     with open(f"{DATA_DIR}/config.json", "r") as file:
 | |
|         return json.load(file)
 | |
| 
 | |
| 
 | |
| def save_to_db(data):
 | |
|     with get_db() as db:
 | |
|         existing_config = db.query(Config).first()
 | |
|         if not existing_config:
 | |
|             new_config = Config(data=data, version=0)
 | |
|             db.add(new_config)
 | |
|         else:
 | |
|             existing_config.data = data
 | |
|             existing_config.updated_at = datetime.now()
 | |
|             db.add(existing_config)
 | |
|         db.commit()
 | |
| 
 | |
| 
 | |
| def reset_config():
 | |
|     with get_db() as db:
 | |
|         db.query(Config).delete()
 | |
|         db.commit()
 | |
| 
 | |
| 
 | |
| # When initializing, check if config.json exists and migrate it to the database
 | |
| if os.path.exists(f"{DATA_DIR}/config.json"):
 | |
|     data = load_json_config()
 | |
|     save_to_db(data)
 | |
|     os.rename(f"{DATA_DIR}/config.json", f"{DATA_DIR}/old_config.json")
 | |
| 
 | |
| DEFAULT_CONFIG = {
 | |
|     "version": 0,
 | |
|     "ui": {
 | |
|         "default_locale": "",
 | |
|         "prompt_suggestions": [
 | |
|             {
 | |
|                 "title": [
 | |
|                     "Help me study",
 | |
|                     "vocabulary for a college entrance exam",
 | |
|                 ],
 | |
|                 "content": "Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option.",
 | |
|             },
 | |
|             {
 | |
|                 "title": [
 | |
|                     "Give me ideas",
 | |
|                     "for what to do with my kids' art",
 | |
|                 ],
 | |
|                 "content": "What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter.",
 | |
|             },
 | |
|             {
 | |
|                 "title": ["Tell me a fun fact", "about the Roman Empire"],
 | |
|                 "content": "Tell me a random fun fact about the Roman Empire",
 | |
|             },
 | |
|             {
 | |
|                 "title": [
 | |
|                     "Show me a code snippet",
 | |
|                     "of a website's sticky header",
 | |
|                 ],
 | |
|                 "content": "Show me a code snippet of a website's sticky header in CSS and JavaScript.",
 | |
|             },
 | |
|             {
 | |
|                 "title": [
 | |
|                     "Explain options trading",
 | |
|                     "if I'm familiar with buying and selling stocks",
 | |
|                 ],
 | |
|                 "content": "Explain options trading in simple terms if I'm familiar with buying and selling stocks.",
 | |
|             },
 | |
|             {
 | |
|                 "title": ["Overcome procrastination", "give me tips"],
 | |
|                 "content": "Could you start by asking me about instances when I procrastinate the most and then give me some suggestions to overcome it?",
 | |
|             },
 | |
|             {
 | |
|                 "title": [
 | |
|                     "Grammar check",
 | |
|                     "rewrite it for better readability ",
 | |
|                 ],
 | |
|                 "content": 'Check the following sentence for grammar and clarity: "[sentence]". Rewrite it for better readability while maintaining its original meaning.',
 | |
|             },
 | |
|         ],
 | |
|     },
 | |
| }
 | |
| 
 | |
| 
 | |
| def get_config():
 | |
|     with get_db() as db:
 | |
|         config_entry = db.query(Config).order_by(Config.id.desc()).first()
 | |
|         return config_entry.data if config_entry else DEFAULT_CONFIG
 | |
| 
 | |
| 
 | |
| CONFIG_DATA = get_config()
 | |
| 
 | |
| 
 | |
| def get_config_value(config_path: str):
 | |
|     path_parts = config_path.split(".")
 | |
|     cur_config = CONFIG_DATA
 | |
|     for key in path_parts:
 | |
|         if key in cur_config:
 | |
|             cur_config = cur_config[key]
 | |
|         else:
 | |
|             return None
 | |
|     return cur_config
 | |
| 
 | |
| 
 | |
| PERSISTENT_CONFIG_REGISTRY = []
 | |
| 
 | |
| 
 | |
| def save_config(config):
 | |
|     global CONFIG_DATA
 | |
|     global PERSISTENT_CONFIG_REGISTRY
 | |
|     try:
 | |
|         save_to_db(config)
 | |
|         CONFIG_DATA = config
 | |
| 
 | |
|         # Trigger updates on all registered PersistentConfig entries
 | |
|         for config_item in PERSISTENT_CONFIG_REGISTRY:
 | |
|             config_item.update()
 | |
|     except Exception as e:
 | |
|         log.exception(e)
 | |
|         return False
 | |
|     return True
 | |
| 
 | |
| 
 | |
| T = TypeVar("T")
 | |
| 
 | |
| 
 | |
| class PersistentConfig(Generic[T]):
 | |
|     def __init__(self, env_name: str, config_path: str, env_value: T):
 | |
|         self.env_name = env_name
 | |
|         self.config_path = config_path
 | |
|         self.env_value = env_value
 | |
|         self.config_value = get_config_value(config_path)
 | |
|         if self.config_value is not None:
 | |
|             log.info(f"'{env_name}' loaded from the latest database entry")
 | |
|             self.value = self.config_value
 | |
|         else:
 | |
|             self.value = env_value
 | |
| 
 | |
|         PERSISTENT_CONFIG_REGISTRY.append(self)
 | |
| 
 | |
|     def __str__(self):
 | |
|         return str(self.value)
 | |
| 
 | |
|     @property
 | |
|     def __dict__(self):
 | |
|         raise TypeError(
 | |
|             "PersistentConfig object cannot be converted to dict, use config_get or .value instead."
 | |
|         )
 | |
| 
 | |
|     def __getattribute__(self, item):
 | |
|         if item == "__dict__":
 | |
|             raise TypeError(
 | |
|                 "PersistentConfig object cannot be converted to dict, use config_get or .value instead."
 | |
|             )
 | |
|         return super().__getattribute__(item)
 | |
| 
 | |
|     def update(self):
 | |
|         new_value = get_config_value(self.config_path)
 | |
|         if new_value is not None:
 | |
|             self.value = new_value
 | |
|             log.info(f"Updated {self.env_name} to new value {self.value}")
 | |
| 
 | |
|     def save(self):
 | |
|         log.info(f"Saving '{self.env_name}' to the database")
 | |
|         path_parts = self.config_path.split(".")
 | |
|         sub_config = CONFIG_DATA
 | |
|         for key in path_parts[:-1]:
 | |
|             if key not in sub_config:
 | |
|                 sub_config[key] = {}
 | |
|             sub_config = sub_config[key]
 | |
|         sub_config[path_parts[-1]] = self.value
 | |
|         save_to_db(CONFIG_DATA)
 | |
|         self.config_value = self.value
 | |
| 
 | |
| 
 | |
| class AppConfig:
 | |
|     _state: dict[str, PersistentConfig]
 | |
| 
 | |
|     def __init__(self):
 | |
|         super().__setattr__("_state", {})
 | |
| 
 | |
|     def __setattr__(self, key, value):
 | |
|         if isinstance(value, PersistentConfig):
 | |
|             self._state[key] = value
 | |
|         else:
 | |
|             self._state[key].value = value
 | |
|             self._state[key].save()
 | |
| 
 | |
|     def __getattr__(self, key):
 | |
|         return self._state[key].value
 | |
| 
 | |
| 
 | |
| ####################################
 | |
| # WEBUI_AUTH (Required for security)
 | |
| ####################################
 | |
| 
 | |
| JWT_EXPIRES_IN = PersistentConfig(
 | |
|     "JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1")
 | |
| )
 | |
| 
 | |
| ####################################
 | |
| # OAuth config
 | |
| ####################################
 | |
| 
 | |
| ENABLE_OAUTH_SIGNUP = PersistentConfig(
 | |
|     "ENABLE_OAUTH_SIGNUP",
 | |
|     "oauth.enable_signup",
 | |
|     os.environ.get("ENABLE_OAUTH_SIGNUP", "False").lower() == "true",
 | |
| )
 | |
| 
 | |
| OAUTH_MERGE_ACCOUNTS_BY_EMAIL = PersistentConfig(
 | |
|     "OAUTH_MERGE_ACCOUNTS_BY_EMAIL",
 | |
|     "oauth.merge_accounts_by_email",
 | |
|     os.environ.get("OAUTH_MERGE_ACCOUNTS_BY_EMAIL", "False").lower() == "true",
 | |
| )
 | |
| 
 | |
| OAUTH_PROVIDERS = {}
 | |
| 
 | |
| GOOGLE_CLIENT_ID = PersistentConfig(
 | |
|     "GOOGLE_CLIENT_ID",
 | |
|     "oauth.google.client_id",
 | |
|     os.environ.get("GOOGLE_CLIENT_ID", ""),
 | |
| )
 | |
| 
 | |
| GOOGLE_CLIENT_SECRET = PersistentConfig(
 | |
|     "GOOGLE_CLIENT_SECRET",
 | |
|     "oauth.google.client_secret",
 | |
|     os.environ.get("GOOGLE_CLIENT_SECRET", ""),
 | |
| )
 | |
| 
 | |
| GOOGLE_OAUTH_SCOPE = PersistentConfig(
 | |
|     "GOOGLE_OAUTH_SCOPE",
 | |
|     "oauth.google.scope",
 | |
|     os.environ.get("GOOGLE_OAUTH_SCOPE", "openid email profile"),
 | |
| )
 | |
| 
 | |
| GOOGLE_REDIRECT_URI = PersistentConfig(
 | |
|     "GOOGLE_REDIRECT_URI",
 | |
|     "oauth.google.redirect_uri",
 | |
|     os.environ.get("GOOGLE_REDIRECT_URI", ""),
 | |
| )
 | |
| 
 | |
| MICROSOFT_CLIENT_ID = PersistentConfig(
 | |
|     "MICROSOFT_CLIENT_ID",
 | |
|     "oauth.microsoft.client_id",
 | |
|     os.environ.get("MICROSOFT_CLIENT_ID", ""),
 | |
| )
 | |
| 
 | |
| MICROSOFT_CLIENT_SECRET = PersistentConfig(
 | |
|     "MICROSOFT_CLIENT_SECRET",
 | |
|     "oauth.microsoft.client_secret",
 | |
|     os.environ.get("MICROSOFT_CLIENT_SECRET", ""),
 | |
| )
 | |
| 
 | |
| MICROSOFT_CLIENT_TENANT_ID = PersistentConfig(
 | |
|     "MICROSOFT_CLIENT_TENANT_ID",
 | |
|     "oauth.microsoft.tenant_id",
 | |
|     os.environ.get("MICROSOFT_CLIENT_TENANT_ID", ""),
 | |
| )
 | |
| 
 | |
| MICROSOFT_OAUTH_SCOPE = PersistentConfig(
 | |
|     "MICROSOFT_OAUTH_SCOPE",
 | |
|     "oauth.microsoft.scope",
 | |
|     os.environ.get("MICROSOFT_OAUTH_SCOPE", "openid email profile"),
 | |
| )
 | |
| 
 | |
| MICROSOFT_REDIRECT_URI = PersistentConfig(
 | |
|     "MICROSOFT_REDIRECT_URI",
 | |
|     "oauth.microsoft.redirect_uri",
 | |
|     os.environ.get("MICROSOFT_REDIRECT_URI", ""),
 | |
| )
 | |
| 
 | |
| OAUTH_CLIENT_ID = PersistentConfig(
 | |
|     "OAUTH_CLIENT_ID",
 | |
|     "oauth.oidc.client_id",
 | |
|     os.environ.get("OAUTH_CLIENT_ID", ""),
 | |
| )
 | |
| 
 | |
| OAUTH_CLIENT_SECRET = PersistentConfig(
 | |
|     "OAUTH_CLIENT_SECRET",
 | |
|     "oauth.oidc.client_secret",
 | |
|     os.environ.get("OAUTH_CLIENT_SECRET", ""),
 | |
| )
 | |
| 
 | |
| OPENID_PROVIDER_URL = PersistentConfig(
 | |
|     "OPENID_PROVIDER_URL",
 | |
|     "oauth.oidc.provider_url",
 | |
|     os.environ.get("OPENID_PROVIDER_URL", ""),
 | |
| )
 | |
| 
 | |
| OPENID_REDIRECT_URI = PersistentConfig(
 | |
|     "OPENID_REDIRECT_URI",
 | |
|     "oauth.oidc.redirect_uri",
 | |
|     os.environ.get("OPENID_REDIRECT_URI", ""),
 | |
| )
 | |
| 
 | |
| OAUTH_SCOPES = PersistentConfig(
 | |
|     "OAUTH_SCOPES",
 | |
|     "oauth.oidc.scopes",
 | |
|     os.environ.get("OAUTH_SCOPES", "openid email profile"),
 | |
| )
 | |
| 
 | |
| OAUTH_PROVIDER_NAME = PersistentConfig(
 | |
|     "OAUTH_PROVIDER_NAME",
 | |
|     "oauth.oidc.provider_name",
 | |
|     os.environ.get("OAUTH_PROVIDER_NAME", "SSO"),
 | |
| )
 | |
| 
 | |
| OAUTH_USERNAME_CLAIM = PersistentConfig(
 | |
|     "OAUTH_USERNAME_CLAIM",
 | |
|     "oauth.oidc.username_claim",
 | |
|     os.environ.get("OAUTH_USERNAME_CLAIM", "name"),
 | |
| )
 | |
| 
 | |
| OAUTH_PICTURE_CLAIM = PersistentConfig(
 | |
|     "OAUTH_PICTURE_CLAIM",
 | |
|     "oauth.oidc.avatar_claim",
 | |
|     os.environ.get("OAUTH_PICTURE_CLAIM", "picture"),
 | |
| )
 | |
| 
 | |
| OAUTH_EMAIL_CLAIM = PersistentConfig(
 | |
|     "OAUTH_EMAIL_CLAIM",
 | |
|     "oauth.oidc.email_claim",
 | |
|     os.environ.get("OAUTH_EMAIL_CLAIM", "email"),
 | |
| )
 | |
| 
 | |
| ENABLE_OAUTH_ROLE_MANAGEMENT = PersistentConfig(
 | |
|     "ENABLE_OAUTH_ROLE_MANAGEMENT",
 | |
|     "oauth.enable_role_mapping",
 | |
|     os.environ.get("ENABLE_OAUTH_ROLE_MANAGEMENT", "False").lower() == "true",
 | |
| )
 | |
| 
 | |
| OAUTH_ROLES_CLAIM = PersistentConfig(
 | |
|     "OAUTH_ROLES_CLAIM",
 | |
|     "oauth.roles_claim",
 | |
|     os.environ.get("OAUTH_ROLES_CLAIM", "roles"),
 | |
| )
 | |
| 
 | |
| OAUTH_ALLOWED_ROLES = PersistentConfig(
 | |
|     "OAUTH_ALLOWED_ROLES",
 | |
|     "oauth.allowed_roles",
 | |
|     [role.strip() for role in os.environ.get("OAUTH_ALLOWED_ROLES", "pending,user,admin").split(",")],
 | |
| )
 | |
| 
 | |
| OAUTH_ADMIN_ROLES = PersistentConfig(
 | |
|     "OAUTH_ADMIN_ROLES",
 | |
|     "oauth.admin_roles",
 | |
|     [role.strip() for role in os.environ.get("OAUTH_ADMIN_ROLES", "admin").split(",")],
 | |
| )
 | |
| 
 | |
| def load_oauth_providers():
 | |
|     OAUTH_PROVIDERS.clear()
 | |
|     if GOOGLE_CLIENT_ID.value and GOOGLE_CLIENT_SECRET.value:
 | |
|         OAUTH_PROVIDERS["google"] = {
 | |
|             "client_id": GOOGLE_CLIENT_ID.value,
 | |
|             "client_secret": GOOGLE_CLIENT_SECRET.value,
 | |
|             "server_metadata_url": "https://accounts.google.com/.well-known/openid-configuration",
 | |
|             "scope": GOOGLE_OAUTH_SCOPE.value,
 | |
|             "redirect_uri": GOOGLE_REDIRECT_URI.value,
 | |
|         }
 | |
| 
 | |
|     if (
 | |
|         MICROSOFT_CLIENT_ID.value
 | |
|         and MICROSOFT_CLIENT_SECRET.value
 | |
|         and MICROSOFT_CLIENT_TENANT_ID.value
 | |
|     ):
 | |
|         OAUTH_PROVIDERS["microsoft"] = {
 | |
|             "client_id": MICROSOFT_CLIENT_ID.value,
 | |
|             "client_secret": MICROSOFT_CLIENT_SECRET.value,
 | |
|             "server_metadata_url": f"https://login.microsoftonline.com/{MICROSOFT_CLIENT_TENANT_ID.value}/v2.0/.well-known/openid-configuration",
 | |
|             "scope": MICROSOFT_OAUTH_SCOPE.value,
 | |
|             "redirect_uri": MICROSOFT_REDIRECT_URI.value,
 | |
|         }
 | |
| 
 | |
|     if (
 | |
|         OAUTH_CLIENT_ID.value
 | |
|         and OAUTH_CLIENT_SECRET.value
 | |
|         and OPENID_PROVIDER_URL.value
 | |
|     ):
 | |
|         OAUTH_PROVIDERS["oidc"] = {
 | |
|             "client_id": OAUTH_CLIENT_ID.value,
 | |
|             "client_secret": OAUTH_CLIENT_SECRET.value,
 | |
|             "server_metadata_url": OPENID_PROVIDER_URL.value,
 | |
|             "scope": OAUTH_SCOPES.value,
 | |
|             "name": OAUTH_PROVIDER_NAME.value,
 | |
|             "redirect_uri": OPENID_REDIRECT_URI.value,
 | |
|         }
 | |
| 
 | |
| 
 | |
| load_oauth_providers()
 | |
| 
 | |
| ####################################
 | |
| # Static DIR
 | |
| ####################################
 | |
| 
 | |
| STATIC_DIR = Path(os.getenv("STATIC_DIR", OPEN_WEBUI_DIR / "static")).resolve()
 | |
| 
 | |
| frontend_favicon = FRONTEND_BUILD_DIR / "static" / "favicon.png"
 | |
| 
 | |
| if frontend_favicon.exists():
 | |
|     try:
 | |
|         shutil.copyfile(frontend_favicon, STATIC_DIR / "favicon.png")
 | |
|     except Exception as e:
 | |
|         logging.error(f"An error occurred: {e}")
 | |
| else:
 | |
|     logging.warning(f"Frontend favicon not found at {frontend_favicon}")
 | |
| 
 | |
| frontend_splash = FRONTEND_BUILD_DIR / "static" / "splash.png"
 | |
| 
 | |
| if frontend_splash.exists():
 | |
|     try:
 | |
|         shutil.copyfile(frontend_splash, STATIC_DIR / "splash.png")
 | |
|     except Exception as e:
 | |
|         logging.error(f"An error occurred: {e}")
 | |
| else:
 | |
|     logging.warning(f"Frontend splash not found at {frontend_splash}")
 | |
| 
 | |
| 
 | |
| ####################################
 | |
| # CUSTOM_NAME
 | |
| ####################################
 | |
| 
 | |
| CUSTOM_NAME = os.environ.get("CUSTOM_NAME", "")
 | |
| 
 | |
| if CUSTOM_NAME:
 | |
|     try:
 | |
|         r = requests.get(f"https://api.openwebui.com/api/v1/custom/{CUSTOM_NAME}")
 | |
|         data = r.json()
 | |
|         if r.ok:
 | |
|             if "logo" in data:
 | |
|                 WEBUI_FAVICON_URL = url = (
 | |
|                     f"https://api.openwebui.com{data['logo']}"
 | |
|                     if data["logo"][0] == "/"
 | |
|                     else data["logo"]
 | |
|                 )
 | |
| 
 | |
|                 r = requests.get(url, stream=True)
 | |
|                 if r.status_code == 200:
 | |
|                     with open(f"{STATIC_DIR}/favicon.png", "wb") as f:
 | |
|                         r.raw.decode_content = True
 | |
|                         shutil.copyfileobj(r.raw, f)
 | |
| 
 | |
|             if "splash" in data:
 | |
|                 url = (
 | |
|                     f"https://api.openwebui.com{data['splash']}"
 | |
|                     if data["splash"][0] == "/"
 | |
|                     else data["splash"]
 | |
|                 )
 | |
| 
 | |
|                 r = requests.get(url, stream=True)
 | |
|                 if r.status_code == 200:
 | |
|                     with open(f"{STATIC_DIR}/splash.png", "wb") as f:
 | |
|                         r.raw.decode_content = True
 | |
|                         shutil.copyfileobj(r.raw, f)
 | |
| 
 | |
|             WEBUI_NAME = data["name"]
 | |
|     except Exception as e:
 | |
|         log.exception(e)
 | |
|         pass
 | |
| 
 | |
| 
 | |
| ####################################
 | |
| # File Upload DIR
 | |
| ####################################
 | |
| 
 | |
| UPLOAD_DIR = f"{DATA_DIR}/uploads"
 | |
| Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True)
 | |
| 
 | |
| 
 | |
| ####################################
 | |
| # Cache DIR
 | |
| ####################################
 | |
| 
 | |
| CACHE_DIR = f"{DATA_DIR}/cache"
 | |
| Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
 | |
| 
 | |
| 
 | |
| ####################################
 | |
| # Docs DIR
 | |
| ####################################
 | |
| 
 | |
| DOCS_DIR = os.getenv("DOCS_DIR", f"{DATA_DIR}/docs")
 | |
| Path(DOCS_DIR).mkdir(parents=True, exist_ok=True)
 | |
| 
 | |
| 
 | |
| ####################################
 | |
| # Tools DIR
 | |
| ####################################
 | |
| 
 | |
| TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools")
 | |
| Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True)
 | |
| 
 | |
| 
 | |
| ####################################
 | |
| # Functions DIR
 | |
| ####################################
 | |
| 
 | |
| FUNCTIONS_DIR = os.getenv("FUNCTIONS_DIR", f"{DATA_DIR}/functions")
 | |
| Path(FUNCTIONS_DIR).mkdir(parents=True, exist_ok=True)
 | |
| 
 | |
| ####################################
 | |
| # OLLAMA_BASE_URL
 | |
| ####################################
 | |
| 
 | |
| 
 | |
| ENABLE_OLLAMA_API = PersistentConfig(
 | |
|     "ENABLE_OLLAMA_API",
 | |
|     "ollama.enable",
 | |
|     os.environ.get("ENABLE_OLLAMA_API", "True").lower() == "true",
 | |
| )
 | |
| 
 | |
| OLLAMA_API_BASE_URL = os.environ.get(
 | |
|     "OLLAMA_API_BASE_URL", "http://localhost:11434/api"
 | |
| )
 | |
| 
 | |
| OLLAMA_BASE_URL = os.environ.get("OLLAMA_BASE_URL", "")
 | |
| AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "")
 | |
| 
 | |
| if AIOHTTP_CLIENT_TIMEOUT == "":
 | |
|     AIOHTTP_CLIENT_TIMEOUT = None
 | |
| else:
 | |
|     try:
 | |
|         AIOHTTP_CLIENT_TIMEOUT = int(AIOHTTP_CLIENT_TIMEOUT)
 | |
|     except Exception:
 | |
|         AIOHTTP_CLIENT_TIMEOUT = 300
 | |
| 
 | |
| 
 | |
| K8S_FLAG = os.environ.get("K8S_FLAG", "")
 | |
| USE_OLLAMA_DOCKER = os.environ.get("USE_OLLAMA_DOCKER", "false")
 | |
| 
 | |
| if OLLAMA_BASE_URL == "" and OLLAMA_API_BASE_URL != "":
 | |
|     OLLAMA_BASE_URL = (
 | |
|         OLLAMA_API_BASE_URL[:-4]
 | |
|         if OLLAMA_API_BASE_URL.endswith("/api")
 | |
|         else OLLAMA_API_BASE_URL
 | |
|     )
 | |
| 
 | |
| if ENV == "prod":
 | |
|     if OLLAMA_BASE_URL == "/ollama" and not K8S_FLAG:
 | |
|         if USE_OLLAMA_DOCKER.lower() == "true":
 | |
|             # if you use all-in-one docker container (Open WebUI + Ollama)
 | |
|             # with the docker build arg USE_OLLAMA=true (--build-arg="USE_OLLAMA=true") this only works with http://localhost:11434
 | |
|             OLLAMA_BASE_URL = "http://localhost:11434"
 | |
|         else:
 | |
|             OLLAMA_BASE_URL = "http://host.docker.internal:11434"
 | |
|     elif K8S_FLAG:
 | |
|         OLLAMA_BASE_URL = "http://ollama-service.open-webui.svc.cluster.local:11434"
 | |
| 
 | |
| 
 | |
| OLLAMA_BASE_URLS = os.environ.get("OLLAMA_BASE_URLS", "")
 | |
| OLLAMA_BASE_URLS = OLLAMA_BASE_URLS if OLLAMA_BASE_URLS != "" else OLLAMA_BASE_URL
 | |
| 
 | |
| OLLAMA_BASE_URLS = [url.strip() for url in OLLAMA_BASE_URLS.split(";")]
 | |
| OLLAMA_BASE_URLS = PersistentConfig(
 | |
|     "OLLAMA_BASE_URLS", "ollama.base_urls", OLLAMA_BASE_URLS
 | |
| )
 | |
| 
 | |
| ####################################
 | |
| # OPENAI_API
 | |
| ####################################
 | |
| 
 | |
| 
 | |
| ENABLE_OPENAI_API = PersistentConfig(
 | |
|     "ENABLE_OPENAI_API",
 | |
|     "openai.enable",
 | |
|     os.environ.get("ENABLE_OPENAI_API", "True").lower() == "true",
 | |
| )
 | |
| 
 | |
| 
 | |
| OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
 | |
| OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "")
 | |
| 
 | |
| 
 | |
| if OPENAI_API_BASE_URL == "":
 | |
|     OPENAI_API_BASE_URL = "https://api.openai.com/v1"
 | |
| 
 | |
| OPENAI_API_KEYS = os.environ.get("OPENAI_API_KEYS", "")
 | |
| OPENAI_API_KEYS = OPENAI_API_KEYS if OPENAI_API_KEYS != "" else OPENAI_API_KEY
 | |
| 
 | |
| OPENAI_API_KEYS = [url.strip() for url in OPENAI_API_KEYS.split(";")]
 | |
| OPENAI_API_KEYS = PersistentConfig(
 | |
|     "OPENAI_API_KEYS", "openai.api_keys", OPENAI_API_KEYS
 | |
| )
 | |
| 
 | |
| OPENAI_API_BASE_URLS = os.environ.get("OPENAI_API_BASE_URLS", "")
 | |
| OPENAI_API_BASE_URLS = (
 | |
|     OPENAI_API_BASE_URLS if OPENAI_API_BASE_URLS != "" else OPENAI_API_BASE_URL
 | |
| )
 | |
| 
 | |
| OPENAI_API_BASE_URLS = [
 | |
|     url.strip() if url != "" else "https://api.openai.com/v1"
 | |
|     for url in OPENAI_API_BASE_URLS.split(";")
 | |
| ]
 | |
| OPENAI_API_BASE_URLS = PersistentConfig(
 | |
|     "OPENAI_API_BASE_URLS", "openai.api_base_urls", OPENAI_API_BASE_URLS
 | |
| )
 | |
| 
 | |
| OPENAI_API_KEY = ""
 | |
| 
 | |
| try:
 | |
|     OPENAI_API_KEY = OPENAI_API_KEYS.value[
 | |
|         OPENAI_API_BASE_URLS.value.index("https://api.openai.com/v1")
 | |
|     ]
 | |
| except Exception:
 | |
|     pass
 | |
| 
 | |
| OPENAI_API_BASE_URL = "https://api.openai.com/v1"
 | |
| 
 | |
| ####################################
 | |
| # WEBUI
 | |
| ####################################
 | |
| 
 | |
| ENABLE_SIGNUP = PersistentConfig(
 | |
|     "ENABLE_SIGNUP",
 | |
|     "ui.enable_signup",
 | |
|     (
 | |
|         False
 | |
|         if not WEBUI_AUTH
 | |
|         else os.environ.get("ENABLE_SIGNUP", "True").lower() == "true"
 | |
|     ),
 | |
| )
 | |
| 
 | |
| ENABLE_LOGIN_FORM = PersistentConfig(
 | |
|     "ENABLE_LOGIN_FORM",
 | |
|     "ui.ENABLE_LOGIN_FORM",
 | |
|     os.environ.get("ENABLE_LOGIN_FORM", "True").lower() == "true",
 | |
| )
 | |
| 
 | |
| DEFAULT_LOCALE = PersistentConfig(
 | |
|     "DEFAULT_LOCALE",
 | |
|     "ui.default_locale",
 | |
|     os.environ.get("DEFAULT_LOCALE", ""),
 | |
| )
 | |
| 
 | |
| DEFAULT_MODELS = PersistentConfig(
 | |
|     "DEFAULT_MODELS", "ui.default_models", os.environ.get("DEFAULT_MODELS", None)
 | |
| )
 | |
| 
 | |
| DEFAULT_PROMPT_SUGGESTIONS = PersistentConfig(
 | |
|     "DEFAULT_PROMPT_SUGGESTIONS",
 | |
|     "ui.prompt_suggestions",
 | |
|     [
 | |
|         {
 | |
|             "title": ["Help me study", "vocabulary for a college entrance exam"],
 | |
|             "content": "Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option.",
 | |
|         },
 | |
|         {
 | |
|             "title": ["Give me ideas", "for what to do with my kids' art"],
 | |
|             "content": "What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter.",
 | |
|         },
 | |
|         {
 | |
|             "title": ["Tell me a fun fact", "about the Roman Empire"],
 | |
|             "content": "Tell me a random fun fact about the Roman Empire",
 | |
|         },
 | |
|         {
 | |
|             "title": ["Show me a code snippet", "of a website's sticky header"],
 | |
|             "content": "Show me a code snippet of a website's sticky header in CSS and JavaScript.",
 | |
|         },
 | |
|         {
 | |
|             "title": [
 | |
|                 "Explain options trading",
 | |
|                 "if I'm familiar with buying and selling stocks",
 | |
|             ],
 | |
|             "content": "Explain options trading in simple terms if I'm familiar with buying and selling stocks.",
 | |
|         },
 | |
|         {
 | |
|             "title": ["Overcome procrastination", "give me tips"],
 | |
|             "content": "Could you start by asking me about instances when I procrastinate the most and then give me some suggestions to overcome it?",
 | |
|         },
 | |
|     ],
 | |
| )
 | |
| 
 | |
| DEFAULT_USER_ROLE = PersistentConfig(
 | |
|     "DEFAULT_USER_ROLE",
 | |
|     "ui.default_user_role",
 | |
|     os.getenv("DEFAULT_USER_ROLE", "pending"),
 | |
| )
 | |
| 
 | |
| USER_PERMISSIONS_CHAT_DELETION = (
 | |
|     os.environ.get("USER_PERMISSIONS_CHAT_DELETION", "True").lower() == "true"
 | |
| )
 | |
| 
 | |
| USER_PERMISSIONS_CHAT_EDITING = (
 | |
|     os.environ.get("USER_PERMISSIONS_CHAT_EDITING", "True").lower() == "true"
 | |
| )
 | |
| 
 | |
| USER_PERMISSIONS_CHAT_TEMPORARY = (
 | |
|     os.environ.get("USER_PERMISSIONS_CHAT_TEMPORARY", "True").lower() == "true"
 | |
| )
 | |
| 
 | |
| USER_PERMISSIONS = PersistentConfig(
 | |
|     "USER_PERMISSIONS",
 | |
|     "ui.user_permissions",
 | |
|     {
 | |
|         "chat": {
 | |
|             "deletion": USER_PERMISSIONS_CHAT_DELETION,
 | |
|             "editing": USER_PERMISSIONS_CHAT_EDITING,
 | |
|             "temporary": USER_PERMISSIONS_CHAT_TEMPORARY,
 | |
|         }
 | |
|     },
 | |
| )
 | |
| 
 | |
| ENABLE_MODEL_FILTER = PersistentConfig(
 | |
|     "ENABLE_MODEL_FILTER",
 | |
|     "model_filter.enable",
 | |
|     os.environ.get("ENABLE_MODEL_FILTER", "False").lower() == "true",
 | |
| )
 | |
| MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "")
 | |
| MODEL_FILTER_LIST = PersistentConfig(
 | |
|     "MODEL_FILTER_LIST",
 | |
|     "model_filter.list",
 | |
|     [model.strip() for model in MODEL_FILTER_LIST.split(";")],
 | |
| )
 | |
| 
 | |
| WEBHOOK_URL = PersistentConfig(
 | |
|     "WEBHOOK_URL", "webhook_url", os.environ.get("WEBHOOK_URL", "")
 | |
| )
 | |
| 
 | |
| ENABLE_ADMIN_EXPORT = os.environ.get("ENABLE_ADMIN_EXPORT", "True").lower() == "true"
 | |
| 
 | |
| ENABLE_ADMIN_CHAT_ACCESS = (
 | |
|     os.environ.get("ENABLE_ADMIN_CHAT_ACCESS", "True").lower() == "true"
 | |
| )
 | |
| 
 | |
| ENABLE_COMMUNITY_SHARING = PersistentConfig(
 | |
|     "ENABLE_COMMUNITY_SHARING",
 | |
|     "ui.enable_community_sharing",
 | |
|     os.environ.get("ENABLE_COMMUNITY_SHARING", "True").lower() == "true",
 | |
| )
 | |
| 
 | |
| ENABLE_MESSAGE_RATING = PersistentConfig(
 | |
|     "ENABLE_MESSAGE_RATING",
 | |
|     "ui.enable_message_rating",
 | |
|     os.environ.get("ENABLE_MESSAGE_RATING", "True").lower() == "true",
 | |
| )
 | |
| 
 | |
| 
 | |
| def validate_cors_origins(origins):
 | |
|     for origin in origins:
 | |
|         if origin != "*":
 | |
|             validate_cors_origin(origin)
 | |
| 
 | |
| 
 | |
| def validate_cors_origin(origin):
 | |
|     parsed_url = urlparse(origin)
 | |
| 
 | |
|     # Check if the scheme is either http or https
 | |
|     if parsed_url.scheme not in ["http", "https"]:
 | |
|         raise ValueError(
 | |
|             f"Invalid scheme in CORS_ALLOW_ORIGIN: '{origin}'. Only 'http' and 'https' are allowed."
 | |
|         )
 | |
| 
 | |
|     # Ensure that the netloc (domain + port) is present, indicating it's a valid URL
 | |
|     if not parsed_url.netloc:
 | |
|         raise ValueError(f"Invalid URL structure in CORS_ALLOW_ORIGIN: '{origin}'.")
 | |
| 
 | |
| 
 | |
| # For production, you should only need one host as
 | |
| # fastapi serves the svelte-kit built frontend and backend from the same host and port.
 | |
| # To test CORS_ALLOW_ORIGIN locally, you can set something like
 | |
| # CORS_ALLOW_ORIGIN=http://localhost:5173;http://localhost:8080
 | |
| # in your .env file depending on your frontend port, 5173 in this case.
 | |
| CORS_ALLOW_ORIGIN = os.environ.get("CORS_ALLOW_ORIGIN", "*").split(";")
 | |
| 
 | |
| if "*" in CORS_ALLOW_ORIGIN:
 | |
|     log.warning(
 | |
|         "\n\nWARNING: CORS_ALLOW_ORIGIN IS SET TO '*' - NOT RECOMMENDED FOR PRODUCTION DEPLOYMENTS.\n"
 | |
|     )
 | |
| 
 | |
| validate_cors_origins(CORS_ALLOW_ORIGIN)
 | |
| 
 | |
| 
 | |
| class BannerModel(BaseModel):
 | |
|     id: str
 | |
|     type: str
 | |
|     title: Optional[str] = None
 | |
|     content: str
 | |
|     dismissible: bool
 | |
|     timestamp: int
 | |
| 
 | |
| 
 | |
| try:
 | |
|     banners = json.loads(os.environ.get("WEBUI_BANNERS", "[]"))
 | |
|     banners = [BannerModel(**banner) for banner in banners]
 | |
| except Exception as e:
 | |
|     print(f"Error loading WEBUI_BANNERS: {e}")
 | |
|     banners = []
 | |
| 
 | |
| WEBUI_BANNERS = PersistentConfig("WEBUI_BANNERS", "ui.banners", banners)
 | |
| 
 | |
| 
 | |
| SHOW_ADMIN_DETAILS = PersistentConfig(
 | |
|     "SHOW_ADMIN_DETAILS",
 | |
|     "auth.admin.show",
 | |
|     os.environ.get("SHOW_ADMIN_DETAILS", "true").lower() == "true",
 | |
| )
 | |
| 
 | |
| ADMIN_EMAIL = PersistentConfig(
 | |
|     "ADMIN_EMAIL",
 | |
|     "auth.admin.email",
 | |
|     os.environ.get("ADMIN_EMAIL", None),
 | |
| )
 | |
| 
 | |
| 
 | |
| ####################################
 | |
| # TASKS
 | |
| ####################################
 | |
| 
 | |
| 
 | |
| TASK_MODEL = PersistentConfig(
 | |
|     "TASK_MODEL",
 | |
|     "task.model.default",
 | |
|     os.environ.get("TASK_MODEL", ""),
 | |
| )
 | |
| 
 | |
| TASK_MODEL_EXTERNAL = PersistentConfig(
 | |
|     "TASK_MODEL_EXTERNAL",
 | |
|     "task.model.external",
 | |
|     os.environ.get("TASK_MODEL_EXTERNAL", ""),
 | |
| )
 | |
| 
 | |
| TITLE_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
 | |
|     "TITLE_GENERATION_PROMPT_TEMPLATE",
 | |
|     "task.title.prompt_template",
 | |
|     os.environ.get("TITLE_GENERATION_PROMPT_TEMPLATE", ""),
 | |
| )
 | |
| 
 | |
| ENABLE_SEARCH_QUERY = PersistentConfig(
 | |
|     "ENABLE_SEARCH_QUERY",
 | |
|     "task.search.enable",
 | |
|     os.environ.get("ENABLE_SEARCH_QUERY", "True").lower() == "true",
 | |
| )
 | |
| 
 | |
| 
 | |
| SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
 | |
|     "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE",
 | |
|     "task.search.prompt_template",
 | |
|     os.environ.get("SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE", ""),
 | |
| )
 | |
| 
 | |
| 
 | |
| TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
 | |
|     "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
 | |
|     "task.tools.prompt_template",
 | |
|     os.environ.get("TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE", ""),
 | |
| )
 | |
| 
 | |
| 
 | |
| ####################################
 | |
| # Vector Database
 | |
| ####################################
 | |
| 
 | |
| VECTOR_DB = os.environ.get("VECTOR_DB", "chroma")
 | |
| 
 | |
| # Chroma
 | |
| CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
 | |
| CHROMA_TENANT = os.environ.get("CHROMA_TENANT", chromadb.DEFAULT_TENANT)
 | |
| CHROMA_DATABASE = os.environ.get("CHROMA_DATABASE", chromadb.DEFAULT_DATABASE)
 | |
| CHROMA_HTTP_HOST = os.environ.get("CHROMA_HTTP_HOST", "")
 | |
| CHROMA_HTTP_PORT = int(os.environ.get("CHROMA_HTTP_PORT", "8000"))
 | |
| # Comma-separated list of header=value pairs
 | |
| CHROMA_HTTP_HEADERS = os.environ.get("CHROMA_HTTP_HEADERS", "")
 | |
| if CHROMA_HTTP_HEADERS:
 | |
|     CHROMA_HTTP_HEADERS = dict(
 | |
|         [pair.split("=") for pair in CHROMA_HTTP_HEADERS.split(",")]
 | |
|     )
 | |
| else:
 | |
|     CHROMA_HTTP_HEADERS = None
 | |
| CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true"
 | |
| # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2)
 | |
| 
 | |
| # Milvus
 | |
| 
 | |
| MILVUS_URI = os.environ.get("MILVUS_URI", f"{DATA_DIR}/vector_db/milvus.db")
 | |
| 
 | |
| ####################################
 | |
| # RAG
 | |
| ####################################
 | |
| 
 | |
| # RAG Content Extraction
 | |
| CONTENT_EXTRACTION_ENGINE = PersistentConfig(
 | |
|     "CONTENT_EXTRACTION_ENGINE",
 | |
|     "rag.CONTENT_EXTRACTION_ENGINE",
 | |
|     os.environ.get("CONTENT_EXTRACTION_ENGINE", "").lower(),
 | |
| )
 | |
| 
 | |
| TIKA_SERVER_URL = PersistentConfig(
 | |
|     "TIKA_SERVER_URL",
 | |
|     "rag.tika_server_url",
 | |
|     os.getenv("TIKA_SERVER_URL", "http://tika:9998"),  # Default for sidecar deployment
 | |
| )
 | |
| 
 | |
| RAG_TOP_K = PersistentConfig(
 | |
|     "RAG_TOP_K", "rag.top_k", int(os.environ.get("RAG_TOP_K", "3"))
 | |
| )
 | |
| RAG_RELEVANCE_THRESHOLD = PersistentConfig(
 | |
|     "RAG_RELEVANCE_THRESHOLD",
 | |
|     "rag.relevance_threshold",
 | |
|     float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0")),
 | |
| )
 | |
| 
 | |
| ENABLE_RAG_HYBRID_SEARCH = PersistentConfig(
 | |
|     "ENABLE_RAG_HYBRID_SEARCH",
 | |
|     "rag.enable_hybrid_search",
 | |
|     os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true",
 | |
| )
 | |
| 
 | |
| RAG_FILE_MAX_COUNT = PersistentConfig(
 | |
|     "RAG_FILE_MAX_COUNT",
 | |
|     "rag.file.max_count",
 | |
|     (
 | |
|         int(os.environ.get("RAG_FILE_MAX_COUNT"))
 | |
|         if os.environ.get("RAG_FILE_MAX_COUNT")
 | |
|         else None
 | |
|     ),
 | |
| )
 | |
| 
 | |
| RAG_FILE_MAX_SIZE = PersistentConfig(
 | |
|     "RAG_FILE_MAX_SIZE",
 | |
|     "rag.file.max_size",
 | |
|     (
 | |
|         int(os.environ.get("RAG_FILE_MAX_SIZE"))
 | |
|         if os.environ.get("RAG_FILE_MAX_SIZE")
 | |
|         else None
 | |
|     ),
 | |
| )
 | |
| 
 | |
| ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = PersistentConfig(
 | |
|     "ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION",
 | |
|     "rag.enable_web_loader_ssl_verification",
 | |
|     os.environ.get("ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION", "True").lower() == "true",
 | |
| )
 | |
| 
 | |
| RAG_EMBEDDING_ENGINE = PersistentConfig(
 | |
|     "RAG_EMBEDDING_ENGINE",
 | |
|     "rag.embedding_engine",
 | |
|     os.environ.get("RAG_EMBEDDING_ENGINE", ""),
 | |
| )
 | |
| 
 | |
| PDF_EXTRACT_IMAGES = PersistentConfig(
 | |
|     "PDF_EXTRACT_IMAGES",
 | |
|     "rag.pdf_extract_images",
 | |
|     os.environ.get("PDF_EXTRACT_IMAGES", "False").lower() == "true",
 | |
| )
 | |
| 
 | |
| RAG_EMBEDDING_MODEL = PersistentConfig(
 | |
|     "RAG_EMBEDDING_MODEL",
 | |
|     "rag.embedding_model",
 | |
|     os.environ.get("RAG_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2"),
 | |
| )
 | |
| log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL.value}")
 | |
| 
 | |
| RAG_EMBEDDING_MODEL_AUTO_UPDATE = (
 | |
|     os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true"
 | |
| )
 | |
| 
 | |
| RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
 | |
|     os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
 | |
| )
 | |
| 
 | |
| RAG_EMBEDDING_OPENAI_BATCH_SIZE = PersistentConfig(
 | |
|     "RAG_EMBEDDING_OPENAI_BATCH_SIZE",
 | |
|     "rag.embedding_openai_batch_size",
 | |
|     int(os.environ.get("RAG_EMBEDDING_OPENAI_BATCH_SIZE", "1")),
 | |
| )
 | |
| 
 | |
| RAG_RERANKING_MODEL = PersistentConfig(
 | |
|     "RAG_RERANKING_MODEL",
 | |
|     "rag.reranking_model",
 | |
|     os.environ.get("RAG_RERANKING_MODEL", ""),
 | |
| )
 | |
| if RAG_RERANKING_MODEL.value != "":
 | |
|     log.info(f"Reranking model set: {RAG_RERANKING_MODEL.value}")
 | |
| 
 | |
| RAG_RERANKING_MODEL_AUTO_UPDATE = (
 | |
|     os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "").lower() == "true"
 | |
| )
 | |
| 
 | |
| RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = (
 | |
|     os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
 | |
| )
 | |
| 
 | |
| CHUNK_SIZE = PersistentConfig(
 | |
|     "CHUNK_SIZE", "rag.chunk_size", int(os.environ.get("CHUNK_SIZE", "1000"))
 | |
| )
 | |
| CHUNK_OVERLAP = PersistentConfig(
 | |
|     "CHUNK_OVERLAP",
 | |
|     "rag.chunk_overlap",
 | |
|     int(os.environ.get("CHUNK_OVERLAP", "100")),
 | |
| )
 | |
| 
 | |
| DEFAULT_RAG_TEMPLATE = """You are given a user query, some textual context and rules, all inside xml tags. You have to answer the query based on the context while respecting the rules.
 | |
| 
 | |
| <context>
 | |
| [context]
 | |
| </context>
 | |
| 
 | |
| <rules>
 | |
| - If you don't know, just say so.
 | |
| - If you are not sure, ask for clarification.
 | |
| - Answer in the same language as the user query.
 | |
| - If the context appears unreadable or of poor quality, tell the user then answer as best as you can.
 | |
| - If the answer is not in the context but you think you know the answer, explain that to the user then answer with your own knowledge.
 | |
| - Answer directly and without using xml tags.
 | |
| </rules>
 | |
| 
 | |
| <user_query>
 | |
| [query]
 | |
| </user_query>
 | |
| """
 | |
| 
 | |
| RAG_TEMPLATE = PersistentConfig(
 | |
|     "RAG_TEMPLATE",
 | |
|     "rag.template",
 | |
|     os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE),
 | |
| )
 | |
| 
 | |
| RAG_OPENAI_API_BASE_URL = PersistentConfig(
 | |
|     "RAG_OPENAI_API_BASE_URL",
 | |
|     "rag.openai_api_base_url",
 | |
|     os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
 | |
| )
 | |
| RAG_OPENAI_API_KEY = PersistentConfig(
 | |
|     "RAG_OPENAI_API_KEY",
 | |
|     "rag.openai_api_key",
 | |
|     os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY),
 | |
| )
 | |
| 
 | |
| ENABLE_RAG_LOCAL_WEB_FETCH = (
 | |
|     os.getenv("ENABLE_RAG_LOCAL_WEB_FETCH", "False").lower() == "true"
 | |
| )
 | |
| 
 | |
| YOUTUBE_LOADER_LANGUAGE = PersistentConfig(
 | |
|     "YOUTUBE_LOADER_LANGUAGE",
 | |
|     "rag.youtube_loader_language",
 | |
|     os.getenv("YOUTUBE_LOADER_LANGUAGE", "en").split(","),
 | |
| )
 | |
| 
 | |
| 
 | |
| ENABLE_RAG_WEB_SEARCH = PersistentConfig(
 | |
|     "ENABLE_RAG_WEB_SEARCH",
 | |
|     "rag.web.search.enable",
 | |
|     os.getenv("ENABLE_RAG_WEB_SEARCH", "False").lower() == "true",
 | |
| )
 | |
| 
 | |
| RAG_WEB_SEARCH_ENGINE = PersistentConfig(
 | |
|     "RAG_WEB_SEARCH_ENGINE",
 | |
|     "rag.web.search.engine",
 | |
|     os.getenv("RAG_WEB_SEARCH_ENGINE", ""),
 | |
| )
 | |
| 
 | |
| # You can provide a list of your own websites to filter after performing a web search.
 | |
| # This ensures the highest level of safety and reliability of the information sources.
 | |
| RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = PersistentConfig(
 | |
|     "RAG_WEB_SEARCH_DOMAIN_FILTER_LIST",
 | |
|     "rag.rag.web.search.domain.filter_list",
 | |
|     [
 | |
|         # "wikipedia.com",
 | |
|         # "wikimedia.org",
 | |
|         # "wikidata.org",
 | |
|     ],
 | |
| )
 | |
| 
 | |
| SEARXNG_QUERY_URL = PersistentConfig(
 | |
|     "SEARXNG_QUERY_URL",
 | |
|     "rag.web.search.searxng_query_url",
 | |
|     os.getenv("SEARXNG_QUERY_URL", ""),
 | |
| )
 | |
| 
 | |
| GOOGLE_PSE_API_KEY = PersistentConfig(
 | |
|     "GOOGLE_PSE_API_KEY",
 | |
|     "rag.web.search.google_pse_api_key",
 | |
|     os.getenv("GOOGLE_PSE_API_KEY", ""),
 | |
| )
 | |
| 
 | |
| GOOGLE_PSE_ENGINE_ID = PersistentConfig(
 | |
|     "GOOGLE_PSE_ENGINE_ID",
 | |
|     "rag.web.search.google_pse_engine_id",
 | |
|     os.getenv("GOOGLE_PSE_ENGINE_ID", ""),
 | |
| )
 | |
| 
 | |
| BRAVE_SEARCH_API_KEY = PersistentConfig(
 | |
|     "BRAVE_SEARCH_API_KEY",
 | |
|     "rag.web.search.brave_search_api_key",
 | |
|     os.getenv("BRAVE_SEARCH_API_KEY", ""),
 | |
| )
 | |
| 
 | |
| SERPSTACK_API_KEY = PersistentConfig(
 | |
|     "SERPSTACK_API_KEY",
 | |
|     "rag.web.search.serpstack_api_key",
 | |
|     os.getenv("SERPSTACK_API_KEY", ""),
 | |
| )
 | |
| 
 | |
| SERPSTACK_HTTPS = PersistentConfig(
 | |
|     "SERPSTACK_HTTPS",
 | |
|     "rag.web.search.serpstack_https",
 | |
|     os.getenv("SERPSTACK_HTTPS", "True").lower() == "true",
 | |
| )
 | |
| 
 | |
| SERPER_API_KEY = PersistentConfig(
 | |
|     "SERPER_API_KEY",
 | |
|     "rag.web.search.serper_api_key",
 | |
|     os.getenv("SERPER_API_KEY", ""),
 | |
| )
 | |
| 
 | |
| SERPLY_API_KEY = PersistentConfig(
 | |
|     "SERPLY_API_KEY",
 | |
|     "rag.web.search.serply_api_key",
 | |
|     os.getenv("SERPLY_API_KEY", ""),
 | |
| )
 | |
| 
 | |
| TAVILY_API_KEY = PersistentConfig(
 | |
|     "TAVILY_API_KEY",
 | |
|     "rag.web.search.tavily_api_key",
 | |
|     os.getenv("TAVILY_API_KEY", ""),
 | |
| )
 | |
| 
 | |
| SEARCHAPI_API_KEY = PersistentConfig(
 | |
|     "SEARCHAPI_API_KEY",
 | |
|     "rag.web.search.searchapi_api_key",
 | |
|     os.getenv("SEARCHAPI_API_KEY", ""),
 | |
| )
 | |
| 
 | |
| SEARCHAPI_ENGINE = PersistentConfig(
 | |
|     "SEARCHAPI_ENGINE",
 | |
|     "rag.web.search.searchapi_engine",
 | |
|     os.getenv("SEARCHAPI_ENGINE", ""),
 | |
| )
 | |
| 
 | |
| RAG_WEB_SEARCH_RESULT_COUNT = PersistentConfig(
 | |
|     "RAG_WEB_SEARCH_RESULT_COUNT",
 | |
|     "rag.web.search.result_count",
 | |
|     int(os.getenv("RAG_WEB_SEARCH_RESULT_COUNT", "3")),
 | |
| )
 | |
| 
 | |
| RAG_WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig(
 | |
|     "RAG_WEB_SEARCH_CONCURRENT_REQUESTS",
 | |
|     "rag.web.search.concurrent_requests",
 | |
|     int(os.getenv("RAG_WEB_SEARCH_CONCURRENT_REQUESTS", "10")),
 | |
| )
 | |
| 
 | |
| 
 | |
| ####################################
 | |
| # Transcribe
 | |
| ####################################
 | |
| 
 | |
| WHISPER_MODEL = os.getenv("WHISPER_MODEL", "base")
 | |
| WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models")
 | |
| WHISPER_MODEL_AUTO_UPDATE = (
 | |
|     os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
 | |
| )
 | |
| 
 | |
| 
 | |
| ####################################
 | |
| # Images
 | |
| ####################################
 | |
| 
 | |
| IMAGE_GENERATION_ENGINE = PersistentConfig(
 | |
|     "IMAGE_GENERATION_ENGINE",
 | |
|     "image_generation.engine",
 | |
|     os.getenv("IMAGE_GENERATION_ENGINE", "openai"),
 | |
| )
 | |
| 
 | |
| ENABLE_IMAGE_GENERATION = PersistentConfig(
 | |
|     "ENABLE_IMAGE_GENERATION",
 | |
|     "image_generation.enable",
 | |
|     os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true",
 | |
| )
 | |
| AUTOMATIC1111_BASE_URL = PersistentConfig(
 | |
|     "AUTOMATIC1111_BASE_URL",
 | |
|     "image_generation.automatic1111.base_url",
 | |
|     os.getenv("AUTOMATIC1111_BASE_URL", ""),
 | |
| )
 | |
| AUTOMATIC1111_API_AUTH = PersistentConfig(
 | |
|     "AUTOMATIC1111_API_AUTH",
 | |
|     "image_generation.automatic1111.api_auth",
 | |
|     os.getenv("AUTOMATIC1111_API_AUTH", ""),
 | |
| )
 | |
| 
 | |
| AUTOMATIC1111_CFG_SCALE = PersistentConfig(
 | |
|     "AUTOMATIC1111_CFG_SCALE",
 | |
|     "image_generation.automatic1111.cfg_scale",
 | |
|     (
 | |
|         float(os.environ.get("AUTOMATIC1111_CFG_SCALE"))
 | |
|         if os.environ.get("AUTOMATIC1111_CFG_SCALE")
 | |
|         else None
 | |
|     ),
 | |
| )
 | |
| 
 | |
| 
 | |
| AUTOMATIC1111_SAMPLER = PersistentConfig(
 | |
|     "AUTOMATIC1111_SAMPLERE",
 | |
|     "image_generation.automatic1111.sampler",
 | |
|     (
 | |
|         os.environ.get("AUTOMATIC1111_SAMPLER")
 | |
|         if os.environ.get("AUTOMATIC1111_SAMPLER")
 | |
|         else None
 | |
|     ),
 | |
| )
 | |
| 
 | |
| AUTOMATIC1111_SCHEDULER = PersistentConfig(
 | |
|     "AUTOMATIC1111_SCHEDULER",
 | |
|     "image_generation.automatic1111.scheduler",
 | |
|     (
 | |
|         os.environ.get("AUTOMATIC1111_SCHEDULER")
 | |
|         if os.environ.get("AUTOMATIC1111_SCHEDULER")
 | |
|         else None
 | |
|     ),
 | |
| )
 | |
| 
 | |
| COMFYUI_BASE_URL = PersistentConfig(
 | |
|     "COMFYUI_BASE_URL",
 | |
|     "image_generation.comfyui.base_url",
 | |
|     os.getenv("COMFYUI_BASE_URL", ""),
 | |
| )
 | |
| 
 | |
| COMFYUI_DEFAULT_WORKFLOW = """
 | |
| {
 | |
|   "3": {
 | |
|     "inputs": {
 | |
|       "seed": 0,
 | |
|       "steps": 20,
 | |
|       "cfg": 8,
 | |
|       "sampler_name": "euler",
 | |
|       "scheduler": "normal",
 | |
|       "denoise": 1,
 | |
|       "model": [
 | |
|         "4",
 | |
|         0
 | |
|       ],
 | |
|       "positive": [
 | |
|         "6",
 | |
|         0
 | |
|       ],
 | |
|       "negative": [
 | |
|         "7",
 | |
|         0
 | |
|       ],
 | |
|       "latent_image": [
 | |
|         "5",
 | |
|         0
 | |
|       ]
 | |
|     },
 | |
|     "class_type": "KSampler",
 | |
|     "_meta": {
 | |
|       "title": "KSampler"
 | |
|     }
 | |
|   },
 | |
|   "4": {
 | |
|     "inputs": {
 | |
|       "ckpt_name": "model.safetensors"
 | |
|     },
 | |
|     "class_type": "CheckpointLoaderSimple",
 | |
|     "_meta": {
 | |
|       "title": "Load Checkpoint"
 | |
|     }
 | |
|   },
 | |
|   "5": {
 | |
|     "inputs": {
 | |
|       "width": 512,
 | |
|       "height": 512,
 | |
|       "batch_size": 1
 | |
|     },
 | |
|     "class_type": "EmptyLatentImage",
 | |
|     "_meta": {
 | |
|       "title": "Empty Latent Image"
 | |
|     }
 | |
|   },
 | |
|   "6": {
 | |
|     "inputs": {
 | |
|       "text": "Prompt",
 | |
|       "clip": [
 | |
|         "4",
 | |
|         1
 | |
|       ]
 | |
|     },
 | |
|     "class_type": "CLIPTextEncode",
 | |
|     "_meta": {
 | |
|       "title": "CLIP Text Encode (Prompt)"
 | |
|     }
 | |
|   },
 | |
|   "7": {
 | |
|     "inputs": {
 | |
|       "text": "",
 | |
|       "clip": [
 | |
|         "4",
 | |
|         1
 | |
|       ]
 | |
|     },
 | |
|     "class_type": "CLIPTextEncode",
 | |
|     "_meta": {
 | |
|       "title": "CLIP Text Encode (Prompt)"
 | |
|     }
 | |
|   },
 | |
|   "8": {
 | |
|     "inputs": {
 | |
|       "samples": [
 | |
|         "3",
 | |
|         0
 | |
|       ],
 | |
|       "vae": [
 | |
|         "4",
 | |
|         2
 | |
|       ]
 | |
|     },
 | |
|     "class_type": "VAEDecode",
 | |
|     "_meta": {
 | |
|       "title": "VAE Decode"
 | |
|     }
 | |
|   },
 | |
|   "9": {
 | |
|     "inputs": {
 | |
|       "filename_prefix": "ComfyUI",
 | |
|       "images": [
 | |
|         "8",
 | |
|         0
 | |
|       ]
 | |
|     },
 | |
|     "class_type": "SaveImage",
 | |
|     "_meta": {
 | |
|       "title": "Save Image"
 | |
|     }
 | |
|   }
 | |
| }
 | |
| """
 | |
| 
 | |
| 
 | |
| COMFYUI_WORKFLOW = PersistentConfig(
 | |
|     "COMFYUI_WORKFLOW",
 | |
|     "image_generation.comfyui.workflow",
 | |
|     os.getenv("COMFYUI_WORKFLOW", COMFYUI_DEFAULT_WORKFLOW),
 | |
| )
 | |
| 
 | |
| COMFYUI_WORKFLOW_NODES = PersistentConfig(
 | |
|     "COMFYUI_WORKFLOW",
 | |
|     "image_generation.comfyui.nodes",
 | |
|     [],
 | |
| )
 | |
| 
 | |
| IMAGES_OPENAI_API_BASE_URL = PersistentConfig(
 | |
|     "IMAGES_OPENAI_API_BASE_URL",
 | |
|     "image_generation.openai.api_base_url",
 | |
|     os.getenv("IMAGES_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
 | |
| )
 | |
| IMAGES_OPENAI_API_KEY = PersistentConfig(
 | |
|     "IMAGES_OPENAI_API_KEY",
 | |
|     "image_generation.openai.api_key",
 | |
|     os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY),
 | |
| )
 | |
| 
 | |
| IMAGE_SIZE = PersistentConfig(
 | |
|     "IMAGE_SIZE", "image_generation.size", os.getenv("IMAGE_SIZE", "512x512")
 | |
| )
 | |
| 
 | |
| IMAGE_STEPS = PersistentConfig(
 | |
|     "IMAGE_STEPS", "image_generation.steps", int(os.getenv("IMAGE_STEPS", 50))
 | |
| )
 | |
| 
 | |
| IMAGE_GENERATION_MODEL = PersistentConfig(
 | |
|     "IMAGE_GENERATION_MODEL",
 | |
|     "image_generation.model",
 | |
|     os.getenv("IMAGE_GENERATION_MODEL", ""),
 | |
| )
 | |
| 
 | |
| ####################################
 | |
| # Audio
 | |
| ####################################
 | |
| 
 | |
| AUDIO_STT_OPENAI_API_BASE_URL = PersistentConfig(
 | |
|     "AUDIO_STT_OPENAI_API_BASE_URL",
 | |
|     "audio.stt.openai.api_base_url",
 | |
|     os.getenv("AUDIO_STT_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
 | |
| )
 | |
| 
 | |
| AUDIO_STT_OPENAI_API_KEY = PersistentConfig(
 | |
|     "AUDIO_STT_OPENAI_API_KEY",
 | |
|     "audio.stt.openai.api_key",
 | |
|     os.getenv("AUDIO_STT_OPENAI_API_KEY", OPENAI_API_KEY),
 | |
| )
 | |
| 
 | |
| AUDIO_STT_ENGINE = PersistentConfig(
 | |
|     "AUDIO_STT_ENGINE",
 | |
|     "audio.stt.engine",
 | |
|     os.getenv("AUDIO_STT_ENGINE", ""),
 | |
| )
 | |
| 
 | |
| AUDIO_STT_MODEL = PersistentConfig(
 | |
|     "AUDIO_STT_MODEL",
 | |
|     "audio.stt.model",
 | |
|     os.getenv("AUDIO_STT_MODEL", "whisper-1"),
 | |
| )
 | |
| 
 | |
| AUDIO_TTS_OPENAI_API_BASE_URL = PersistentConfig(
 | |
|     "AUDIO_TTS_OPENAI_API_BASE_URL",
 | |
|     "audio.tts.openai.api_base_url",
 | |
|     os.getenv("AUDIO_TTS_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
 | |
| )
 | |
| AUDIO_TTS_OPENAI_API_KEY = PersistentConfig(
 | |
|     "AUDIO_TTS_OPENAI_API_KEY",
 | |
|     "audio.tts.openai.api_key",
 | |
|     os.getenv("AUDIO_TTS_OPENAI_API_KEY", OPENAI_API_KEY),
 | |
| )
 | |
| 
 | |
| AUDIO_TTS_API_KEY = PersistentConfig(
 | |
|     "AUDIO_TTS_API_KEY",
 | |
|     "audio.tts.api_key",
 | |
|     os.getenv("AUDIO_TTS_API_KEY", ""),
 | |
| )
 | |
| 
 | |
| AUDIO_TTS_ENGINE = PersistentConfig(
 | |
|     "AUDIO_TTS_ENGINE",
 | |
|     "audio.tts.engine",
 | |
|     os.getenv("AUDIO_TTS_ENGINE", ""),
 | |
| )
 | |
| 
 | |
| 
 | |
| AUDIO_TTS_MODEL = PersistentConfig(
 | |
|     "AUDIO_TTS_MODEL",
 | |
|     "audio.tts.model",
 | |
|     os.getenv("AUDIO_TTS_MODEL", "tts-1"),  # OpenAI default model
 | |
| )
 | |
| 
 | |
| AUDIO_TTS_VOICE = PersistentConfig(
 | |
|     "AUDIO_TTS_VOICE",
 | |
|     "audio.tts.voice",
 | |
|     os.getenv("AUDIO_TTS_VOICE", "alloy"),  # OpenAI default voice
 | |
| )
 | |
| 
 | |
| AUDIO_TTS_SPLIT_ON = PersistentConfig(
 | |
|     "AUDIO_TTS_SPLIT_ON",
 | |
|     "audio.tts.split_on",
 | |
|     os.getenv("AUDIO_TTS_SPLIT_ON", "punctuation"),
 | |
| )
 | |
| 
 | |
| AUDIO_TTS_AZURE_SPEECH_REGION = PersistentConfig(
 | |
|     "AUDIO_TTS_AZURE_SPEECH_REGION",
 | |
|     "audio.tts.azure.speech_region",
 | |
|     os.getenv("AUDIO_TTS_AZURE_SPEECH_REGION", "eastus"),
 | |
| )
 | |
| 
 | |
| AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT = PersistentConfig(
 | |
|     "AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT",
 | |
|     "audio.tts.azure.speech_output_format",
 | |
|     os.getenv(
 | |
|         "AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT", "audio-24khz-160kbitrate-mono-mp3"
 | |
|     ),
 | |
| )
 |