Merge branch 'dev' into vector-search-branch

This commit is contained in:
hiwylee 2025-08-01 04:23:38 +09:00 committed by GitHub
commit bd215a1b96
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
367 changed files with 23637 additions and 9204 deletions

View File

@ -32,7 +32,7 @@ jobs:
node-version: '22' node-version: '22'
- name: Install Dependencies - name: Install Dependencies
run: npm install run: npm install --force
- name: Format Frontend - name: Format Frontend
run: npm run format run: npm run format
@ -59,7 +59,7 @@ jobs:
node-version: '22' node-version: '22'
- name: Install Dependencies - name: Install Dependencies
run: npm ci run: npm ci --force
- name: Run vitest - name: Run vitest
run: npm run test:frontend run: npm run test:frontend

View File

@ -5,6 +5,118 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.6.18] - 2025-07-19
### Fixed
- 🚑 **Users Not Loading in Groups**: Resolved an issue where user list was not displaying within user groups, restoring full visibility and management of group memberships for teams and admins.
## [0.6.17] - 2025-07-19
### Added
- 📂 **Dedicated Folder View with Chat List**: Clicking a folder now reveals a brand-new landing page showcasing a list of all chats within that folder, making navigation simpler and giving teams immediate visibility into project-specific conversations.
- 🆕 **Streamlined Folder Creation Modal**: Creating a new folder is now a seamless, unified experience with a dedicated modal that visually and functionally matches the edit folder flow, making workspace organization more intuitive and error-free for all users.
- 🗃️ **Direct File Uploads to Folder Knowledge**: You can now upload files straight to a folders knowledge—empowering you to enrich project spaces by adding resources and documents directly, without the need to pre-create knowledge bases beforehand.
- 🔎 **Chat Preview in Search**: When searching chats, instantly preview results in context without having to open them—making discovery, auditing, and recall dramatically quicker, especially in large, active teams.
- 🖼️ **Image Upload and Inline Insertion in Notes**: Notes now support inserting images directly among your text, letting you create rich, visually structured documentation, brainstorms, or reports in a more natural and engaging way—no more images just as attachments.
- 📱 **Enhanced Note Selection Editing and Q&A**: Select any portion of your notes to either edit just the highlighted part or ask focused questions about that content—streamlining workflows, boosting productivity, and making reviews or AI-powered enhancements more targeted.
- 📝 **Copy Notes as Rich Text**: Copy entire notes—including all formatting, images, and structure—directly as rich text for seamless pasting into emails, reports, or other tools, maintaining clarity and consistency outside the WebUI.
- ⚡ **Fade-In Streaming Text Experience**: Live-generated responses now elegantly fade in as the AI streams them, creating a more natural and visually engaging reading experience; easily toggled off in Interface settings if you prefer static displays.
- 🔄 **Settings for Follow-Up Prompts**: Fine-tune your follow-up prompt experience—with new controls, you can choose to keep them visible or have them inserted directly into the message input instead of auto-submitting, giving you more flexibility and control over your workflow.
- 🔗 **Prompt Variable Documentation Quick Link**: Access documentation for prompt variables in one click from the prompt editor modal—shortening the learning curve and making advanced prompt-building more accessible.
- 📈 **Active and Total User Metrics for Telemetry**: Gain valuable insights into usage patterns and platform engagement with new metrics tracking active and total users—enhancing auditability and planning for large organizations.
- 🏷️ **Traceability with Log Trace and Span IDs**: Each log entry now carries detailed trace and span IDs, making it much easier for admins to pinpoint and resolve issues across distributed systems or in complex troubleshooting.
- 👥 **User Group Add/Remove Endpoints**: Effortlessly add or remove users from groups with new, improved endpoints—giving admins and team leads faster, clearer control over collaboration and permissions.
- ⚙️ **Note Settings and Controls Streamlined**: The main “Settings” for notes are now simply called “Controls”, and note files now reside in a dedicated controls section, decluttering navigation and making it easier to find and configure note-related options.
- 🚀 **Faster Admin User Page Loads**: The user list endpoint for admins has been optimized to exclude heavy profile images, speeding up load times for large teams and reducing waiting during administrative tasks.
- 📡 **Chat ID Header Forwarding**: Ollama and OpenAI router requests now include the chat ID in request headers, enabling better request correlation and debugging capabilities across AI model integrations.
- 🧠 **Enhanced Reasoning Tag Processing**: Improved and expanded reasoning tag parsing to handle various tag formats more robustly, including standard XML-style tags and custom delimiters, ensuring better AI reasoning transparency and debugging capabilities.
- 🔐 **OAuth Token Endpoint Authentication Method**: Added configurable OAuth token endpoint authentication method support, providing enhanced flexibility and security options for enterprise OAuth integrations and identity provider compatibility.
- 🛡️ **Redis Sentinel High Availability Support**: Comprehensive Redis Sentinel failover implementation with automatic master discovery, intelligent retry logic for connection failures, and seamless operation during master node outages—eliminating single points of failure and ensuring continuous service availability in production deployments.
- 🌐 **Localization & Internationalization Improvements**: Refined and expanded translations for Simplified Chinese, Traditional Chinese, French, German, Korean, and Polish, ensuring a more fluent and native experience for global users across all supported languages.
### Fixed
- 🏷️ **Hybrid Search Functionality Restored**: Hybrid search now works seamlessly again—enabling more accurate, relevant, and comprehensive knowledge discovery across all RAG-powered workflows.
- 🚦 **Note Chat - Edit Button Disabled During AI Generation**: The edit button when chatting with a note is now disabled while the AI is responding—preventing accidental edits and ensuring workflow clarity during chat sessions.
- 🧹 **Cleaner Database Credentials**: Database connection no longer duplicates @ in credentials, preventing potential connection issues and ensuring smoother, more reliable integrations.
- 🧑‍💻 **File Deletion Now Removes Related Vector Data**: When files are deleted from storage, they are now purged from the vector database as well, ensuring clean data management and preventing clutter or stale search results.
- 📁 **Files Modal Translation Issues Fixed**: All modal dialog strings—including “Using Entire Document” and “Using Focused Retrieval”—are now fully translated for a more consistent and localized UI experience.
- 🚫 **Drag-and-Drop File Upload Disabled for Unsupported Models**: File upload by drag-and-drop is disabled when using models that do not support attachments—removing confusion and preventing workflow interruptions.
- 🔑 **Ollama Tool Calls Now Reliable**: Fixed issues with Ollama-based tool calls, ensuring uninterrupted AI augmentation and tool use for every chat.
- 📄 **MIME Type Help String Correction**: Cleaned up mimetype help text by removing extraneous characters, providing clearer guidance for file upload configurations.
- 📝 **Note Editor Permission Fix**: Removed unnecessary admin-only restriction from note chat functionality, allowing all authorized users to access note editing features as intended.
- 📋 **Chat Sources Handling Improved**: Fixed sources handling logic to prevent duplicate source assignments in chat messages, ensuring cleaner and more accurate source attribution during conversations.
- 😀 **Emoji Generation Error Handling**: Improved error handling in audio router and fixed metadata structure for emoji generation tasks, preventing crashes and ensuring more reliable emoji generation functionality.
- 🔒 **Folder System Prompt Permission Enforcement**: System prompt fields in folder edit modal are now properly hidden for users without system prompt permissions, ensuring consistent security policy enforcement across all folder management interfaces.
- 🌐 **WebSocket Redis Lock Timeout Type Conversion**: Fixed proper integer type conversion for WebSocket Redis lock timeout configuration with robust error handling, preventing potential configuration errors and ensuring stable WebSocket connections.
- 📦 **PostHog Dependency Added**: Added PostHog 5.4.0 library to resolve ChromaDB compatibility issues, ensuring stable vector database operations and preventing library version conflicts during deployment.
### Changed
- 👀 **Tiptap Editor Upgraded to v3**: The underlying rich text editor has been updated for future-proofing, though some supporting libraries remain on v2 for compatibility. For now, please install dependencies using 'npm install --force' to avoid installation errors.
- 🚫 **Removed Redundant or Unused Strings and Elements**: Miscellaneous unused, duplicate, or obsolete code and translations have been cleaned up to maintain a streamlined and high-performance experience.
## [0.6.16] - 2025-07-14
### Added
- 🗂️ **Folders as Projects**: Organize your workflow with folder-based projects—set folder-level system prompts and associate custom knowledge, bringing seamless, context-rich management to teams and users handling multiple initiatives or clients.
- 📁 **Instant Folder-Based Chat Creation**: Start a new chat directly from any folder; just click and your new conversation is automatically embedded in the right project context—no more manual dragging or setup, saving time and eliminating mistakes.
- 🧩 **Prompt Variables with Automatic Input Modal**: Prompts containing variables now display a clean, auto-generated input modal that **autofocuses on the first field** for instant value entry—just select the prompt and fill in exactly whats needed, reducing friction and guesswork.
- 🔡 **Variable Input Typing in Prompts**: Define input types for prompt variables (e.g., text, textarea, number, select, color, date, map and more), giving everyone a clearer and more precise prompt-building experience for advanced automation or workflows.
- 🚀 **Base Model List Caching**: Cache your base model list to speed up model selection and reduce repeated API calls; toggle this in Admin Settings > Connections for responsive model management even in large or multi-provider setups.
- ⏱️ **Configurable Model List Cache TTL**: Take control over model list caching with the new MODEL_LIST_CACHE_TTL environment variable. Set a custom cache duration in seconds to balance performance and freshness, reducing API requests in stable environments or ensuring rapid updates when models change frequently.
- 🔖 **Reference Notes as Knowledge or in Chats**: Use any note as knowledge for a model or folder, or reference it directly from chat—integrate living documentation into your Retrieval Augmented Generation workflows or discussions, bridging knowledge and action.
- 📝 **Chat Directly with Notes (Experimental)**: Ask questions about any note, and directly edit or update notes from within a chat—unlock direct AI-powered brainstorming, summarization, and cleanup, like having your own collaborative AI canvas.
- 🤝 **Collaborative Notes with Multi-User Editing**: Share notes with others and collaborate live—multiple users can edit a note in real-time, boosting cooperative knowledge building and workflow documentation.
- 🛡️ **Collaborative Note Permissions**: Control who can view or edit each note with robust sharing permissions, ensuring privacy or collaboration per your organizational needs.
- 🔗 **Copy Link to Notes**: Quickly copy and share direct links to notes for easier knowledge transfer within your team or external collaborators.
- 📋 **Task List Support in Notes**: Add, organize, and manage checklists or tasks inside your notes—plan projects, track to-dos, and keep everything actionable in a single space.
- 🧠 **AI-Generated Note Titles**: Instantly generate relevant and concise titles for your notes using AI—keep your knowledge library organized without tedious manual editing.
- 🔄 **Full Undo/Redo Support in Notes**: Effortlessly undo or redo your latest note changes—never fear mistakes or accidental edits while collaborating or writing.
- 📝 **Enhanced Note Word/Character Counter**: Always know the size of your notes with built-in counters, making it easier to adhere to length guidelines for shared or published content.
- 🖊️ **Floating & Bubble Formatting Menus in Note Editor**: Access text formatting tools through both a floating menu and an intuitive bubble menu directly in the note editor—making rich text editing faster, more discoverable, and easier than ever.
- ✍️ **Rich Text Prompt Insertion**: A new setting allows prompts to be inserted directly into the chat box as fully-formatted rich text, preserving Markdown elements like headings, lists, and bold text for a more intuitive and visually consistent editing experience.
- 🌐 **Configurable Database URL**: WebUI now supports more flexible database configuration via new environment variables—making deployment and scaling simpler across various infrastructure setups.
- 🎛️ **Completely Frontend-Handled File Upload in Temporary Chats**: When using temporary chats, file extraction now occurs fully in your browser with zero files sent to the backend, further strengthening privacy and giving you instant feedback.
- 🔄 **Enhanced Banner and Chat Command Visibility**: Banner handling and command feedback in chat are now clearer and more contextually visible, making alerts, suggestions, and automation easier to spot and interact with for all users.
- 📱 **Mobile Experience Polished**: The "new chat" button is back in mobile, plus core navigation and input controls have been smoothed out for better usability on phones and tablets.
- 📄 **OpenDocument Text (.odt) Support**: Seamlessly upload and process .odt files from open-source office suites like LibreOffice and OpenOffice, expanding your ability to build knowledge from a wider range of document formats.
- 📑 **Enhanced Markdown Document Splitting**: Improve knowledge retrieval from Markdown files with a new header-aware splitting strategy. This method intelligently chunks documents based on their header structure, preserving the original context and hierarchy for more accurate and relevant RAG results.
- 📚 **Full Context Mode for Knowledge Bases**: When adding a knowledge base to a folder or custom model, you can now toggle full context mode for the entire knowledge base. This bypasses the usual chunking and retrieval process, making it perfect for leaner knowledge bases.
- 🕰️ **Configurable OAuth Timeout**: Enhance login reliability by setting a custom timeout (OAUTH_TIMEOUT) for all OAuth providers (Google, Microsoft, GitHub, OIDC), preventing authentication failures on slow or restricted networks.
- 🎨 **Accessibility & High-Contrast Theme Enhancements**: Major accessibility overhaul with significant updates to the high-contrast theme. Improved focus visibility, ARIA labels, and semantic HTML ensure core components like the chat interface and model selector are fully compliant and readable for visually impaired users.
- ↕️ **Resizable System Prompt Fields**: Conveniently resize system prompt input fields to comfortably view and edit lengthy or complex instructions, improving the user experience for advanced model configuration.
- 🔧 **Granular Update Check Control**: Gain finer control over outbound connections with the new ENABLE_VERSION_UPDATE_CHECK flag. This allows administrators to disable version update checks independently of the full OFFLINE_MODE, perfect for environments with restricted internet access that still need to download embedding models.
- 🗃️ **Configurable Qdrant Collection Prefix**: Enhance scalability by setting a custom QDRANT_COLLECTION_PREFIX. This allows multiple Open WebUI instances to share a single Qdrant cluster safely, ensuring complete data isolation between separate deployments without conflicts.
- ⚙️ **Improved Default Database Performance**: Enhanced out-of-the-box performance by setting smarter database connection pooling defaults, reducing API response times for users on non-SQLite databases without requiring manual configuration.
- 🔧 **Configurable Redis Key Prefix**: Added support for the REDIS_KEY_PREFIX environment variable, allowing multiple Open WebUI instances to share a Redis cluster with isolated key namespaces for improved multi-tenancy.
- ➡️ **Forward User Context to Reranker**: For advanced RAG integrations, user information (ID, name, email, role) can now be forwarded as HTTP headers to external reranking services, enabling personalized results or per-user access control.
- ⚙️ **PGVector Connection Pooling**: Enhance performance and stability for PGVector-based RAG by enabling and configuring the database connection pool. New environment variables allow fine-tuning of pool size, timeout, and overflow settings to handle high-concurrency workloads efficiently.
- ⚙️ **General Backend Refactoring**: Extensive refactoring delivers a faster, more reliable, and robust backend experience—improving chat speed, model management, and day-to-day reliability.
- 🌍 **Expanded & Improved Translations**: Enjoy a more accessible and intuitive experience thanks to comprehensive updates and enhancements for Chinese (Simplified and Traditional), German, French, Catalan, Irish, and Spanish translations throughout the interface.
### Fixed
- 🛠️ **Rich Text Input Stability and Performance**: Multiple improvements ensure faster, cleaner text editing and rendering with reduced glitches—especially supporting links, color picking, checkbox controls, and code blocks in notes and chats.
- 📷 **Seamless iPhone Image Uploads**: Effortlessly upload photos from iPhones and other devices using HEIC format—images are now correctly recognized and processed, eliminating compatibility issues.
- 🔄 **Audio MIME Type Registration**: Issues with audio file content types have been resolved, guaranteeing smoother, error-free uploads and playback for transcription or note attachments.
- 🖍️ **Input Commands Now Always Visible**: Input commands (like prompts or knowledge) dynamically adjust their height on small screens, ensuring nothing is cut off and every tool remains easily accessible.
- 🛑 **Tool Result Rendering**: Fixed display problems with tool results, providing fast, clear feedback when using external or internal tools.
- 🗂️ **Table Alignment in Markdown**: Markdown tables are now rendered and aligned as expected, keeping reports and documentation readable.
- 🖼️ **Thread Image Handling**: Fixed an issue where messages containing only images in threads werent displayed correctly.
- 🗝️ **Note Access Control Security**: Tightened access control logic for notes to guarantee that shared or collaborative notes respect all user permissions and privacy safeguards.
- 🧾 **Ollama API Compatibility**: Fixed model parameter naming in the API to ensure uninterrupted compatibility for all Ollama endpoints.
- 🛠️ **Detection for 'text/html' Files**: Files loaded with docling/tika are now reliably detected as the correct type, improving knowledge ingestion and document parsing.
- 🔐 **OAuth Login Stability**: Resolved a critical OAuth bug that caused login failures on subsequent attempts after logging out. The user session is now completely cleared on logout, ensuring reliable and secure authentication across all supported providers (Google, Microsoft, GitHub, OIDC).
- 🚪 **OAuth Logout and Redirect Reliability**: The OAuth logout process has been made more robust. Logout requests now correctly use proxy environment variables, ensuring they succeed in corporate networks. Additionally, the custom WEBUI_AUTH_SIGNOUT_REDIRECT_URL is now properly respected for all OAuth/OIDC configurations, ensuring a seamless sign-out experience.
- 📜 **Banner Newline Rendering**: Banners now correctly render newline characters, ensuring that multi-line announcements and messages are displayed with their intended formatting.
- **Consistent Model Description Rendering**: Model descriptions now render Markdown correctly in the main chat interface, matching the formatting seen in the model selection dropdown for a consistent user experience.
- 🔄 **Offline Mode Update Check Display**: Corrected a UI bug where the "Checking for Updates..." message would display indefinitely when the application was set to offline mode.
- 🛠️ **Tool Result Encoding**: Fixed a bug where tool calls returning non-ASCII characters would fail, ensuring robust handling of international text and special characters in tool outputs.
## [0.6.15] - 2025-06-16 ## [0.6.15] - 2025-06-16
### Added ### Added

View File

@ -30,7 +30,7 @@ WORKDIR /app
RUN apk add --no-cache git RUN apk add --no-cache git
COPY package.json package-lock.json ./ COPY package.json package-lock.json ./
RUN npm ci RUN npm ci --force
COPY . . COPY . .
ENV APP_BUILD_HASH=${BUILD_HASH} ENV APP_BUILD_HASH=${BUILD_HASH}

53
LICENSE_HISTORY Normal file
View File

@ -0,0 +1,53 @@
All code and materials created before commit `60d84a3aae9802339705826e9095e272e3c83623` are subject to the following copyright and license:
Copyright (c) 2023-2025 Timothy Jaeryang Baek
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
All code and materials created before commit `a76068d69cd59568b920dfab85dc573dbbb8f131` are subject to the following copyright and license:
MIT License
Copyright (c) 2023 Timothy Jaeryang Baek
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -1,2 +1,3 @@
export CORS_ALLOW_ORIGIN="http://localhost:5173"
PORT="${PORT:-8080}" PORT="${PORT:-8080}"
uvicorn open_webui.main:app --port $PORT --host 0.0.0.0 --forwarded-allow-ips '*' --reload uvicorn open_webui.main:app --port $PORT --host 0.0.0.0 --forwarded-allow-ips '*' --reload

View File

@ -13,12 +13,15 @@ from urllib.parse import urlparse
import requests import requests
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import JSON, Column, DateTime, Integer, func from sqlalchemy import JSON, Column, DateTime, Integer, func
from authlib.integrations.starlette_client import OAuth
from open_webui.env import ( from open_webui.env import (
DATA_DIR, DATA_DIR,
DATABASE_URL, DATABASE_URL,
ENV, ENV,
REDIS_URL, REDIS_URL,
REDIS_KEY_PREFIX,
REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_HOSTS,
REDIS_SENTINEL_PORT, REDIS_SENTINEL_PORT,
FRONTEND_BUILD_DIR, FRONTEND_BUILD_DIR,
@ -211,11 +214,16 @@ class PersistentConfig(Generic[T]):
class AppConfig: class AppConfig:
_state: dict[str, PersistentConfig] _state: dict[str, PersistentConfig]
_redis: Optional[redis.Redis] = None _redis: Optional[redis.Redis] = None
_redis_key_prefix: str
def __init__( def __init__(
self, redis_url: Optional[str] = None, redis_sentinels: Optional[list] = [] self,
redis_url: Optional[str] = None,
redis_sentinels: Optional[list] = [],
redis_key_prefix: str = "open-webui",
): ):
super().__setattr__("_state", {}) super().__setattr__("_state", {})
super().__setattr__("_redis_key_prefix", redis_key_prefix)
if redis_url: if redis_url:
super().__setattr__( super().__setattr__(
"_redis", "_redis",
@ -230,7 +238,7 @@ class AppConfig:
self._state[key].save() self._state[key].save()
if self._redis: if self._redis:
redis_key = f"open-webui:config:{key}" redis_key = f"{self._redis_key_prefix}:config:{key}"
self._redis.set(redis_key, json.dumps(self._state[key].value)) self._redis.set(redis_key, json.dumps(self._state[key].value))
def __getattr__(self, key): def __getattr__(self, key):
@ -239,7 +247,7 @@ class AppConfig:
# If Redis is available, check for an updated value # If Redis is available, check for an updated value
if self._redis: if self._redis:
redis_key = f"open-webui:config:{key}" redis_key = f"{self._redis_key_prefix}:config:{key}"
redis_value = self._redis.get(redis_key) redis_value = self._redis.get(redis_key)
if redis_value is not None: if redis_value is not None:
@ -431,6 +439,18 @@ OAUTH_SCOPES = PersistentConfig(
os.environ.get("OAUTH_SCOPES", "openid email profile"), os.environ.get("OAUTH_SCOPES", "openid email profile"),
) )
OAUTH_TIMEOUT = PersistentConfig(
"OAUTH_TIMEOUT",
"oauth.oidc.oauth_timeout",
os.environ.get("OAUTH_TIMEOUT", ""),
)
OAUTH_TOKEN_ENDPOINT_AUTH_METHOD = PersistentConfig(
"OAUTH_TOKEN_ENDPOINT_AUTH_METHOD",
"oauth.oidc.token_endpoint_auth_method",
os.environ.get("OAUTH_TOKEN_ENDPOINT_AUTH_METHOD", None),
)
OAUTH_CODE_CHALLENGE_METHOD = PersistentConfig( OAUTH_CODE_CHALLENGE_METHOD = PersistentConfig(
"OAUTH_CODE_CHALLENGE_METHOD", "OAUTH_CODE_CHALLENGE_METHOD",
"oauth.oidc.code_challenge_method", "oauth.oidc.code_challenge_method",
@ -534,13 +554,20 @@ def load_oauth_providers():
OAUTH_PROVIDERS.clear() OAUTH_PROVIDERS.clear()
if GOOGLE_CLIENT_ID.value and GOOGLE_CLIENT_SECRET.value: if GOOGLE_CLIENT_ID.value and GOOGLE_CLIENT_SECRET.value:
def google_oauth_register(client): def google_oauth_register(client: OAuth):
client.register( client.register(
name="google", name="google",
client_id=GOOGLE_CLIENT_ID.value, client_id=GOOGLE_CLIENT_ID.value,
client_secret=GOOGLE_CLIENT_SECRET.value, client_secret=GOOGLE_CLIENT_SECRET.value,
server_metadata_url="https://accounts.google.com/.well-known/openid-configuration", server_metadata_url="https://accounts.google.com/.well-known/openid-configuration",
client_kwargs={"scope": GOOGLE_OAUTH_SCOPE.value}, client_kwargs={
"scope": GOOGLE_OAUTH_SCOPE.value,
**(
{"timeout": int(OAUTH_TIMEOUT.value)}
if OAUTH_TIMEOUT.value
else {}
),
},
redirect_uri=GOOGLE_REDIRECT_URI.value, redirect_uri=GOOGLE_REDIRECT_URI.value,
) )
@ -555,7 +582,7 @@ def load_oauth_providers():
and MICROSOFT_CLIENT_TENANT_ID.value and MICROSOFT_CLIENT_TENANT_ID.value
): ):
def microsoft_oauth_register(client): def microsoft_oauth_register(client: OAuth):
client.register( client.register(
name="microsoft", name="microsoft",
client_id=MICROSOFT_CLIENT_ID.value, client_id=MICROSOFT_CLIENT_ID.value,
@ -563,6 +590,11 @@ def load_oauth_providers():
server_metadata_url=f"{MICROSOFT_CLIENT_LOGIN_BASE_URL.value}/{MICROSOFT_CLIENT_TENANT_ID.value}/v2.0/.well-known/openid-configuration?appid={MICROSOFT_CLIENT_ID.value}", server_metadata_url=f"{MICROSOFT_CLIENT_LOGIN_BASE_URL.value}/{MICROSOFT_CLIENT_TENANT_ID.value}/v2.0/.well-known/openid-configuration?appid={MICROSOFT_CLIENT_ID.value}",
client_kwargs={ client_kwargs={
"scope": MICROSOFT_OAUTH_SCOPE.value, "scope": MICROSOFT_OAUTH_SCOPE.value,
**(
{"timeout": int(OAUTH_TIMEOUT.value)}
if OAUTH_TIMEOUT.value
else {}
),
}, },
redirect_uri=MICROSOFT_REDIRECT_URI.value, redirect_uri=MICROSOFT_REDIRECT_URI.value,
) )
@ -575,7 +607,7 @@ def load_oauth_providers():
if GITHUB_CLIENT_ID.value and GITHUB_CLIENT_SECRET.value: if GITHUB_CLIENT_ID.value and GITHUB_CLIENT_SECRET.value:
def github_oauth_register(client): def github_oauth_register(client: OAuth):
client.register( client.register(
name="github", name="github",
client_id=GITHUB_CLIENT_ID.value, client_id=GITHUB_CLIENT_ID.value,
@ -584,7 +616,14 @@ def load_oauth_providers():
authorize_url="https://github.com/login/oauth/authorize", authorize_url="https://github.com/login/oauth/authorize",
api_base_url="https://api.github.com", api_base_url="https://api.github.com",
userinfo_endpoint="https://api.github.com/user", userinfo_endpoint="https://api.github.com/user",
client_kwargs={"scope": GITHUB_CLIENT_SCOPE.value}, client_kwargs={
"scope": GITHUB_CLIENT_SCOPE.value,
**(
{"timeout": int(OAUTH_TIMEOUT.value)}
if OAUTH_TIMEOUT.value
else {}
),
},
redirect_uri=GITHUB_CLIENT_REDIRECT_URI.value, redirect_uri=GITHUB_CLIENT_REDIRECT_URI.value,
) )
@ -600,9 +639,19 @@ def load_oauth_providers():
and OPENID_PROVIDER_URL.value and OPENID_PROVIDER_URL.value
): ):
def oidc_oauth_register(client): def oidc_oauth_register(client: OAuth):
client_kwargs = { client_kwargs = {
"scope": OAUTH_SCOPES.value, "scope": OAUTH_SCOPES.value,
**(
{
"token_endpoint_auth_method": OAUTH_TOKEN_ENDPOINT_AUTH_METHOD.value
}
if OAUTH_TOKEN_ENDPOINT_AUTH_METHOD.value
else {}
),
**(
{"timeout": int(OAUTH_TIMEOUT.value)} if OAUTH_TIMEOUT.value else {}
),
} }
if ( if (
@ -640,6 +689,17 @@ load_oauth_providers()
STATIC_DIR = Path(os.getenv("STATIC_DIR", OPEN_WEBUI_DIR / "static")).resolve() STATIC_DIR = Path(os.getenv("STATIC_DIR", OPEN_WEBUI_DIR / "static")).resolve()
try:
if STATIC_DIR.exists():
for item in STATIC_DIR.iterdir():
if item.is_file() or item.is_symlink():
try:
item.unlink()
except Exception as e:
pass
except Exception as e:
pass
for file_path in (FRONTEND_BUILD_DIR / "static").glob("**/*"): for file_path in (FRONTEND_BUILD_DIR / "static").glob("**/*"):
if file_path.is_file(): if file_path.is_file():
target_path = STATIC_DIR / file_path.relative_to( target_path = STATIC_DIR / file_path.relative_to(
@ -719,12 +779,6 @@ if CUSTOM_NAME:
pass pass
####################################
# LICENSE_KEY
####################################
LICENSE_KEY = os.environ.get("LICENSE_KEY", "")
#################################### ####################################
# STORAGE PROVIDER # STORAGE PROVIDER
#################################### ####################################
@ -895,6 +949,18 @@ except Exception:
pass pass
OPENAI_API_BASE_URL = "https://api.openai.com/v1" OPENAI_API_BASE_URL = "https://api.openai.com/v1"
####################################
# MODELS
####################################
ENABLE_BASE_MODELS_CACHE = PersistentConfig(
"ENABLE_BASE_MODELS_CACHE",
"models.base_models_cache",
os.environ.get("ENABLE_BASE_MODELS_CACHE", "False").lower() == "true",
)
#################################### ####################################
# TOOL_SERVERS # TOOL_SERVERS
#################################### ####################################
@ -1077,10 +1143,18 @@ USER_PERMISSIONS_CHAT_CONTROLS = (
os.environ.get("USER_PERMISSIONS_CHAT_CONTROLS", "True").lower() == "true" os.environ.get("USER_PERMISSIONS_CHAT_CONTROLS", "True").lower() == "true"
) )
USER_PERMISSIONS_CHAT_VALVES = (
os.environ.get("USER_PERMISSIONS_CHAT_VALVES", "True").lower() == "true"
)
USER_PERMISSIONS_CHAT_SYSTEM_PROMPT = ( USER_PERMISSIONS_CHAT_SYSTEM_PROMPT = (
os.environ.get("USER_PERMISSIONS_CHAT_SYSTEM_PROMPT", "True").lower() == "true" os.environ.get("USER_PERMISSIONS_CHAT_SYSTEM_PROMPT", "True").lower() == "true"
) )
USER_PERMISSIONS_CHAT_PARAMS = (
os.environ.get("USER_PERMISSIONS_CHAT_PARAMS", "True").lower() == "true"
)
USER_PERMISSIONS_CHAT_FILE_UPLOAD = ( USER_PERMISSIONS_CHAT_FILE_UPLOAD = (
os.environ.get("USER_PERMISSIONS_CHAT_FILE_UPLOAD", "True").lower() == "true" os.environ.get("USER_PERMISSIONS_CHAT_FILE_UPLOAD", "True").lower() == "true"
) )
@ -1166,7 +1240,9 @@ DEFAULT_USER_PERMISSIONS = {
}, },
"chat": { "chat": {
"controls": USER_PERMISSIONS_CHAT_CONTROLS, "controls": USER_PERMISSIONS_CHAT_CONTROLS,
"valves": USER_PERMISSIONS_CHAT_VALVES,
"system_prompt": USER_PERMISSIONS_CHAT_SYSTEM_PROMPT, "system_prompt": USER_PERMISSIONS_CHAT_SYSTEM_PROMPT,
"params": USER_PERMISSIONS_CHAT_PARAMS,
"file_upload": USER_PERMISSIONS_CHAT_FILE_UPLOAD, "file_upload": USER_PERMISSIONS_CHAT_FILE_UPLOAD,
"delete": USER_PERMISSIONS_CHAT_DELETE, "delete": USER_PERMISSIONS_CHAT_DELETE,
"edit": USER_PERMISSIONS_CHAT_EDIT, "edit": USER_PERMISSIONS_CHAT_EDIT,
@ -1794,11 +1870,12 @@ MILVUS_IVF_FLAT_NLIST = int(os.environ.get("MILVUS_IVF_FLAT_NLIST", "128"))
QDRANT_URI = os.environ.get("QDRANT_URI", None) QDRANT_URI = os.environ.get("QDRANT_URI", None)
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", None) QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", None)
QDRANT_ON_DISK = os.environ.get("QDRANT_ON_DISK", "false").lower() == "true" QDRANT_ON_DISK = os.environ.get("QDRANT_ON_DISK", "false").lower() == "true"
QDRANT_PREFER_GRPC = os.environ.get("QDRANT_PREFER_GRPC", "False").lower() == "true" QDRANT_PREFER_GRPC = os.environ.get("QDRANT_PREFER_GRPC", "false").lower() == "true"
QDRANT_GRPC_PORT = int(os.environ.get("QDRANT_GRPC_PORT", "6334")) QDRANT_GRPC_PORT = int(os.environ.get("QDRANT_GRPC_PORT", "6334"))
ENABLE_QDRANT_MULTITENANCY_MODE = ( ENABLE_QDRANT_MULTITENANCY_MODE = (
os.environ.get("ENABLE_QDRANT_MULTITENANCY_MODE", "false").lower() == "true" os.environ.get("ENABLE_QDRANT_MULTITENANCY_MODE", "true").lower() == "true"
) )
QDRANT_COLLECTION_PREFIX = os.environ.get("QDRANT_COLLECTION_PREFIX", "open-webui")
# OpenSearch # OpenSearch
OPENSEARCH_URI = os.environ.get("OPENSEARCH_URI", "https://localhost:9200") OPENSEARCH_URI = os.environ.get("OPENSEARCH_URI", "https://localhost:9200")
@ -1837,6 +1914,45 @@ if PGVECTOR_PGCRYPTO and not PGVECTOR_PGCRYPTO_KEY:
"PGVECTOR_PGCRYPTO is enabled but PGVECTOR_PGCRYPTO_KEY is not set. Please provide a valid key." "PGVECTOR_PGCRYPTO is enabled but PGVECTOR_PGCRYPTO_KEY is not set. Please provide a valid key."
) )
PGVECTOR_POOL_SIZE = os.environ.get("PGVECTOR_POOL_SIZE", None)
if PGVECTOR_POOL_SIZE != None:
try:
PGVECTOR_POOL_SIZE = int(PGVECTOR_POOL_SIZE)
except Exception:
PGVECTOR_POOL_SIZE = None
PGVECTOR_POOL_MAX_OVERFLOW = os.environ.get("PGVECTOR_POOL_MAX_OVERFLOW", 0)
if PGVECTOR_POOL_MAX_OVERFLOW == "":
PGVECTOR_POOL_MAX_OVERFLOW = 0
else:
try:
PGVECTOR_POOL_MAX_OVERFLOW = int(PGVECTOR_POOL_MAX_OVERFLOW)
except Exception:
PGVECTOR_POOL_MAX_OVERFLOW = 0
PGVECTOR_POOL_TIMEOUT = os.environ.get("PGVECTOR_POOL_TIMEOUT", 30)
if PGVECTOR_POOL_TIMEOUT == "":
PGVECTOR_POOL_TIMEOUT = 30
else:
try:
PGVECTOR_POOL_TIMEOUT = int(PGVECTOR_POOL_TIMEOUT)
except Exception:
PGVECTOR_POOL_TIMEOUT = 30
PGVECTOR_POOL_RECYCLE = os.environ.get("PGVECTOR_POOL_RECYCLE", 3600)
if PGVECTOR_POOL_RECYCLE == "":
PGVECTOR_POOL_RECYCLE = 3600
else:
try:
PGVECTOR_POOL_RECYCLE = int(PGVECTOR_POOL_RECYCLE)
except Exception:
PGVECTOR_POOL_RECYCLE = 3600
# Pinecone # Pinecone
PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY", None) PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY", None)
PINECONE_ENVIRONMENT = os.environ.get("PINECONE_ENVIRONMENT", None) PINECONE_ENVIRONMENT = os.environ.get("PINECONE_ENVIRONMENT", None)

View File

@ -7,6 +7,7 @@ import sys
import shutil import shutil
from uuid import uuid4 from uuid import uuid4
from pathlib import Path from pathlib import Path
from cryptography.hazmat.primitives import serialization
import markdown import markdown
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
@ -199,6 +200,7 @@ CHANGELOG = changelog_json
SAFE_MODE = os.environ.get("SAFE_MODE", "false").lower() == "true" SAFE_MODE = os.environ.get("SAFE_MODE", "false").lower() == "true"
#################################### ####################################
# ENABLE_FORWARD_USER_INFO_HEADERS # ENABLE_FORWARD_USER_INFO_HEADERS
#################################### ####################################
@ -266,21 +268,40 @@ else:
DATABASE_URL = os.environ.get("DATABASE_URL", f"sqlite:///{DATA_DIR}/webui.db") DATABASE_URL = os.environ.get("DATABASE_URL", f"sqlite:///{DATA_DIR}/webui.db")
DATABASE_TYPE = os.environ.get("DATABASE_TYPE")
DATABASE_USER = os.environ.get("DATABASE_USER")
DATABASE_PASSWORD = os.environ.get("DATABASE_PASSWORD")
DATABASE_CRED = ""
if DATABASE_USER:
DATABASE_CRED += f"{DATABASE_USER}"
if DATABASE_PASSWORD:
DATABASE_CRED += f":{DATABASE_PASSWORD}"
DB_VARS = {
"db_type": DATABASE_TYPE,
"db_cred": DATABASE_CRED,
"db_host": os.environ.get("DATABASE_HOST"),
"db_port": os.environ.get("DATABASE_PORT"),
"db_name": os.environ.get("DATABASE_NAME"),
}
if all(DB_VARS.values()):
DATABASE_URL = f"{DB_VARS['db_type']}://{DB_VARS['db_cred']}@{DB_VARS['db_host']}:{DB_VARS['db_port']}/{DB_VARS['db_name']}"
# Replace the postgres:// with postgresql:// # Replace the postgres:// with postgresql://
if "postgres://" in DATABASE_URL: if "postgres://" in DATABASE_URL:
DATABASE_URL = DATABASE_URL.replace("postgres://", "postgresql://") DATABASE_URL = DATABASE_URL.replace("postgres://", "postgresql://")
DATABASE_SCHEMA = os.environ.get("DATABASE_SCHEMA", None) DATABASE_SCHEMA = os.environ.get("DATABASE_SCHEMA", None)
DATABASE_POOL_SIZE = os.environ.get("DATABASE_POOL_SIZE", 0) DATABASE_POOL_SIZE = os.environ.get("DATABASE_POOL_SIZE", None)
if DATABASE_POOL_SIZE == "": if DATABASE_POOL_SIZE != None:
DATABASE_POOL_SIZE = 0
else:
try: try:
DATABASE_POOL_SIZE = int(DATABASE_POOL_SIZE) DATABASE_POOL_SIZE = int(DATABASE_POOL_SIZE)
except Exception: except Exception:
DATABASE_POOL_SIZE = 0 DATABASE_POOL_SIZE = None
DATABASE_POOL_MAX_OVERFLOW = os.environ.get("DATABASE_POOL_MAX_OVERFLOW", 0) DATABASE_POOL_MAX_OVERFLOW = os.environ.get("DATABASE_POOL_MAX_OVERFLOW", 0)
@ -325,9 +346,19 @@ ENABLE_REALTIME_CHAT_SAVE = (
#################################### ####################################
REDIS_URL = os.environ.get("REDIS_URL", "") REDIS_URL = os.environ.get("REDIS_URL", "")
REDIS_KEY_PREFIX = os.environ.get("REDIS_KEY_PREFIX", "open-webui")
REDIS_SENTINEL_HOSTS = os.environ.get("REDIS_SENTINEL_HOSTS", "") REDIS_SENTINEL_HOSTS = os.environ.get("REDIS_SENTINEL_HOSTS", "")
REDIS_SENTINEL_PORT = os.environ.get("REDIS_SENTINEL_PORT", "26379") REDIS_SENTINEL_PORT = os.environ.get("REDIS_SENTINEL_PORT", "26379")
# Maximum number of retries for Redis operations when using Sentinel fail-over
REDIS_SENTINEL_MAX_RETRY_COUNT = os.environ.get("REDIS_SENTINEL_MAX_RETRY_COUNT", "2")
try:
REDIS_SENTINEL_MAX_RETRY_COUNT = int(REDIS_SENTINEL_MAX_RETRY_COUNT)
if REDIS_SENTINEL_MAX_RETRY_COUNT < 1:
REDIS_SENTINEL_MAX_RETRY_COUNT = 2
except ValueError:
REDIS_SENTINEL_MAX_RETRY_COUNT = 2
#################################### ####################################
# UVICORN WORKERS # UVICORN WORKERS
#################################### ####################################
@ -347,6 +378,10 @@ except ValueError:
#################################### ####################################
WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true" WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true"
ENABLE_SIGNUP_PASSWORD_CONFIRMATION = (
os.environ.get("ENABLE_SIGNUP_PASSWORD_CONFIRMATION", "False").lower() == "true"
)
WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get( WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get(
"WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None "WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None
) )
@ -396,14 +431,71 @@ WEBUI_AUTH_COOKIE_SECURE = (
if WEBUI_AUTH and WEBUI_SECRET_KEY == "": if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND) raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND)
ENABLE_COMPRESSION_MIDDLEWARE = (
os.environ.get("ENABLE_COMPRESSION_MIDDLEWARE", "True").lower() == "true"
)
####################################
# LICENSE_KEY
####################################
LICENSE_KEY = os.environ.get("LICENSE_KEY", "")
LICENSE_BLOB = None
LICENSE_BLOB_PATH = os.environ.get("LICENSE_BLOB_PATH", DATA_DIR / "l.data")
if LICENSE_BLOB_PATH and os.path.exists(LICENSE_BLOB_PATH):
with open(LICENSE_BLOB_PATH, "rb") as f:
LICENSE_BLOB = f.read()
LICENSE_PUBLIC_KEY = os.environ.get("LICENSE_PUBLIC_KEY", "")
pk = None
if LICENSE_PUBLIC_KEY:
pk = serialization.load_pem_public_key(
f"""
-----BEGIN PUBLIC KEY-----
{LICENSE_PUBLIC_KEY}
-----END PUBLIC KEY-----
""".encode(
"utf-8"
)
)
####################################
# MODELS
####################################
MODELS_CACHE_TTL = os.environ.get("MODELS_CACHE_TTL", "1")
if MODELS_CACHE_TTL == "":
MODELS_CACHE_TTL = None
else:
try:
MODELS_CACHE_TTL = int(MODELS_CACHE_TTL)
except Exception:
MODELS_CACHE_TTL = 1
####################################
# WEBSOCKET SUPPORT
####################################
ENABLE_WEBSOCKET_SUPPORT = ( ENABLE_WEBSOCKET_SUPPORT = (
os.environ.get("ENABLE_WEBSOCKET_SUPPORT", "True").lower() == "true" os.environ.get("ENABLE_WEBSOCKET_SUPPORT", "True").lower() == "true"
) )
WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "") WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "")
WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL) WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL)
WEBSOCKET_REDIS_LOCK_TIMEOUT = os.environ.get("WEBSOCKET_REDIS_LOCK_TIMEOUT", 60)
websocket_redis_lock_timeout = os.environ.get("WEBSOCKET_REDIS_LOCK_TIMEOUT", "60")
try:
WEBSOCKET_REDIS_LOCK_TIMEOUT = int(websocket_redis_lock_timeout)
except ValueError:
WEBSOCKET_REDIS_LOCK_TIMEOUT = 60
WEBSOCKET_SENTINEL_HOSTS = os.environ.get("WEBSOCKET_SENTINEL_HOSTS", "") WEBSOCKET_SENTINEL_HOSTS = os.environ.get("WEBSOCKET_SENTINEL_HOSTS", "")
@ -506,11 +598,14 @@ else:
# OFFLINE_MODE # OFFLINE_MODE
#################################### ####################################
ENABLE_VERSION_UPDATE_CHECK = (
os.environ.get("ENABLE_VERSION_UPDATE_CHECK", "true").lower() == "true"
)
OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true" OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true"
if OFFLINE_MODE: if OFFLINE_MODE:
os.environ["HF_HUB_OFFLINE"] = "1" os.environ["HF_HUB_OFFLINE"] = "1"
ENABLE_VERSION_UPDATE_CHECK = False
#################################### ####################################
# AUDIT LOGGING # AUDIT LOGGING
@ -519,6 +614,14 @@ if OFFLINE_MODE:
AUDIT_LOGS_FILE_PATH = f"{DATA_DIR}/audit.log" AUDIT_LOGS_FILE_PATH = f"{DATA_DIR}/audit.log"
# Maximum size of a file before rotating into a new log file # Maximum size of a file before rotating into a new log file
AUDIT_LOG_FILE_ROTATION_SIZE = os.getenv("AUDIT_LOG_FILE_ROTATION_SIZE", "10MB") AUDIT_LOG_FILE_ROTATION_SIZE = os.getenv("AUDIT_LOG_FILE_ROTATION_SIZE", "10MB")
# Comma separated list of logger names to use for audit logging
# Default is "uvicorn.access" which is the access log for Uvicorn
# You can add more logger names to this list if you want to capture more logs
AUDIT_UVICORN_LOGGER_NAMES = os.getenv(
"AUDIT_UVICORN_LOGGER_NAMES", "uvicorn.access"
).split(",")
# METADATA | REQUEST | REQUEST_RESPONSE # METADATA | REQUEST | REQUEST_RESPONSE
AUDIT_LOG_LEVEL = os.getenv("AUDIT_LOG_LEVEL", "NONE").upper() AUDIT_LOG_LEVEL = os.getenv("AUDIT_LOG_LEVEL", "NONE").upper()
try: try:
@ -543,6 +646,9 @@ ENABLE_OTEL_METRICS = os.environ.get("ENABLE_OTEL_METRICS", "False").lower() ==
OTEL_EXPORTER_OTLP_ENDPOINT = os.environ.get( OTEL_EXPORTER_OTLP_ENDPOINT = os.environ.get(
"OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317" "OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317"
) )
OTEL_EXPORTER_OTLP_INSECURE = (
os.environ.get("OTEL_EXPORTER_OTLP_INSECURE", "False").lower() == "true"
)
OTEL_SERVICE_NAME = os.environ.get("OTEL_SERVICE_NAME", "open-webui") OTEL_SERVICE_NAME = os.environ.get("OTEL_SERVICE_NAME", "open-webui")
OTEL_RESOURCE_ATTRIBUTES = os.environ.get( OTEL_RESOURCE_ATTRIBUTES = os.environ.get(
"OTEL_RESOURCE_ATTRIBUTES", "" "OTEL_RESOURCE_ATTRIBUTES", ""
@ -550,6 +656,14 @@ OTEL_RESOURCE_ATTRIBUTES = os.environ.get(
OTEL_TRACES_SAMPLER = os.environ.get( OTEL_TRACES_SAMPLER = os.environ.get(
"OTEL_TRACES_SAMPLER", "parentbased_always_on" "OTEL_TRACES_SAMPLER", "parentbased_always_on"
).lower() ).lower()
OTEL_BASIC_AUTH_USERNAME = os.environ.get("OTEL_BASIC_AUTH_USERNAME", "")
OTEL_BASIC_AUTH_PASSWORD = os.environ.get("OTEL_BASIC_AUTH_PASSWORD", "")
OTEL_OTLP_SPAN_EXPORTER = os.environ.get(
"OTEL_OTLP_SPAN_EXPORTER", "grpc"
).lower() # grpc or http
#################################### ####################################
# TOOLS/FUNCTIONS PIP OPTIONS # TOOLS/FUNCTIONS PIP OPTIONS

View File

@ -62,6 +62,9 @@ def handle_peewee_migration(DATABASE_URL):
except Exception as e: except Exception as e:
log.error(f"Failed to initialize the database connection: {e}") log.error(f"Failed to initialize the database connection: {e}")
log.warning(
"Hint: If your database password contains special characters, you may need to URL-encode it."
)
raise raise
finally: finally:
# Properly closing the database connection # Properly closing the database connection
@ -81,6 +84,7 @@ if "sqlite" in SQLALCHEMY_DATABASE_URL:
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False} SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
) )
else: else:
if isinstance(DATABASE_POOL_SIZE, int):
if DATABASE_POOL_SIZE > 0: if DATABASE_POOL_SIZE > 0:
engine = create_engine( engine = create_engine(
SQLALCHEMY_DATABASE_URL, SQLALCHEMY_DATABASE_URL,
@ -95,6 +99,8 @@ else:
engine = create_engine( engine = create_engine(
SQLALCHEMY_DATABASE_URL, pool_pre_ping=True, poolclass=NullPool SQLALCHEMY_DATABASE_URL, pool_pre_ping=True, poolclass=NullPool
) )
else:
engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True)
SessionLocal = sessionmaker( SessionLocal = sessionmaker(

View File

@ -36,7 +36,6 @@ from fastapi import (
applications, applications,
BackgroundTasks, BackgroundTasks,
) )
from fastapi.openapi.docs import get_swagger_ui_html from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
@ -49,6 +48,7 @@ from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.sessions import SessionMiddleware from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import Response, StreamingResponse from starlette.responses import Response, StreamingResponse
from starlette.datastructures import Headers
from open_webui.utils import logger from open_webui.utils import logger
@ -89,6 +89,7 @@ from open_webui.routers import (
from open_webui.routers.retrieval import ( from open_webui.routers.retrieval import (
get_embedding_function, get_embedding_function,
get_reranking_function,
get_ef, get_ef,
get_rf, get_rf,
) )
@ -101,7 +102,6 @@ from open_webui.models.users import UserModel, Users
from open_webui.models.chats import Chats from open_webui.models.chats import Chats
from open_webui.config import ( from open_webui.config import (
LICENSE_KEY,
# Ollama # Ollama
ENABLE_OLLAMA_API, ENABLE_OLLAMA_API,
OLLAMA_BASE_URLS, OLLAMA_BASE_URLS,
@ -116,6 +116,8 @@ from open_webui.config import (
OPENAI_API_CONFIGS, OPENAI_API_CONFIGS,
# Direct Connections # Direct Connections
ENABLE_DIRECT_CONNECTIONS, ENABLE_DIRECT_CONNECTIONS,
# Model list
ENABLE_BASE_MODELS_CACHE,
# Thread pool size for FastAPI/AnyIO # Thread pool size for FastAPI/AnyIO
THREAD_POOL_SIZE, THREAD_POOL_SIZE,
# Tool Server Configs # Tool Server Configs
@ -392,10 +394,12 @@ from open_webui.config import (
reset_config, reset_config,
) )
from open_webui.env import ( from open_webui.env import (
LICENSE_KEY,
AUDIT_EXCLUDED_PATHS, AUDIT_EXCLUDED_PATHS,
AUDIT_LOG_LEVEL, AUDIT_LOG_LEVEL,
CHANGELOG, CHANGELOG,
REDIS_URL, REDIS_URL,
REDIS_KEY_PREFIX,
REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_HOSTS,
REDIS_SENTINEL_PORT, REDIS_SENTINEL_PORT,
GLOBAL_LOG_LEVEL, GLOBAL_LOG_LEVEL,
@ -408,13 +412,15 @@ from open_webui.env import (
WEBUI_SECRET_KEY, WEBUI_SECRET_KEY,
WEBUI_SESSION_COOKIE_SAME_SITE, WEBUI_SESSION_COOKIE_SAME_SITE,
WEBUI_SESSION_COOKIE_SECURE, WEBUI_SESSION_COOKIE_SECURE,
ENABLE_SIGNUP_PASSWORD_CONFIRMATION,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER, WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
WEBUI_AUTH_TRUSTED_NAME_HEADER, WEBUI_AUTH_TRUSTED_NAME_HEADER,
WEBUI_AUTH_SIGNOUT_REDIRECT_URL, WEBUI_AUTH_SIGNOUT_REDIRECT_URL,
ENABLE_COMPRESSION_MIDDLEWARE,
ENABLE_WEBSOCKET_SUPPORT, ENABLE_WEBSOCKET_SUPPORT,
BYPASS_MODEL_ACCESS_CONTROL, BYPASS_MODEL_ACCESS_CONTROL,
RESET_CONFIG_ON_START, RESET_CONFIG_ON_START,
OFFLINE_MODE, ENABLE_VERSION_UPDATE_CHECK,
ENABLE_OTEL, ENABLE_OTEL,
EXTERNAL_PWA_MANIFEST_URL, EXTERNAL_PWA_MANIFEST_URL,
AIOHTTP_CLIENT_SESSION_SSL, AIOHTTP_CLIENT_SESSION_SSL,
@ -449,7 +455,7 @@ from open_webui.utils.redis import get_redis_connection
from open_webui.tasks import ( from open_webui.tasks import (
redis_task_command_listener, redis_task_command_listener,
list_task_ids_by_chat_id, list_task_ids_by_item_id,
stop_task, stop_task,
list_tasks, list_tasks,
) # Import from tasks.py ) # Import from tasks.py
@ -533,6 +539,27 @@ async def lifespan(app: FastAPI):
asyncio.create_task(periodic_usage_pool_cleanup()) asyncio.create_task(periodic_usage_pool_cleanup())
if app.state.config.ENABLE_BASE_MODELS_CACHE:
await get_all_models(
Request(
# Creating a mock request object to pass to get_all_models
{
"type": "http",
"asgi.version": "3.0",
"asgi.spec_version": "2.0",
"method": "GET",
"path": "/internal",
"query_string": b"",
"headers": Headers({}).raw,
"client": ("127.0.0.1", 12345),
"server": ("127.0.0.1", 80),
"scheme": "http",
"app": app,
}
),
None,
)
yield yield
if hasattr(app.state, "redis_task_command_listener"): if hasattr(app.state, "redis_task_command_listener"):
@ -553,6 +580,7 @@ app.state.instance_id = None
app.state.config = AppConfig( app.state.config = AppConfig(
redis_url=REDIS_URL, redis_url=REDIS_URL,
redis_sentinels=get_sentinels_from_env(REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT), redis_sentinels=get_sentinels_from_env(REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT),
redis_key_prefix=REDIS_KEY_PREFIX,
) )
app.state.redis = None app.state.redis = None
@ -615,6 +643,15 @@ app.state.TOOL_SERVERS = []
app.state.config.ENABLE_DIRECT_CONNECTIONS = ENABLE_DIRECT_CONNECTIONS app.state.config.ENABLE_DIRECT_CONNECTIONS = ENABLE_DIRECT_CONNECTIONS
########################################
#
# MODELS
#
########################################
app.state.config.ENABLE_BASE_MODELS_CACHE = ENABLE_BASE_MODELS_CACHE
app.state.BASE_MODELS = []
######################################## ########################################
# #
# WEBUI # WEBUI
@ -843,6 +880,7 @@ app.state.config.FIRECRAWL_API_KEY = FIRECRAWL_API_KEY
app.state.config.TAVILY_EXTRACT_DEPTH = TAVILY_EXTRACT_DEPTH app.state.config.TAVILY_EXTRACT_DEPTH = TAVILY_EXTRACT_DEPTH
app.state.EMBEDDING_FUNCTION = None app.state.EMBEDDING_FUNCTION = None
app.state.RERANKING_FUNCTION = None
app.state.ef = None app.state.ef = None
app.state.rf = None app.state.rf = None
@ -871,8 +909,8 @@ except Exception as e:
app.state.EMBEDDING_FUNCTION = get_embedding_function( app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.config.RAG_EMBEDDING_ENGINE, app.state.config.RAG_EMBEDDING_ENGINE,
app.state.config.RAG_EMBEDDING_MODEL, app.state.config.RAG_EMBEDDING_MODEL,
app.state.ef, embedding_function=app.state.ef,
( url=(
app.state.config.RAG_OPENAI_API_BASE_URL app.state.config.RAG_OPENAI_API_BASE_URL
if app.state.config.RAG_EMBEDDING_ENGINE == "openai" if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else ( else (
@ -881,7 +919,7 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
else app.state.config.RAG_AZURE_OPENAI_BASE_URL else app.state.config.RAG_AZURE_OPENAI_BASE_URL
) )
), ),
( key=(
app.state.config.RAG_OPENAI_API_KEY app.state.config.RAG_OPENAI_API_KEY
if app.state.config.RAG_EMBEDDING_ENGINE == "openai" if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else ( else (
@ -890,7 +928,7 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
else app.state.config.RAG_AZURE_OPENAI_API_KEY else app.state.config.RAG_AZURE_OPENAI_API_KEY
) )
), ),
app.state.config.RAG_EMBEDDING_BATCH_SIZE, embedding_batch_size=app.state.config.RAG_EMBEDDING_BATCH_SIZE,
azure_api_version=( azure_api_version=(
app.state.config.RAG_AZURE_OPENAI_API_VERSION app.state.config.RAG_AZURE_OPENAI_API_VERSION
if app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai" if app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
@ -898,6 +936,12 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
), ),
) )
app.state.RERANKING_FUNCTION = get_reranking_function(
app.state.config.RAG_RERANKING_ENGINE,
app.state.config.RAG_RERANKING_MODEL,
reranking_function=app.state.rf,
)
######################################## ########################################
# #
# CODE EXECUTION # CODE EXECUTION
@ -1072,7 +1116,9 @@ class RedirectMiddleware(BaseHTTPMiddleware):
# Add the middleware to the app # Add the middleware to the app
if ENABLE_COMPRESSION_MIDDLEWARE:
app.add_middleware(CompressMiddleware) app.add_middleware(CompressMiddleware)
app.add_middleware(RedirectMiddleware) app.add_middleware(RedirectMiddleware)
app.add_middleware(SecurityHeadersMiddleware) app.add_middleware(SecurityHeadersMiddleware)
@ -1188,7 +1234,9 @@ if audit_level != AuditLevel.NONE:
@app.get("/api/models") @app.get("/api/models")
async def get_models(request: Request, user=Depends(get_verified_user)): async def get_models(
request: Request, refresh: bool = False, user=Depends(get_verified_user)
):
def get_filtered_models(models, user): def get_filtered_models(models, user):
filtered_models = [] filtered_models = []
for model in models: for model in models:
@ -1212,7 +1260,7 @@ async def get_models(request: Request, user=Depends(get_verified_user)):
return filtered_models return filtered_models
all_models = await get_all_models(request, user=user) all_models = await get_all_models(request, refresh=refresh, user=user)
models = [] models = []
for model in all_models: for model in all_models:
@ -1249,7 +1297,7 @@ async def get_models(request: Request, user=Depends(get_verified_user)):
models = get_filtered_models(models, user) models = get_filtered_models(models, user)
log.debug( log.debug(
f"/api/models returned filtered models accessible to the user: {json.dumps([model['id'] for model in models])}" f"/api/models returned filtered models accessible to the user: {json.dumps([model.get('id') for model in models])}"
) )
return {"data": models} return {"data": models}
@ -1357,7 +1405,6 @@ async def chat_completion(
form_data, metadata, events = await process_chat_payload( form_data, metadata, events = await process_chat_payload(
request, form_data, user, metadata, model request, form_data, user, metadata, model
) )
except Exception as e: except Exception as e:
log.debug(f"Error processing chat payload: {e}") log.debug(f"Error processing chat payload: {e}")
if metadata.get("chat_id") and metadata.get("message_id"): if metadata.get("chat_id") and metadata.get("message_id"):
@ -1377,6 +1424,14 @@ async def chat_completion(
try: try:
response = await chat_completion_handler(request, form_data, user) response = await chat_completion_handler(request, form_data, user)
if metadata.get("chat_id") and metadata.get("message_id"):
Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"],
metadata["message_id"],
{
"model": model_id,
},
)
return await process_chat_response( return await process_chat_response(
request, response, form_data, user, metadata, model, events, tasks request, response, form_data, user, metadata, model, events, tasks
@ -1447,7 +1502,7 @@ async def stop_task_endpoint(
request: Request, task_id: str, user=Depends(get_verified_user) request: Request, task_id: str, user=Depends(get_verified_user)
): ):
try: try:
result = await stop_task(request, task_id) result = await stop_task(request.app.state.redis, task_id)
return result return result
except ValueError as e: except ValueError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
@ -1455,7 +1510,7 @@ async def stop_task_endpoint(
@app.get("/api/tasks") @app.get("/api/tasks")
async def list_tasks_endpoint(request: Request, user=Depends(get_verified_user)): async def list_tasks_endpoint(request: Request, user=Depends(get_verified_user)):
return {"tasks": await list_tasks(request)} return {"tasks": await list_tasks(request.app.state.redis)}
@app.get("/api/tasks/chat/{chat_id}") @app.get("/api/tasks/chat/{chat_id}")
@ -1466,9 +1521,9 @@ async def list_tasks_by_chat_id_endpoint(
if chat is None or chat.user_id != user.id: if chat is None or chat.user_id != user.id:
return {"task_ids": []} return {"task_ids": []}
task_ids = await list_task_ids_by_chat_id(request, chat_id) task_ids = await list_task_ids_by_item_id(request.app.state.redis, chat_id)
print(f"Task IDs for chat {chat_id}: {task_ids}") log.debug(f"Task IDs for chat {chat_id}: {task_ids}")
return {"task_ids": task_ids} return {"task_ids": task_ids}
@ -1516,11 +1571,13 @@ async def get_app_config(request: Request):
"features": { "features": {
"auth": WEBUI_AUTH, "auth": WEBUI_AUTH,
"auth_trusted_header": bool(app.state.AUTH_TRUSTED_EMAIL_HEADER), "auth_trusted_header": bool(app.state.AUTH_TRUSTED_EMAIL_HEADER),
"enable_signup_password_confirmation": ENABLE_SIGNUP_PASSWORD_CONFIRMATION,
"enable_ldap": app.state.config.ENABLE_LDAP, "enable_ldap": app.state.config.ENABLE_LDAP,
"enable_api_key": app.state.config.ENABLE_API_KEY, "enable_api_key": app.state.config.ENABLE_API_KEY,
"enable_signup": app.state.config.ENABLE_SIGNUP, "enable_signup": app.state.config.ENABLE_SIGNUP,
"enable_login_form": app.state.config.ENABLE_LOGIN_FORM, "enable_login_form": app.state.config.ENABLE_LOGIN_FORM,
"enable_websocket": ENABLE_WEBSOCKET_SUPPORT, "enable_websocket": ENABLE_WEBSOCKET_SUPPORT,
"enable_version_update_check": ENABLE_VERSION_UPDATE_CHECK,
**( **(
{ {
"enable_direct_connections": app.state.config.ENABLE_DIRECT_CONNECTIONS, "enable_direct_connections": app.state.config.ENABLE_DIRECT_CONNECTIONS,
@ -1593,8 +1650,23 @@ async def get_app_config(request: Request):
else {} else {}
), ),
} }
if user is not None if user is not None and (user.role in ["admin", "user"])
else {
**(
{
"metadata": {
"login_footer": app.state.LICENSE_METADATA.get(
"login_footer", ""
),
"auth_logo_position": app.state.LICENSE_METADATA.get(
"auth_logo_position", ""
),
}
}
if app.state.LICENSE_METADATA
else {} else {}
)
}
), ),
} }
@ -1626,9 +1698,9 @@ async def get_app_version():
@app.get("/api/version/updates") @app.get("/api/version/updates")
async def get_app_latest_release_version(user=Depends(get_verified_user)): async def get_app_latest_release_version(user=Depends(get_verified_user)):
if OFFLINE_MODE: if not ENABLE_VERSION_UPDATE_CHECK:
log.debug( log.debug(
f"Offline mode is enabled, returning current version as latest version" f"Version update check is disabled, returning current version as latest version"
) )
return {"current": VERSION, "latest": VERSION} return {"current": VERSION, "latest": VERSION}
try: try:
@ -1709,7 +1781,6 @@ async def get_manifest_json():
"start_url": "/", "start_url": "/",
"display": "standalone", "display": "standalone",
"background_color": "#343541", "background_color": "#343541",
"orientation": "any",
"icons": [ "icons": [
{ {
"src": "/static/logo.png", "src": "/static/logo.png",

View File

@ -0,0 +1,23 @@
"""Update folder table data
Revision ID: d31026856c01
Revises: 9f0c9cd09105
Create Date: 2025-07-13 03:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
revision = "d31026856c01"
down_revision = "9f0c9cd09105"
branch_labels = None
depends_on = None
def upgrade():
op.add_column("folder", sa.Column("data", sa.JSON(), nullable=True))
def downgrade():
op.drop_column("folder", "data")

View File

@ -12,6 +12,7 @@ from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
from sqlalchemy import or_, func, select, and_, text from sqlalchemy import or_, func, select, and_, text
from sqlalchemy.sql import exists from sqlalchemy.sql import exists
from sqlalchemy.sql.expression import bindparam
#################### ####################
# Chat DB Schema # Chat DB Schema
@ -66,12 +67,14 @@ class ChatModel(BaseModel):
class ChatForm(BaseModel): class ChatForm(BaseModel):
chat: dict chat: dict
folder_id: Optional[str] = None
class ChatImportForm(ChatForm): class ChatImportForm(ChatForm):
meta: Optional[dict] = {} meta: Optional[dict] = {}
pinned: Optional[bool] = False pinned: Optional[bool] = False
folder_id: Optional[str] = None created_at: Optional[int] = None
updated_at: Optional[int] = None
class ChatTitleMessagesForm(BaseModel): class ChatTitleMessagesForm(BaseModel):
@ -118,6 +121,7 @@ class ChatTable:
else "New Chat" else "New Chat"
), ),
"chat": form_data.chat, "chat": form_data.chat,
"folder_id": form_data.folder_id,
"created_at": int(time.time()), "created_at": int(time.time()),
"updated_at": int(time.time()), "updated_at": int(time.time()),
} }
@ -147,8 +151,16 @@ class ChatTable:
"meta": form_data.meta, "meta": form_data.meta,
"pinned": form_data.pinned, "pinned": form_data.pinned,
"folder_id": form_data.folder_id, "folder_id": form_data.folder_id,
"created_at": int(time.time()), "created_at": (
"updated_at": int(time.time()), form_data.created_at
if form_data.created_at
else int(time.time())
),
"updated_at": (
form_data.updated_at
if form_data.updated_at
else int(time.time())
),
} }
) )
@ -232,6 +244,10 @@ class ChatTable:
if chat is None: if chat is None:
return None return None
# Sanitize message content for null characters before upserting
if isinstance(message.get("content"), str):
message["content"] = message["content"].replace("\x00", "")
chat = chat.chat chat = chat.chat
history = chat.get("history", {}) history = chat.get("history", {})
@ -580,7 +596,7 @@ class ChatTable:
""" """
Filters chats based on a search query using Python, allowing pagination using skip and limit. Filters chats based on a search query using Python, allowing pagination using skip and limit.
""" """
search_text = search_text.lower().strip() search_text = search_text.replace("\u0000", "").lower().strip()
if not search_text: if not search_text:
return self.get_chat_list_by_user_id( return self.get_chat_list_by_user_id(
@ -614,21 +630,18 @@ class ChatTable:
dialect_name = db.bind.dialect.name dialect_name = db.bind.dialect.name
if dialect_name == "sqlite": if dialect_name == "sqlite":
# SQLite case: using JSON1 extension for JSON searching # SQLite case: using JSON1 extension for JSON searching
sqlite_content_sql = (
"EXISTS ("
" SELECT 1 "
" FROM json_each(Chat.chat, '$.messages') AS message "
" WHERE LOWER(message.value->>'content') LIKE '%' || :content_key || '%'"
")"
)
sqlite_content_clause = text(sqlite_content_sql)
query = query.filter( query = query.filter(
( or_(
Chat.title.ilike( Chat.title.ilike(bindparam("title_key")), sqlite_content_clause
f"%{search_text}%" ).params(title_key=f"%{search_text}%", content_key=search_text)
) # Case-insensitive search in title
| text(
"""
EXISTS (
SELECT 1
FROM json_each(Chat.chat, '$.messages') AS message
WHERE LOWER(message.value->>'content') LIKE '%' || :search_text || '%'
)
"""
)
).params(search_text=search_text)
) )
# Check if there are any tags to filter, it should have all the tags # Check if there are any tags to filter, it should have all the tags
@ -663,21 +676,19 @@ class ChatTable:
elif dialect_name == "postgresql": elif dialect_name == "postgresql":
# PostgreSQL relies on proper JSON query for search # PostgreSQL relies on proper JSON query for search
postgres_content_sql = (
"EXISTS ("
" SELECT 1 "
" FROM json_array_elements(Chat.chat->'messages') AS message "
" WHERE LOWER(message->>'content') LIKE '%' || :content_key || '%'"
")"
)
postgres_content_clause = text(postgres_content_sql)
query = query.filter( query = query.filter(
( or_(
Chat.title.ilike( Chat.title.ilike(bindparam("title_key")),
f"%{search_text}%" postgres_content_clause,
) # Case-insensitive search in title ).params(title_key=f"%{search_text}%", content_key=search_text)
| text(
"""
EXISTS (
SELECT 1
FROM json_array_elements(Chat.chat->'messages') AS message
WHERE LOWER(message->>'content') LIKE '%' || :search_text || '%'
)
"""
)
).params(search_text=search_text)
) )
# Check if there are any tags to filter, it should have all the tags # Check if there are any tags to filter, it should have all the tags

View File

@ -29,6 +29,7 @@ class Folder(Base):
name = Column(Text) name = Column(Text)
items = Column(JSON, nullable=True) items = Column(JSON, nullable=True)
meta = Column(JSON, nullable=True) meta = Column(JSON, nullable=True)
data = Column(JSON, nullable=True)
is_expanded = Column(Boolean, default=False) is_expanded = Column(Boolean, default=False)
created_at = Column(BigInteger) created_at = Column(BigInteger)
updated_at = Column(BigInteger) updated_at = Column(BigInteger)
@ -41,6 +42,7 @@ class FolderModel(BaseModel):
name: str name: str
items: Optional[dict] = None items: Optional[dict] = None
meta: Optional[dict] = None meta: Optional[dict] = None
data: Optional[dict] = None
is_expanded: bool = False is_expanded: bool = False
created_at: int created_at: int
updated_at: int updated_at: int
@ -55,12 +57,13 @@ class FolderModel(BaseModel):
class FolderForm(BaseModel): class FolderForm(BaseModel):
name: str name: str
data: Optional[dict] = None
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
class FolderTable: class FolderTable:
def insert_new_folder( def insert_new_folder(
self, user_id: str, name: str, parent_id: Optional[str] = None self, user_id: str, form_data: FolderForm, parent_id: Optional[str] = None
) -> Optional[FolderModel]: ) -> Optional[FolderModel]:
with get_db() as db: with get_db() as db:
id = str(uuid.uuid4()) id = str(uuid.uuid4())
@ -68,7 +71,7 @@ class FolderTable:
**{ **{
"id": id, "id": id,
"user_id": user_id, "user_id": user_id,
"name": name, **(form_data.model_dump(exclude_unset=True) or {}),
"parent_id": parent_id, "parent_id": parent_id,
"created_at": int(time.time()), "created_at": int(time.time()),
"updated_at": int(time.time()), "updated_at": int(time.time()),
@ -187,8 +190,8 @@ class FolderTable:
log.error(f"update_folder: {e}") log.error(f"update_folder: {e}")
return return
def update_folder_name_by_id_and_user_id( def update_folder_by_id_and_user_id(
self, id: str, user_id: str, name: str self, id: str, user_id: str, form_data: FolderForm
) -> Optional[FolderModel]: ) -> Optional[FolderModel]:
try: try:
with get_db() as db: with get_db() as db:
@ -197,16 +200,28 @@ class FolderTable:
if not folder: if not folder:
return None return None
form_data = form_data.model_dump(exclude_unset=True)
existing_folder = ( existing_folder = (
db.query(Folder) db.query(Folder)
.filter_by(name=name, parent_id=folder.parent_id, user_id=user_id) .filter_by(
name=form_data.get("name"),
parent_id=folder.parent_id,
user_id=user_id,
)
.first() .first()
) )
if existing_folder: if existing_folder and existing_folder.id != id:
return None return None
folder.name = name folder.name = form_data.get("name", folder.name)
if "data" in form_data:
folder.data = {
**(folder.data or {}),
**form_data["data"],
}
folder.updated_at = int(time.time()) folder.updated_at = int(time.time())
db.commit() db.commit()

View File

@ -83,10 +83,14 @@ class GroupForm(BaseModel):
permissions: Optional[dict] = None permissions: Optional[dict] = None
class GroupUpdateForm(GroupForm): class UserIdsForm(BaseModel):
user_ids: Optional[list[str]] = None user_ids: Optional[list[str]] = None
class GroupUpdateForm(GroupForm, UserIdsForm):
pass
class GroupTable: class GroupTable:
def insert_new_group( def insert_new_group(
self, user_id: str, form_data: GroupForm self, user_id: str, form_data: GroupForm
@ -275,5 +279,53 @@ class GroupTable:
log.exception(e) log.exception(e)
return False return False
def add_users_to_group(
self, id: str, user_ids: Optional[list[str]] = None
) -> Optional[GroupModel]:
try:
with get_db() as db:
group = db.query(Group).filter_by(id=id).first()
if not group:
return None
if not group.user_ids:
group.user_ids = []
for user_id in user_ids:
if user_id not in group.user_ids:
group.user_ids.append(user_id)
group.updated_at = int(time.time())
db.commit()
db.refresh(group)
return GroupModel.model_validate(group)
except Exception as e:
log.exception(e)
return None
def remove_users_from_group(
self, id: str, user_ids: Optional[list[str]] = None
) -> Optional[GroupModel]:
try:
with get_db() as db:
group = db.query(Group).filter_by(id=id).first()
if not group:
return None
if not group.user_ids:
return GroupModel.model_validate(group)
for user_id in user_ids:
if user_id in group.user_ids:
group.user_ids.remove(user_id)
group.updated_at = int(time.time())
db.commit()
db.refresh(group)
return GroupModel.model_validate(group)
except Exception as e:
log.exception(e)
return None
Groups = GroupTable() Groups = GroupTable()

View File

@ -71,9 +71,13 @@ class MemoriesTable:
) -> Optional[MemoryModel]: ) -> Optional[MemoryModel]:
with get_db() as db: with get_db() as db:
try: try:
db.query(Memory).filter_by(id=id, user_id=user_id).update( memory = db.get(Memory, id)
{"content": content, "updated_at": int(time.time())} if not memory or memory.user_id != user_id:
) return None
memory.content = content
memory.updated_at = int(time.time())
db.commit() db.commit()
return self.get_memory_by_id(id) return self.get_memory_by_id(id)
except Exception: except Exception:
@ -127,7 +131,12 @@ class MemoriesTable:
def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool: def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool:
with get_db() as db: with get_db() as db:
try: try:
db.query(Memory).filter_by(id=id, user_id=user_id).delete() memory = db.get(Memory, id)
if not memory or memory.user_id != user_id:
return None
# Delete the memory
db.delete(memory)
db.commit() db.commit()
return True return True

View File

@ -269,5 +269,49 @@ class ModelsTable:
except Exception: except Exception:
return False return False
def sync_models(self, user_id: str, models: list[ModelModel]) -> list[ModelModel]:
try:
with get_db() as db:
# Get existing models
existing_models = db.query(Model).all()
existing_ids = {model.id for model in existing_models}
# Prepare a set of new model IDs
new_model_ids = {model.id for model in models}
# Update or insert models
for model in models:
if model.id in existing_ids:
db.query(Model).filter_by(id=model.id).update(
{
**model.model_dump(),
"user_id": user_id,
"updated_at": int(time.time()),
}
)
else:
new_model = Model(
**{
**model.model_dump(),
"user_id": user_id,
"updated_at": int(time.time()),
}
)
db.add(new_model)
# Remove models that are no longer present
for model in existing_models:
if model.id not in new_model_ids:
db.delete(model)
db.commit()
return [
ModelModel.model_validate(model) for model in db.query(Model).all()
]
except Exception as e:
log.exception(f"Error syncing models for user {user_id}: {e}")
return []
Models = ModelsTable() Models = ModelsTable()

View File

@ -62,6 +62,13 @@ class NoteForm(BaseModel):
access_control: Optional[dict] = None access_control: Optional[dict] = None
class NoteUpdateForm(BaseModel):
title: Optional[str] = None
data: Optional[dict] = None
meta: Optional[dict] = None
access_control: Optional[dict] = None
class NoteUserResponse(NoteModel): class NoteUserResponse(NoteModel):
user: Optional[UserResponse] = None user: Optional[UserResponse] = None
@ -110,16 +117,26 @@ class NoteTable:
note = db.query(Note).filter(Note.id == id).first() note = db.query(Note).filter(Note.id == id).first()
return NoteModel.model_validate(note) if note else None return NoteModel.model_validate(note) if note else None
def update_note_by_id(self, id: str, form_data: NoteForm) -> Optional[NoteModel]: def update_note_by_id(
self, id: str, form_data: NoteUpdateForm
) -> Optional[NoteModel]:
with get_db() as db: with get_db() as db:
note = db.query(Note).filter(Note.id == id).first() note = db.query(Note).filter(Note.id == id).first()
if not note: if not note:
return None return None
note.title = form_data.title form_data = form_data.model_dump(exclude_unset=True)
note.data = form_data.data
note.meta = form_data.meta if "title" in form_data:
note.access_control = form_data.access_control note.title = form_data["title"]
if "data" in form_data:
note.data = {**note.data, **form_data["data"]}
if "meta" in form_data:
note.meta = {**note.meta, **form_data["meta"]}
if "access_control" in form_data:
note.access_control = form_data["access_control"]
note.updated_at = int(time.time_ns()) note.updated_at = int(time.time_ns())
db.commit() db.commit()

View File

@ -74,6 +74,18 @@ class UserListResponse(BaseModel):
total: int total: int
class UserInfoResponse(BaseModel):
id: str
name: str
email: str
role: str
class UserInfoListResponse(BaseModel):
users: list[UserInfoResponse]
total: int
class UserResponse(BaseModel): class UserResponse(BaseModel):
id: str id: str
name: str name: str

View File

@ -14,7 +14,7 @@ from langchain_community.document_loaders import (
TextLoader, TextLoader,
UnstructuredEPubLoader, UnstructuredEPubLoader,
UnstructuredExcelLoader, UnstructuredExcelLoader,
UnstructuredMarkdownLoader, UnstructuredODTLoader,
UnstructuredPowerPointLoader, UnstructuredPowerPointLoader,
UnstructuredRSTLoader, UnstructuredRSTLoader,
UnstructuredXMLLoader, UnstructuredXMLLoader,
@ -181,7 +181,7 @@ class DoclingLoader:
if lang.strip() if lang.strip()
] ]
endpoint = f"{self.url}/v1alpha/convert/file" endpoint = f"{self.url}/v1/convert/file"
r = requests.post(endpoint, files=files, data=params) r = requests.post(endpoint, files=files, data=params)
if r.ok: if r.ok:
@ -226,7 +226,10 @@ class Loader:
def _is_text_file(self, file_ext: str, file_content_type: str) -> bool: def _is_text_file(self, file_ext: str, file_content_type: str) -> bool:
return file_ext in known_source_ext or ( return file_ext in known_source_ext or (
file_content_type and file_content_type.find("text/") >= 0 file_content_type
and file_content_type.find("text/") >= 0
# Avoid text/html files being detected as text
and not file_content_type.find("html") >= 0
) )
def _get_loader(self, filename: str, file_content_type: str, file_path: str): def _get_loader(self, filename: str, file_content_type: str, file_path: str):
@ -389,6 +392,8 @@ class Loader:
loader = UnstructuredPowerPointLoader(file_path) loader = UnstructuredPowerPointLoader(file_path)
elif file_ext == "msg": elif file_ext == "msg":
loader = OutlookMessageLoader(file_path) loader = OutlookMessageLoader(file_path)
elif file_ext == "odt":
loader = UnstructuredODTLoader(file_path)
elif self._is_text_file(file_ext, file_content_type): elif self._is_text_file(file_ext, file_content_type):
loader = TextLoader(file_path, autodetect_encoding=True) loader = TextLoader(file_path, autodetect_encoding=True)
else: else:

View File

@ -507,6 +507,7 @@ class MistralLoader:
timeout=timeout, timeout=timeout,
headers={"User-Agent": "OpenWebUI-MistralLoader/2.0"}, headers={"User-Agent": "OpenWebUI-MistralLoader/2.0"},
raise_for_status=False, # We handle status codes manually raise_for_status=False, # We handle status codes manually
trust_env=True,
) as session: ) as session:
yield session yield session

View File

@ -1,8 +1,10 @@
import logging import logging
import requests import requests
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from urllib.parse import quote
from open_webui.env import SRC_LOG_LEVELS
from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS
from open_webui.retrieval.models.base_reranker import BaseReranker from open_webui.retrieval.models.base_reranker import BaseReranker
@ -21,7 +23,9 @@ class ExternalReranker(BaseReranker):
self.url = url self.url = url
self.model = model self.model = model
def predict(self, sentences: List[Tuple[str, str]]) -> Optional[List[float]]: def predict(
self, sentences: List[Tuple[str, str]], user=None
) -> Optional[List[float]]:
query = sentences[0][0] query = sentences[0][0]
docs = [i[1] for i in sentences] docs = [i[1] for i in sentences]
@ -41,6 +45,16 @@ class ExternalReranker(BaseReranker):
headers={ headers={
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}", "Authorization": f"Bearer {self.api_key}",
**(
{
"X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
}, },
json=payload, json=payload,
) )

View File

@ -7,6 +7,7 @@ import hashlib
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import time import time
from urllib.parse import quote
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
from langchain_community.retrievers import BM25Retriever from langchain_community.retrievers import BM25Retriever
@ -17,8 +18,11 @@ from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
from open_webui.models.users import UserModel from open_webui.models.users import UserModel
from open_webui.models.files import Files from open_webui.models.files import Files
from open_webui.models.knowledge import Knowledges
from open_webui.models.notes import Notes
from open_webui.retrieval.vector.main import GetResult from open_webui.retrieval.vector.main import GetResult
from open_webui.utils.access_control import has_access
from open_webui.env import ( from open_webui.env import (
@ -441,9 +445,20 @@ def get_embedding_function(
raise ValueError(f"Unknown embedding engine: {embedding_engine}") raise ValueError(f"Unknown embedding engine: {embedding_engine}")
def get_sources_from_files( def get_reranking_function(reranking_engine, reranking_model, reranking_function):
if reranking_function is None:
return None
if reranking_engine == "external":
return lambda sentences, user=None: reranking_function.predict(
sentences, user=user
)
else:
return lambda sentences, user=None: reranking_function.predict(sentences)
def get_sources_from_items(
request, request,
files, items,
queries, queries,
embedding_function, embedding_function,
k, k,
@ -453,36 +468,117 @@ def get_sources_from_files(
hybrid_bm25_weight, hybrid_bm25_weight,
hybrid_search, hybrid_search,
full_context=False, full_context=False,
user: Optional[UserModel] = None,
): ):
log.debug( log.debug(
f"files: {files} {queries} {embedding_function} {reranking_function} {full_context}" f"items: {items} {queries} {embedding_function} {reranking_function} {full_context}"
) )
extracted_collections = [] extracted_collections = []
relevant_contexts = [] query_results = []
for file in files: for item in items:
query_result = None
collection_names = []
context = None if item.get("type") == "text":
if file.get("docs"): # Raw Text
# BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL # Used during temporary chat file uploads
context = {
"documents": [[doc.get("content") for doc in file.get("docs")]], if item.get("file"):
"metadatas": [[doc.get("metadata") for doc in file.get("docs")]], # if item has file data, use it
query_result = {
"documents": [
[item.get("file", {}).get("data", {}).get("content")]
],
"metadatas": [
[item.get("file", {}).get("data", {}).get("meta", {})]
],
} }
elif file.get("context") == "full": else:
# Manual Full Mode Toggle # Fallback to item content
context = { query_result = {
"documents": [[file.get("file").get("data", {}).get("content")]], "documents": [[item.get("content")]],
"metadatas": [[{"file_id": file.get("id"), "name": file.get("name")}]], "metadatas": [
[{"file_id": item.get("id"), "name": item.get("name")}]
],
} }
elif (
file.get("type") != "web_search" elif item.get("type") == "note":
and request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL # Note Attached
note = Notes.get_note_by_id(item.get("id"))
if note and (
user.role == "admin"
or note.user_id == user.id
or has_access(user.id, "read", note.access_control)
): ):
# BYPASS_EMBEDDING_AND_RETRIEVAL # User has access to the note
if file.get("type") == "collection": query_result = {
file_ids = file.get("data", {}).get("file_ids", []) "documents": [[note.data.get("content", {}).get("md", "")]],
"metadatas": [[{"file_id": note.id, "name": note.title}]],
}
elif item.get("type") == "file":
if (
item.get("context") == "full"
or request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
):
if item.get("file", {}).get("data", {}).get("content", ""):
# Manual Full Mode Toggle
# Used from chat file modal, we can assume that the file content will be available from item.get("file").get("data", {}).get("content")
query_result = {
"documents": [
[item.get("file", {}).get("data", {}).get("content", "")]
],
"metadatas": [
[
{
"file_id": item.get("id"),
"name": item.get("name"),
**item.get("file")
.get("data", {})
.get("metadata", {}),
}
]
],
}
elif item.get("id"):
file_object = Files.get_file_by_id(item.get("id"))
if file_object:
query_result = {
"documents": [[file_object.data.get("content", "")]],
"metadatas": [
[
{
"file_id": item.get("id"),
"name": file_object.filename,
"source": file_object.filename,
}
]
],
}
else:
# Fallback to collection names
if item.get("legacy"):
collection_names.append(f"{item['id']}")
else:
collection_names.append(f"file-{item['id']}")
elif item.get("type") == "collection":
if (
item.get("context") == "full"
or request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
):
# Manual Full Mode Toggle for Collection
knowledge_base = Knowledges.get_knowledge_by_id(item.get("id"))
if knowledge_base and (
user.role == "admin"
or has_access(user.id, "read", knowledge_base.access_control)
):
file_ids = knowledge_base.data.get("file_ids", [])
documents = [] documents = []
metadatas = [] metadatas = []
@ -499,68 +595,46 @@ def get_sources_from_files(
} }
) )
context = { query_result = {
"documents": [documents], "documents": [documents],
"metadatas": [metadatas], "metadatas": [metadatas],
} }
else:
# Fallback to collection names
if item.get("legacy"):
collection_names = item.get("collection_names", [])
else:
collection_names.append(item["id"])
elif file.get("id"): elif item.get("docs"):
file_object = Files.get_file_by_id(file.get("id")) # BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
if file_object: query_result = {
context = { "documents": [[doc.get("content") for doc in item.get("docs")]],
"documents": [[file_object.data.get("content", "")]], "metadatas": [[doc.get("metadata") for doc in item.get("docs")]],
"metadatas": [
[
{
"file_id": file.get("id"),
"name": file_object.filename,
"source": file_object.filename,
} }
] elif item.get("collection_name"):
], # Direct Collection Name
} collection_names.append(item["collection_name"])
elif file.get("file").get("data"): elif item.get("collection_names"):
context = { # Collection Names List
"documents": [[file.get("file").get("data", {}).get("content")]], collection_names.extend(item["collection_names"])
"metadatas": [
[file.get("file").get("data", {}).get("metadata", {})]
],
}
else:
collection_names = []
if file.get("type") == "collection":
if file.get("legacy"):
collection_names = file.get("collection_names", [])
else:
collection_names.append(file["id"])
elif file.get("collection_name"):
collection_names.append(file["collection_name"])
elif file.get("id"):
if file.get("legacy"):
collection_names.append(f"{file['id']}")
else:
collection_names.append(f"file-{file['id']}")
# If query_result is None
# Fallback to collection names and vector search the collections
if query_result is None and collection_names:
collection_names = set(collection_names).difference(extracted_collections) collection_names = set(collection_names).difference(extracted_collections)
if not collection_names: if not collection_names:
log.debug(f"skipping {file} as it has already been extracted") log.debug(f"skipping {item} as it has already been extracted")
continue continue
try:
if full_context: if full_context:
try: query_result = get_all_items_from_collections(collection_names)
context = get_all_items_from_collections(collection_names)
except Exception as e:
log.exception(e)
else:
try:
context = None
if file.get("type") == "text":
context = file["content"]
else: else:
query_result = None # Initialize to None
if hybrid_search: if hybrid_search:
try: try:
context = query_collection_with_hybrid_search( query_result = query_collection_with_hybrid_search(
collection_names=collection_names, collection_names=collection_names,
queries=queries, queries=queries,
embedding_function=embedding_function, embedding_function=embedding_function,
@ -572,12 +646,12 @@ def get_sources_from_files(
) )
except Exception as e: except Exception as e:
log.debug( log.debug(
"Error when using hybrid search, using" "Error when using hybrid search, using non hybrid search as fallback."
" non hybrid search as fallback."
) )
if (not hybrid_search) or (context is None): # fallback to non-hybrid search
context = query_collection( if not hybrid_search and query_result is None:
query_result = query_collection(
collection_names=collection_names, collection_names=collection_names,
queries=queries, queries=queries,
embedding_function=embedding_function, embedding_function=embedding_function,
@ -588,24 +662,23 @@ def get_sources_from_files(
extracted_collections.extend(collection_names) extracted_collections.extend(collection_names)
if context: if query_result:
if "data" in file: if "data" in item:
del file["data"] del item["data"]
query_results.append({**query_result, "file": item})
relevant_contexts.append({**context, "file": file})
sources = [] sources = []
for context in relevant_contexts: for query_result in query_results:
try: try:
if "documents" in context: if "documents" in query_result:
if "metadatas" in context: if "metadatas" in query_result:
source = { source = {
"source": context["file"], "source": query_result["file"],
"document": context["documents"][0], "document": query_result["documents"][0],
"metadata": context["metadatas"][0], "metadata": query_result["metadatas"][0],
} }
if "distances" in context and context["distances"]: if "distances" in query_result and query_result["distances"]:
source["distances"] = context["distances"][0] source["distances"] = query_result["distances"][0]
sources.append(source) sources.append(source)
except Exception as e: except Exception as e:
@ -678,7 +751,7 @@ def generate_openai_batch_embeddings(
"Authorization": f"Bearer {key}", "Authorization": f"Bearer {key}",
**( **(
{ {
"X-OpenWebUI-User-Name": user.name, "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id, "X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email, "X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role, "X-OpenWebUI-User-Role": user.role,
@ -727,7 +800,7 @@ def generate_azure_openai_batch_embeddings(
"api-key": key, "api-key": key,
**( **(
{ {
"X-OpenWebUI-User-Name": user.name, "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id, "X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email, "X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role, "X-OpenWebUI-User-Role": user.role,
@ -777,7 +850,7 @@ def generate_ollama_batch_embeddings(
"Authorization": f"Bearer {key}", "Authorization": f"Bearer {key}",
**( **(
{ {
"X-OpenWebUI-User-Name": user.name, "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id, "X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email, "X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role, "X-OpenWebUI-User-Role": user.role,
@ -874,7 +947,7 @@ class RerankCompressor(BaseDocumentCompressor):
reranking = self.reranking_function is not None reranking = self.reranking_function is not None
if reranking: if reranking:
scores = self.reranking_function.predict( scores = self.reranking_function(
[(query, doc.page_content) for doc in documents] [(query, doc.page_content) for doc in documents]
) )
else: else:

View File

@ -11,6 +11,8 @@ from open_webui.retrieval.vector.main import (
SearchResult, SearchResult,
GetResult, GetResult,
) )
from open_webui.retrieval.vector.utils import stringify_metadata
from open_webui.config import ( from open_webui.config import (
CHROMA_DATA_PATH, CHROMA_DATA_PATH,
CHROMA_HTTP_HOST, CHROMA_HTTP_HOST,
@ -144,7 +146,7 @@ class ChromaClient(VectorDBBase):
ids = [item["id"] for item in items] ids = [item["id"] for item in items]
documents = [item["text"] for item in items] documents = [item["text"] for item in items]
embeddings = [item["vector"] for item in items] embeddings = [item["vector"] for item in items]
metadatas = [item["metadata"] for item in items] metadatas = [stringify_metadata(item["metadata"]) for item in items]
for batch in create_batches( for batch in create_batches(
api=self.client, api=self.client,
@ -164,7 +166,7 @@ class ChromaClient(VectorDBBase):
ids = [item["id"] for item in items] ids = [item["id"] for item in items]
documents = [item["text"] for item in items] documents = [item["text"] for item in items]
embeddings = [item["vector"] for item in items] embeddings = [item["vector"] for item in items]
metadatas = [item["metadata"] for item in items] metadatas = [stringify_metadata(item["metadata"]) for item in items]
collection.upsert( collection.upsert(
ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas

View File

@ -3,6 +3,8 @@ from pymilvus import FieldSchema, DataType
import json import json
import logging import logging
from typing import Optional from typing import Optional
from open_webui.retrieval.vector.utils import stringify_metadata
from open_webui.retrieval.vector.main import ( from open_webui.retrieval.vector.main import (
VectorDBBase, VectorDBBase,
VectorItem, VectorItem,
@ -311,7 +313,7 @@ class MilvusClient(VectorDBBase):
"id": item["id"], "id": item["id"],
"vector": item["vector"], "vector": item["vector"],
"data": {"text": item["text"]}, "data": {"text": item["text"]},
"metadata": item["metadata"], "metadata": stringify_metadata(item["metadata"]),
} }
for item in items for item in items
], ],
@ -347,7 +349,7 @@ class MilvusClient(VectorDBBase):
"id": item["id"], "id": item["id"],
"vector": item["vector"], "vector": item["vector"],
"data": {"text": item["text"]}, "data": {"text": item["text"]},
"metadata": item["metadata"], "metadata": stringify_metadata(item["metadata"]),
} }
for item in items for item in items
], ],

View File

@ -157,10 +157,10 @@ class OpenSearchClient(VectorDBBase):
for field, value in filter.items(): for field, value in filter.items():
query_body["query"]["bool"]["filter"].append( query_body["query"]["bool"]["filter"].append(
{"match": {"metadata." + str(field): value}} {"term": {"metadata." + str(field) + ".keyword": value}}
) )
size = limit if limit else 10 size = limit if limit else 10000
try: try:
result = self.client.search( result = self.client.search(
@ -206,6 +206,7 @@ class OpenSearchClient(VectorDBBase):
for item in batch for item in batch
] ]
bulk(self.client, actions) bulk(self.client, actions)
self.client.indices.refresh(self._get_index_name(collection_name))
def upsert(self, collection_name: str, items: list[VectorItem]): def upsert(self, collection_name: str, items: list[VectorItem]):
self._create_index_if_not_exists( self._create_index_if_not_exists(
@ -228,6 +229,7 @@ class OpenSearchClient(VectorDBBase):
for item in batch for item in batch
] ]
bulk(self.client, actions) bulk(self.client, actions)
self.client.indices.refresh(self._get_index_name(collection_name))
def delete( def delete(
self, self,
@ -251,11 +253,12 @@ class OpenSearchClient(VectorDBBase):
} }
for field, value in filter.items(): for field, value in filter.items():
query_body["query"]["bool"]["filter"].append( query_body["query"]["bool"]["filter"].append(
{"match": {"metadata." + str(field): value}} {"term": {"metadata." + str(field) + ".keyword": value}}
) )
self.client.delete_by_query( self.client.delete_by_query(
index=self._get_index_name(collection_name), body=query_body index=self._get_index_name(collection_name), body=query_body
) )
self.client.indices.refresh(self._get_index_name(collection_name))
def reset(self): def reset(self):
indices = self.client.indices.get(index=f"{self.index_prefix}_*") indices = self.client.indices.get(index=f"{self.index_prefix}_*")

View File

@ -18,7 +18,7 @@ from sqlalchemy import (
values, values,
) )
from sqlalchemy.sql import true from sqlalchemy.sql import true
from sqlalchemy.pool import NullPool from sqlalchemy.pool import NullPool, QueuePool
from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
from sqlalchemy.dialects.postgresql import JSONB, array from sqlalchemy.dialects.postgresql import JSONB, array
@ -26,6 +26,8 @@ from pgvector.sqlalchemy import Vector
from sqlalchemy.ext.mutable import MutableDict from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.exc import NoSuchTableError from sqlalchemy.exc import NoSuchTableError
from open_webui.retrieval.vector.utils import stringify_metadata
from open_webui.retrieval.vector.main import ( from open_webui.retrieval.vector.main import (
VectorDBBase, VectorDBBase,
VectorItem, VectorItem,
@ -37,6 +39,10 @@ from open_webui.config import (
PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH,
PGVECTOR_PGCRYPTO, PGVECTOR_PGCRYPTO,
PGVECTOR_PGCRYPTO_KEY, PGVECTOR_PGCRYPTO_KEY,
PGVECTOR_POOL_SIZE,
PGVECTOR_POOL_MAX_OVERFLOW,
PGVECTOR_POOL_TIMEOUT,
PGVECTOR_POOL_RECYCLE,
) )
from open_webui.env import SRC_LOG_LEVELS from open_webui.env import SRC_LOG_LEVELS
@ -79,10 +85,25 @@ class PgvectorClient(VectorDBBase):
from open_webui.internal.db import Session from open_webui.internal.db import Session
self.session = Session self.session = Session
else:
if isinstance(PGVECTOR_POOL_SIZE, int):
if PGVECTOR_POOL_SIZE > 0:
engine = create_engine(
PGVECTOR_DB_URL,
pool_size=PGVECTOR_POOL_SIZE,
max_overflow=PGVECTOR_POOL_MAX_OVERFLOW,
pool_timeout=PGVECTOR_POOL_TIMEOUT,
pool_recycle=PGVECTOR_POOL_RECYCLE,
pool_pre_ping=True,
poolclass=QueuePool,
)
else: else:
engine = create_engine( engine = create_engine(
PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool
) )
else:
engine = create_engine(PGVECTOR_DB_URL, pool_pre_ping=True)
SessionLocal = sessionmaker( SessionLocal = sessionmaker(
autocommit=False, autoflush=False, bind=engine, expire_on_commit=False autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
) )
@ -216,7 +237,7 @@ class PgvectorClient(VectorDBBase):
vector=vector, vector=vector,
collection_name=collection_name, collection_name=collection_name,
text=item["text"], text=item["text"],
vmetadata=item["metadata"], vmetadata=stringify_metadata(item["metadata"]),
) )
new_items.append(new_chunk) new_items.append(new_chunk)
self.session.bulk_save_objects(new_items) self.session.bulk_save_objects(new_items)
@ -273,7 +294,7 @@ class PgvectorClient(VectorDBBase):
if existing: if existing:
existing.vector = vector existing.vector = vector
existing.text = item["text"] existing.text = item["text"]
existing.vmetadata = item["metadata"] existing.vmetadata = stringify_metadata(item["metadata"])
existing.collection_name = ( existing.collection_name = (
collection_name # Update collection_name if necessary collection_name # Update collection_name if necessary
) )
@ -283,7 +304,7 @@ class PgvectorClient(VectorDBBase):
vector=vector, vector=vector,
collection_name=collection_name, collection_name=collection_name,
text=item["text"], text=item["text"],
vmetadata=item["metadata"], vmetadata=stringify_metadata(item["metadata"]),
) )
self.session.add(new_chunk) self.session.add(new_chunk)
self.session.commit() self.session.commit()

View File

@ -18,6 +18,7 @@ from open_webui.config import (
QDRANT_ON_DISK, QDRANT_ON_DISK,
QDRANT_GRPC_PORT, QDRANT_GRPC_PORT,
QDRANT_PREFER_GRPC, QDRANT_PREFER_GRPC,
QDRANT_COLLECTION_PREFIX,
) )
from open_webui.env import SRC_LOG_LEVELS from open_webui.env import SRC_LOG_LEVELS
@ -29,7 +30,7 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
class QdrantClient(VectorDBBase): class QdrantClient(VectorDBBase):
def __init__(self): def __init__(self):
self.collection_prefix = "open-webui" self.collection_prefix = QDRANT_COLLECTION_PREFIX
self.QDRANT_URI = QDRANT_URI self.QDRANT_URI = QDRANT_URI
self.QDRANT_API_KEY = QDRANT_API_KEY self.QDRANT_API_KEY = QDRANT_API_KEY
self.QDRANT_ON_DISK = QDRANT_ON_DISK self.QDRANT_ON_DISK = QDRANT_ON_DISK
@ -86,6 +87,25 @@ class QdrantClient(VectorDBBase):
), ),
) )
# Create payload indexes for efficient filtering
self.client.create_payload_index(
collection_name=collection_name_with_prefix,
field_name="metadata.hash",
field_schema=models.KeywordIndexParams(
type=models.KeywordIndexType.KEYWORD,
is_tenant=False,
on_disk=self.QDRANT_ON_DISK,
),
)
self.client.create_payload_index(
collection_name=collection_name_with_prefix,
field_name="metadata.file_id",
field_schema=models.KeywordIndexParams(
type=models.KeywordIndexType.KEYWORD,
is_tenant=False,
on_disk=self.QDRANT_ON_DISK,
),
)
log.info(f"collection {collection_name_with_prefix} successfully created!") log.info(f"collection {collection_name_with_prefix} successfully created!")
def _create_collection_if_not_exists(self, collection_name, dimension): def _create_collection_if_not_exists(self, collection_name, dimension):

View File

@ -1,5 +1,5 @@
import logging import logging
from typing import Optional, Tuple from typing import Optional, Tuple, List, Dict, Any
from urllib.parse import urlparse from urllib.parse import urlparse
import grpc import grpc
@ -9,6 +9,7 @@ from open_webui.config import (
QDRANT_ON_DISK, QDRANT_ON_DISK,
QDRANT_PREFER_GRPC, QDRANT_PREFER_GRPC,
QDRANT_URI, QDRANT_URI,
QDRANT_COLLECTION_PREFIX,
) )
from open_webui.env import SRC_LOG_LEVELS from open_webui.env import SRC_LOG_LEVELS
from open_webui.retrieval.vector.main import ( from open_webui.retrieval.vector.main import (
@ -23,14 +24,28 @@ from qdrant_client.http.models import PointStruct
from qdrant_client.models import models from qdrant_client.models import models
NO_LIMIT = 999999999 NO_LIMIT = 999999999
TENANT_ID_FIELD = "tenant_id"
DEFAULT_DIMENSION = 384
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])
def _tenant_filter(tenant_id: str) -> models.FieldCondition:
return models.FieldCondition(
key=TENANT_ID_FIELD, match=models.MatchValue(value=tenant_id)
)
def _metadata_filter(key: str, value: Any) -> models.FieldCondition:
return models.FieldCondition(
key=f"metadata.{key}", match=models.MatchValue(value=value)
)
class QdrantClient(VectorDBBase): class QdrantClient(VectorDBBase):
def __init__(self): def __init__(self):
self.collection_prefix = "open-webui" self.collection_prefix = QDRANT_COLLECTION_PREFIX
self.QDRANT_URI = QDRANT_URI self.QDRANT_URI = QDRANT_URI
self.QDRANT_API_KEY = QDRANT_API_KEY self.QDRANT_API_KEY = QDRANT_API_KEY
self.QDRANT_ON_DISK = QDRANT_ON_DISK self.QDRANT_ON_DISK = QDRANT_ON_DISK
@ -38,24 +53,26 @@ class QdrantClient(VectorDBBase):
self.GRPC_PORT = QDRANT_GRPC_PORT self.GRPC_PORT = QDRANT_GRPC_PORT
if not self.QDRANT_URI: if not self.QDRANT_URI:
self.client = None raise ValueError(
return "QDRANT_URI is not set. Please configure it in the environment variables."
)
# Unified handling for either scheme # Unified handling for either scheme
parsed = urlparse(self.QDRANT_URI) parsed = urlparse(self.QDRANT_URI)
host = parsed.hostname or self.QDRANT_URI host = parsed.hostname or self.QDRANT_URI
http_port = parsed.port or 6333 # default REST port http_port = parsed.port or 6333 # default REST port
if self.PREFER_GRPC: self.client = (
self.client = Qclient( Qclient(
host=host, host=host,
port=http_port, port=http_port,
grpc_port=self.GRPC_PORT, grpc_port=self.GRPC_PORT,
prefer_grpc=self.PREFER_GRPC, prefer_grpc=self.PREFER_GRPC,
api_key=self.QDRANT_API_KEY, api_key=self.QDRANT_API_KEY,
) )
else: if self.PREFER_GRPC
self.client = Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY) else Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)
)
# Main collection types for multi-tenancy # Main collection types for multi-tenancy
self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories" self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories"
@ -65,23 +82,13 @@ class QdrantClient(VectorDBBase):
self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash-based" self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash-based"
def _result_to_get_result(self, points) -> GetResult: def _result_to_get_result(self, points) -> GetResult:
ids = [] ids, documents, metadatas = [], [], []
documents = []
metadatas = []
for point in points: for point in points:
payload = point.payload payload = point.payload
ids.append(point.id) ids.append(point.id)
documents.append(payload["text"]) documents.append(payload["text"])
metadatas.append(payload["metadata"]) metadatas.append(payload["metadata"])
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
return GetResult(
**{
"ids": [ids],
"documents": [documents],
"metadatas": [metadatas],
}
)
def _get_collection_and_tenant_id(self, collection_name: str) -> Tuple[str, str]: def _get_collection_and_tenant_id(self, collection_name: str) -> Tuple[str, str]:
""" """
@ -113,96 +120,12 @@ class QdrantClient(VectorDBBase):
else: else:
return self.KNOWLEDGE_COLLECTION, tenant_id return self.KNOWLEDGE_COLLECTION, tenant_id
def _extract_error_message(self, exception): def _create_multi_tenant_collection(
""" self, mt_collection_name: str, dimension: int = DEFAULT_DIMENSION
Extract error message from either HTTP or gRPC exceptions
Returns:
tuple: (status_code, error_message)
"""
# Check if it's an HTTP exception
if isinstance(exception, UnexpectedResponse):
try:
error_data = exception.structured()
error_msg = error_data.get("status", {}).get("error", "")
return exception.status_code, error_msg
except Exception as inner_e:
log.error(f"Failed to parse HTTP error: {inner_e}")
return exception.status_code, str(exception)
# Check if it's a gRPC exception
elif isinstance(exception, grpc.RpcError):
# Extract status code from gRPC error
status_code = None
if hasattr(exception, "code") and callable(exception.code):
status_code = exception.code().value[0]
# Extract error message
error_msg = str(exception)
if "details =" in error_msg:
# Parse the details line which contains the actual error message
try:
details_line = [
line.strip()
for line in error_msg.split("\n")
if "details =" in line
][0]
error_msg = details_line.split("details =")[1].strip(' "')
except (IndexError, AttributeError):
# Fall back to full message if parsing fails
pass
return status_code, error_msg
# For any other type of exception
return None, str(exception)
def _is_collection_not_found_error(self, exception):
"""
Check if the exception is due to collection not found, supporting both HTTP and gRPC
"""
status_code, error_msg = self._extract_error_message(exception)
# HTTP error (404)
if (
status_code == 404
and "Collection" in error_msg
and "doesn't exist" in error_msg
):
return True
# gRPC error (NOT_FOUND status)
if (
isinstance(exception, grpc.RpcError)
and exception.code() == grpc.StatusCode.NOT_FOUND
):
return True
return False
def _is_dimension_mismatch_error(self, exception):
"""
Check if the exception is due to dimension mismatch, supporting both HTTP and gRPC
"""
status_code, error_msg = self._extract_error_message(exception)
# Common patterns in both HTTP and gRPC
return (
"Vector dimension error" in error_msg
or "dimensions mismatch" in error_msg
or "invalid vector size" in error_msg
)
def _create_multi_tenant_collection_if_not_exists(
self, mt_collection_name: str, dimension: int = 384
): ):
""" """
Creates a collection with multi-tenancy configuration if it doesn't exist. Creates a collection with multi-tenancy configuration and payload indexes for tenant_id and metadata fields.
Default dimension is set to 384 which corresponds to 'sentence-transformers/all-MiniLM-L6-v2'.
When creating collections dynamically (insert/upsert), the actual vector dimensions will be used.
""" """
try:
# Try to create the collection directly - will fail if it already exists
self.client.create_collection( self.client.create_collection(
collection_name=mt_collection_name, collection_name=mt_collection_name,
vectors_config=models.VectorParams( vectors_config=models.VectorParams(
@ -210,46 +133,34 @@ class QdrantClient(VectorDBBase):
distance=models.Distance.COSINE, distance=models.Distance.COSINE,
on_disk=self.QDRANT_ON_DISK, on_disk=self.QDRANT_ON_DISK,
), ),
hnsw_config=models.HnswConfigDiff( )
payload_m=16, # Enable per-tenant indexing log.info(
m=0, f"Multi-tenant collection {mt_collection_name} created with dimension {dimension}!"
on_disk=self.QDRANT_ON_DISK,
),
) )
# Create tenant ID payload index
self.client.create_payload_index( self.client.create_payload_index(
collection_name=mt_collection_name, collection_name=mt_collection_name,
field_name="tenant_id", field_name=TENANT_ID_FIELD,
field_schema=models.KeywordIndexParams( field_schema=models.KeywordIndexParams(
type=models.KeywordIndexType.KEYWORD, type=models.KeywordIndexType.KEYWORD,
is_tenant=True, is_tenant=True,
on_disk=self.QDRANT_ON_DISK, on_disk=self.QDRANT_ON_DISK,
), ),
wait=True,
) )
log.info( for field in ("metadata.hash", "metadata.file_id"):
f"Multi-tenant collection {mt_collection_name} created with dimension {dimension}!" self.client.create_payload_index(
collection_name=mt_collection_name,
field_name=field,
field_schema=models.KeywordIndexParams(
type=models.KeywordIndexType.KEYWORD,
on_disk=self.QDRANT_ON_DISK,
),
) )
except (UnexpectedResponse, grpc.RpcError) as e:
# Check for the specific error indicating collection already exists
status_code, error_msg = self._extract_error_message(e)
# HTTP status code 409 or gRPC ALREADY_EXISTS def _create_points(
if (isinstance(e, UnexpectedResponse) and status_code == 409) or ( self, items: List[VectorItem], tenant_id: str
isinstance(e, grpc.RpcError) ) -> List[PointStruct]:
and e.code() == grpc.StatusCode.ALREADY_EXISTS
):
if "already exists" in error_msg:
log.debug(f"Collection {mt_collection_name} already exists")
return
# If it's not an already exists error, re-raise
raise e
except Exception as e:
raise e
def _create_points(self, items: list[VectorItem], tenant_id: str):
""" """
Create point structs from vector items with tenant ID. Create point structs from vector items with tenant ID.
""" """
@ -260,56 +171,42 @@ class QdrantClient(VectorDBBase):
payload={ payload={
"text": item["text"], "text": item["text"],
"metadata": item["metadata"], "metadata": item["metadata"],
"tenant_id": tenant_id, TENANT_ID_FIELD: tenant_id,
}, },
) )
for item in items for item in items
] ]
def _ensure_collection(
self, mt_collection_name: str, dimension: int = DEFAULT_DIMENSION
):
"""
Ensure the collection exists and payload indexes are created for tenant_id and metadata fields.
"""
if not self.client.collection_exists(collection_name=mt_collection_name):
self._create_multi_tenant_collection(mt_collection_name, dimension)
def has_collection(self, collection_name: str) -> bool: def has_collection(self, collection_name: str) -> bool:
""" """
Check if a logical collection exists by checking for any points with the tenant ID. Check if a logical collection exists by checking for any points with the tenant ID.
""" """
if not self.client: if not self.client:
return False return False
# Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name) mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
if not self.client.collection_exists(collection_name=mt_collection):
# Create tenant filter return False
tenant_filter = models.FieldCondition( tenant_filter = _tenant_filter(tenant_id)
key="tenant_id", match=models.MatchValue(value=tenant_id) count_result = self.client.count(
)
try:
# Try directly querying - most of the time collection should exist
response = self.client.query_points(
collection_name=mt_collection, collection_name=mt_collection,
query_filter=models.Filter(must=[tenant_filter]), count_filter=models.Filter(must=[tenant_filter]),
limit=1,
) )
return count_result.count > 0
# Collection exists with this tenant ID if there are points
return len(response.points) > 0
except (UnexpectedResponse, grpc.RpcError) as e:
if self._is_collection_not_found_error(e):
log.debug(f"Collection {mt_collection} doesn't exist")
return False
else:
# For other API errors, log and return False
_, error_msg = self._extract_error_message(e)
log.warning(f"Unexpected Qdrant error: {error_msg}")
return False
except Exception as e:
# For any other errors, log and return False
log.debug(f"Error checking collection {mt_collection}: {e}")
return False
def delete( def delete(
self, self,
collection_name: str, collection_name: str,
ids: Optional[list[str]] = None, ids: Optional[List[str]] = None,
filter: Optional[dict] = None, filter: Optional[Dict[str, Any]] = None,
): ):
""" """
Delete vectors by ID or filter from a collection with tenant isolation. Delete vectors by ID or filter from a collection with tenant isolation.
@ -317,189 +214,76 @@ class QdrantClient(VectorDBBase):
if not self.client: if not self.client:
return None return None
# Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name) mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
if not self.client.collection_exists(collection_name=mt_collection):
log.debug(f"Collection {mt_collection} doesn't exist, nothing to delete")
return None
# Create tenant filter must_conditions = [_tenant_filter(tenant_id)]
tenant_filter = models.FieldCondition(
key="tenant_id", match=models.MatchValue(value=tenant_id)
)
must_conditions = [tenant_filter]
should_conditions = [] should_conditions = []
if ids: if ids:
for id_value in ids: should_conditions = [_metadata_filter("id", id_value) for id_value in ids]
should_conditions.append(
models.FieldCondition(
key="metadata.id",
match=models.MatchValue(value=id_value),
),
)
elif filter: elif filter:
for key, value in filter.items(): must_conditions += [_metadata_filter(k, v) for k, v in filter.items()]
must_conditions.append(
models.FieldCondition(
key=f"metadata.{key}",
match=models.MatchValue(value=value),
),
)
try: return self.client.delete(
# Try to delete directly - most of the time collection should exist
update_result = self.client.delete(
collection_name=mt_collection, collection_name=mt_collection,
points_selector=models.FilterSelector( points_selector=models.FilterSelector(
filter=models.Filter(must=must_conditions, should=should_conditions) filter=models.Filter(must=must_conditions, should=should_conditions)
), ),
) )
return update_result
except (UnexpectedResponse, grpc.RpcError) as e:
if self._is_collection_not_found_error(e):
log.debug(
f"Collection {mt_collection} doesn't exist, nothing to delete"
)
return None
else:
# For other API errors, log and re-raise
_, error_msg = self._extract_error_message(e)
log.warning(f"Unexpected Qdrant error: {error_msg}")
raise
except Exception as e:
# For non-Qdrant exceptions, re-raise
raise
def search( def search(
self, collection_name: str, vectors: list[list[float | int]], limit: int self, collection_name: str, vectors: List[List[float | int]], limit: int
) -> Optional[SearchResult]: ) -> Optional[SearchResult]:
""" """
Search for the nearest neighbor items based on the vectors with tenant isolation. Search for the nearest neighbor items based on the vectors with tenant isolation.
""" """
if not self.client: if not self.client or not vectors:
return None
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
if not self.client.collection_exists(collection_name=mt_collection):
log.debug(f"Collection {mt_collection} doesn't exist, search returns None")
return None return None
# Map to multi-tenant collection and tenant ID tenant_filter = _tenant_filter(tenant_id)
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
# Get the vector dimension from the query vector
dimension = len(vectors[0]) if vectors and len(vectors) > 0 else None
try:
# Try the search operation directly - most of the time collection should exist
# Create tenant filter
tenant_filter = models.FieldCondition(
key="tenant_id", match=models.MatchValue(value=tenant_id)
)
# Ensure vector dimensions match the collection
collection_dim = self.client.get_collection(
mt_collection
).config.params.vectors.size
if collection_dim != dimension:
if collection_dim < dimension:
vectors = [vector[:collection_dim] for vector in vectors]
else:
vectors = [
vector + [0] * (collection_dim - dimension)
for vector in vectors
]
# Search with tenant filter
prefetch_query = models.Prefetch(
filter=models.Filter(must=[tenant_filter]),
limit=NO_LIMIT,
)
query_response = self.client.query_points( query_response = self.client.query_points(
collection_name=mt_collection, collection_name=mt_collection,
query=vectors[0], query=vectors[0],
prefetch=prefetch_query,
limit=limit, limit=limit,
query_filter=models.Filter(must=[tenant_filter]),
) )
get_result = self._result_to_get_result(query_response.points) get_result = self._result_to_get_result(query_response.points)
return SearchResult( return SearchResult(
ids=get_result.ids, ids=get_result.ids,
documents=get_result.documents, documents=get_result.documents,
metadatas=get_result.metadatas, metadatas=get_result.metadatas,
# qdrant distance is [-1, 1], normalize to [0, 1] distances=[[(point.score + 1.0) / 2.0 for point in query_response.points]],
distances=[
[(point.score + 1.0) / 2.0 for point in query_response.points]
],
) )
except (UnexpectedResponse, grpc.RpcError) as e:
if self._is_collection_not_found_error(e):
log.debug(
f"Collection {mt_collection} doesn't exist, search returns None"
)
return None
else:
# For other API errors, log and re-raise
_, error_msg = self._extract_error_message(e)
log.warning(f"Unexpected Qdrant error during search: {error_msg}")
raise
except Exception as e:
# For non-Qdrant exceptions, log and return None
log.exception(f"Error searching collection '{collection_name}': {e}")
return None
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None): def query(
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
):
""" """
Query points with filters and tenant isolation. Query points with filters and tenant isolation.
""" """
if not self.client: if not self.client:
return None return None
# Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name) mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
if not self.client.collection_exists(collection_name=mt_collection):
# Set default limit if not provided log.debug(f"Collection {mt_collection} doesn't exist, query returns None")
return None
if limit is None: if limit is None:
limit = NO_LIMIT limit = NO_LIMIT
tenant_filter = _tenant_filter(tenant_id)
# Create tenant filter field_conditions = [_metadata_filter(k, v) for k, v in filter.items()]
tenant_filter = models.FieldCondition(
key="tenant_id", match=models.MatchValue(value=tenant_id)
)
# Create metadata filters
field_conditions = []
for key, value in filter.items():
field_conditions.append(
models.FieldCondition(
key=f"metadata.{key}", match=models.MatchValue(value=value)
)
)
# Combine tenant filter with metadata filters
combined_filter = models.Filter(must=[tenant_filter, *field_conditions]) combined_filter = models.Filter(must=[tenant_filter, *field_conditions])
try:
# Try the query directly - most of the time collection should exist
points = self.client.query_points( points = self.client.query_points(
collection_name=mt_collection, collection_name=mt_collection,
query_filter=combined_filter, query_filter=combined_filter,
limit=limit, limit=limit,
) )
return self._result_to_get_result(points.points) return self._result_to_get_result(points.points)
except (UnexpectedResponse, grpc.RpcError) as e:
if self._is_collection_not_found_error(e):
log.debug(
f"Collection {mt_collection} doesn't exist, query returns None"
)
return None
else:
# For other API errors, log and re-raise
_, error_msg = self._extract_error_message(e)
log.warning(f"Unexpected Qdrant error during query: {error_msg}")
raise
except Exception as e:
# For non-Qdrant exceptions, log and re-raise
log.exception(f"Error querying collection '{collection_name}': {e}")
return None
def get(self, collection_name: str) -> Optional[GetResult]: def get(self, collection_name: str) -> Optional[GetResult]:
""" """
@ -507,169 +291,36 @@ class QdrantClient(VectorDBBase):
""" """
if not self.client: if not self.client:
return None return None
# Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name) mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
if not self.client.collection_exists(collection_name=mt_collection):
# Create tenant filter log.debug(f"Collection {mt_collection} doesn't exist, get returns None")
tenant_filter = models.FieldCondition( return None
key="tenant_id", match=models.MatchValue(value=tenant_id) tenant_filter = _tenant_filter(tenant_id)
)
try:
# Try to get points directly - most of the time collection should exist
points = self.client.query_points( points = self.client.query_points(
collection_name=mt_collection, collection_name=mt_collection,
query_filter=models.Filter(must=[tenant_filter]), query_filter=models.Filter(must=[tenant_filter]),
limit=NO_LIMIT, limit=NO_LIMIT,
) )
return self._result_to_get_result(points.points) return self._result_to_get_result(points.points)
except (UnexpectedResponse, grpc.RpcError) as e:
if self._is_collection_not_found_error(e):
log.debug(f"Collection {mt_collection} doesn't exist, get returns None")
return None
else:
# For other API errors, log and re-raise
_, error_msg = self._extract_error_message(e)
log.warning(f"Unexpected Qdrant error during get: {error_msg}")
raise
except Exception as e:
# For non-Qdrant exceptions, log and return None
log.exception(f"Error getting collection '{collection_name}': {e}")
return None
def _handle_operation_with_error_retry( def upsert(self, collection_name: str, items: List[VectorItem]):
self, operation_name, mt_collection, points, dimension
):
"""
Private helper to handle common error cases for insert and upsert operations.
Args:
operation_name: 'insert' or 'upsert'
mt_collection: The multi-tenant collection name
points: The vector points to insert/upsert
dimension: The dimension of the vectors
Returns:
The operation result (for upsert) or None (for insert)
"""
try:
if operation_name == "insert":
self.client.upload_points(mt_collection, points)
return None
else: # upsert
return self.client.upsert(mt_collection, points)
except (UnexpectedResponse, grpc.RpcError) as e:
# Handle collection not found
if self._is_collection_not_found_error(e):
log.info(
f"Collection {mt_collection} doesn't exist. Creating it with dimension {dimension}."
)
# Create collection with correct dimensions from our vectors
self._create_multi_tenant_collection_if_not_exists(
mt_collection_name=mt_collection, dimension=dimension
)
# Try operation again - no need for dimension adjustment since we just created with correct dimensions
if operation_name == "insert":
self.client.upload_points(mt_collection, points)
return None
else: # upsert
return self.client.upsert(mt_collection, points)
# Handle dimension mismatch
elif self._is_dimension_mismatch_error(e):
# For dimension errors, the collection must exist, so get its configuration
mt_collection_info = self.client.get_collection(mt_collection)
existing_size = mt_collection_info.config.params.vectors.size
log.info(
f"Dimension mismatch: Collection {mt_collection} expects {existing_size}, got {dimension}"
)
if existing_size < dimension:
# Truncate vectors to fit
log.info(
f"Truncating vectors from {dimension} to {existing_size} dimensions"
)
points = [
PointStruct(
id=point.id,
vector=point.vector[:existing_size],
payload=point.payload,
)
for point in points
]
elif existing_size > dimension:
# Pad vectors with zeros
log.info(
f"Padding vectors from {dimension} to {existing_size} dimensions with zeros"
)
points = [
PointStruct(
id=point.id,
vector=point.vector
+ [0] * (existing_size - len(point.vector)),
payload=point.payload,
)
for point in points
]
# Try operation again with adjusted dimensions
if operation_name == "insert":
self.client.upload_points(mt_collection, points)
return None
else: # upsert
return self.client.upsert(mt_collection, points)
else:
# Not a known error we can handle, log and re-raise
_, error_msg = self._extract_error_message(e)
log.warning(f"Unhandled Qdrant error: {error_msg}")
raise
except Exception as e:
# For non-Qdrant exceptions, re-raise
raise
def insert(self, collection_name: str, items: list[VectorItem]):
"""
Insert items with tenant ID.
"""
if not self.client or not items:
return None
# Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
# Get dimensions from the actual vectors
dimension = len(items[0]["vector"]) if items else None
# Create points with tenant ID
points = self._create_points(items, tenant_id)
# Handle the operation with error retry
return self._handle_operation_with_error_retry(
"insert", mt_collection, points, dimension
)
def upsert(self, collection_name: str, items: list[VectorItem]):
""" """
Upsert items with tenant ID. Upsert items with tenant ID.
""" """
if not self.client or not items: if not self.client or not items:
return None return None
# Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name) mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
dimension = len(items[0]["vector"])
# Get dimensions from the actual vectors self._ensure_collection(mt_collection, dimension)
dimension = len(items[0]["vector"]) if items else None
# Create points with tenant ID
points = self._create_points(items, tenant_id) points = self._create_points(items, tenant_id)
self.client.upload_points(mt_collection, points)
return None
# Handle the operation with error retry def insert(self, collection_name: str, items: List[VectorItem]):
return self._handle_operation_with_error_retry( """
"upsert", mt_collection, points, dimension Insert items with tenant ID.
) """
return self.upsert(collection_name, items)
def reset(self): def reset(self):
""" """
@ -677,11 +328,9 @@ class QdrantClient(VectorDBBase):
""" """
if not self.client: if not self.client:
return None return None
for collection in self.client.get_collections().collections:
collection_names = self.client.get_collections().collections if collection.name.startswith(self.collection_prefix):
for collection_name in collection_names: self.client.delete_collection(collection_name=collection.name)
if collection_name.name.startswith(self.collection_prefix):
self.client.delete_collection(collection_name=collection_name.name)
def delete_collection(self, collection_name: str): def delete_collection(self, collection_name: str):
""" """
@ -689,24 +338,13 @@ class QdrantClient(VectorDBBase):
""" """
if not self.client: if not self.client:
return None return None
# Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name) mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
if not self.client.collection_exists(collection_name=mt_collection):
tenant_filter = models.FieldCondition( log.debug(f"Collection {mt_collection} doesn't exist, nothing to delete")
key="tenant_id", match=models.MatchValue(value=tenant_id) return None
) self.client.delete(
field_conditions = [tenant_filter]
update_result = self.client.delete(
collection_name=mt_collection, collection_name=mt_collection,
points_selector=models.FilterSelector( points_selector=models.FilterSelector(
filter=models.Filter(must=field_conditions) filter=models.Filter(must=[_tenant_filter(tenant_id)])
), ),
) )
if self.client.get_collection(mt_collection).points_count == 0:
self.client.delete_collection(mt_collection)
return update_result

View File

@ -0,0 +1,745 @@
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
GetResult,
SearchResult,
)
from open_webui.config import S3_VECTOR_BUCKET_NAME, S3_VECTOR_REGION
from open_webui.env import SRC_LOG_LEVELS
from typing import List, Optional, Dict, Any, Union
import logging
import boto3
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class S3VectorClient(VectorDBBase):
"""
AWS S3 Vector integration for Open WebUI Knowledge.
"""
def __init__(self):
self.bucket_name = S3_VECTOR_BUCKET_NAME
self.region = S3_VECTOR_REGION
# Simple validation - log warnings instead of raising exceptions
if not self.bucket_name:
log.warning("S3_VECTOR_BUCKET_NAME not set - S3Vector will not work")
if not self.region:
log.warning("S3_VECTOR_REGION not set - S3Vector will not work")
if self.bucket_name and self.region:
try:
self.client = boto3.client("s3vectors", region_name=self.region)
log.info(
f"S3Vector client initialized for bucket '{self.bucket_name}' in region '{self.region}'"
)
except Exception as e:
log.error(f"Failed to initialize S3Vector client: {e}")
self.client = None
else:
self.client = None
def _create_index(
self,
index_name: str,
dimension: int,
data_type: str = "float32",
distance_metric: str = "cosine",
) -> None:
"""
Create a new index in the S3 vector bucket for the given collection if it does not exist.
"""
if self.has_collection(index_name):
log.debug(f"Index '{index_name}' already exists, skipping creation")
return
try:
self.client.create_index(
vectorBucketName=self.bucket_name,
indexName=index_name,
dataType=data_type,
dimension=dimension,
distanceMetric=distance_metric,
)
log.info(
f"Created S3 index: {index_name} (dim={dimension}, type={data_type}, metric={distance_metric})"
)
except Exception as e:
log.error(f"Error creating S3 index '{index_name}': {e}")
raise
def _filter_metadata(
self, metadata: Dict[str, Any], item_id: str
) -> Dict[str, Any]:
"""
Filter vector metadata keys to comply with S3 Vector API limit of 10 keys maximum.
"""
if not isinstance(metadata, dict) or len(metadata) <= 10:
return metadata
# Keep only the first 10 keys, prioritizing important ones based on actual Open WebUI metadata
important_keys = [
"text", # The actual document content
"file_id", # File ID
"source", # Document source file
"title", # Document title
"page", # Page number
"total_pages", # Total pages in document
"embedding_config", # Embedding configuration
"created_by", # User who created it
"name", # Document name
"hash", # Content hash
]
filtered_metadata = {}
# First, add important keys if they exist
for key in important_keys:
if key in metadata:
filtered_metadata[key] = metadata[key]
if len(filtered_metadata) >= 10:
break
# If we still have room, add other keys
if len(filtered_metadata) < 10:
for key, value in metadata.items():
if key not in filtered_metadata:
filtered_metadata[key] = value
if len(filtered_metadata) >= 10:
break
log.warning(
f"Metadata for key '{item_id}' had {len(metadata)} keys, limited to 10 keys"
)
return filtered_metadata
def has_collection(self, collection_name: str) -> bool:
"""
Check if a vector index (collection) exists in the S3 vector bucket.
"""
try:
response = self.client.list_indexes(vectorBucketName=self.bucket_name)
indexes = response.get("indexes", [])
return any(idx.get("indexName") == collection_name for idx in indexes)
except Exception as e:
log.error(f"Error listing indexes: {e}")
return False
def delete_collection(self, collection_name: str) -> None:
"""
Delete an entire S3 Vector index/collection.
"""
if not self.has_collection(collection_name):
log.warning(
f"Collection '{collection_name}' does not exist, nothing to delete"
)
return
try:
log.info(f"Deleting collection '{collection_name}'")
self.client.delete_index(
vectorBucketName=self.bucket_name, indexName=collection_name
)
log.info(f"Successfully deleted collection '{collection_name}'")
except Exception as e:
log.error(f"Error deleting collection '{collection_name}': {e}")
raise
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
"""
Insert vector items into the S3 Vector index. Create index if it does not exist.
"""
if not items:
log.warning("No items to insert")
return
dimension = len(items[0]["vector"])
try:
if not self.has_collection(collection_name):
log.info(f"Index '{collection_name}' does not exist. Creating index.")
self._create_index(
index_name=collection_name,
dimension=dimension,
data_type="float32",
distance_metric="cosine",
)
# Prepare vectors for insertion
vectors = []
for item in items:
# Ensure vector data is in the correct format for S3 Vector API
vector_data = item["vector"]
if isinstance(vector_data, list):
# Convert list to float32 values as required by S3 Vector API
vector_data = [float(x) for x in vector_data]
# Prepare metadata, ensuring the text field is preserved
metadata = item.get("metadata", {}).copy()
# Add the text field to metadata so it's available for retrieval
metadata["text"] = item["text"]
# Filter metadata to comply with S3 Vector API limit of 10 keys
metadata = self._filter_metadata(metadata, item["id"])
vectors.append(
{
"key": item["id"],
"data": {"float32": vector_data},
"metadata": metadata,
}
)
# Insert vectors
self.client.put_vectors(
vectorBucketName=self.bucket_name,
indexName=collection_name,
vectors=vectors,
)
log.info(f"Inserted {len(vectors)} vectors into index '{collection_name}'.")
except Exception as e:
log.error(f"Error inserting vectors: {e}")
raise
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
"""
Insert or update vector items in the S3 Vector index. Create index if it does not exist.
"""
if not items:
log.warning("No items to upsert")
return
dimension = len(items[0]["vector"])
log.info(f"Upsert dimension: {dimension}")
try:
if not self.has_collection(collection_name):
log.info(
f"Index '{collection_name}' does not exist. Creating index for upsert."
)
self._create_index(
index_name=collection_name,
dimension=dimension,
data_type="float32",
distance_metric="cosine",
)
# Prepare vectors for upsert
vectors = []
for item in items:
# Ensure vector data is in the correct format for S3 Vector API
vector_data = item["vector"]
if isinstance(vector_data, list):
# Convert list to float32 values as required by S3 Vector API
vector_data = [float(x) for x in vector_data]
# Prepare metadata, ensuring the text field is preserved
metadata = item.get("metadata", {}).copy()
# Add the text field to metadata so it's available for retrieval
metadata["text"] = item["text"]
# Filter metadata to comply with S3 Vector API limit of 10 keys
metadata = self._filter_metadata(metadata, item["id"])
vectors.append(
{
"key": item["id"],
"data": {"float32": vector_data},
"metadata": metadata,
}
)
# Upsert vectors (using put_vectors for upsert semantics)
log.info(
f"Upserting {len(vectors)} vectors. First vector sample: key={vectors[0]['key']}, data_type={type(vectors[0]['data']['float32'])}, data_len={len(vectors[0]['data']['float32'])}"
)
self.client.put_vectors(
vectorBucketName=self.bucket_name,
indexName=collection_name,
vectors=vectors,
)
log.info(f"Upserted {len(vectors)} vectors into index '{collection_name}'.")
except Exception as e:
log.error(f"Error upserting vectors: {e}")
raise
def search(
self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
) -> Optional[SearchResult]:
"""
Search for similar vectors in a collection using multiple query vectors.
"""
if not self.has_collection(collection_name):
log.warning(f"Collection '{collection_name}' does not exist")
return None
if not vectors:
log.warning("No query vectors provided")
return None
try:
log.info(
f"Searching collection '{collection_name}' with {len(vectors)} query vectors, limit={limit}"
)
# Initialize result lists
all_ids = []
all_documents = []
all_metadatas = []
all_distances = []
# Process each query vector
for i, query_vector in enumerate(vectors):
log.debug(f"Processing query vector {i+1}/{len(vectors)}")
# Prepare the query vector in S3 Vector format
query_vector_dict = {"float32": [float(x) for x in query_vector]}
# Call S3 Vector query API
response = self.client.query_vectors(
vectorBucketName=self.bucket_name,
indexName=collection_name,
topK=limit,
queryVector=query_vector_dict,
returnMetadata=True,
returnDistance=True,
)
# Process results for this query
query_ids = []
query_documents = []
query_metadatas = []
query_distances = []
result_vectors = response.get("vectors", [])
for vector in result_vectors:
vector_id = vector.get("key")
vector_metadata = vector.get("metadata", {})
vector_distance = vector.get("distance", 0.0)
# Extract document text from metadata
document_text = ""
if isinstance(vector_metadata, dict):
# Get the text field first (highest priority)
document_text = vector_metadata.get("text")
if not document_text:
# Fallback to other possible text fields
document_text = (
vector_metadata.get("content")
or vector_metadata.get("document")
or vector_id
)
else:
document_text = vector_id
query_ids.append(vector_id)
query_documents.append(document_text)
query_metadatas.append(vector_metadata)
query_distances.append(vector_distance)
# Add this query's results to the overall results
all_ids.append(query_ids)
all_documents.append(query_documents)
all_metadatas.append(query_metadatas)
all_distances.append(query_distances)
log.info(f"Search completed. Found results for {len(all_ids)} queries")
# Return SearchResult format
return SearchResult(
ids=all_ids if all_ids else None,
documents=all_documents if all_documents else None,
metadatas=all_metadatas if all_metadatas else None,
distances=all_distances if all_distances else None,
)
except Exception as e:
log.error(f"Error searching collection '{collection_name}': {str(e)}")
# Handle specific AWS exceptions
if hasattr(e, "response") and "Error" in e.response:
error_code = e.response["Error"]["Code"]
if error_code == "NotFoundException":
log.warning(f"Collection '{collection_name}' not found")
return None
elif error_code == "ValidationException":
log.error(f"Invalid query vector dimensions or parameters")
return None
elif error_code == "AccessDeniedException":
log.error(
f"Access denied for collection '{collection_name}'. Check permissions."
)
return None
raise
def query(
self, collection_name: str, filter: Dict, limit: Optional[int] = None
) -> Optional[GetResult]:
"""
Query vectors from a collection using metadata filter.
"""
if not self.has_collection(collection_name):
log.warning(f"Collection '{collection_name}' does not exist")
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
if not filter:
log.warning("No filter provided, returning all vectors")
return self.get(collection_name)
try:
log.info(f"Querying collection '{collection_name}' with filter: {filter}")
# For S3 Vector, we need to use list_vectors and then filter results
# Since S3 Vector may not support complex server-side filtering,
# we'll retrieve all vectors and filter client-side
# Get all vectors first
all_vectors_result = self.get(collection_name)
if not all_vectors_result or not all_vectors_result.ids:
log.warning("No vectors found in collection")
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
# Extract the lists from the result
all_ids = all_vectors_result.ids[0] if all_vectors_result.ids else []
all_documents = (
all_vectors_result.documents[0] if all_vectors_result.documents else []
)
all_metadatas = (
all_vectors_result.metadatas[0] if all_vectors_result.metadatas else []
)
# Apply client-side filtering
filtered_ids = []
filtered_documents = []
filtered_metadatas = []
for i, metadata in enumerate(all_metadatas):
if self._matches_filter(metadata, filter):
if i < len(all_ids):
filtered_ids.append(all_ids[i])
if i < len(all_documents):
filtered_documents.append(all_documents[i])
filtered_metadatas.append(metadata)
# Apply limit if specified
if limit and len(filtered_ids) >= limit:
break
log.info(
f"Filter applied: {len(filtered_ids)} vectors match out of {len(all_ids)} total"
)
# Return GetResult format
if filtered_ids:
return GetResult(
ids=[filtered_ids],
documents=[filtered_documents],
metadatas=[filtered_metadatas],
)
else:
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
except Exception as e:
log.error(f"Error querying collection '{collection_name}': {str(e)}")
# Handle specific AWS exceptions
if hasattr(e, "response") and "Error" in e.response:
error_code = e.response["Error"]["Code"]
if error_code == "NotFoundException":
log.warning(f"Collection '{collection_name}' not found")
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
elif error_code == "AccessDeniedException":
log.error(
f"Access denied for collection '{collection_name}'. Check permissions."
)
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
raise
def get(self, collection_name: str) -> Optional[GetResult]:
"""
Retrieve all vectors from a collection.
"""
if not self.has_collection(collection_name):
log.warning(f"Collection '{collection_name}' does not exist")
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
try:
log.info(f"Retrieving all vectors from collection '{collection_name}'")
# Initialize result lists
all_ids = []
all_documents = []
all_metadatas = []
# Handle pagination
next_token = None
while True:
# Prepare request parameters
request_params = {
"vectorBucketName": self.bucket_name,
"indexName": collection_name,
"returnData": False, # Don't include vector data (not needed for get)
"returnMetadata": True, # Include metadata
"maxResults": 500, # Use reasonable page size
}
if next_token:
request_params["nextToken"] = next_token
# Call S3 Vector API
response = self.client.list_vectors(**request_params)
# Process vectors in this page
vectors = response.get("vectors", [])
for vector in vectors:
vector_id = vector.get("key")
vector_data = vector.get("data", {})
vector_metadata = vector.get("metadata", {})
# Extract the actual vector array
vector_array = vector_data.get("float32", [])
# For documents, we try to extract text from metadata or use the vector ID
document_text = ""
if isinstance(vector_metadata, dict):
# Get the text field first (highest priority)
document_text = vector_metadata.get("text")
if not document_text:
# Fallback to other possible text fields
document_text = (
vector_metadata.get("content")
or vector_metadata.get("document")
or vector_id
)
# Log the actual content for debugging
log.debug(
f"Document text preview (first 200 chars): {str(document_text)[:200]}"
)
else:
document_text = vector_id
all_ids.append(vector_id)
all_documents.append(document_text)
all_metadatas.append(vector_metadata)
# Check if there are more pages
next_token = response.get("nextToken")
if not next_token:
break
log.info(
f"Retrieved {len(all_ids)} vectors from collection '{collection_name}'"
)
# Return in GetResult format
# The Open WebUI GetResult expects lists of lists, so we wrap each list
if all_ids:
return GetResult(
ids=[all_ids], documents=[all_documents], metadatas=[all_metadatas]
)
else:
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
except Exception as e:
log.error(
f"Error retrieving vectors from collection '{collection_name}': {str(e)}"
)
# Handle specific AWS exceptions
if hasattr(e, "response") and "Error" in e.response:
error_code = e.response["Error"]["Code"]
if error_code == "NotFoundException":
log.warning(f"Collection '{collection_name}' not found")
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
elif error_code == "AccessDeniedException":
log.error(
f"Access denied for collection '{collection_name}'. Check permissions."
)
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
raise
def delete(
self,
collection_name: str,
ids: Optional[List[str]] = None,
filter: Optional[Dict] = None,
) -> None:
"""
Delete vectors by ID or filter from a collection.
"""
if not self.has_collection(collection_name):
log.warning(
f"Collection '{collection_name}' does not exist, nothing to delete"
)
return
# Check if this is a knowledge collection (not file-specific)
is_knowledge_collection = not collection_name.startswith("file-")
try:
if ids:
# Delete by specific vector IDs/keys
log.info(
f"Deleting {len(ids)} vectors by IDs from collection '{collection_name}'"
)
self.client.delete_vectors(
vectorBucketName=self.bucket_name,
indexName=collection_name,
keys=ids,
)
log.info(f"Deleted {len(ids)} vectors from index '{collection_name}'")
elif filter:
# Handle filter-based deletion
log.info(
f"Deleting vectors by filter from collection '{collection_name}': {filter}"
)
# If this is a knowledge collection and we have a file_id filter,
# also clean up the corresponding file-specific collection
if is_knowledge_collection and "file_id" in filter:
file_id = filter["file_id"]
file_collection_name = f"file-{file_id}"
if self.has_collection(file_collection_name):
log.info(
f"Found related file-specific collection '{file_collection_name}', deleting it to prevent duplicates"
)
self.delete_collection(file_collection_name)
# For the main collection, implement query-then-delete
# First, query to get IDs matching the filter
query_result = self.query(collection_name, filter)
if query_result and query_result.ids and query_result.ids[0]:
matching_ids = query_result.ids[0]
log.info(
f"Found {len(matching_ids)} vectors matching filter, deleting them"
)
# Delete the matching vectors by ID
self.client.delete_vectors(
vectorBucketName=self.bucket_name,
indexName=collection_name,
keys=matching_ids,
)
log.info(
f"Deleted {len(matching_ids)} vectors from index '{collection_name}' using filter"
)
else:
log.warning("No vectors found matching the filter criteria")
else:
log.warning("No IDs or filter provided for deletion")
except Exception as e:
log.error(
f"Error deleting vectors from collection '{collection_name}': {e}"
)
raise
def reset(self) -> None:
"""
Reset/clear all vector data. For S3 Vector, this deletes all indexes.
"""
try:
log.warning(
"Reset called - this will delete all vector indexes in the S3 bucket"
)
# List all indexes
response = self.client.list_indexes(vectorBucketName=self.bucket_name)
indexes = response.get("indexes", [])
if not indexes:
log.warning("No indexes found to delete")
return
# Delete all indexes
deleted_count = 0
for index in indexes:
index_name = index.get("indexName")
if index_name:
try:
self.client.delete_index(
vectorBucketName=self.bucket_name, indexName=index_name
)
deleted_count += 1
log.info(f"Deleted index: {index_name}")
except Exception as e:
log.error(f"Error deleting index '{index_name}': {e}")
log.info(f"Reset completed: deleted {deleted_count} indexes")
except Exception as e:
log.error(f"Error during reset: {e}")
raise
def _matches_filter(self, metadata: Dict[str, Any], filter: Dict[str, Any]) -> bool:
"""
Check if metadata matches the given filter conditions.
"""
if not isinstance(metadata, dict) or not isinstance(filter, dict):
return False
# Check each filter condition
for key, expected_value in filter.items():
# Handle special operators
if key.startswith("$"):
if key == "$and":
# All conditions must match
if not isinstance(expected_value, list):
continue
for condition in expected_value:
if not self._matches_filter(metadata, condition):
return False
elif key == "$or":
# At least one condition must match
if not isinstance(expected_value, list):
continue
any_match = False
for condition in expected_value:
if self._matches_filter(metadata, condition):
any_match = True
break
if not any_match:
return False
continue
# Get the actual value from metadata
actual_value = metadata.get(key)
# Handle different types of expected values
if isinstance(expected_value, dict):
# Handle comparison operators
for op, op_value in expected_value.items():
if op == "$eq":
if actual_value != op_value:
return False
elif op == "$ne":
if actual_value == op_value:
return False
elif op == "$in":
if (
not isinstance(op_value, list)
or actual_value not in op_value
):
return False
elif op == "$nin":
if isinstance(op_value, list) and actual_value in op_value:
return False
elif op == "$exists":
if bool(op_value) != (key in metadata):
return False
# Add more operators as needed
else:
# Simple equality check
if actual_value != expected_value:
return False
return True

View File

@ -30,6 +30,10 @@ class Vector:
from open_webui.retrieval.vector.dbs.pinecone import PineconeClient from open_webui.retrieval.vector.dbs.pinecone import PineconeClient
return PineconeClient() return PineconeClient()
case VectorType.S3VECTOR:
from open_webui.retrieval.vector.dbs.s3vector import S3VectorClient
return S3VectorClient()
case VectorType.OPENSEARCH: case VectorType.OPENSEARCH:
from open_webui.retrieval.vector.dbs.opensearch import OpenSearchClient from open_webui.retrieval.vector.dbs.opensearch import OpenSearchClient

View File

@ -0,0 +1,14 @@
from datetime import datetime
def stringify_metadata(
metadata: dict[str, any],
) -> dict[str, any]:
for key, value in metadata.items():
if (
isinstance(value, datetime)
or isinstance(value, list)
or isinstance(value, dict)
):
metadata[key] = str(value)
return metadata

View File

@ -36,7 +36,9 @@ def search_brave(
return [ return [
SearchResult( SearchResult(
link=result["url"], title=result.get("title"), snippet=result.get("snippet") link=result["url"],
title=result.get("title"),
snippet=result.get("description"),
) )
for result in results[:count] for result in results[:count]
] ]

View File

@ -2,8 +2,8 @@ import logging
from typing import Optional from typing import Optional
from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from duckduckgo_search import DDGS from ddgs import DDGS
from duckduckgo_search.exceptions import RatelimitException from ddgs.exceptions import RatelimitException
from open_webui.env import SRC_LOG_LEVELS from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)

View File

@ -15,6 +15,7 @@ import aiohttp
import aiofiles import aiofiles
import requests import requests
import mimetypes import mimetypes
from urllib.parse import quote
from fastapi import ( from fastapi import (
Depends, Depends,
@ -327,6 +328,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
log.exception(e) log.exception(e)
raise HTTPException(status_code=400, detail="Invalid JSON payload") raise HTTPException(status_code=400, detail="Invalid JSON payload")
r = None
if request.app.state.config.TTS_ENGINE == "openai": if request.app.state.config.TTS_ENGINE == "openai":
payload["model"] = request.app.state.config.TTS_MODEL payload["model"] = request.app.state.config.TTS_MODEL
@ -335,7 +337,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
async with aiohttp.ClientSession( async with aiohttp.ClientSession(
timeout=timeout, trust_env=True timeout=timeout, trust_env=True
) as session: ) as session:
async with session.post( r = await session.post(
url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech", url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
json=payload, json=payload,
headers={ headers={
@ -343,7 +345,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
"Authorization": f"Bearer {request.app.state.config.TTS_OPENAI_API_KEY}", "Authorization": f"Bearer {request.app.state.config.TTS_OPENAI_API_KEY}",
**( **(
{ {
"X-OpenWebUI-User-Name": user.name, "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id, "X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email, "X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role, "X-OpenWebUI-User-Role": user.role,
@ -353,7 +355,8 @@ async def speech(request: Request, user=Depends(get_verified_user)):
), ),
}, },
ssl=AIOHTTP_CLIENT_SESSION_SSL, ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as r: )
r.raise_for_status() r.raise_for_status()
async with aiofiles.open(file_path, "wb") as f: async with aiofiles.open(file_path, "wb") as f:
@ -368,18 +371,22 @@ async def speech(request: Request, user=Depends(get_verified_user)):
log.exception(e) log.exception(e)
detail = None detail = None
try: status_code = 500
if r.status != 200: detail = f"Open WebUI: Server Connection Error"
res = await r.json()
if r is not None:
status_code = r.status
try:
res = await r.json()
if "error" in res: if "error" in res:
detail = f"External: {res['error'].get('message', '')}" detail = f"External: {res['error']}"
except Exception: except Exception:
detail = f"External: {e}" detail = f"External: {e}"
raise HTTPException( raise HTTPException(
status_code=getattr(r, "status", 500) if r else 500, status_code=status_code,
detail=detail if detail else "Open WebUI: Server Connection Error", detail=detail,
) )
elif request.app.state.config.TTS_ENGINE == "elevenlabs": elif request.app.state.config.TTS_ENGINE == "elevenlabs":
@ -554,7 +561,11 @@ def transcription_handler(request, file_path, metadata):
file_path, file_path,
beam_size=5, beam_size=5,
vad_filter=request.app.state.config.WHISPER_VAD_FILTER, vad_filter=request.app.state.config.WHISPER_VAD_FILTER,
language=metadata.get("language") or WHISPER_LANGUAGE, language=(
metadata.get("language", None)
if WHISPER_LANGUAGE == ""
else WHISPER_LANGUAGE
),
) )
log.info( log.info(
"Detected language '%s' with probability %f" "Detected language '%s' with probability %f"
@ -919,14 +930,18 @@ def transcription(
): ):
log.info(f"file.content_type: {file.content_type}") log.info(f"file.content_type: {file.content_type}")
supported_content_types = request.app.state.config.STT_SUPPORTED_CONTENT_TYPES or [ stt_supported_content_types = getattr(
"audio/*", request.app.state.config, "STT_SUPPORTED_CONTENT_TYPES", []
"video/webm", )
]
if not any( if not any(
fnmatch(file.content_type, content_type) fnmatch(file.content_type, content_type)
for content_type in supported_content_types for content_type in (
stt_supported_content_types
if stt_supported_content_types
and any(t.strip() for t in stt_supported_content_types)
else ["audio/*", "video/webm"]
)
): ):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,

View File

@ -669,12 +669,13 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
@router.get("/signout") @router.get("/signout")
async def signout(request: Request, response: Response): async def signout(request: Request, response: Response):
response.delete_cookie("token") response.delete_cookie("token")
response.delete_cookie("oui-session")
if ENABLE_OAUTH_SIGNUP.value: if ENABLE_OAUTH_SIGNUP.value:
oauth_id_token = request.cookies.get("oauth_id_token") oauth_id_token = request.cookies.get("oauth_id_token")
if oauth_id_token: if oauth_id_token:
try: try:
async with ClientSession() as session: async with ClientSession(trust_env=True) as session:
async with session.get(OPENID_PROVIDER_URL.value) as resp: async with session.get(OPENID_PROVIDER_URL.value) as resp:
if resp.status == 200: if resp.status == 200:
openid_data = await resp.json() openid_data = await resp.json()
@ -686,7 +687,12 @@ async def signout(request: Request, response: Response):
status_code=200, status_code=200,
content={ content={
"status": True, "status": True,
"redirect_url": f"{logout_url}?id_token_hint={oauth_id_token}", "redirect_url": f"{logout_url}?id_token_hint={oauth_id_token}"
+ (
f"&post_logout_redirect_uri={WEBUI_AUTH_SIGNOUT_REDIRECT_URL}"
if WEBUI_AUTH_SIGNOUT_REDIRECT_URL
else ""
),
}, },
headers=response.headers, headers=response.headers,
) )

View File

@ -40,9 +40,13 @@ router = APIRouter()
@router.get("/", response_model=list[ChannelModel]) @router.get("/", response_model=list[ChannelModel])
async def get_channels(user=Depends(get_verified_user)): async def get_channels(user=Depends(get_verified_user)):
return Channels.get_channels_by_user_id(user.id)
@router.get("/list", response_model=list[ChannelModel])
async def get_all_channels(user=Depends(get_verified_user)):
if user.role == "admin": if user.role == "admin":
return Channels.get_channels() return Channels.get_channels()
else:
return Channels.get_channels_by_user_id(user.id) return Channels.get_channels_by_user_id(user.id)
@ -430,13 +434,6 @@ async def update_message_by_id(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
) )
if user.role != "admin" and not has_access(
user.id, type="read", access_control=channel.access_control
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
message = Messages.get_message_by_id(message_id) message = Messages.get_message_by_id(message_id)
if not message: if not message:
raise HTTPException( raise HTTPException(
@ -448,6 +445,15 @@ async def update_message_by_id(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
) )
if (
user.role != "admin"
and message.user_id != user.id
and not has_access(user.id, type="read", access_control=channel.access_control)
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
try: try:
message = Messages.update_message_by_id(message_id, form_data) message = Messages.update_message_by_id(message_id, form_data)
message = Messages.get_message_by_id(message_id) message = Messages.get_message_by_id(message_id)
@ -637,13 +643,6 @@ async def delete_message_by_id(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
) )
if user.role != "admin" and not has_access(
user.id, type="read", access_control=channel.access_control
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
message = Messages.get_message_by_id(message_id) message = Messages.get_message_by_id(message_id)
if not message: if not message:
raise HTTPException( raise HTTPException(
@ -655,6 +654,15 @@ async def delete_message_by_id(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
) )
if (
user.role != "admin"
and message.user_id != user.id
and not has_access(user.id, type="read", access_control=channel.access_control)
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
try: try:
Messages.delete_message_by_id(message_id) Messages.delete_message_by_id(message_id)
await sio.emit( await sio.emit(

View File

@ -39,13 +39,21 @@ router = APIRouter()
async def get_session_user_chat_list( async def get_session_user_chat_list(
user=Depends(get_verified_user), page: Optional[int] = None user=Depends(get_verified_user), page: Optional[int] = None
): ):
try:
if page is not None: if page is not None:
limit = 60 limit = 60
skip = (page - 1) * limit skip = (page - 1) * limit
return Chats.get_chat_title_id_list_by_user_id(user.id, skip=skip, limit=limit) return Chats.get_chat_title_id_list_by_user_id(
user.id, skip=skip, limit=limit
)
else: else:
return Chats.get_chat_title_id_list_by_user_id(user.id) return Chats.get_chat_title_id_list_by_user_id(user.id)
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
############################ ############################
@ -684,8 +692,10 @@ async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
@router.post("/{id}/share", response_model=Optional[ChatResponse]) @router.post("/{id}/share", response_model=Optional[ChatResponse])
async def share_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)): async def share_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)):
if not has_permission( if (user.role != "admin") and (
not has_permission(
user.id, "chat.share", request.app.state.config.USER_PERMISSIONS user.id, "chat.share", request.app.state.config.USER_PERMISSIONS
)
): ):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,

View File

@ -7,7 +7,11 @@ from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.config import get_config, save_config from open_webui.config import get_config, save_config
from open_webui.config import BannerModel from open_webui.config import BannerModel
from open_webui.utils.tools import get_tool_server_data, get_tool_servers_data from open_webui.utils.tools import (
get_tool_server_data,
get_tool_servers_data,
get_tool_server_url,
)
router = APIRouter() router = APIRouter()
@ -39,32 +43,39 @@ async def export_config(user=Depends(get_admin_user)):
############################ ############################
# Direct Connections Config # Connections Config
############################ ############################
class DirectConnectionsConfigForm(BaseModel): class ConnectionsConfigForm(BaseModel):
ENABLE_DIRECT_CONNECTIONS: bool ENABLE_DIRECT_CONNECTIONS: bool
ENABLE_BASE_MODELS_CACHE: bool
@router.get("/direct_connections", response_model=DirectConnectionsConfigForm) @router.get("/connections", response_model=ConnectionsConfigForm)
async def get_direct_connections_config(request: Request, user=Depends(get_admin_user)): async def get_connections_config(request: Request, user=Depends(get_admin_user)):
return { return {
"ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS, "ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS,
"ENABLE_BASE_MODELS_CACHE": request.app.state.config.ENABLE_BASE_MODELS_CACHE,
} }
@router.post("/direct_connections", response_model=DirectConnectionsConfigForm) @router.post("/connections", response_model=ConnectionsConfigForm)
async def set_direct_connections_config( async def set_connections_config(
request: Request, request: Request,
form_data: DirectConnectionsConfigForm, form_data: ConnectionsConfigForm,
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
request.app.state.config.ENABLE_DIRECT_CONNECTIONS = ( request.app.state.config.ENABLE_DIRECT_CONNECTIONS = (
form_data.ENABLE_DIRECT_CONNECTIONS form_data.ENABLE_DIRECT_CONNECTIONS
) )
request.app.state.config.ENABLE_BASE_MODELS_CACHE = (
form_data.ENABLE_BASE_MODELS_CACHE
)
return { return {
"ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS, "ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS,
"ENABLE_BASE_MODELS_CACHE": request.app.state.config.ENABLE_BASE_MODELS_CACHE,
} }
@ -128,7 +139,7 @@ async def verify_tool_servers_config(
elif form_data.auth_type == "session": elif form_data.auth_type == "session":
token = request.state.token.credentials token = request.state.token.credentials
url = f"{form_data.url}/{form_data.path}" url = get_tool_server_url(form_data.url, form_data.path)
return await get_tool_server_data(token, url) return await get_tool_server_data(token, url)
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(

View File

@ -129,6 +129,9 @@ async def create_feedback(
@router.get("/feedback/{id}", response_model=FeedbackModel) @router.get("/feedback/{id}", response_model=FeedbackModel)
async def get_feedback_by_id(id: str, user=Depends(get_verified_user)): async def get_feedback_by_id(id: str, user=Depends(get_verified_user)):
if user.role == "admin":
feedback = Feedbacks.get_feedback_by_id(id=id)
else:
feedback = Feedbacks.get_feedback_by_id_and_user_id(id=id, user_id=user.id) feedback = Feedbacks.get_feedback_by_id_and_user_id(id=id, user_id=user.id)
if not feedback: if not feedback:
@ -143,6 +146,9 @@ async def get_feedback_by_id(id: str, user=Depends(get_verified_user)):
async def update_feedback_by_id( async def update_feedback_by_id(
id: str, form_data: FeedbackForm, user=Depends(get_verified_user) id: str, form_data: FeedbackForm, user=Depends(get_verified_user)
): ):
if user.role == "admin":
feedback = Feedbacks.update_feedback_by_id(id=id, form_data=form_data)
else:
feedback = Feedbacks.update_feedback_by_id_and_user_id( feedback = Feedbacks.update_feedback_by_id_and_user_id(
id=id, user_id=user.id, form_data=form_data id=id, user_id=user.id, form_data=form_data
) )

View File

@ -21,6 +21,7 @@ from fastapi import (
from fastapi.responses import FileResponse, StreamingResponse from fastapi.responses import FileResponse, StreamingResponse
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
from open_webui.env import SRC_LOG_LEVELS from open_webui.env import SRC_LOG_LEVELS
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
from open_webui.models.users import Users from open_webui.models.users import Users
from open_webui.models.files import ( from open_webui.models.files import (
@ -155,17 +156,18 @@ def upload_file(
if process: if process:
try: try:
if file.content_type: if file.content_type:
stt_supported_content_types = ( stt_supported_content_types = getattr(
request.app.state.config.STT_SUPPORTED_CONTENT_TYPES request.app.state.config, "STT_SUPPORTED_CONTENT_TYPES", []
or [
"audio/*",
"video/webm",
]
) )
if any( if any(
fnmatch(file.content_type, content_type) fnmatch(file.content_type, content_type)
for content_type in stt_supported_content_types for content_type in (
stt_supported_content_types
if stt_supported_content_types
and any(t.strip() for t in stt_supported_content_types)
else ["audio/*", "video/webm"]
)
): ):
file_path = Storage.get_file(file_path) file_path = Storage.get_file(file_path)
result = transcribe(request, file_path, file_metadata) result = transcribe(request, file_path, file_metadata)
@ -285,6 +287,7 @@ async def delete_all_files(user=Depends(get_admin_user)):
if result: if result:
try: try:
Storage.delete_all_files() Storage.delete_all_files()
VECTOR_DB_CLIENT.reset()
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
log.error("Error deleting files") log.error("Error deleting files")
@ -602,12 +605,12 @@ async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
or user.role == "admin" or user.role == "admin"
or has_access_to_file(id, "write", user) or has_access_to_file(id, "write", user)
): ):
# We should add Chroma cleanup here
result = Files.delete_file_by_id(id) result = Files.delete_file_by_id(id)
if result: if result:
try: try:
Storage.delete_file(file.path) Storage.delete_file(file.path)
VECTOR_DB_CLIENT.delete(collection_name=f"file-{id}")
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
log.error("Error deleting files") log.error("Error deleting files")

View File

@ -49,7 +49,7 @@ async def get_folders(user=Depends(get_verified_user)):
**folder.model_dump(), **folder.model_dump(),
"items": { "items": {
"chats": [ "chats": [
{"title": chat.title, "id": chat.id} {"title": chat.title, "id": chat.id, "updated_at": chat.updated_at}
for chat in Chats.get_chats_by_folder_id_and_user_id( for chat in Chats.get_chats_by_folder_id_and_user_id(
folder.id, user.id folder.id, user.id
) )
@ -78,7 +78,7 @@ def create_folder(form_data: FolderForm, user=Depends(get_verified_user)):
) )
try: try:
folder = Folders.insert_new_folder(user.id, form_data.name) folder = Folders.insert_new_folder(user.id, form_data)
return folder return folder
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
@ -120,16 +120,14 @@ async def update_folder_name_by_id(
existing_folder = Folders.get_folder_by_parent_id_and_user_id_and_name( existing_folder = Folders.get_folder_by_parent_id_and_user_id_and_name(
folder.parent_id, user.id, form_data.name folder.parent_id, user.id, form_data.name
) )
if existing_folder: if existing_folder and existing_folder.id != id:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Folder already exists"), detail=ERROR_MESSAGES.DEFAULT("Folder already exists"),
) )
try: try:
folder = Folders.update_folder_name_by_id_and_user_id( folder = Folders.update_folder_by_id_and_user_id(id, user.id, form_data)
id, user.id, form_data.name
)
return folder return folder
except Exception as e: except Exception as e:

View File

@ -105,7 +105,7 @@ async def load_function_from_url(
) )
try: try:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get( async with session.get(
url, headers={"Content-Type": "application/json"} url, headers={"Content-Type": "application/json"}
) as resp: ) as resp:
@ -131,15 +131,29 @@ async def load_function_from_url(
############################ ############################
class SyncFunctionsForm(FunctionForm): class SyncFunctionsForm(BaseModel):
functions: list[FunctionModel] = [] functions: list[FunctionModel] = []
@router.post("/sync", response_model=Optional[FunctionModel]) @router.post("/sync", response_model=list[FunctionModel])
async def sync_functions( async def sync_functions(
request: Request, form_data: SyncFunctionsForm, user=Depends(get_admin_user) request: Request, form_data: SyncFunctionsForm, user=Depends(get_admin_user)
): ):
try:
for function in form_data.functions:
function.content = replace_imports(function.content)
function_module, function_type, frontmatter = load_function_module_by_id(
function.id,
content=function.content,
)
return Functions.sync_functions(user.id, form_data.functions) return Functions.sync_functions(user.id, form_data.functions)
except Exception as e:
log.exception(f"Failed to load a function: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
############################ ############################

View File

@ -9,6 +9,7 @@ from open_webui.models.groups import (
GroupForm, GroupForm,
GroupUpdateForm, GroupUpdateForm,
GroupResponse, GroupResponse,
UserIdsForm,
) )
from open_webui.config import CACHE_DIR from open_webui.config import CACHE_DIR
@ -107,6 +108,56 @@ async def update_group_by_id(
) )
############################
# AddUserToGroupByUserIdAndGroupId
############################
@router.post("/id/{id}/users/add", response_model=Optional[GroupResponse])
async def add_user_to_group(
id: str, form_data: UserIdsForm, user=Depends(get_admin_user)
):
try:
if form_data.user_ids:
form_data.user_ids = Users.get_valid_user_ids(form_data.user_ids)
group = Groups.add_users_to_group(id, form_data.user_ids)
if group:
return group
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error adding users to group"),
)
except Exception as e:
log.exception(f"Error adding users to group {id}: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
@router.post("/id/{id}/users/remove", response_model=Optional[GroupResponse])
async def remove_users_from_group(
id: str, form_data: UserIdsForm, user=Depends(get_admin_user)
):
try:
group = Groups.remove_users_from_group(id, form_data.user_ids)
if group:
return group
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error removing users from group"),
)
except Exception as e:
log.exception(f"Error removing users from group {id}: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
############################ ############################
# DeleteGroupById # DeleteGroupById
############################ ############################

View File

@ -8,6 +8,7 @@ import re
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
from urllib.parse import quote
import requests import requests
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
from open_webui.config import CACHE_DIR from open_webui.config import CACHE_DIR
@ -302,8 +303,16 @@ async def update_image_config(
): ):
set_image_model(request, form_data.MODEL) set_image_model(request, form_data.MODEL)
if form_data.IMAGE_SIZE == "auto" and form_data.MODEL != "gpt-image-1":
raise HTTPException(
status_code=400,
detail=ERROR_MESSAGES.INCORRECT_FORMAT(
" (auto is only allowed with gpt-image-1)."
),
)
pattern = r"^\d+x\d+$" pattern = r"^\d+x\d+$"
if re.match(pattern, form_data.IMAGE_SIZE): if form_data.IMAGE_SIZE == "auto" or re.match(pattern, form_data.IMAGE_SIZE):
request.app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE request.app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE
else: else:
raise HTTPException( raise HTTPException(
@ -471,7 +480,14 @@ async def image_generations(
form_data: GenerateImageForm, form_data: GenerateImageForm,
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
width, height = tuple(map(int, request.app.state.config.IMAGE_SIZE.split("x"))) # if IMAGE_SIZE = 'auto', default WidthxHeight to the 512x512 default
# This is only relevant when the user has set IMAGE_SIZE to 'auto' with an
# image model other than gpt-image-1, which is warned about on settings save
width, height = (
tuple(map(int, request.app.state.config.IMAGE_SIZE.split("x")))
if "x" in request.app.state.config.IMAGE_SIZE
else (512, 512)
)
r = None r = None
try: try:
@ -483,7 +499,7 @@ async def image_generations(
headers["Content-Type"] = "application/json" headers["Content-Type"] = "application/json"
if ENABLE_FORWARD_USER_INFO_HEADERS: if ENABLE_FORWARD_USER_INFO_HEADERS:
headers["X-OpenWebUI-User-Name"] = user.name headers["X-OpenWebUI-User-Name"] = quote(user.name, safe=" ")
headers["X-OpenWebUI-User-Id"] = user.id headers["X-OpenWebUI-User-Id"] = user.id
headers["X-OpenWebUI-User-Email"] = user.email headers["X-OpenWebUI-User-Email"] = user.email
headers["X-OpenWebUI-User-Role"] = user.role headers["X-OpenWebUI-User-Role"] = user.role

View File

@ -82,6 +82,10 @@ class QueryMemoryForm(BaseModel):
async def query_memory( async def query_memory(
request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user) request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user)
): ):
memories = Memories.get_memories_by_user_id(user.id)
if not memories:
raise HTTPException(status_code=404, detail="No memories found for user")
results = VECTOR_DB_CLIENT.search( results = VECTOR_DB_CLIENT.search(
collection_name=f"user-memory-{user.id}", collection_name=f"user-memory-{user.id}",
vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content, user=user)], vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content, user=user)],

View File

@ -7,6 +7,8 @@ from open_webui.models.models import (
ModelUserResponse, ModelUserResponse,
Models, Models,
) )
from pydantic import BaseModel
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi import APIRouter, Depends, HTTPException, Request, status
@ -78,6 +80,32 @@ async def create_new_model(
) )
############################
# ExportModels
############################
@router.get("/export", response_model=list[ModelModel])
async def export_models(user=Depends(get_admin_user)):
return Models.get_models()
############################
# SyncModels
############################
class SyncModelsForm(BaseModel):
models: list[ModelModel] = []
@router.post("/sync", response_model=list[ModelModel])
async def sync_models(
request: Request, form_data: SyncModelsForm, user=Depends(get_admin_user)
):
return Models.sync_models(user.id, form_data.models)
########################### ###########################
# GetModelById # GetModelById
########################### ###########################
@ -102,7 +130,7 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)):
############################ ############################
# ToggelModelById # ToggleModelById
############################ ############################

View File

@ -6,6 +6,9 @@ from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Request, status, BackgroundTasks from fastapi import APIRouter, Depends, HTTPException, Request, status, BackgroundTasks
from pydantic import BaseModel from pydantic import BaseModel
from open_webui.socket.main import sio
from open_webui.models.users import Users, UserResponse from open_webui.models.users import Users, UserResponse
from open_webui.models.notes import Notes, NoteModel, NoteForm, NoteUserResponse from open_webui.models.notes import Notes, NoteModel, NoteForm, NoteUserResponse
@ -51,7 +54,14 @@ async def get_notes(request: Request, user=Depends(get_verified_user)):
return notes return notes
@router.get("/list", response_model=list[NoteUserResponse]) class NoteTitleIdResponse(BaseModel):
id: str
title: str
updated_at: int
created_at: int
@router.get("/list", response_model=list[NoteTitleIdResponse])
async def get_note_list(request: Request, user=Depends(get_verified_user)): async def get_note_list(request: Request, user=Depends(get_verified_user)):
if user.role != "admin" and not has_permission( if user.role != "admin" and not has_permission(
@ -63,13 +73,8 @@ async def get_note_list(request: Request, user=Depends(get_verified_user)):
) )
notes = [ notes = [
NoteUserResponse( NoteTitleIdResponse(**note.model_dump())
**{ for note in Notes.get_notes_by_user_id(user.id, "write")
**note.model_dump(),
"user": UserResponse(**Users.get_user_by_id(note.user_id).model_dump()),
}
)
for note in Notes.get_notes_by_user_id(user.id, "read")
] ]
return notes return notes
@ -168,6 +173,12 @@ async def update_note_by_id(
try: try:
note = Notes.update_note_by_id(id, form_data) note = Notes.update_note_by_id(id, form_data)
await sio.emit(
"note-events",
note.model_dump(),
to=f"note:{note.id}",
)
return note return note
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)

View File

@ -16,6 +16,7 @@ from urllib.parse import urlparse
import aiohttp import aiohttp
from aiocache import cached from aiocache import cached
import requests import requests
from urllib.parse import quote
from open_webui.models.chats import Chats from open_webui.models.chats import Chats
from open_webui.models.users import UserModel from open_webui.models.users import UserModel
@ -58,6 +59,7 @@ from open_webui.config import (
from open_webui.env import ( from open_webui.env import (
ENV, ENV,
SRC_LOG_LEVELS, SRC_LOG_LEVELS,
MODELS_CACHE_TTL,
AIOHTTP_CLIENT_SESSION_SSL, AIOHTTP_CLIENT_SESSION_SSL,
AIOHTTP_CLIENT_TIMEOUT, AIOHTTP_CLIENT_TIMEOUT,
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST, AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST,
@ -87,7 +89,7 @@ async def send_get_request(url, key=None, user: UserModel = None):
**({"Authorization": f"Bearer {key}"} if key else {}), **({"Authorization": f"Bearer {key}"} if key else {}),
**( **(
{ {
"X-OpenWebUI-User-Name": user.name, "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id, "X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email, "X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role, "X-OpenWebUI-User-Role": user.role,
@ -122,6 +124,7 @@ async def send_post_request(
key: Optional[str] = None, key: Optional[str] = None,
content_type: Optional[str] = None, content_type: Optional[str] = None,
user: UserModel = None, user: UserModel = None,
metadata: Optional[dict] = None,
): ):
r = None r = None
@ -138,10 +141,15 @@ async def send_post_request(
**({"Authorization": f"Bearer {key}"} if key else {}), **({"Authorization": f"Bearer {key}"} if key else {}),
**( **(
{ {
"X-OpenWebUI-User-Name": user.name, "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id, "X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email, "X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role, "X-OpenWebUI-User-Role": user.role,
**(
{"X-OpenWebUI-Chat-Id": metadata.get("chat_id")}
if metadata and metadata.get("chat_id")
else {}
),
} }
if ENABLE_FORWARD_USER_INFO_HEADERS and user if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {} else {}
@ -182,7 +190,6 @@ async def send_post_request(
) )
else: else:
res = await r.json() res = await r.json()
await cleanup_response(r, session)
return res return res
except HTTPException as e: except HTTPException as e:
@ -194,6 +201,9 @@ async def send_post_request(
status_code=r.status if r else 500, status_code=r.status if r else 500,
detail=detail if e else "Open WebUI: Server Connection Error", detail=detail if e else "Open WebUI: Server Connection Error",
) )
finally:
if not stream:
await cleanup_response(r, session)
def get_api_key(idx, url, configs): def get_api_key(idx, url, configs):
@ -242,7 +252,7 @@ async def verify_connection(
**({"Authorization": f"Bearer {key}"} if key else {}), **({"Authorization": f"Bearer {key}"} if key else {}),
**( **(
{ {
"X-OpenWebUI-User-Name": user.name, "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id, "X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email, "X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role, "X-OpenWebUI-User-Role": user.role,
@ -329,7 +339,7 @@ def merge_ollama_models_lists(model_lists):
return list(merged_models.values()) return list(merged_models.values())
@cached(ttl=1) @cached(ttl=MODELS_CACHE_TTL)
async def get_all_models(request: Request, user: UserModel = None): async def get_all_models(request: Request, user: UserModel = None):
log.info("get_all_models()") log.info("get_all_models()")
if request.app.state.config.ENABLE_OLLAMA_API: if request.app.state.config.ENABLE_OLLAMA_API:
@ -462,7 +472,7 @@ async def get_ollama_tags(
**({"Authorization": f"Bearer {key}"} if key else {}), **({"Authorization": f"Bearer {key}"} if key else {}),
**( **(
{ {
"X-OpenWebUI-User-Name": user.name, "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id, "X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email, "X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role, "X-OpenWebUI-User-Role": user.role,
@ -634,7 +644,10 @@ async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
class ModelNameForm(BaseModel): class ModelNameForm(BaseModel):
name: str model: Optional[str] = None
model_config = ConfigDict(
extra="allow",
)
@router.post("/api/unload") @router.post("/api/unload")
@ -643,10 +656,12 @@ async def unload_model(
form_data: ModelNameForm, form_data: ModelNameForm,
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
model_name = form_data.name form_data = form_data.model_dump(exclude_none=True)
model_name = form_data.get("model", form_data.get("name"))
if not model_name: if not model_name:
raise HTTPException( raise HTTPException(
status_code=400, detail="Missing 'name' of model to unload." status_code=400, detail="Missing name of the model to unload."
) )
# Refresh/load models if needed, get mapping from name to URLs # Refresh/load models if needed, get mapping from name to URLs
@ -709,11 +724,14 @@ async def pull_model(
url_idx: int = 0, url_idx: int = 0,
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
form_data = form_data.model_dump(exclude_none=True)
form_data["model"] = form_data.get("model", form_data.get("name"))
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
# Admin should be able to pull models from any source # Admin should be able to pull models from any source
payload = {**form_data.model_dump(exclude_none=True), "insecure": True} payload = {**form_data, "insecure": True}
return await send_post_request( return await send_post_request(
url=f"{url}/api/pull", url=f"{url}/api/pull",
@ -724,7 +742,7 @@ async def pull_model(
class PushModelForm(BaseModel): class PushModelForm(BaseModel):
name: str model: str
insecure: Optional[bool] = None insecure: Optional[bool] = None
stream: Optional[bool] = None stream: Optional[bool] = None
@ -741,12 +759,12 @@ async def push_model(
await get_all_models(request, user=user) await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS models = request.app.state.OLLAMA_MODELS
if form_data.name in models: if form_data.model in models:
url_idx = models[form_data.name]["urls"][0] url_idx = models[form_data.model]["urls"][0]
else: else:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
) )
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
@ -824,7 +842,7 @@ async def copy_model(
**({"Authorization": f"Bearer {key}"} if key else {}), **({"Authorization": f"Bearer {key}"} if key else {}),
**( **(
{ {
"X-OpenWebUI-User-Name": user.name, "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id, "X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email, "X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role, "X-OpenWebUI-User-Role": user.role,
@ -865,16 +883,21 @@ async def delete_model(
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
form_data = form_data.model_dump(exclude_none=True)
form_data["model"] = form_data.get("model", form_data.get("name"))
model = form_data.get("model")
if url_idx is None: if url_idx is None:
await get_all_models(request, user=user) await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS models = request.app.state.OLLAMA_MODELS
if form_data.name in models: if model in models:
url_idx = models[form_data.name]["urls"][0] url_idx = models[model]["urls"][0]
else: else:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model),
) )
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
@ -884,13 +907,13 @@ async def delete_model(
r = requests.request( r = requests.request(
method="DELETE", method="DELETE",
url=f"{url}/api/delete", url=f"{url}/api/delete",
data=form_data.model_dump_json(exclude_none=True).encode(), data=json.dumps(form_data).encode(),
headers={ headers={
"Content-Type": "application/json", "Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}), **({"Authorization": f"Bearer {key}"} if key else {}),
**( **(
{ {
"X-OpenWebUI-User-Name": user.name, "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id, "X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email, "X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role, "X-OpenWebUI-User-Role": user.role,
@ -926,16 +949,21 @@ async def delete_model(
async def show_model_info( async def show_model_info(
request: Request, form_data: ModelNameForm, user=Depends(get_verified_user) request: Request, form_data: ModelNameForm, user=Depends(get_verified_user)
): ):
form_data = form_data.model_dump(exclude_none=True)
form_data["model"] = form_data.get("model", form_data.get("name"))
await get_all_models(request, user=user) await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS models = request.app.state.OLLAMA_MODELS
if form_data.name not in models: model = form_data.get("model")
if model not in models:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model),
) )
url_idx = random.choice(models[form_data.name]["urls"]) url_idx = random.choice(models[model]["urls"])
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS) key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
@ -949,7 +977,7 @@ async def show_model_info(
**({"Authorization": f"Bearer {key}"} if key else {}), **({"Authorization": f"Bearer {key}"} if key else {}),
**( **(
{ {
"X-OpenWebUI-User-Name": user.name, "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id, "X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email, "X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role, "X-OpenWebUI-User-Role": user.role,
@ -958,7 +986,7 @@ async def show_model_info(
else {} else {}
), ),
}, },
data=form_data.model_dump_json(exclude_none=True).encode(), data=json.dumps(form_data).encode(),
) )
r.raise_for_status() r.raise_for_status()
@ -1036,7 +1064,7 @@ async def embed(
**({"Authorization": f"Bearer {key}"} if key else {}), **({"Authorization": f"Bearer {key}"} if key else {}),
**( **(
{ {
"X-OpenWebUI-User-Name": user.name, "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id, "X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email, "X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role, "X-OpenWebUI-User-Role": user.role,
@ -1123,7 +1151,7 @@ async def embeddings(
**({"Authorization": f"Bearer {key}"} if key else {}), **({"Authorization": f"Bearer {key}"} if key else {}),
**( **(
{ {
"X-OpenWebUI-User-Name": user.name, "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id, "X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email, "X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role, "X-OpenWebUI-User-Role": user.role,
@ -1343,6 +1371,7 @@ async def generate_chat_completion(
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
content_type="application/x-ndjson", content_type="application/x-ndjson",
user=user, user=user,
metadata=metadata,
) )
@ -1381,6 +1410,8 @@ async def generate_openai_completion(
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
metadata = form_data.pop("metadata", None)
try: try:
form_data = OpenAICompletionForm(**form_data) form_data = OpenAICompletionForm(**form_data)
except Exception as e: except Exception as e:
@ -1446,6 +1477,7 @@ async def generate_openai_completion(
stream=payload.get("stream", False), stream=payload.get("stream", False),
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
user=user, user=user,
metadata=metadata,
) )
@ -1527,6 +1559,7 @@ async def generate_openai_chat_completion(
stream=payload.get("stream", False), stream=payload.get("stream", False),
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
user=user, user=user,
metadata=metadata,
) )

View File

@ -8,7 +8,7 @@ from typing import Literal, Optional, overload
import aiohttp import aiohttp
from aiocache import cached from aiocache import cached
import requests import requests
from urllib.parse import quote
from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
@ -21,6 +21,7 @@ from open_webui.config import (
CACHE_DIR, CACHE_DIR,
) )
from open_webui.env import ( from open_webui.env import (
MODELS_CACHE_TTL,
AIOHTTP_CLIENT_SESSION_SSL, AIOHTTP_CLIENT_SESSION_SSL,
AIOHTTP_CLIENT_TIMEOUT, AIOHTTP_CLIENT_TIMEOUT,
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST, AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST,
@ -66,7 +67,7 @@ async def send_get_request(url, key=None, user: UserModel = None):
**({"Authorization": f"Bearer {key}"} if key else {}), **({"Authorization": f"Bearer {key}"} if key else {}),
**( **(
{ {
"X-OpenWebUI-User-Name": user.name, "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id, "X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email, "X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role, "X-OpenWebUI-User-Role": user.role,
@ -225,7 +226,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
), ),
**( **(
{ {
"X-OpenWebUI-User-Name": user.name, "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id, "X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email, "X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role, "X-OpenWebUI-User-Role": user.role,
@ -361,7 +362,9 @@ async def get_all_models_responses(request: Request, user: UserModel) -> list:
response if isinstance(response, list) else response.get("data", []) response if isinstance(response, list) else response.get("data", [])
): ):
if prefix_id: if prefix_id:
model["id"] = f"{prefix_id}.{model['id']}" model["id"] = (
f"{prefix_id}.{model.get('id', model.get('name', ''))}"
)
if tags: if tags:
model["tags"] = tags model["tags"] = tags
@ -386,7 +389,7 @@ async def get_filtered_models(models, user):
return filtered_models return filtered_models
@cached(ttl=1) @cached(ttl=MODELS_CACHE_TTL)
async def get_all_models(request: Request, user: UserModel) -> dict[str, list]: async def get_all_models(request: Request, user: UserModel) -> dict[str, list]:
log.info("get_all_models()") log.info("get_all_models()")
@ -478,7 +481,7 @@ async def get_models(
"Content-Type": "application/json", "Content-Type": "application/json",
**( **(
{ {
"X-OpenWebUI-User-Name": user.name, "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id, "X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email, "X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role, "X-OpenWebUI-User-Role": user.role,
@ -573,7 +576,7 @@ async def verify_connection(
"Content-Type": "application/json", "Content-Type": "application/json",
**( **(
{ {
"X-OpenWebUI-User-Name": user.name, "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id, "X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email, "X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role, "X-OpenWebUI-User-Role": user.role,
@ -633,13 +636,7 @@ async def verify_connection(
raise HTTPException(status_code=500, detail=error_detail) raise HTTPException(status_code=500, detail=error_detail)
def convert_to_azure_payload( def get_azure_allowed_params(api_version: str) -> set[str]:
url,
payload: dict,
):
model = payload.get("model", "")
# Filter allowed parameters based on Azure OpenAI API
allowed_params = { allowed_params = {
"messages", "messages",
"temperature", "temperature",
@ -669,6 +666,23 @@ def convert_to_azure_payload(
"max_completion_tokens", "max_completion_tokens",
} }
try:
if api_version >= "2024-09-01-preview":
allowed_params.add("stream_options")
except ValueError:
log.debug(
f"Invalid API version {api_version} for Azure OpenAI. Defaulting to allowed parameters."
)
return allowed_params
def convert_to_azure_payload(url, payload: dict, api_version: str):
model = payload.get("model", "")
# Filter allowed parameters based on Azure OpenAI API
allowed_params = get_azure_allowed_params(api_version)
# Special handling for o-series models # Special handling for o-series models
if model.startswith("o") and model.endswith("-mini"): if model.startswith("o") and model.endswith("-mini"):
# Convert max_tokens to max_completion_tokens for o-series models # Convert max_tokens to max_completion_tokens for o-series models
@ -806,10 +820,15 @@ async def generate_chat_completion(
), ),
**( **(
{ {
"X-OpenWebUI-User-Name": user.name, "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id, "X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email, "X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role, "X-OpenWebUI-User-Role": user.role,
**(
{"X-OpenWebUI-Chat-Id": metadata.get("chat_id")}
if metadata and metadata.get("chat_id")
else {}
),
} }
if ENABLE_FORWARD_USER_INFO_HEADERS if ENABLE_FORWARD_USER_INFO_HEADERS
else {} else {}
@ -817,8 +836,8 @@ async def generate_chat_completion(
} }
if api_config.get("azure", False): if api_config.get("azure", False):
request_url, payload = convert_to_azure_payload(url, payload) api_version = api_config.get("api_version", "2023-03-15-preview")
api_version = api_config.get("api_version", "") or "2023-03-15-preview" request_url, payload = convert_to_azure_payload(url, payload, api_version)
headers["api-key"] = key headers["api-key"] = key
headers["api-version"] = api_version headers["api-version"] = api_version
request_url = f"{request_url}/chat/completions?api-version={api_version}" request_url = f"{request_url}/chat/completions?api-version={api_version}"
@ -881,10 +900,8 @@ async def generate_chat_completion(
detail=detail if detail else "Open WebUI: Server Connection Error", detail=detail if detail else "Open WebUI: Server Connection Error",
) )
finally: finally:
if not streaming and session: if not streaming:
if r: await cleanup_response(r, session)
r.close()
await session.close()
async def embeddings(request: Request, form_data: dict, user): async def embeddings(request: Request, form_data: dict, user):
@ -924,7 +941,7 @@ async def embeddings(request: Request, form_data: dict, user):
"Content-Type": "application/json", "Content-Type": "application/json",
**( **(
{ {
"X-OpenWebUI-User-Name": user.name, "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id, "X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email, "X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role, "X-OpenWebUI-User-Role": user.role,
@ -963,10 +980,8 @@ async def embeddings(request: Request, form_data: dict, user):
detail=detail if detail else "Open WebUI: Server Connection Error", detail=detail if detail else "Open WebUI: Server Connection Error",
) )
finally: finally:
if not streaming and session: if not streaming:
if r: await cleanup_response(r, session)
r.close()
await session.close()
@router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) @router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
@ -996,7 +1011,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
"Content-Type": "application/json", "Content-Type": "application/json",
**( **(
{ {
"X-OpenWebUI-User-Name": user.name, "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
"X-OpenWebUI-User-Id": user.id, "X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email, "X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role, "X-OpenWebUI-User-Role": user.role,
@ -1007,16 +1022,15 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
} }
if api_config.get("azure", False): if api_config.get("azure", False):
api_version = api_config.get("api_version", "2023-03-15-preview")
headers["api-key"] = key headers["api-key"] = key
headers["api-version"] = ( headers["api-version"] = api_version
api_config.get("api_version", "") or "2023-03-15-preview"
)
payload = json.loads(body) payload = json.loads(body)
url, payload = convert_to_azure_payload(url, payload) url, payload = convert_to_azure_payload(url, payload, api_version)
body = json.dumps(payload).encode() body = json.dumps(payload).encode()
request_url = f"{url}/{path}?api-version={api_config.get('api_version', '2023-03-15-preview')}" request_url = f"{url}/{path}?api-version={api_version}"
else: else:
headers["Authorization"] = f"Bearer {key}" headers["Authorization"] = f"Bearer {key}"
request_url = f"{url}/{path}" request_url = f"{url}/{path}"
@ -1063,7 +1077,5 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
detail=detail if detail else "Open WebUI: Server Connection Error", detail=detail if detail else "Open WebUI: Server Connection Error",
) )
finally: finally:
if not streaming and session: if not streaming:
if r: await cleanup_response(r, session)
r.close()
await session.close()

View File

@ -29,6 +29,7 @@ import tiktoken
from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter
from langchain_text_splitters import MarkdownHeaderTextSplitter
from langchain_core.documents import Document from langchain_core.documents import Document
from open_webui.models.files import FileModel, Files from open_webui.models.files import FileModel, Files
@ -69,6 +70,7 @@ from open_webui.retrieval.web.external import search_external
from open_webui.retrieval.utils import ( from open_webui.retrieval.utils import (
get_embedding_function, get_embedding_function,
get_reranking_function,
get_model_path, get_model_path,
query_collection, query_collection,
query_collection_with_hybrid_search, query_collection_with_hybrid_search,
@ -813,7 +815,11 @@ async def update_rag_config(
f"Updating reranking model: {request.app.state.config.RAG_RERANKING_MODEL} to {form_data.RAG_RERANKING_MODEL}" f"Updating reranking model: {request.app.state.config.RAG_RERANKING_MODEL} to {form_data.RAG_RERANKING_MODEL}"
) )
try: try:
request.app.state.config.RAG_RERANKING_MODEL = form_data.RAG_RERANKING_MODEL request.app.state.config.RAG_RERANKING_MODEL = (
form_data.RAG_RERANKING_MODEL
if form_data.RAG_RERANKING_MODEL is not None
else request.app.state.config.RAG_RERANKING_MODEL
)
try: try:
request.app.state.rf = get_rf( request.app.state.rf = get_rf(
@ -823,6 +829,12 @@ async def update_rag_config(
request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY, request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
True, True,
) )
request.app.state.RERANKING_FUNCTION = get_reranking_function(
request.app.state.config.RAG_RERANKING_ENGINE,
request.app.state.config.RAG_RERANKING_MODEL,
request.app.state.rf,
)
except Exception as e: except Exception as e:
log.error(f"Error loading reranking model: {e}") log.error(f"Error loading reranking model: {e}")
request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
@ -1146,6 +1158,7 @@ def save_docs_to_vector_db(
chunk_overlap=request.app.state.config.CHUNK_OVERLAP, chunk_overlap=request.app.state.config.CHUNK_OVERLAP,
add_start_index=True, add_start_index=True,
) )
docs = text_splitter.split_documents(docs)
elif request.app.state.config.TEXT_SPLITTER == "token": elif request.app.state.config.TEXT_SPLITTER == "token":
log.info( log.info(
f"Using token text splitter: {request.app.state.config.TIKTOKEN_ENCODING_NAME}" f"Using token text splitter: {request.app.state.config.TIKTOKEN_ENCODING_NAME}"
@ -1158,11 +1171,56 @@ def save_docs_to_vector_db(
chunk_overlap=request.app.state.config.CHUNK_OVERLAP, chunk_overlap=request.app.state.config.CHUNK_OVERLAP,
add_start_index=True, add_start_index=True,
) )
docs = text_splitter.split_documents(docs)
elif request.app.state.config.TEXT_SPLITTER == "markdown_header":
log.info("Using markdown header text splitter")
# Define headers to split on - covering most common markdown header levels
headers_to_split_on = [
("#", "Header 1"),
("##", "Header 2"),
("###", "Header 3"),
("####", "Header 4"),
("#####", "Header 5"),
("######", "Header 6"),
]
markdown_splitter = MarkdownHeaderTextSplitter(
headers_to_split_on=headers_to_split_on,
strip_headers=False, # Keep headers in content for context
)
md_split_docs = []
for doc in docs:
md_header_splits = markdown_splitter.split_text(doc.page_content)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=request.app.state.config.CHUNK_SIZE,
chunk_overlap=request.app.state.config.CHUNK_OVERLAP,
add_start_index=True,
)
md_header_splits = text_splitter.split_documents(md_header_splits)
# Convert back to Document objects, preserving original metadata
for split_chunk in md_header_splits:
headings_list = []
# Extract header values in order based on headers_to_split_on
for _, header_meta_key_name in headers_to_split_on:
if header_meta_key_name in split_chunk.metadata:
headings_list.append(
split_chunk.metadata[header_meta_key_name]
)
md_split_docs.append(
Document(
page_content=split_chunk.page_content,
metadata={**doc.metadata, "headings": headings_list},
)
)
docs = md_split_docs
else: else:
raise ValueError(ERROR_MESSAGES.DEFAULT("Invalid text splitter")) raise ValueError(ERROR_MESSAGES.DEFAULT("Invalid text splitter"))
docs = text_splitter.split_documents(docs)
if len(docs) == 0: if len(docs) == 0:
raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT) raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
@ -1171,27 +1229,14 @@ def save_docs_to_vector_db(
{ {
**doc.metadata, **doc.metadata,
**(metadata if metadata else {}), **(metadata if metadata else {}),
"embedding_config": json.dumps( "embedding_config": {
{
"engine": request.app.state.config.RAG_EMBEDDING_ENGINE, "engine": request.app.state.config.RAG_EMBEDDING_ENGINE,
"model": request.app.state.config.RAG_EMBEDDING_MODEL, "model": request.app.state.config.RAG_EMBEDDING_MODEL,
} },
),
} }
for doc in docs for doc in docs
] ]
# ChromaDB does not like datetime formats
# for meta-data so convert them to string.
for metadata in metadatas:
for key, value in metadata.items():
if (
isinstance(value, datetime)
or isinstance(value, list)
or isinstance(value, dict)
):
metadata[key] = str(value)
try: try:
if VECTOR_DB_CLIENT.has_collection(collection_name=collection_name): if VECTOR_DB_CLIENT.has_collection(collection_name=collection_name):
log.info(f"collection {collection_name} already exists") log.info(f"collection {collection_name} already exists")
@ -1747,6 +1792,16 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
) )
else: else:
raise Exception("No TAVILY_API_KEY found in environment variables") raise Exception("No TAVILY_API_KEY found in environment variables")
elif engine == "exa":
if request.app.state.config.EXA_API_KEY:
return search_exa(
request.app.state.config.EXA_API_KEY,
query,
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
)
else:
raise Exception("No EXA_API_KEY found in environment variables")
elif engine == "searchapi": elif engine == "searchapi":
if request.app.state.config.SEARCHAPI_API_KEY: if request.app.state.config.SEARCHAPI_API_KEY:
return search_searchapi( return search_searchapi(
@ -1784,6 +1839,13 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
request.app.state.config.WEB_SEARCH_RESULT_COUNT, request.app.state.config.WEB_SEARCH_RESULT_COUNT,
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
) )
elif engine == "exa":
return search_exa(
request.app.state.config.EXA_API_KEY,
query,
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
)
elif engine == "perplexity": elif engine == "perplexity":
return search_perplexity( return search_perplexity(
request.app.state.config.PERPLEXITY_API_KEY, request.app.state.config.PERPLEXITY_API_KEY,
@ -1978,7 +2040,15 @@ def query_doc_handler(
query, prefix=prefix, user=user query, prefix=prefix, user=user
), ),
k=form_data.k if form_data.k else request.app.state.config.TOP_K, k=form_data.k if form_data.k else request.app.state.config.TOP_K,
reranking_function=request.app.state.rf, reranking_function=(
(
lambda sentences: request.app.state.RERANKING_FUNCTION(
sentences, user=user
)
)
if request.app.state.RERANKING_FUNCTION
else None
),
k_reranker=form_data.k_reranker k_reranker=form_data.k_reranker
or request.app.state.config.TOP_K_RERANKER, or request.app.state.config.TOP_K_RERANKER,
r=( r=(
@ -2035,7 +2105,15 @@ def query_collection_handler(
query, prefix=prefix, user=user query, prefix=prefix, user=user
), ),
k=form_data.k if form_data.k else request.app.state.config.TOP_K, k=form_data.k if form_data.k else request.app.state.config.TOP_K,
reranking_function=request.app.state.rf, reranking_function=(
(
lambda sentences: request.app.state.RERANKING_FUNCTION(
sentences, user=user
)
)
if request.app.state.RERANKING_FUNCTION
else None
),
k_reranker=form_data.k_reranker k_reranker=form_data.k_reranker
or request.app.state.config.TOP_K_RERANKER, or request.app.state.config.TOP_K_RERANKER,
r=( r=(

View File

@ -695,11 +695,11 @@ async def generate_emoji(
"max_completion_tokens": 4, "max_completion_tokens": 4,
} }
), ),
"chat_id": form_data.get("chat_id", None),
"metadata": { "metadata": {
**(request.state.metadata if hasattr(request.state, "metadata") else {}), **(request.state.metadata if hasattr(request.state, "metadata") else {}),
"task": str(TASKS.EMOJI_GENERATION), "task": str(TASKS.EMOJI_GENERATION),
"task_body": form_data, "task_body": form_data,
"chat_id": form_data.get("chat_id", None),
}, },
} }

View File

@ -153,7 +153,7 @@ async def load_tool_from_url(
) )
try: try:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get( async with session.get(
url, headers={"Content-Type": "application/json"} url, headers={"Content-Type": "application/json"}
) as resp: ) as resp:

View File

@ -7,6 +7,7 @@ from open_webui.models.chats import Chats
from open_webui.models.users import ( from open_webui.models.users import (
UserModel, UserModel,
UserListResponse, UserListResponse,
UserInfoListResponse,
UserRoleUpdateForm, UserRoleUpdateForm,
Users, Users,
UserSettings, UserSettings,
@ -83,7 +84,7 @@ async def get_users(
return Users.get_users(filter=filter, skip=skip, limit=limit) return Users.get_users(filter=filter, skip=skip, limit=limit)
@router.get("/all", response_model=UserListResponse) @router.get("/all", response_model=UserInfoListResponse)
async def get_all_users( async def get_all_users(
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
@ -133,7 +134,9 @@ class SharingPermissions(BaseModel):
class ChatPermissions(BaseModel): class ChatPermissions(BaseModel):
controls: bool = True controls: bool = True
valves: bool = True
system_prompt: bool = True system_prompt: bool = True
params: bool = True
file_upload: bool = True file_upload: bool = True
delete: bool = True delete: bool = True
edit: bool = True edit: bool = True

View File

@ -1,13 +1,18 @@
import asyncio import asyncio
import random
import socketio import socketio
import logging import logging
import sys import sys
import time import time
from typing import Dict, Set
from redis import asyncio as aioredis from redis import asyncio as aioredis
import pycrdt as Y
from open_webui.models.users import Users, UserNameResponse from open_webui.models.users import Users, UserNameResponse
from open_webui.models.channels import Channels from open_webui.models.channels import Channels
from open_webui.models.chats import Chats from open_webui.models.chats import Chats
from open_webui.models.notes import Notes, NoteUpdateForm
from open_webui.utils.redis import ( from open_webui.utils.redis import (
get_sentinels_from_env, get_sentinels_from_env,
get_sentinel_url_from_env, get_sentinel_url_from_env,
@ -20,9 +25,14 @@ from open_webui.env import (
WEBSOCKET_REDIS_LOCK_TIMEOUT, WEBSOCKET_REDIS_LOCK_TIMEOUT,
WEBSOCKET_SENTINEL_PORT, WEBSOCKET_SENTINEL_PORT,
WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_HOSTS,
REDIS_KEY_PREFIX,
) )
from open_webui.utils.auth import decode_token from open_webui.utils.auth import decode_token
from open_webui.socket.utils import RedisDict, RedisLock from open_webui.socket.utils import RedisDict, RedisLock, YdocManager
from open_webui.tasks import create_task, stop_item_tasks
from open_webui.utils.redis import get_redis_connection
from open_webui.utils.access_control import has_access, get_users_with_access
from open_webui.env import ( from open_webui.env import (
GLOBAL_LOG_LEVEL, GLOBAL_LOG_LEVEL,
@ -35,6 +45,8 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["SOCKET"]) log.setLevel(SRC_LOG_LEVELS["SOCKET"])
REDIS = None
if WEBSOCKET_MANAGER == "redis": if WEBSOCKET_MANAGER == "redis":
if WEBSOCKET_SENTINEL_HOSTS: if WEBSOCKET_SENTINEL_HOSTS:
mgr = socketio.AsyncRedisManager( mgr = socketio.AsyncRedisManager(
@ -69,21 +81,29 @@ TIMEOUT_DURATION = 3
if WEBSOCKET_MANAGER == "redis": if WEBSOCKET_MANAGER == "redis":
log.debug("Using Redis to manage websockets.") log.debug("Using Redis to manage websockets.")
REDIS = get_redis_connection(
redis_url=WEBSOCKET_REDIS_URL,
redis_sentinels=get_sentinels_from_env(
WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT
),
async_mode=True,
)
redis_sentinels = get_sentinels_from_env( redis_sentinels = get_sentinels_from_env(
WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT
) )
SESSION_POOL = RedisDict( SESSION_POOL = RedisDict(
"open-webui:session_pool", f"{REDIS_KEY_PREFIX}:session_pool",
redis_url=WEBSOCKET_REDIS_URL, redis_url=WEBSOCKET_REDIS_URL,
redis_sentinels=redis_sentinels, redis_sentinels=redis_sentinels,
) )
USER_POOL = RedisDict( USER_POOL = RedisDict(
"open-webui:user_pool", f"{REDIS_KEY_PREFIX}:user_pool",
redis_url=WEBSOCKET_REDIS_URL, redis_url=WEBSOCKET_REDIS_URL,
redis_sentinels=redis_sentinels, redis_sentinels=redis_sentinels,
) )
USAGE_POOL = RedisDict( USAGE_POOL = RedisDict(
"open-webui:usage_pool", f"{REDIS_KEY_PREFIX}:usage_pool",
redis_url=WEBSOCKET_REDIS_URL, redis_url=WEBSOCKET_REDIS_URL,
redis_sentinels=redis_sentinels, redis_sentinels=redis_sentinels,
) )
@ -101,14 +121,37 @@ else:
SESSION_POOL = {} SESSION_POOL = {}
USER_POOL = {} USER_POOL = {}
USAGE_POOL = {} USAGE_POOL = {}
aquire_func = release_func = renew_func = lambda: True aquire_func = release_func = renew_func = lambda: True
YDOC_MANAGER = YdocManager(
redis=REDIS,
redis_key_prefix=f"{REDIS_KEY_PREFIX}:ydoc:documents",
)
async def periodic_usage_pool_cleanup(): async def periodic_usage_pool_cleanup():
if not aquire_func(): max_retries = 2
log.debug("Usage pool cleanup lock already exists. Not running it.") retry_delay = random.uniform(
WEBSOCKET_REDIS_LOCK_TIMEOUT / 2, WEBSOCKET_REDIS_LOCK_TIMEOUT
)
for attempt in range(max_retries + 1):
if aquire_func():
break
else:
if attempt < max_retries:
log.debug(
f"Cleanup lock already exists. Retry {attempt + 1} after {retry_delay}s..."
)
await asyncio.sleep(retry_delay)
else:
log.warning(
"Failed to acquire cleanup lock after retries. Skipping cleanup."
)
return return
log.debug("Running periodic_usage_pool_cleanup")
log.debug("Running periodic_cleanup")
try: try:
while True: while True:
if not renew_func(): if not renew_func():
@ -169,16 +212,20 @@ def get_user_id_from_session_pool(sid):
return None return None
def get_user_ids_from_room(room): def get_session_ids_from_room(room):
"""Get all session IDs from a specific room."""
active_session_ids = sio.manager.get_participants( active_session_ids = sio.manager.get_participants(
namespace="/", namespace="/",
room=room, room=room,
) )
return [session_id[0] for session_id in active_session_ids]
def get_user_ids_from_room(room):
active_session_ids = get_session_ids_from_room(room)
active_user_ids = list( active_user_ids = list(
set( set([SESSION_POOL.get(session_id)["id"] for session_id in active_session_ids])
[SESSION_POOL.get(session_id[0])["id"] for session_id in active_session_ids]
)
) )
return active_user_ids return active_user_ids
@ -270,6 +317,37 @@ async def join_channel(sid, data):
await sio.enter_room(sid, f"channel:{channel.id}") await sio.enter_room(sid, f"channel:{channel.id}")
@sio.on("join-note")
async def join_note(sid, data):
auth = data["auth"] if "auth" in data else None
if not auth or "token" not in auth:
return
token_data = decode_token(auth["token"])
if token_data is None or "id" not in token_data:
return
user = Users.get_user_by_id(token_data["id"])
if not user:
return
note = Notes.get_note_by_id(data["note_id"])
if not note:
log.error(f"Note {data['note_id']} not found for user {user.id}")
return
if (
user.role != "admin"
and user.id != note.user_id
and not has_access(user.id, type="read", access_control=note.access_control)
):
log.error(f"User {user.id} does not have access to note {data['note_id']}")
return
log.debug(f"Joining note {note.id} for user {user.id}")
await sio.enter_room(sid, f"note:{note.id}")
@sio.on("channel-events") @sio.on("channel-events")
async def channel_events(sid, data): async def channel_events(sid, data):
room = f"channel:{data['channel_id']}" room = f"channel:{data['channel_id']}"
@ -298,6 +376,242 @@ async def channel_events(sid, data):
) )
@sio.on("ydoc:document:join")
async def ydoc_document_join(sid, data):
"""Handle user joining a document"""
user = SESSION_POOL.get(sid)
try:
document_id = data["document_id"]
if document_id.startswith("note:"):
note_id = document_id.split(":")[1]
note = Notes.get_note_by_id(note_id)
if not note:
log.error(f"Note {note_id} not found")
return
if (
user.get("role") != "admin"
and user.get("id") != note.user_id
and not has_access(
user.get("id"), type="read", access_control=note.access_control
)
):
log.error(
f"User {user.get('id')} does not have access to note {note_id}"
)
return
user_id = data.get("user_id", sid)
user_name = data.get("user_name", "Anonymous")
user_color = data.get("user_color", "#000000")
log.info(f"User {user_id} joining document {document_id}")
await YDOC_MANAGER.add_user(document_id=document_id, user_id=sid)
# Join Socket.IO room
await sio.enter_room(sid, f"doc_{document_id}")
active_session_ids = get_session_ids_from_room(f"doc_{document_id}")
# Get the Yjs document state
ydoc = Y.Doc()
updates = await YDOC_MANAGER.get_updates(document_id)
for update in updates:
ydoc.apply_update(bytes(update))
# Encode the entire document state as an update
state_update = ydoc.get_update()
await sio.emit(
"ydoc:document:state",
{
"document_id": document_id,
"state": list(state_update), # Convert bytes to list for JSON
"sessions": active_session_ids,
},
room=sid,
)
# Notify other users about the new user
await sio.emit(
"ydoc:user:joined",
{
"document_id": document_id,
"user_id": user_id,
"user_name": user_name,
"user_color": user_color,
},
room=f"doc_{document_id}",
skip_sid=sid,
)
log.info(f"User {user_id} successfully joined document {document_id}")
except Exception as e:
log.error(f"Error in yjs_document_join: {e}")
await sio.emit("error", {"message": "Failed to join document"}, room=sid)
async def document_save_handler(document_id, data, user):
if document_id.startswith("note:"):
note_id = document_id.split(":")[1]
note = Notes.get_note_by_id(note_id)
if not note:
log.error(f"Note {note_id} not found")
return
if (
user.get("role") != "admin"
and user.get("id") != note.user_id
and not has_access(
user.get("id"), type="read", access_control=note.access_control
)
):
log.error(f"User {user.get('id')} does not have access to note {note_id}")
return
Notes.update_note_by_id(note_id, NoteUpdateForm(data=data))
@sio.on("ydoc:document:state")
async def yjs_document_state(sid, data):
"""Send the current state of the Yjs document to the user"""
try:
document_id = data["document_id"]
room = f"doc_{document_id}"
active_session_ids = get_session_ids_from_room(room)
if sid not in active_session_ids:
log.warning(f"Session {sid} not in room {room}. Cannot send state.")
return
if not await YDOC_MANAGER.document_exists(document_id):
log.warning(f"Document {document_id} not found")
return
# Get the Yjs document state
ydoc = Y.Doc()
updates = await YDOC_MANAGER.get_updates(document_id)
for update in updates:
ydoc.apply_update(bytes(update))
# Encode the entire document state as an update
state_update = ydoc.get_update()
await sio.emit(
"ydoc:document:state",
{
"document_id": document_id,
"state": list(state_update), # Convert bytes to list for JSON
"sessions": active_session_ids,
},
room=sid,
)
except Exception as e:
log.error(f"Error in yjs_document_state: {e}")
@sio.on("ydoc:document:update")
async def yjs_document_update(sid, data):
"""Handle Yjs document updates"""
try:
document_id = data["document_id"]
try:
await stop_item_tasks(REDIS, document_id)
except:
pass
user_id = data.get("user_id", sid)
update = data["update"] # List of bytes from frontend
await YDOC_MANAGER.append_to_updates(
document_id=document_id,
update=update, # Convert list of bytes to bytes
)
# Broadcast update to all other users in the document
await sio.emit(
"ydoc:document:update",
{
"document_id": document_id,
"user_id": user_id,
"update": update,
"socket_id": sid, # Add socket_id to match frontend filtering
},
room=f"doc_{document_id}",
skip_sid=sid,
)
async def debounced_save():
await asyncio.sleep(0.5)
await document_save_handler(
document_id, data.get("data", {}), SESSION_POOL.get(sid)
)
if data.get("data"):
await create_task(REDIS, debounced_save(), document_id)
except Exception as e:
log.error(f"Error in yjs_document_update: {e}")
@sio.on("ydoc:document:leave")
async def yjs_document_leave(sid, data):
"""Handle user leaving a document"""
try:
document_id = data["document_id"]
user_id = data.get("user_id", sid)
log.info(f"User {user_id} leaving document {document_id}")
# Remove user from the document
await YDOC_MANAGER.remove_user(document_id=document_id, user_id=sid)
# Leave Socket.IO room
await sio.leave_room(sid, f"doc_{document_id}")
# Notify other users
await sio.emit(
"ydoc:user:left",
{"document_id": document_id, "user_id": user_id},
room=f"doc_{document_id}",
)
if (
await YDOC_MANAGER.document_exists(document_id)
and len(await YDOC_MANAGER.get_users(document_id)) == 0
):
log.info(f"Cleaning up document {document_id} as no users are left")
await YDOC_MANAGER.clear_document(document_id)
except Exception as e:
log.error(f"Error in yjs_document_leave: {e}")
@sio.on("ydoc:awareness:update")
async def yjs_awareness_update(sid, data):
"""Handle awareness updates (cursors, selections, etc.)"""
try:
document_id = data["document_id"]
user_id = data.get("user_id", sid)
update = data["update"]
# Broadcast awareness update to all other users in the document
await sio.emit(
"ydoc:awareness:update",
{"document_id": document_id, "user_id": user_id, "update": update},
room=f"doc_{document_id}",
skip_sid=sid,
)
except Exception as e:
log.error(f"Error in yjs_awareness_update: {e}")
@sio.event @sio.event
async def disconnect(sid): async def disconnect(sid):
if sid in SESSION_POOL: if sid in SESSION_POOL:
@ -309,6 +623,8 @@ async def disconnect(sid):
if len(USER_POOL[user_id]) == 0: if len(USER_POOL[user_id]) == 0:
del USER_POOL[user_id] del USER_POOL[user_id]
await YDOC_MANAGER.remove_user_from_all_documents(sid)
else: else:
pass pass
# print(f"Unknown session ID {sid} disconnected") # print(f"Unknown session ID {sid} disconnected")

View File

@ -1,6 +1,9 @@
import json import json
import uuid import uuid
from open_webui.utils.redis import get_redis_connection from open_webui.utils.redis import get_redis_connection
from open_webui.env import REDIS_KEY_PREFIX
from typing import Optional, List, Tuple
import pycrdt as Y
class RedisLock: class RedisLock:
@ -89,3 +92,109 @@ class RedisDict:
if key not in self: if key not in self:
self[key] = default self[key] = default
return self[key] return self[key]
class YdocManager:
def __init__(
self,
redis=None,
redis_key_prefix: str = f"{REDIS_KEY_PREFIX}:ydoc:documents",
):
self._updates = {}
self._users = {}
self._redis = redis
self._redis_key_prefix = redis_key_prefix
async def append_to_updates(self, document_id: str, update: bytes):
document_id = document_id.replace(":", "_")
if self._redis:
redis_key = f"{self._redis_key_prefix}:{document_id}:updates"
await self._redis.rpush(redis_key, json.dumps(list(update)))
else:
if document_id not in self._updates:
self._updates[document_id] = []
self._updates[document_id].append(update)
async def get_updates(self, document_id: str) -> List[bytes]:
document_id = document_id.replace(":", "_")
if self._redis:
redis_key = f"{self._redis_key_prefix}:{document_id}:updates"
updates = await self._redis.lrange(redis_key, 0, -1)
return [bytes(json.loads(update)) for update in updates]
else:
return self._updates.get(document_id, [])
async def document_exists(self, document_id: str) -> bool:
document_id = document_id.replace(":", "_")
if self._redis:
redis_key = f"{self._redis_key_prefix}:{document_id}:updates"
return await self._redis.exists(redis_key) > 0
else:
return document_id in self._updates
async def get_users(self, document_id: str) -> List[str]:
document_id = document_id.replace(":", "_")
if self._redis:
redis_key = f"{self._redis_key_prefix}:{document_id}:users"
users = await self._redis.smembers(redis_key)
return list(users)
else:
return self._users.get(document_id, [])
async def add_user(self, document_id: str, user_id: str):
document_id = document_id.replace(":", "_")
if self._redis:
redis_key = f"{self._redis_key_prefix}:{document_id}:users"
await self._redis.sadd(redis_key, user_id)
else:
if document_id not in self._users:
self._users[document_id] = set()
self._users[document_id].add(user_id)
async def remove_user(self, document_id: str, user_id: str):
document_id = document_id.replace(":", "_")
if self._redis:
redis_key = f"{self._redis_key_prefix}:{document_id}:users"
await self._redis.srem(redis_key, user_id)
else:
if document_id in self._users and user_id in self._users[document_id]:
self._users[document_id].remove(user_id)
async def remove_user_from_all_documents(self, user_id: str):
if self._redis:
keys = await self._redis.keys(f"{self._redis_key_prefix}:*")
for key in keys:
if key.endswith(":users"):
await self._redis.srem(key, user_id)
document_id = key.split(":")[-2]
if len(await self.get_users(document_id)) == 0:
await self.clear_document(document_id)
else:
for document_id in list(self._users.keys()):
if user_id in self._users[document_id]:
self._users[document_id].remove(user_id)
if not self._users[document_id]:
del self._users[document_id]
await self.clear_document(document_id)
async def clear_document(self, document_id: str):
document_id = document_id.replace(":", "_")
if self._redis:
redis_key = f"{self._redis_key_prefix}:{document_id}:updates"
await self._redis.delete(redis_key)
redis_users_key = f"{self._redis_key_prefix}:{document_id}:users"
await self._redis.delete(redis_users_key)
else:
if document_id in self._updates:
del self._updates[document_id]
if document_id in self._users:
del self._users[document_id]

View File

@ -3,23 +3,25 @@ import asyncio
from typing import Dict from typing import Dict
from uuid import uuid4 from uuid import uuid4
import json import json
import logging
from redis.asyncio import Redis from redis.asyncio import Redis
from fastapi import Request from fastapi import Request
from typing import Dict, List, Optional from typing import Dict, List, Optional
from open_webui.env import SRC_LOG_LEVELS, REDIS_KEY_PREFIX
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
# A dictionary to keep track of active tasks # A dictionary to keep track of active tasks
tasks: Dict[str, asyncio.Task] = {} tasks: Dict[str, asyncio.Task] = {}
chat_tasks = {} item_tasks = {}
REDIS_TASKS_KEY = "open-webui:tasks" REDIS_TASKS_KEY = f"{REDIS_KEY_PREFIX}:tasks"
REDIS_CHAT_TASKS_KEY = "open-webui:tasks:chat" REDIS_ITEM_TASKS_KEY = f"{REDIS_KEY_PREFIX}:tasks:item"
REDIS_PUBSUB_CHANNEL = "open-webui:tasks:commands" REDIS_PUBSUB_CHANNEL = f"{REDIS_KEY_PREFIX}:tasks:commands"
def is_redis(request: Request) -> bool:
# Called everywhere a request is available to check Redis
return hasattr(request.app.state, "redis") and (request.app.state.redis is not None)
async def redis_task_command_listener(app): async def redis_task_command_listener(app):
@ -38,7 +40,7 @@ async def redis_task_command_listener(app):
if local_task: if local_task:
local_task.cancel() local_task.cancel()
except Exception as e: except Exception as e:
print(f"Error handling distributed task command: {e}") log.exception(f"Error handling distributed task command: {e}")
### ------------------------------ ### ------------------------------
@ -46,21 +48,21 @@ async def redis_task_command_listener(app):
### ------------------------------ ### ------------------------------
async def redis_save_task(redis: Redis, task_id: str, chat_id: Optional[str]): async def redis_save_task(redis: Redis, task_id: str, item_id: Optional[str]):
pipe = redis.pipeline() pipe = redis.pipeline()
pipe.hset(REDIS_TASKS_KEY, task_id, chat_id or "") pipe.hset(REDIS_TASKS_KEY, task_id, item_id or "")
if chat_id: if item_id:
pipe.sadd(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}", task_id) pipe.sadd(f"{REDIS_ITEM_TASKS_KEY}:{item_id}", task_id)
await pipe.execute() await pipe.execute()
async def redis_cleanup_task(redis: Redis, task_id: str, chat_id: Optional[str]): async def redis_cleanup_task(redis: Redis, task_id: str, item_id: Optional[str]):
pipe = redis.pipeline() pipe = redis.pipeline()
pipe.hdel(REDIS_TASKS_KEY, task_id) pipe.hdel(REDIS_TASKS_KEY, task_id)
if chat_id: if item_id:
pipe.srem(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}", task_id) pipe.srem(f"{REDIS_ITEM_TASKS_KEY}:{item_id}", task_id)
if (await pipe.scard(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}").execute())[-1] == 0: if (await pipe.scard(f"{REDIS_ITEM_TASKS_KEY}:{item_id}").execute())[-1] == 0:
pipe.delete(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}") # Remove if empty set pipe.delete(f"{REDIS_ITEM_TASKS_KEY}:{item_id}") # Remove if empty set
await pipe.execute() await pipe.execute()
@ -68,31 +70,31 @@ async def redis_list_tasks(redis: Redis) -> List[str]:
return list(await redis.hkeys(REDIS_TASKS_KEY)) return list(await redis.hkeys(REDIS_TASKS_KEY))
async def redis_list_chat_tasks(redis: Redis, chat_id: str) -> List[str]: async def redis_list_item_tasks(redis: Redis, item_id: str) -> List[str]:
return list(await redis.smembers(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}")) return list(await redis.smembers(f"{REDIS_ITEM_TASKS_KEY}:{item_id}"))
async def redis_send_command(redis: Redis, command: dict): async def redis_send_command(redis: Redis, command: dict):
await redis.publish(REDIS_PUBSUB_CHANNEL, json.dumps(command)) await redis.publish(REDIS_PUBSUB_CHANNEL, json.dumps(command))
async def cleanup_task(request, task_id: str, id=None): async def cleanup_task(redis, task_id: str, id=None):
""" """
Remove a completed or canceled task from the global `tasks` dictionary. Remove a completed or canceled task from the global `tasks` dictionary.
""" """
if is_redis(request): if redis:
await redis_cleanup_task(request.app.state.redis, task_id, id) await redis_cleanup_task(redis, task_id, id)
tasks.pop(task_id, None) # Remove the task if it exists tasks.pop(task_id, None) # Remove the task if it exists
# If an ID is provided, remove the task from the chat_tasks dictionary # If an ID is provided, remove the task from the item_tasks dictionary
if id and task_id in chat_tasks.get(id, []): if id and task_id in item_tasks.get(id, []):
chat_tasks[id].remove(task_id) item_tasks[id].remove(task_id)
if not chat_tasks[id]: # If no tasks left for this ID, remove the entry if not item_tasks[id]: # If no tasks left for this ID, remove the entry
chat_tasks.pop(id, None) item_tasks.pop(id, None)
async def create_task(request, coroutine, id=None): async def create_task(redis, coroutine, id=None):
""" """
Create a new asyncio task and add it to the global task dictionary. Create a new asyncio task and add it to the global task dictionary.
""" """
@ -101,48 +103,48 @@ async def create_task(request, coroutine, id=None):
# Add a done callback for cleanup # Add a done callback for cleanup
task.add_done_callback( task.add_done_callback(
lambda t: asyncio.create_task(cleanup_task(request, task_id, id)) lambda t: asyncio.create_task(cleanup_task(redis, task_id, id))
) )
tasks[task_id] = task tasks[task_id] = task
# If an ID is provided, associate the task with that ID # If an ID is provided, associate the task with that ID
if chat_tasks.get(id): if item_tasks.get(id):
chat_tasks[id].append(task_id) item_tasks[id].append(task_id)
else: else:
chat_tasks[id] = [task_id] item_tasks[id] = [task_id]
if is_redis(request): if redis:
await redis_save_task(request.app.state.redis, task_id, id) await redis_save_task(redis, task_id, id)
return task_id, task return task_id, task
async def list_tasks(request): async def list_tasks(redis):
""" """
List all currently active task IDs. List all currently active task IDs.
""" """
if is_redis(request): if redis:
return await redis_list_tasks(request.app.state.redis) return await redis_list_tasks(redis)
return list(tasks.keys()) return list(tasks.keys())
async def list_task_ids_by_chat_id(request, id): async def list_task_ids_by_item_id(redis, id):
""" """
List all tasks associated with a specific ID. List all tasks associated with a specific ID.
""" """
if is_redis(request): if redis:
return await redis_list_chat_tasks(request.app.state.redis, id) return await redis_list_item_tasks(redis, id)
return chat_tasks.get(id, []) return item_tasks.get(id, [])
async def stop_task(request, task_id: str): async def stop_task(redis, task_id: str):
""" """
Cancel a running task and remove it from the global task list. Cancel a running task and remove it from the global task list.
""" """
if is_redis(request): if redis:
# PUBSUB: All instances check if they have this task, and stop if so. # PUBSUB: All instances check if they have this task, and stop if so.
await redis_send_command( await redis_send_command(
request.app.state.redis, redis,
{ {
"action": "stop", "action": "stop",
"task_id": task_id, "task_id": task_id,
@ -151,7 +153,7 @@ async def stop_task(request, task_id: str):
# Optionally check if task_id still in Redis a few moments later for feedback? # Optionally check if task_id still in Redis a few moments later for feedback?
return {"status": True, "message": f"Stop signal sent for {task_id}"} return {"status": True, "message": f"Stop signal sent for {task_id}"}
task = tasks.get(task_id) task = tasks.pop(task_id)
if not task: if not task:
raise ValueError(f"Task with ID {task_id} not found.") raise ValueError(f"Task with ID {task_id} not found.")
@ -160,7 +162,22 @@ async def stop_task(request, task_id: str):
await task # Wait for the task to handle the cancellation await task # Wait for the task to handle the cancellation
except asyncio.CancelledError: except asyncio.CancelledError:
# Task successfully canceled # Task successfully canceled
tasks.pop(task_id, None) # Remove it from the dictionary
return {"status": True, "message": f"Task {task_id} successfully stopped."} return {"status": True, "message": f"Task {task_id} successfully stopped."}
return {"status": False, "message": f"Failed to stop task {task_id}."} return {"status": False, "message": f"Failed to stop task {task_id}."}
async def stop_item_tasks(redis: Redis, item_id: str):
"""
Stop all tasks associated with a specific item ID.
"""
task_ids = await list_task_ids_by_item_id(redis, item_id)
if not task_ids:
return {"status": True, "message": f"No tasks found for item {item_id}."}
for task_id in task_ids:
result = await stop_task(redis, task_id)
if not result["status"]:
return result # Return the first failure
return {"status": True, "message": f"All tasks for item {item_id} stopped."}

View File

@ -0,0 +1,793 @@
import pytest
from unittest.mock import Mock, patch, AsyncMock
import redis
from open_webui.utils.redis import (
SentinelRedisProxy,
parse_redis_service_url,
get_redis_connection,
get_sentinels_from_env,
MAX_RETRY_COUNT,
)
import inspect
class TestSentinelRedisProxy:
"""Test Redis Sentinel failover functionality"""
def test_parse_redis_service_url_valid(self):
"""Test parsing valid Redis service URL"""
url = "redis://user:pass@mymaster:6379/0"
result = parse_redis_service_url(url)
assert result["username"] == "user"
assert result["password"] == "pass"
assert result["service"] == "mymaster"
assert result["port"] == 6379
assert result["db"] == 0
def test_parse_redis_service_url_defaults(self):
"""Test parsing Redis service URL with defaults"""
url = "redis://mymaster"
result = parse_redis_service_url(url)
assert result["username"] is None
assert result["password"] is None
assert result["service"] == "mymaster"
assert result["port"] == 6379
assert result["db"] == 0
def test_parse_redis_service_url_invalid_scheme(self):
"""Test parsing invalid URL scheme"""
with pytest.raises(ValueError, match="Invalid Redis URL scheme"):
parse_redis_service_url("http://invalid")
def test_get_sentinels_from_env(self):
"""Test parsing sentinel hosts from environment"""
hosts = "sentinel1,sentinel2,sentinel3"
port = "26379"
result = get_sentinels_from_env(hosts, port)
expected = [("sentinel1", 26379), ("sentinel2", 26379), ("sentinel3", 26379)]
assert result == expected
def test_get_sentinels_from_env_empty(self):
"""Test empty sentinel hosts"""
result = get_sentinels_from_env(None, "26379")
assert result == []
@patch("redis.sentinel.Sentinel")
def test_sentinel_redis_proxy_sync_success(self, mock_sentinel_class):
"""Test successful sync operation with SentinelRedisProxy"""
mock_sentinel = Mock()
mock_master = Mock()
mock_master.get.return_value = "test_value"
mock_sentinel.master_for.return_value = mock_master
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=False)
# Test attribute access
get_method = proxy.__getattr__("get")
result = get_method("test_key")
assert result == "test_value"
mock_sentinel.master_for.assert_called_with("mymaster")
mock_master.get.assert_called_with("test_key")
@patch("redis.sentinel.Sentinel")
@pytest.mark.asyncio
async def test_sentinel_redis_proxy_async_success(self, mock_sentinel_class):
"""Test successful async operation with SentinelRedisProxy"""
mock_sentinel = Mock()
mock_master = Mock()
mock_master.get = AsyncMock(return_value="test_value")
mock_sentinel.master_for.return_value = mock_master
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=True)
# Test async attribute access
get_method = proxy.__getattr__("get")
result = await get_method("test_key")
assert result == "test_value"
mock_sentinel.master_for.assert_called_with("mymaster")
mock_master.get.assert_called_with("test_key")
@patch("redis.sentinel.Sentinel")
def test_sentinel_redis_proxy_failover_retry(self, mock_sentinel_class):
"""Test retry mechanism during failover"""
mock_sentinel = Mock()
mock_master = Mock()
# First call fails, second succeeds
mock_master.get.side_effect = [
redis.exceptions.ConnectionError("Master down"),
"test_value",
]
mock_sentinel.master_for.return_value = mock_master
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=False)
get_method = proxy.__getattr__("get")
result = get_method("test_key")
assert result == "test_value"
assert mock_master.get.call_count == 2
@patch("redis.sentinel.Sentinel")
def test_sentinel_redis_proxy_max_retries_exceeded(self, mock_sentinel_class):
"""Test failure after max retries exceeded"""
mock_sentinel = Mock()
mock_master = Mock()
# All calls fail
mock_master.get.side_effect = redis.exceptions.ConnectionError("Master down")
mock_sentinel.master_for.return_value = mock_master
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=False)
get_method = proxy.__getattr__("get")
with pytest.raises(redis.exceptions.ConnectionError):
get_method("test_key")
assert mock_master.get.call_count == MAX_RETRY_COUNT
@patch("redis.sentinel.Sentinel")
def test_sentinel_redis_proxy_readonly_error_retry(self, mock_sentinel_class):
"""Test retry on ReadOnlyError"""
mock_sentinel = Mock()
mock_master = Mock()
# First call gets ReadOnlyError (old master), second succeeds (new master)
mock_master.get.side_effect = [
redis.exceptions.ReadOnlyError("Read only"),
"test_value",
]
mock_sentinel.master_for.return_value = mock_master
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=False)
get_method = proxy.__getattr__("get")
result = get_method("test_key")
assert result == "test_value"
assert mock_master.get.call_count == 2
@patch("redis.sentinel.Sentinel")
def test_sentinel_redis_proxy_factory_methods(self, mock_sentinel_class):
"""Test factory methods are passed through directly"""
mock_sentinel = Mock()
mock_master = Mock()
mock_pipeline = Mock()
mock_master.pipeline.return_value = mock_pipeline
mock_sentinel.master_for.return_value = mock_master
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=False)
# Factory methods should be passed through without wrapping
pipeline_method = proxy.__getattr__("pipeline")
result = pipeline_method()
assert result == mock_pipeline
mock_master.pipeline.assert_called_once()
@patch("redis.sentinel.Sentinel")
@patch("redis.from_url")
def test_get_redis_connection_with_sentinel(
self, mock_from_url, mock_sentinel_class
):
"""Test getting Redis connection with Sentinel"""
mock_sentinel = Mock()
mock_sentinel_class.return_value = mock_sentinel
sentinels = [("sentinel1", 26379), ("sentinel2", 26379)]
redis_url = "redis://user:pass@mymaster:6379/0"
result = get_redis_connection(
redis_url=redis_url, redis_sentinels=sentinels, async_mode=False
)
assert isinstance(result, SentinelRedisProxy)
mock_sentinel_class.assert_called_once()
mock_from_url.assert_not_called()
@patch("redis.Redis.from_url")
def test_get_redis_connection_without_sentinel(self, mock_from_url):
"""Test getting Redis connection without Sentinel"""
mock_redis = Mock()
mock_from_url.return_value = mock_redis
redis_url = "redis://localhost:6379/0"
result = get_redis_connection(
redis_url=redis_url, redis_sentinels=None, async_mode=False
)
assert result == mock_redis
mock_from_url.assert_called_once_with(redis_url, decode_responses=True)
@patch("redis.asyncio.from_url")
def test_get_redis_connection_without_sentinel_async(self, mock_from_url):
"""Test getting async Redis connection without Sentinel"""
mock_redis = Mock()
mock_from_url.return_value = mock_redis
redis_url = "redis://localhost:6379/0"
result = get_redis_connection(
redis_url=redis_url, redis_sentinels=None, async_mode=True
)
assert result == mock_redis
mock_from_url.assert_called_once_with(redis_url, decode_responses=True)
class TestSentinelRedisProxyCommands:
"""Test Redis commands through SentinelRedisProxy"""
@patch("redis.sentinel.Sentinel")
def test_hash_commands_sync(self, mock_sentinel_class):
"""Test Redis hash commands in sync mode"""
mock_sentinel = Mock()
mock_master = Mock()
# Mock hash command responses
mock_master.hset.return_value = 1
mock_master.hget.return_value = "test_value"
mock_master.hgetall.return_value = {"key1": "value1", "key2": "value2"}
mock_master.hdel.return_value = 1
mock_sentinel.master_for.return_value = mock_master
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=False)
# Test hset
hset_method = proxy.__getattr__("hset")
result = hset_method("test_hash", "field1", "value1")
assert result == 1
mock_master.hset.assert_called_with("test_hash", "field1", "value1")
# Test hget
hget_method = proxy.__getattr__("hget")
result = hget_method("test_hash", "field1")
assert result == "test_value"
mock_master.hget.assert_called_with("test_hash", "field1")
# Test hgetall
hgetall_method = proxy.__getattr__("hgetall")
result = hgetall_method("test_hash")
assert result == {"key1": "value1", "key2": "value2"}
mock_master.hgetall.assert_called_with("test_hash")
# Test hdel
hdel_method = proxy.__getattr__("hdel")
result = hdel_method("test_hash", "field1")
assert result == 1
mock_master.hdel.assert_called_with("test_hash", "field1")
@patch("redis.sentinel.Sentinel")
@pytest.mark.asyncio
async def test_hash_commands_async(self, mock_sentinel_class):
"""Test Redis hash commands in async mode"""
mock_sentinel = Mock()
mock_master = Mock()
# Mock async hash command responses
mock_master.hset = AsyncMock(return_value=1)
mock_master.hget = AsyncMock(return_value="test_value")
mock_master.hgetall = AsyncMock(
return_value={"key1": "value1", "key2": "value2"}
)
mock_sentinel.master_for.return_value = mock_master
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=True)
# Test hset
hset_method = proxy.__getattr__("hset")
result = await hset_method("test_hash", "field1", "value1")
assert result == 1
mock_master.hset.assert_called_with("test_hash", "field1", "value1")
# Test hget
hget_method = proxy.__getattr__("hget")
result = await hget_method("test_hash", "field1")
assert result == "test_value"
mock_master.hget.assert_called_with("test_hash", "field1")
# Test hgetall
hgetall_method = proxy.__getattr__("hgetall")
result = await hgetall_method("test_hash")
assert result == {"key1": "value1", "key2": "value2"}
mock_master.hgetall.assert_called_with("test_hash")
@patch("redis.sentinel.Sentinel")
def test_string_commands_sync(self, mock_sentinel_class):
"""Test Redis string commands in sync mode"""
mock_sentinel = Mock()
mock_master = Mock()
# Mock string command responses
mock_master.set.return_value = True
mock_master.get.return_value = "test_value"
mock_master.delete.return_value = 1
mock_master.exists.return_value = True
mock_sentinel.master_for.return_value = mock_master
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=False)
# Test set
set_method = proxy.__getattr__("set")
result = set_method("test_key", "test_value")
assert result is True
mock_master.set.assert_called_with("test_key", "test_value")
# Test get
get_method = proxy.__getattr__("get")
result = get_method("test_key")
assert result == "test_value"
mock_master.get.assert_called_with("test_key")
# Test delete
delete_method = proxy.__getattr__("delete")
result = delete_method("test_key")
assert result == 1
mock_master.delete.assert_called_with("test_key")
# Test exists
exists_method = proxy.__getattr__("exists")
result = exists_method("test_key")
assert result is True
mock_master.exists.assert_called_with("test_key")
@patch("redis.sentinel.Sentinel")
def test_list_commands_sync(self, mock_sentinel_class):
"""Test Redis list commands in sync mode"""
mock_sentinel = Mock()
mock_master = Mock()
# Mock list command responses
mock_master.lpush.return_value = 1
mock_master.rpop.return_value = "test_value"
mock_master.llen.return_value = 5
mock_master.lrange.return_value = ["item1", "item2", "item3"]
mock_sentinel.master_for.return_value = mock_master
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=False)
# Test lpush
lpush_method = proxy.__getattr__("lpush")
result = lpush_method("test_list", "item1")
assert result == 1
mock_master.lpush.assert_called_with("test_list", "item1")
# Test rpop
rpop_method = proxy.__getattr__("rpop")
result = rpop_method("test_list")
assert result == "test_value"
mock_master.rpop.assert_called_with("test_list")
# Test llen
llen_method = proxy.__getattr__("llen")
result = llen_method("test_list")
assert result == 5
mock_master.llen.assert_called_with("test_list")
# Test lrange
lrange_method = proxy.__getattr__("lrange")
result = lrange_method("test_list", 0, -1)
assert result == ["item1", "item2", "item3"]
mock_master.lrange.assert_called_with("test_list", 0, -1)
@patch("redis.sentinel.Sentinel")
def test_pubsub_commands_sync(self, mock_sentinel_class):
"""Test Redis pubsub commands in sync mode"""
mock_sentinel = Mock()
mock_master = Mock()
mock_pubsub = Mock()
# Mock pubsub responses
mock_master.pubsub.return_value = mock_pubsub
mock_master.publish.return_value = 1
mock_pubsub.subscribe.return_value = None
mock_pubsub.get_message.return_value = {"type": "message", "data": "test_data"}
mock_sentinel.master_for.return_value = mock_master
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=False)
# Test pubsub (factory method - should pass through)
pubsub_method = proxy.__getattr__("pubsub")
result = pubsub_method()
assert result == mock_pubsub
mock_master.pubsub.assert_called_once()
# Test publish
publish_method = proxy.__getattr__("publish")
result = publish_method("test_channel", "test_message")
assert result == 1
mock_master.publish.assert_called_with("test_channel", "test_message")
@patch("redis.sentinel.Sentinel")
def test_pipeline_commands_sync(self, mock_sentinel_class):
"""Test Redis pipeline commands in sync mode"""
mock_sentinel = Mock()
mock_master = Mock()
mock_pipeline = Mock()
# Mock pipeline responses
mock_master.pipeline.return_value = mock_pipeline
mock_pipeline.set.return_value = mock_pipeline
mock_pipeline.get.return_value = mock_pipeline
mock_pipeline.execute.return_value = [True, "test_value"]
mock_sentinel.master_for.return_value = mock_master
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=False)
# Test pipeline (factory method - should pass through)
pipeline_method = proxy.__getattr__("pipeline")
result = pipeline_method()
assert result == mock_pipeline
mock_master.pipeline.assert_called_once()
@patch("redis.sentinel.Sentinel")
def test_commands_with_failover_retry(self, mock_sentinel_class):
"""Test Redis commands with failover retry mechanism"""
mock_sentinel = Mock()
mock_master = Mock()
# First call fails with connection error, second succeeds
mock_master.hget.side_effect = [
redis.exceptions.ConnectionError("Connection failed"),
"recovered_value",
]
mock_sentinel.master_for.return_value = mock_master
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=False)
# Test hget with retry
hget_method = proxy.__getattr__("hget")
result = hget_method("test_hash", "field1")
assert result == "recovered_value"
assert mock_master.hget.call_count == 2
# Verify both calls were made with same parameters
expected_calls = [(("test_hash", "field1"),), (("test_hash", "field1"),)]
actual_calls = [call.args for call in mock_master.hget.call_args_list]
assert actual_calls == expected_calls
@patch("redis.sentinel.Sentinel")
def test_commands_with_readonly_error_retry(self, mock_sentinel_class):
"""Test Redis commands with ReadOnlyError retry mechanism"""
mock_sentinel = Mock()
mock_master = Mock()
# First call fails with ReadOnlyError, second succeeds
mock_master.hset.side_effect = [
redis.exceptions.ReadOnlyError(
"READONLY You can't write against a read only replica"
),
1,
]
mock_sentinel.master_for.return_value = mock_master
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=False)
# Test hset with retry
hset_method = proxy.__getattr__("hset")
result = hset_method("test_hash", "field1", "value1")
assert result == 1
assert mock_master.hset.call_count == 2
# Verify both calls were made with same parameters
expected_calls = [
(("test_hash", "field1", "value1"),),
(("test_hash", "field1", "value1"),),
]
actual_calls = [call.args for call in mock_master.hset.call_args_list]
assert actual_calls == expected_calls
@patch("redis.sentinel.Sentinel")
@pytest.mark.asyncio
async def test_async_commands_with_failover_retry(self, mock_sentinel_class):
"""Test async Redis commands with failover retry mechanism"""
mock_sentinel = Mock()
mock_master = Mock()
# First call fails with connection error, second succeeds
mock_master.hget = AsyncMock(
side_effect=[
redis.exceptions.ConnectionError("Connection failed"),
"recovered_value",
]
)
mock_sentinel.master_for.return_value = mock_master
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=True)
# Test async hget with retry
hget_method = proxy.__getattr__("hget")
result = await hget_method("test_hash", "field1")
assert result == "recovered_value"
assert mock_master.hget.call_count == 2
# Verify both calls were made with same parameters
expected_calls = [(("test_hash", "field1"),), (("test_hash", "field1"),)]
actual_calls = [call.args for call in mock_master.hget.call_args_list]
assert actual_calls == expected_calls
class TestSentinelRedisProxyFactoryMethods:
"""Test Redis factory methods in async mode - these are special cases that remain sync"""
@patch("redis.sentinel.Sentinel")
@pytest.mark.asyncio
async def test_pubsub_factory_method_async(self, mock_sentinel_class):
"""Test pubsub factory method in async mode - should pass through without wrapping"""
mock_sentinel = Mock()
mock_master = Mock()
mock_pubsub = Mock()
# Mock pubsub factory method
mock_master.pubsub.return_value = mock_pubsub
mock_sentinel.master_for.return_value = mock_master
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=True)
# Test pubsub factory method - should NOT be wrapped as async
pubsub_method = proxy.__getattr__("pubsub")
result = pubsub_method()
assert result == mock_pubsub
mock_master.pubsub.assert_called_once()
# Verify it's not wrapped as async (no await needed)
assert not inspect.iscoroutine(result)
@patch("redis.sentinel.Sentinel")
@pytest.mark.asyncio
async def test_pipeline_factory_method_async(self, mock_sentinel_class):
"""Test pipeline factory method in async mode - should pass through without wrapping"""
mock_sentinel = Mock()
mock_master = Mock()
mock_pipeline = Mock()
# Mock pipeline factory method
mock_master.pipeline.return_value = mock_pipeline
mock_pipeline.set.return_value = mock_pipeline
mock_pipeline.get.return_value = mock_pipeline
mock_pipeline.execute.return_value = [True, "test_value"]
mock_sentinel.master_for.return_value = mock_master
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=True)
# Test pipeline factory method - should NOT be wrapped as async
pipeline_method = proxy.__getattr__("pipeline")
result = pipeline_method()
assert result == mock_pipeline
mock_master.pipeline.assert_called_once()
# Verify it's not wrapped as async (no await needed)
assert not inspect.iscoroutine(result)
# Test pipeline usage (these should also be sync)
pipeline_result = result.set("key", "value").get("key").execute()
assert pipeline_result == [True, "test_value"]
@patch("redis.sentinel.Sentinel")
@pytest.mark.asyncio
async def test_factory_methods_vs_regular_commands_async(self, mock_sentinel_class):
"""Test that factory methods behave differently from regular commands in async mode"""
mock_sentinel = Mock()
mock_master = Mock()
# Mock both factory method and regular command
mock_pubsub = Mock()
mock_master.pubsub.return_value = mock_pubsub
mock_master.get = AsyncMock(return_value="test_value")
mock_sentinel.master_for.return_value = mock_master
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=True)
# Test factory method - should NOT be wrapped
pubsub_method = proxy.__getattr__("pubsub")
pubsub_result = pubsub_method()
# Test regular command - should be wrapped as async
get_method = proxy.__getattr__("get")
get_result = get_method("test_key")
# Factory method returns directly
assert pubsub_result == mock_pubsub
assert not inspect.iscoroutine(pubsub_result)
# Regular command returns coroutine
assert inspect.iscoroutine(get_result)
# Regular command needs await
actual_value = await get_result
assert actual_value == "test_value"
@patch("redis.sentinel.Sentinel")
@pytest.mark.asyncio
async def test_factory_methods_with_failover_async(self, mock_sentinel_class):
"""Test factory methods with failover in async mode"""
mock_sentinel = Mock()
mock_master = Mock()
# First call fails, second succeeds
mock_pubsub = Mock()
mock_master.pubsub.side_effect = [
redis.exceptions.ConnectionError("Connection failed"),
mock_pubsub,
]
mock_sentinel.master_for.return_value = mock_master
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=True)
# Test pubsub factory method with failover
pubsub_method = proxy.__getattr__("pubsub")
result = pubsub_method()
assert result == mock_pubsub
assert mock_master.pubsub.call_count == 2 # Retry happened
# Verify it's still not wrapped as async after retry
assert not inspect.iscoroutine(result)
@patch("redis.sentinel.Sentinel")
@pytest.mark.asyncio
async def test_monitor_factory_method_async(self, mock_sentinel_class):
"""Test monitor factory method in async mode - should pass through without wrapping"""
mock_sentinel = Mock()
mock_master = Mock()
mock_monitor = Mock()
# Mock monitor factory method
mock_master.monitor.return_value = mock_monitor
mock_sentinel.master_for.return_value = mock_master
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=True)
# Test monitor factory method - should NOT be wrapped as async
monitor_method = proxy.__getattr__("monitor")
result = monitor_method()
assert result == mock_monitor
mock_master.monitor.assert_called_once()
# Verify it's not wrapped as async (no await needed)
assert not inspect.iscoroutine(result)
@patch("redis.sentinel.Sentinel")
@pytest.mark.asyncio
async def test_client_factory_method_async(self, mock_sentinel_class):
"""Test client factory method in async mode - should pass through without wrapping"""
mock_sentinel = Mock()
mock_master = Mock()
mock_client = Mock()
# Mock client factory method
mock_master.client.return_value = mock_client
mock_sentinel.master_for.return_value = mock_master
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=True)
# Test client factory method - should NOT be wrapped as async
client_method = proxy.__getattr__("client")
result = client_method()
assert result == mock_client
mock_master.client.assert_called_once()
# Verify it's not wrapped as async (no await needed)
assert not inspect.iscoroutine(result)
@patch("redis.sentinel.Sentinel")
@pytest.mark.asyncio
async def test_transaction_factory_method_async(self, mock_sentinel_class):
"""Test transaction factory method in async mode - should pass through without wrapping"""
mock_sentinel = Mock()
mock_master = Mock()
mock_transaction = Mock()
# Mock transaction factory method
mock_master.transaction.return_value = mock_transaction
mock_sentinel.master_for.return_value = mock_master
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=True)
# Test transaction factory method - should NOT be wrapped as async
transaction_method = proxy.__getattr__("transaction")
result = transaction_method()
assert result == mock_transaction
mock_master.transaction.assert_called_once()
# Verify it's not wrapped as async (no await needed)
assert not inspect.iscoroutine(result)
@patch("redis.sentinel.Sentinel")
@pytest.mark.asyncio
async def test_all_factory_methods_async(self, mock_sentinel_class):
"""Test all factory methods in async mode - comprehensive test"""
mock_sentinel = Mock()
mock_master = Mock()
# Mock all factory methods
mock_objects = {
"pipeline": Mock(),
"pubsub": Mock(),
"monitor": Mock(),
"client": Mock(),
"transaction": Mock(),
}
for method_name, mock_obj in mock_objects.items():
setattr(mock_master, method_name, Mock(return_value=mock_obj))
mock_sentinel.master_for.return_value = mock_master
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=True)
# Test all factory methods
for method_name, expected_obj in mock_objects.items():
method = proxy.__getattr__(method_name)
result = method()
assert result == expected_obj
assert not inspect.iscoroutine(result)
getattr(mock_master, method_name).assert_called_once()
# Reset mock for next iteration
getattr(mock_master, method_name).reset_mock()
@patch("redis.sentinel.Sentinel")
@pytest.mark.asyncio
async def test_mixed_factory_and_regular_commands_async(self, mock_sentinel_class):
"""Test using both factory methods and regular commands in async mode"""
mock_sentinel = Mock()
mock_master = Mock()
# Mock pipeline factory and regular commands
mock_pipeline = Mock()
mock_master.pipeline.return_value = mock_pipeline
mock_pipeline.set.return_value = mock_pipeline
mock_pipeline.get.return_value = mock_pipeline
mock_pipeline.execute.return_value = [True, "pipeline_value"]
mock_master.get = AsyncMock(return_value="regular_value")
mock_sentinel.master_for.return_value = mock_master
proxy = SentinelRedisProxy(mock_sentinel, "mymaster", async_mode=True)
# Use factory method (sync)
pipeline = proxy.__getattr__("pipeline")()
pipeline_result = pipeline.set("key1", "value1").get("key1").execute()
# Use regular command (async)
get_method = proxy.__getattr__("get")
regular_result = await get_method("key2")
# Verify both work correctly
assert pipeline_result == [True, "pipeline_value"]
assert regular_result == "regular_value"
# Verify calls
mock_master.pipeline.assert_called_once()
mock_master.get.assert_called_with("key2")

View File

@ -8,6 +8,12 @@ import requests
import os import os
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from cryptography.hazmat.primitives.asymmetric import ed25519
from cryptography.hazmat.primitives import serialization
import json
from datetime import datetime, timedelta from datetime import datetime, timedelta
import pytz import pytz
from pytz import UTC from pytz import UTC
@ -18,7 +24,11 @@ from opentelemetry import trace
from open_webui.models.users import Users from open_webui.models.users import Users
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
from open_webui.env import ( from open_webui.env import (
OFFLINE_MODE,
LICENSE_BLOB,
pk,
WEBUI_SECRET_KEY, WEBUI_SECRET_KEY,
TRUSTED_SIGNATURE_KEY, TRUSTED_SIGNATURE_KEY,
STATIC_DIR, STATIC_DIR,
@ -74,17 +84,8 @@ def override_static(path: str, content: str):
def get_license_data(app, key): def get_license_data(app, key):
if key: def data_handler(data):
try: for k, v in data.items():
res = requests.post(
"https://api.openwebui.com/api/v1/license/",
json={"key": key, "version": "1"},
timeout=5,
)
if getattr(res, "ok", False):
payload = getattr(res, "json", lambda: {})()
for k, v in payload.items():
if k == "resources": if k == "resources":
for p, c in v.items(): for p, c in v.items():
globals().get("override_static", lambda a, b: None)(p, c) globals().get("override_static", lambda a, b: None)(p, c)
@ -94,13 +95,62 @@ def get_license_data(app, key):
setattr(app.state, "WEBUI_NAME", v) setattr(app.state, "WEBUI_NAME", v)
elif k == "metadata": elif k == "metadata":
setattr(app.state, "LICENSE_METADATA", v) setattr(app.state, "LICENSE_METADATA", v)
def handler(u):
res = requests.post(
f"{u}/api/v1/license/",
json={"key": key, "version": "1"},
timeout=5,
)
if getattr(res, "ok", False):
payload = getattr(res, "json", lambda: {})()
data_handler(payload)
return True return True
else: else:
log.error( log.error(
f"License: retrieval issue: {getattr(res, 'text', 'unknown error')}" f"License: retrieval issue: {getattr(res, 'text', 'unknown error')}"
) )
if key:
us = [
"https://api.openwebui.com",
"https://licenses.api.openwebui.com",
]
try:
for u in us:
if handler(u):
return True
except Exception as ex: except Exception as ex:
log.exception(f"License: Uncaught Exception: {ex}") log.exception(f"License: Uncaught Exception: {ex}")
try:
if LICENSE_BLOB:
nl = 12
kb = hashlib.sha256((key.replace("-", "").upper()).encode()).digest()
def nt(b):
return b[:nl], b[nl:]
lb = base64.b64decode(LICENSE_BLOB)
ln, lt = nt(lb)
aesgcm = AESGCM(kb)
p = json.loads(aesgcm.decrypt(ln, lt, None))
pk.verify(base64.b64decode(p["s"]), p["p"].encode())
pb = base64.b64decode(p["p"])
pn, pt = nt(pb)
data = json.loads(aesgcm.decrypt(pn, pt, None).decode())
if not data.get("exp") and data.get("exp") < datetime.now().date():
return False
data_handler(data)
return True
except Exception as e:
log.error(f"License: {e}")
return False return False

View File

@ -419,7 +419,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
params[key] = value params[key] = value
if "__user__" in sig.parameters: if "__user__" in sig.parameters:
__user__ = (user.model_dump() if isinstance(user, UserModel) else {},) __user__ = user.model_dump() if isinstance(user, UserModel) else {}
try: try:
if hasattr(function_module, "UserValves"): if hasattr(function_module, "UserValves"):

View File

@ -4,12 +4,16 @@ import sys
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from loguru import logger from loguru import logger
from opentelemetry import trace
from open_webui.env import ( from open_webui.env import (
AUDIT_UVICORN_LOGGER_NAMES,
AUDIT_LOG_FILE_ROTATION_SIZE, AUDIT_LOG_FILE_ROTATION_SIZE,
AUDIT_LOG_LEVEL, AUDIT_LOG_LEVEL,
AUDIT_LOGS_FILE_PATH, AUDIT_LOGS_FILE_PATH,
GLOBAL_LOG_LEVEL, GLOBAL_LOG_LEVEL,
ENABLE_OTEL,
) )
@ -58,9 +62,20 @@ class InterceptHandler(logging.Handler):
frame = frame.f_back frame = frame.f_back
depth += 1 depth += 1
logger.opt(depth=depth, exception=record.exc_info).log( logger.opt(depth=depth, exception=record.exc_info).bind(
level, record.getMessage() **self._get_extras()
) ).log(level, record.getMessage())
def _get_extras(self):
if not ENABLE_OTEL:
return {}
extras = {}
context = trace.get_current_span().get_span_context()
if context.is_valid:
extras["trace_id"] = trace.format_trace_id(context.trace_id)
extras["span_id"] = trace.format_span_id(context.span_id)
return extras
def file_format(record: "Record"): def file_format(record: "Record"):
@ -128,11 +143,13 @@ def start_logger():
logging.basicConfig( logging.basicConfig(
handlers=[InterceptHandler()], level=GLOBAL_LOG_LEVEL, force=True handlers=[InterceptHandler()], level=GLOBAL_LOG_LEVEL, force=True
) )
for uvicorn_logger_name in ["uvicorn", "uvicorn.error"]: for uvicorn_logger_name in ["uvicorn", "uvicorn.error"]:
uvicorn_logger = logging.getLogger(uvicorn_logger_name) uvicorn_logger = logging.getLogger(uvicorn_logger_name)
uvicorn_logger.setLevel(GLOBAL_LOG_LEVEL) uvicorn_logger.setLevel(GLOBAL_LOG_LEVEL)
uvicorn_logger.handlers = [] uvicorn_logger.handlers = []
for uvicorn_logger_name in ["uvicorn.access"]:
for uvicorn_logger_name in AUDIT_UVICORN_LOGGER_NAMES:
uvicorn_logger = logging.getLogger(uvicorn_logger_name) uvicorn_logger = logging.getLogger(uvicorn_logger_name)
uvicorn_logger.setLevel(GLOBAL_LOG_LEVEL) uvicorn_logger.setLevel(GLOBAL_LOG_LEVEL)
uvicorn_logger.handlers = [InterceptHandler()] uvicorn_logger.handlers = [InterceptHandler()]

View File

@ -23,6 +23,7 @@ from starlette.responses import Response, StreamingResponse
from open_webui.models.chats import Chats from open_webui.models.chats import Chats
from open_webui.models.folders import Folders
from open_webui.models.users import Users from open_webui.models.users import Users
from open_webui.socket.main import ( from open_webui.socket.main import (
get_event_call, get_event_call,
@ -56,7 +57,7 @@ from open_webui.models.users import UserModel
from open_webui.models.functions import Functions from open_webui.models.functions import Functions
from open_webui.models.models import Models from open_webui.models.models import Models
from open_webui.retrieval.utils import get_sources_from_files from open_webui.retrieval.utils import get_sources_from_items
from open_webui.utils.chat import generate_chat_completion from open_webui.utils.chat import generate_chat_completion
@ -82,6 +83,7 @@ from open_webui.utils.filter import (
process_filter_functions, process_filter_functions,
) )
from open_webui.utils.code_interpreter import execute_code_jupyter from open_webui.utils.code_interpreter import execute_code_jupyter
from open_webui.utils.payload import apply_model_system_prompt_to_body
from open_webui.tasks import create_task from open_webui.tasks import create_task
@ -248,9 +250,7 @@ async def chat_completion_tools_handler(
if tool_id if tool_id
else f"{tool_function_name}" else f"{tool_function_name}"
) )
if tool.get("metadata", {}).get("citation", False) or tool.get(
"direct", False
):
# Citation is enabled for this tool # Citation is enabled for this tool
sources.append( sources.append(
{ {
@ -264,9 +264,9 @@ async def chat_completion_tools_handler(
"parameters": tool_function_params, "parameters": tool_function_params,
} }
], ],
"tool_result": True,
} }
) )
else:
# Citation is not enabled for this tool # Citation is not enabled for this tool
body["messages"] = add_or_update_user_message( body["messages"] = add_or_update_user_message(
f"\nTool `{tool_name}` Output: {tool_result}", f"\nTool `{tool_name}` Output: {tool_result}",
@ -640,25 +640,34 @@ async def chat_completion_files_handler(
queries = [get_last_user_message(body["messages"])] queries = [get_last_user_message(body["messages"])]
try: try:
# Offload get_sources_from_files to a separate thread # Offload get_sources_from_items to a separate thread
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
with ThreadPoolExecutor() as executor: with ThreadPoolExecutor() as executor:
sources = await loop.run_in_executor( sources = await loop.run_in_executor(
executor, executor,
lambda: get_sources_from_files( lambda: get_sources_from_items(
request=request, request=request,
files=files, items=files,
queries=queries, queries=queries,
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION( embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
query, prefix=prefix, user=user query, prefix=prefix, user=user
), ),
k=request.app.state.config.TOP_K, k=request.app.state.config.TOP_K,
reranking_function=request.app.state.rf, reranking_function=(
(
lambda sentences: request.app.state.RERANKING_FUNCTION(
sentences, user=user
)
)
if request.app.state.RERANKING_FUNCTION
else None
),
k_reranker=request.app.state.config.TOP_K_RERANKER, k_reranker=request.app.state.config.TOP_K_RERANKER,
r=request.app.state.config.RELEVANCE_THRESHOLD, r=request.app.state.config.RELEVANCE_THRESHOLD,
hybrid_bm25_weight=request.app.state.config.HYBRID_BM25_WEIGHT, hybrid_bm25_weight=request.app.state.config.HYBRID_BM25_WEIGHT,
hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
full_context=request.app.state.config.RAG_FULL_CONTEXT, full_context=request.app.state.config.RAG_FULL_CONTEXT,
user=user,
), ),
) )
except Exception as e: except Exception as e:
@ -718,6 +727,10 @@ def apply_params_to_form_data(form_data, model):
async def process_chat_payload(request, form_data, user, metadata, model): async def process_chat_payload(request, form_data, user, metadata, model):
# Pipeline Inlet -> Filter Inlet -> Chat Memory -> Chat Web Search -> Chat Image Generation
# -> Chat Code Interpreter (Form Data Update) -> (Default) Chat Tools Function Calling
# -> Chat Files
form_data = apply_params_to_form_data(form_data, model) form_data = apply_params_to_form_data(form_data, model)
log.debug(f"form_data: {form_data}") log.debug(f"form_data: {form_data}")
@ -752,6 +765,29 @@ async def process_chat_payload(request, form_data, user, metadata, model):
events = [] events = []
sources = [] sources = []
# Folder "Project" handling
# Check if the request has chat_id and is inside of a folder
chat_id = metadata.get("chat_id", None)
if chat_id and user:
chat = Chats.get_chat_by_id_and_user_id(chat_id, user.id)
if chat and chat.folder_id:
folder = Folders.get_folder_by_id_and_user_id(chat.folder_id, user.id)
if folder and folder.data:
if "system_prompt" in folder.data:
form_data = apply_model_system_prompt_to_body(
folder.data["system_prompt"],
form_data,
metadata,
user
)
if "files" in folder.data:
form_data["files"] = [
*folder.data["files"],
*form_data.get("files", []),
]
# Model "Knowledge" handling
user_message = get_last_user_message(form_data["messages"]) user_message = get_last_user_message(form_data["messages"])
model_knowledge = model.get("info", {}).get("meta", {}).get("knowledge", False) model_knowledge = model.get("info", {}).get("meta", {}).get("knowledge", False)
@ -804,7 +840,6 @@ async def process_chat_payload(request, form_data, user, metadata, model):
raise e raise e
try: try:
filter_functions = [ filter_functions = [
Functions.get_function_by_id(filter_id) Functions.get_function_by_id(filter_id)
for filter_id in get_sorted_filter_ids( for filter_id in get_sorted_filter_ids(
@ -912,7 +947,6 @@ async def process_chat_payload(request, form_data, user, metadata, model):
request, form_data, extra_params, user, models, tools_dict request, form_data, extra_params, user, models, tools_dict
) )
sources.extend(flags.get("sources", [])) sources.extend(flags.get("sources", []))
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
@ -925,39 +959,43 @@ async def process_chat_payload(request, form_data, user, metadata, model):
# If context is not empty, insert it into the messages # If context is not empty, insert it into the messages
if len(sources) > 0: if len(sources) > 0:
context_string = "" context_string = ""
citation_idx = {} citation_idx_map = {}
for source in sources: for source in sources:
if "document" in source: is_tool_result = source.get("tool_result", False)
for doc_context, doc_meta in zip(
if "document" in source and not is_tool_result:
for document_text, document_metadata in zip(
source["document"], source["metadata"] source["document"], source["metadata"]
): ):
source_name = source.get("source", {}).get("name", None) source_name = source.get("source", {}).get("name", None)
citation_id = ( source_id = (
doc_meta.get("source", None) document_metadata.get("source", None)
or source.get("source", {}).get("id", None) or source.get("source", {}).get("id", None)
or "N/A" or "N/A"
) )
if citation_id not in citation_idx:
citation_idx[citation_id] = len(citation_idx) + 1 if source_id not in citation_idx_map:
citation_idx_map[source_id] = len(citation_idx_map) + 1
context_string += ( context_string += (
f'<source id="{citation_idx[citation_id]}"' f'<source id="{citation_idx_map[source_id]}"'
+ (f' name="{source_name}"' if source_name else "") + (f' name="{source_name}"' if source_name else "")
+ f">{doc_context}</source>\n" + f">{document_text}</source>\n"
) )
context_string = context_string.strip() context_string = context_string.strip()
prompt = get_last_user_message(form_data["messages"])
prompt = get_last_user_message(form_data["messages"])
if prompt is None: if prompt is None:
raise Exception("No user message found") raise Exception("No user message found")
if (
request.app.state.config.RELEVANCE_THRESHOLD == 0 if context_string == "":
and context_string.strip() == "" if request.app.state.config.RELEVANCE_THRESHOLD == 0:
):
log.debug( log.debug(
f"With a 0 relevancy threshold for RAG, the context cannot be empty" f"With a 0 relevancy threshold for RAG, the context cannot be empty"
) )
else:
# Workaround for Ollama 2.0+ system prompt issue # Workaround for Ollama 2.0+ system prompt issue
# TODO: replace with add_or_update_system_message # TODO: replace with add_or_update_system_message
if model.get("owned_by") == "ollama": if model.get("owned_by") == "ollama":
@ -1347,14 +1385,6 @@ async def process_chat_response(
task_id = str(uuid4()) # Create a unique task ID. task_id = str(uuid4()) # Create a unique task ID.
model_id = form_data.get("model", "") model_id = form_data.get("model", "")
Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"],
metadata["message_id"],
{
"model": model_id,
},
)
def split_content_and_whitespace(content): def split_content_and_whitespace(content):
content_stripped = content.rstrip() content_stripped = content.rstrip()
original_whitespace = ( original_whitespace = (
@ -1370,7 +1400,7 @@ async def process_chat_response(
return len(backtick_segments) > 1 and len(backtick_segments) % 2 == 0 return len(backtick_segments) > 1 and len(backtick_segments) % 2 == 0
# Handle as a background task # Handle as a background task
async def post_response_handler(response, events): async def response_handler(response, events):
def serialize_content_blocks(content_blocks, raw=False): def serialize_content_blocks(content_blocks, raw=False):
content = "" content = ""
@ -1405,7 +1435,7 @@ async def process_chat_response(
break break
if tool_result: if tool_result:
tool_calls_display_content = f'{tool_calls_display_content}\n<details type="tool_calls" done="true" id="{tool_call_id}" name="{tool_name}" arguments="{html.escape(json.dumps(tool_arguments))}" result="{html.escape(json.dumps(tool_result))}" files="{html.escape(json.dumps(tool_result_files)) if tool_result_files else ""}">\n<summary>Tool Executed</summary>\n</details>\n' tool_calls_display_content = f'{tool_calls_display_content}\n<details type="tool_calls" done="true" id="{tool_call_id}" name="{tool_name}" arguments="{html.escape(json.dumps(tool_arguments))}" result="{html.escape(json.dumps(tool_result, ensure_ascii=False))}" files="{html.escape(json.dumps(tool_result_files)) if tool_result_files else ""}">\n<summary>Tool Executed</summary>\n</details>\n'
else: else:
tool_calls_display_content = f'{tool_calls_display_content}\n<details type="tool_calls" done="false" id="{tool_call_id}" name="{tool_name}" arguments="{html.escape(json.dumps(tool_arguments))}">\n<summary>Executing...</summary>\n</details>' tool_calls_display_content = f'{tool_calls_display_content}\n<details type="tool_calls" done="false" id="{tool_call_id}" name="{tool_name}" arguments="{html.escape(json.dumps(tool_arguments))}">\n<summary>Executing...</summary>\n</details>'
@ -1438,12 +1468,12 @@ async def process_chat_response(
if reasoning_duration is not None: if reasoning_duration is not None:
if raw: if raw:
content = f'{content}\n<{block["start_tag"]}>{block["content"]}<{block["end_tag"]}>\n' content = f'{content}\n{block["start_tag"]}{block["content"]}{block["end_tag"]}\n'
else: else:
content = f'{content}\n<details type="reasoning" done="true" duration="{reasoning_duration}">\n<summary>Thought for {reasoning_duration} seconds</summary>\n{reasoning_display_content}\n</details>\n' content = f'{content}\n<details type="reasoning" done="true" duration="{reasoning_duration}">\n<summary>Thought for {reasoning_duration} seconds</summary>\n{reasoning_display_content}\n</details>\n'
else: else:
if raw: if raw:
content = f'{content}\n<{block["start_tag"]}>{block["content"]}<{block["end_tag"]}>\n' content = f'{content}\n{block["start_tag"]}{block["content"]}{block["end_tag"]}\n'
else: else:
content = f'{content}\n<details type="reasoning" done="false">\n<summary>Thinking…</summary>\n{reasoning_display_content}\n</details>\n' content = f'{content}\n<details type="reasoning" done="false">\n<summary>Thinking…</summary>\n{reasoning_display_content}\n</details>\n'
@ -1540,8 +1570,16 @@ async def process_chat_response(
if content_blocks[-1]["type"] == "text": if content_blocks[-1]["type"] == "text":
for start_tag, end_tag in tags: for start_tag, end_tag in tags:
start_tag_pattern = rf"{re.escape(start_tag)}"
if start_tag.startswith("<") and start_tag.endswith(">"):
# Match start tag e.g., <tag> or <tag attr="value"> # Match start tag e.g., <tag> or <tag attr="value">
start_tag_pattern = rf"<{re.escape(start_tag)}(\s.*?)?>" # remove both '<' and '>' from start_tag
# Match start tag with attributes
start_tag_pattern = (
rf"<{re.escape(start_tag[1:-1])}(\s.*?)?>"
)
match = re.search(start_tag_pattern, content) match = re.search(start_tag_pattern, content)
if match: if match:
attr_content = ( attr_content = (
@ -1592,8 +1630,13 @@ async def process_chat_response(
elif content_blocks[-1]["type"] == content_type: elif content_blocks[-1]["type"] == content_type:
start_tag = content_blocks[-1]["start_tag"] start_tag = content_blocks[-1]["start_tag"]
end_tag = content_blocks[-1]["end_tag"] end_tag = content_blocks[-1]["end_tag"]
if end_tag.startswith("<") and end_tag.endswith(">"):
# Match end tag e.g., </tag> # Match end tag e.g., </tag>
end_tag_pattern = rf"<{re.escape(end_tag)}>" end_tag_pattern = rf"{re.escape(end_tag)}"
else:
# Handle cases where end_tag is just a tag name
end_tag_pattern = rf"{re.escape(end_tag)}"
# Check if the content has the end tag # Check if the content has the end tag
if re.search(end_tag_pattern, content): if re.search(end_tag_pattern, content):
@ -1665,8 +1708,17 @@ async def process_chat_response(
) )
# Clean processed content # Clean processed content
start_tag_pattern = rf"{re.escape(start_tag)}"
if start_tag.startswith("<") and start_tag.endswith(">"):
# Match start tag e.g., <tag> or <tag attr="value">
# remove both '<' and '>' from start_tag
# Match start tag with attributes
start_tag_pattern = (
rf"<{re.escape(start_tag[1:-1])}(\s.*?)?>"
)
content = re.sub( content = re.sub(
rf"<{re.escape(start_tag)}(.*?)>(.|\n)*?<{re.escape(end_tag)}>", rf"{start_tag_pattern}(.|\n)*?{re.escape(end_tag)}",
"", "",
content, content,
flags=re.DOTALL, flags=re.DOTALL,
@ -1710,18 +1762,19 @@ async def process_chat_response(
) )
reasoning_tags = [ reasoning_tags = [
("think", "/think"), ("<think>", "</think>"),
("thinking", "/thinking"), ("<thinking>", "</thinking>"),
("reason", "/reason"), ("<reason>", "</reason>"),
("reasoning", "/reasoning"), ("<reasoning>", "</reasoning>"),
("thought", "/thought"), ("<thought>", "</thought>"),
("Thought", "/Thought"), ("<Thought>", "</Thought>"),
("|begin_of_thought|", "|end_of_thought|"), ("<|begin_of_thought|>", "<|end_of_thought|>"),
("◁think▷", "◁/think▷"),
] ]
code_interpreter_tags = [("code_interpreter", "/code_interpreter")] code_interpreter_tags = [("<code_interpreter>", "</code_interpreter>")]
solution_tags = [("|begin_of_solution|", "|end_of_solution|")] solution_tags = [("<|begin_of_solution|>", "<|end_of_solution|>")]
try: try:
for event in events: for event in events:
@ -1741,7 +1794,7 @@ async def process_chat_response(
}, },
) )
async def stream_body_handler(response): async def stream_body_handler(response, form_data):
nonlocal content nonlocal content
nonlocal content_blocks nonlocal content_blocks
@ -1770,7 +1823,7 @@ async def process_chat_response(
filter_functions=filter_functions, filter_functions=filter_functions,
filter_type="stream", filter_type="stream",
form_data=data, form_data=data,
extra_params=extra_params, extra_params={"__body__": form_data, **extra_params},
) )
if data: if data:
@ -2005,7 +2058,7 @@ async def process_chat_response(
if done: if done:
pass pass
else: else:
log.debug("Error: ", e) log.debug(f"Error: {e}")
continue continue
if content_blocks: if content_blocks:
@ -2032,7 +2085,7 @@ async def process_chat_response(
if response.background: if response.background:
await response.background() await response.background()
await stream_body_handler(response) await stream_body_handler(response, form_data)
MAX_TOOL_CALL_RETRIES = 10 MAX_TOOL_CALL_RETRIES = 10
tool_call_retries = 0 tool_call_retries = 0
@ -2148,7 +2201,9 @@ async def process_chat_response(
if isinstance(tool_result, dict) or isinstance( if isinstance(tool_result, dict) or isinstance(
tool_result, list tool_result, list
): ):
tool_result = json.dumps(tool_result, indent=2) tool_result = json.dumps(
tool_result, indent=2, ensure_ascii=False
)
results.append( results.append(
{ {
@ -2181,9 +2236,7 @@ async def process_chat_response(
) )
try: try:
res = await generate_chat_completion( new_form_data = {
request,
{
"model": model_id, "model": model_id,
"stream": True, "stream": True,
"tools": form_data["tools"], "tools": form_data["tools"],
@ -2191,12 +2244,16 @@ async def process_chat_response(
*form_data["messages"], *form_data["messages"],
*convert_content_blocks_to_messages(content_blocks), *convert_content_blocks_to_messages(content_blocks),
], ],
}, }
res = await generate_chat_completion(
request,
new_form_data,
user, user,
) )
if isinstance(res, StreamingResponse): if isinstance(res, StreamingResponse):
await stream_body_handler(res) await stream_body_handler(res, new_form_data)
else: else:
break break
except Exception as e: except Exception as e:
@ -2211,6 +2268,7 @@ async def process_chat_response(
content_blocks[-1]["type"] == "code_interpreter" content_blocks[-1]["type"] == "code_interpreter"
and retries < MAX_RETRIES and retries < MAX_RETRIES
): ):
await event_emitter( await event_emitter(
{ {
"type": "chat:completion", "type": "chat:completion",
@ -2343,9 +2401,7 @@ async def process_chat_response(
) )
try: try:
res = await generate_chat_completion( new_form_data = {
request,
{
"model": model_id, "model": model_id,
"stream": True, "stream": True,
"messages": [ "messages": [
@ -2357,12 +2413,16 @@ async def process_chat_response(
), ),
}, },
], ],
}, }
res = await generate_chat_completion(
request,
new_form_data,
user, user,
) )
if isinstance(res, StreamingResponse): if isinstance(res, StreamingResponse):
await stream_body_handler(res) await stream_body_handler(res, new_form_data)
else: else:
break break
except Exception as e: except Exception as e:
@ -2427,9 +2487,11 @@ async def process_chat_response(
if response.background is not None: if response.background is not None:
await response.background() await response.background()
# background_tasks.add_task(post_response_handler, response, events) # background_tasks.add_task(response_handler, response, events)
task_id, _ = await create_task( task_id, _ = await create_task(
request, post_response_handler(response, events), id=metadata["chat_id"] request.app.state.redis,
response_handler(response, events),
id=metadata["chat_id"],
) )
return {"status": True, "task_id": task_id} return {"status": True, "task_id": task_id}

View File

@ -76,8 +76,19 @@ async def get_all_base_models(request: Request, user: UserModel = None):
return function_models + openai_models + ollama_models return function_models + openai_models + ollama_models
async def get_all_models(request, user: UserModel = None): async def get_all_models(request, refresh: bool = False, user: UserModel = None):
models = await get_all_base_models(request, user=user) if (
request.app.state.MODELS
and request.app.state.BASE_MODELS
and (request.app.state.config.ENABLE_BASE_MODELS_CACHE and not refresh)
):
base_models = request.app.state.BASE_MODELS
else:
base_models = await get_all_base_models(request, user=user)
request.app.state.BASE_MODELS = base_models
# deep copy the base models to avoid modifying the original list
models = [model.copy() for model in base_models]
# If there are no models, return an empty list # If there are no models, return an empty list
if len(models) == 0: if len(models) == 0:
@ -137,6 +148,7 @@ async def get_all_models(request, user: UserModel = None):
custom_models = Models.get_all_models() custom_models = Models.get_all_models()
for custom_model in custom_models: for custom_model in custom_models:
if custom_model.base_model_id is None: if custom_model.base_model_id is None:
# Applied directly to a base model
for model in models: for model in models:
if custom_model.id == model["id"] or ( if custom_model.id == model["id"] or (
model.get("owned_by") == "ollama" model.get("owned_by") == "ollama"

View File

@ -1,6 +1,97 @@
import socketio import inspect
from urllib.parse import urlparse from urllib.parse import urlparse
from typing import Optional
import logging
import redis
from open_webui.env import REDIS_SENTINEL_MAX_RETRY_COUNT
log = logging.getLogger(__name__)
_CONNECTION_CACHE = {}
class SentinelRedisProxy:
def __init__(self, sentinel, service, *, async_mode: bool = True, **kw):
self._sentinel = sentinel
self._service = service
self._kw = kw
self._async_mode = async_mode
def _master(self):
return self._sentinel.master_for(self._service, **self._kw)
def __getattr__(self, item):
master = self._master()
orig_attr = getattr(master, item)
if not callable(orig_attr):
return orig_attr
FACTORY_METHODS = {"pipeline", "pubsub", "monitor", "client", "transaction"}
if item in FACTORY_METHODS:
return orig_attr
if self._async_mode:
async def _wrapped(*args, **kwargs):
for i in range(REDIS_SENTINEL_MAX_RETRY_COUNT):
try:
method = getattr(self._master(), item)
result = method(*args, **kwargs)
if inspect.iscoroutine(result):
return await result
return result
except (
redis.exceptions.ConnectionError,
redis.exceptions.ReadOnlyError,
) as e:
if i < REDIS_SENTINEL_MAX_RETRY_COUNT - 1:
log.debug(
"Redis sentinel fail-over (%s). Retry %s/%s",
type(e).__name__,
i + 1,
REDIS_SENTINEL_MAX_RETRY_COUNT,
)
continue
log.error(
"Redis operation failed after %s retries: %s",
REDIS_SENTINEL_MAX_RETRY_COUNT,
e,
)
raise e from e
return _wrapped
else:
def _wrapped(*args, **kwargs):
for i in range(REDIS_SENTINEL_MAX_RETRY_COUNT):
try:
method = getattr(self._master(), item)
return method(*args, **kwargs)
except (
redis.exceptions.ConnectionError,
redis.exceptions.ReadOnlyError,
) as e:
if i < REDIS_SENTINEL_MAX_RETRY_COUNT - 1:
log.debug(
"Redis sentinel fail-over (%s). Retry %s/%s",
type(e).__name__,
i + 1,
REDIS_SENTINEL_MAX_RETRY_COUNT,
)
continue
log.error(
"Redis operation failed after %s retries: %s",
REDIS_SENTINEL_MAX_RETRY_COUNT,
e,
)
raise e from e
return _wrapped
def parse_redis_service_url(redis_url): def parse_redis_service_url(redis_url):
@ -20,6 +111,14 @@ def parse_redis_service_url(redis_url):
def get_redis_connection( def get_redis_connection(
redis_url, redis_sentinels, async_mode=False, decode_responses=True redis_url, redis_sentinels, async_mode=False, decode_responses=True
): ):
cache_key = (redis_url, tuple(redis_sentinels) if redis_sentinels else (), async_mode, decode_responses)
if cache_key in _CONNECTION_CACHE:
return _CONNECTION_CACHE[cache_key]
connection = None
if async_mode: if async_mode:
import redis.asyncio as redis import redis.asyncio as redis
@ -34,11 +133,13 @@ def get_redis_connection(
password=redis_config["password"], password=redis_config["password"],
decode_responses=decode_responses, decode_responses=decode_responses,
) )
return sentinel.master_for(redis_config["service"]) connection = SentinelRedisProxy(
sentinel,
redis_config["service"],
async_mode=async_mode,
)
elif redis_url: elif redis_url:
return redis.from_url(redis_url, decode_responses=decode_responses) connection = redis.from_url(redis_url, decode_responses=decode_responses)
else:
return None
else: else:
import redis import redis
@ -52,11 +153,16 @@ def get_redis_connection(
password=redis_config["password"], password=redis_config["password"],
decode_responses=decode_responses, decode_responses=decode_responses,
) )
return sentinel.master_for(redis_config["service"]) connection = SentinelRedisProxy(
sentinel,
redis_config["service"],
async_mode=async_mode,
)
elif redis_url: elif redis_url:
return redis.Redis.from_url(redis_url, decode_responses=decode_responses) connection = redis.Redis.from_url(redis_url, decode_responses=decode_responses)
else:
return None _CONNECTION_CACHE[cache_key] = connection
return connection
def get_sentinels_from_env(sentinel_hosts_env, sentinel_port_env): def get_sentinels_from_env(sentinel_hosts_env, sentinel_port_env):

View File

@ -6,18 +6,17 @@ from open_webui.utils.misc import (
) )
def convert_ollama_tool_call_to_openai(tool_calls: dict) -> dict: def convert_ollama_tool_call_to_openai(tool_calls: list) -> list:
openai_tool_calls = [] openai_tool_calls = []
for tool_call in tool_calls: for tool_call in tool_calls:
function = tool_call.get("function", {})
openai_tool_call = { openai_tool_call = {
"index": tool_call.get("index", 0), "index": tool_call.get("index", function.get("index", 0)),
"id": tool_call.get("id", f"call_{str(uuid4())}"), "id": tool_call.get("id", f"call_{str(uuid4())}"),
"type": "function", "type": "function",
"function": { "function": {
"name": tool_call.get("function", {}).get("name", ""), "name": function.get("name", ""),
"arguments": json.dumps( "arguments": json.dumps(function.get("arguments", {})),
tool_call.get("function", {}).get("arguments", {})
),
}, },
} }
openai_tool_calls.append(openai_tool_call) openai_tool_calls.append(openai_tool_call)

View File

@ -1,31 +0,0 @@
import threading
from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.sdk.trace.export import BatchSpanProcessor
class LazyBatchSpanProcessor(BatchSpanProcessor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.done = True
with self.condition:
self.condition.notify_all()
self.worker_thread.join()
self.done = False
self.worker_thread = None
def on_end(self, span: ReadableSpan) -> None:
if self.worker_thread is None:
self.worker_thread = threading.Thread(
name=self.__class__.__name__, target=self.worker, daemon=True
)
self.worker_thread.start()
super().on_end(span)
def shutdown(self) -> None:
self.done = True
with self.condition:
self.condition.notify_all()
if self.worker_thread:
self.worker_thread.join()
self.span_exporter.shutdown()

View File

@ -34,6 +34,8 @@ from opentelemetry.sdk.resources import SERVICE_NAME, Resource
from open_webui.env import OTEL_SERVICE_NAME, OTEL_EXPORTER_OTLP_ENDPOINT from open_webui.env import OTEL_SERVICE_NAME, OTEL_EXPORTER_OTLP_ENDPOINT
from open_webui.socket.main import get_active_user_ids
from open_webui.models.users import Users
_EXPORT_INTERVAL_MILLIS = 10_000 # 10 seconds _EXPORT_INTERVAL_MILLIS = 10_000 # 10 seconds
@ -59,6 +61,12 @@ def _build_meter_provider() -> MeterProvider:
instrument_name="http.server.requests", instrument_name="http.server.requests",
attribute_keys=["http.method", "http.route", "http.status_code"], attribute_keys=["http.method", "http.route", "http.status_code"],
), ),
View(
instrument_name="webui.users.total",
),
View(
instrument_name="webui.users.active",
),
] ]
provider = MeterProvider( provider = MeterProvider(
@ -87,6 +95,38 @@ def setup_metrics(app: FastAPI) -> None:
unit="ms", unit="ms",
) )
def observe_active_users(
options: metrics.CallbackOptions,
) -> Sequence[metrics.Observation]:
return [
metrics.Observation(
value=len(get_active_user_ids()),
)
]
def observe_total_registered_users(
options: metrics.CallbackOptions,
) -> Sequence[metrics.Observation]:
return [
metrics.Observation(
value=len(Users.get_users()["users"]),
)
]
meter.create_observable_gauge(
name="webui.users.total",
description="Total number of registered users",
unit="users",
callbacks=[observe_total_registered_users],
)
meter.create_observable_gauge(
name="webui.users.active",
description="Number of currently active users",
unit="users",
callbacks=[observe_active_users],
)
# FastAPI middleware # FastAPI middleware
@app.middleware("http") @app.middleware("http")
async def _metrics_middleware(request: Request, call_next): async def _metrics_middleware(request: Request, call_next):

View File

@ -1,17 +1,26 @@
from fastapi import FastAPI from fastapi import FastAPI
from opentelemetry import trace from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
OTLPSpanExporter as HttpOTLPSpanExporter,
)
from opentelemetry.sdk.resources import SERVICE_NAME, Resource from opentelemetry.sdk.resources import SERVICE_NAME, Resource
from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from sqlalchemy import Engine from sqlalchemy import Engine
from base64 import b64encode
from open_webui.utils.telemetry.exporters import LazyBatchSpanProcessor
from open_webui.utils.telemetry.instrumentors import Instrumentor from open_webui.utils.telemetry.instrumentors import Instrumentor
from open_webui.utils.telemetry.metrics import setup_metrics from open_webui.utils.telemetry.metrics import setup_metrics
from open_webui.env import ( from open_webui.env import (
OTEL_SERVICE_NAME, OTEL_SERVICE_NAME,
OTEL_EXPORTER_OTLP_ENDPOINT, OTEL_EXPORTER_OTLP_ENDPOINT,
OTEL_EXPORTER_OTLP_INSECURE,
ENABLE_OTEL_METRICS, ENABLE_OTEL_METRICS,
OTEL_BASIC_AUTH_USERNAME,
OTEL_BASIC_AUTH_PASSWORD,
OTEL_OTLP_SPAN_EXPORTER,
) )
@ -22,9 +31,27 @@ def setup(app: FastAPI, db_engine: Engine):
resource=Resource.create(attributes={SERVICE_NAME: OTEL_SERVICE_NAME}) resource=Resource.create(attributes={SERVICE_NAME: OTEL_SERVICE_NAME})
) )
) )
# Add basic auth header only if both username and password are not empty
headers = []
if OTEL_BASIC_AUTH_USERNAME and OTEL_BASIC_AUTH_PASSWORD:
auth_string = f"{OTEL_BASIC_AUTH_USERNAME}:{OTEL_BASIC_AUTH_PASSWORD}"
auth_header = b64encode(auth_string.encode()).decode()
headers = [("authorization", f"Basic {auth_header}")]
# otlp export # otlp export
exporter = OTLPSpanExporter(endpoint=OTEL_EXPORTER_OTLP_ENDPOINT) if OTEL_OTLP_SPAN_EXPORTER == "http":
trace.get_tracer_provider().add_span_processor(LazyBatchSpanProcessor(exporter)) exporter = HttpOTLPSpanExporter(
endpoint=OTEL_EXPORTER_OTLP_ENDPOINT,
headers=headers,
)
else:
exporter = OTLPSpanExporter(
endpoint=OTEL_EXPORTER_OTLP_ENDPOINT,
insecure=OTEL_EXPORTER_OTLP_INSECURE,
headers=headers,
)
trace.get_tracer_provider().add_span_processor(BatchSpanProcessor(exporter))
Instrumentor(app=app, db_engine=db_engine).instrument() Instrumentor(app=app, db_engine=db_engine).instrument()
# set up metrics only if enabled # set up metrics only if enabled

View File

@ -101,9 +101,6 @@ def get_tools(
def make_tool_function(function_name, token, tool_server_data): def make_tool_function(function_name, token, tool_server_data):
async def tool_function(**kwargs): async def tool_function(**kwargs):
print(
f"Executing tool function {function_name} with params: {kwargs}"
)
return await execute_tool_server( return await execute_tool_server(
token=token, token=token,
url=tool_server_data["url"], url=tool_server_data["url"],
@ -492,15 +489,7 @@ async def get_tool_servers_data(
if server.get("config", {}).get("enable"): if server.get("config", {}).get("enable"):
# Path (to OpenAPI spec URL) can be either a full URL or a path to append to the base URL # Path (to OpenAPI spec URL) can be either a full URL or a path to append to the base URL
openapi_path = server.get("path", "openapi.json") openapi_path = server.get("path", "openapi.json")
if "://" in openapi_path: full_url = get_tool_server_url(server.get("url"), openapi_path)
# If it contains "://", it's a full URL
full_url = openapi_path
else:
if not openapi_path.startswith("/"):
# Ensure the path starts with a slash
openapi_path = f"/{openapi_path}"
full_url = f"{server.get('url')}{openapi_path}"
info = server.get("info", {}) info = server.get("info", {})
@ -646,3 +635,16 @@ async def execute_tool_server(
error = str(err) error = str(err)
log.exception(f"API Request Error: {error}") log.exception(f"API Request Error: {error}")
return {"error": error} return {"error": error}
def get_tool_server_url(url: Optional[str], path: str) -> str:
"""
Build the full URL for a tool server, given a base url and a path.
"""
if "://" in path:
# If it contains "://", it's a full URL
return path
if not path.startswith("/"):
# Ensure the path starts with a slash
path = f"/{path}"
return f"{url}{path}"

View File

@ -1,11 +1,12 @@
fastapi==0.115.7 fastapi==0.115.7
uvicorn[standard]==0.34.2 uvicorn[standard]==0.35.0
pydantic==2.10.6 pydantic==2.11.7
python-multipart==0.0.20 python-multipart==0.0.20
python-socketio==5.13.0 python-socketio==5.13.0
python-jose==3.4.0 python-jose==3.4.0
passlib[bcrypt]==1.7.4 passlib[bcrypt]==1.7.4
cryptography
requests==2.32.4 requests==2.32.4
aiohttp==3.11.11 aiohttp==3.11.11
@ -13,6 +14,7 @@ async-timeout
aiocache aiocache
aiofiles aiofiles
starlette-compress==1.6.0 starlette-compress==1.6.0
httpx[socks,http2,zstd,cli,brotli]==0.28.1
sqlalchemy==2.0.38 sqlalchemy==2.0.38
alembic==1.14.0 alembic==1.14.0
@ -30,6 +32,8 @@ boto3==1.35.53
argon2-cffi==23.1.0 argon2-cffi==23.1.0
APScheduler==3.10.4 APScheduler==3.10.4
pycrdt==0.12.25
RestrictedPython==8.0 RestrictedPython==8.0
loguru==0.7.3 loguru==0.7.3
@ -42,13 +46,14 @@ google-genai==1.15.0
google-generativeai==0.8.5 google-generativeai==0.8.5
tiktoken tiktoken
langchain==0.3.24 langchain==0.3.26
langchain-community==0.3.23 langchain-community==0.3.26
fake-useragent==2.1.0 fake-useragent==2.1.0
chromadb==0.6.3 chromadb==0.6.3
posthog==5.4.0
pymilvus==2.5.0 pymilvus==2.5.0
qdrant-client~=1.12.0 qdrant-client==1.14.3
opensearch-py==2.8.0 opensearch-py==2.8.0
playwright==1.49.1 # Caution: version must match docker-compose.playwright.yaml playwright==1.49.1 # Caution: version must match docker-compose.playwright.yaml
elasticsearch==9.0.1 elasticsearch==9.0.1
@ -59,6 +64,7 @@ transformers
sentence-transformers==4.1.0 sentence-transformers==4.1.0
accelerate accelerate
colbert-ai==0.2.21 colbert-ai==0.2.21
pyarrow==20.0.0
einops==0.8.1 einops==0.8.1
@ -100,7 +106,7 @@ youtube-transcript-api==1.1.0
pytube==15.0.0 pytube==15.0.0
pydub pydub
duckduckgo-search==8.0.2 ddgs==9.0.0
## Google Drive ## Google Drive
google-api-python-client google-api-python-client
@ -115,7 +121,7 @@ pytest-docker~=3.1.1
googleapis-common-protos==1.63.2 googleapis-common-protos==1.63.2
google-cloud-storage==2.19.0 google-cloud-storage==2.19.0
azure-identity==1.21.0 azure-identity==1.23.0
azure-storage-blob==12.24.1 azure-storage-blob==12.24.1
@ -129,14 +135,14 @@ firecrawl-py==1.12.0
tencentcloud-sdk-python==3.0.1336 tencentcloud-sdk-python==3.0.1336
## Trace ## Trace
opentelemetry-api==1.32.1 opentelemetry-api==1.36.0
opentelemetry-sdk==1.32.1 opentelemetry-sdk==1.36.0
opentelemetry-exporter-otlp==1.32.1 opentelemetry-exporter-otlp==1.36.0
opentelemetry-instrumentation==0.53b1 opentelemetry-instrumentation==0.57b0
opentelemetry-instrumentation-fastapi==0.53b1 opentelemetry-instrumentation-fastapi==0.57b0
opentelemetry-instrumentation-sqlalchemy==0.53b1 opentelemetry-instrumentation-sqlalchemy==0.57b0
opentelemetry-instrumentation-redis==0.53b1 opentelemetry-instrumentation-redis==0.57b0
opentelemetry-instrumentation-requests==0.53b1 opentelemetry-instrumentation-requests==0.57b0
opentelemetry-instrumentation-logging==0.53b1 opentelemetry-instrumentation-logging==0.57b0
opentelemetry-instrumentation-httpx==0.53b1 opentelemetry-instrumentation-httpx==0.57b0
opentelemetry-instrumentation-aiohttp-client==0.53b1 opentelemetry-instrumentation-aiohttp-client==0.57b0

View File

@ -21,14 +21,14 @@ describe('Settings', () => {
// Click on the model selector // Click on the model selector
cy.get('button[aria-label="Select a model"]').click(); cy.get('button[aria-label="Select a model"]').click();
// Select the first model // Select the first model
cy.get('button[aria-label="model-item"]').first().click(); cy.get('button[aria-roledescription="model-item"]').first().click();
}); });
it('user can perform text chat', () => { it('user can perform text chat', () => {
// Click on the model selector // Click on the model selector
cy.get('button[aria-label="Select a model"]').click(); cy.get('button[aria-label="Select a model"]').click();
// Select the first model // Select the first model
cy.get('button[aria-label="model-item"]').first().click(); cy.get('button[aria-roledescription="model-item"]').first().click();
// Type a message // Type a message
cy.get('#chat-input').type('Hi, what can you do? A single sentence only please.', { cy.get('#chat-input').type('Hi, what can you do? A single sentence only please.', {
force: true force: true
@ -48,7 +48,7 @@ describe('Settings', () => {
// Click on the model selector // Click on the model selector
cy.get('button[aria-label="Select a model"]').click(); cy.get('button[aria-label="Select a model"]').click();
// Select the first model // Select the first model
cy.get('button[aria-label="model-item"]').first().click(); cy.get('button[aria-roledescription="model-item"]').first().click();
// Type a message // Type a message
cy.get('#chat-input').type('Hi, what can you do? A single sentence only please.', { cy.get('#chat-input').type('Hi, what can you do? A single sentence only please.', {
force: true force: true
@ -83,7 +83,7 @@ describe('Settings', () => {
// Click on the model selector // Click on the model selector
cy.get('button[aria-label="Select a model"]').click(); cy.get('button[aria-label="Select a model"]').click();
// Select the first model // Select the first model
cy.get('button[aria-label="model-item"]').first().click(); cy.get('button[aria-roledescription="model-item"]').first().click();
// Type a message // Type a message
cy.get('#chat-input').type('Hi, what can you do? A single sentence only please.', { cy.get('#chat-input').type('Hi, what can you do? A single sentence only please.', {
force: true force: true

35
docker-compose.otel.yaml Normal file
View File

@ -0,0 +1,35 @@
services:
grafana:
image: grafana/otel-lgtm:latest
container_name: lgtm
ports:
- "3000:3000" # Grafana UI
- "4317:4317" # OTLP/gRPC
- "4318:4318" # OTLP/HTTP
restart: unless-stopped
open-webui:
build:
context: .
dockerfile: Dockerfile
image: ghcr.io/open-webui/open-webui:${WEBUI_DOCKER_TAG-main}
container_name: open-webui
volumes:
- open-webui:/app/backend/data
depends_on:
- grafana
ports:
- ${OPEN_WEBUI_PORT-8088}:8080
environment:
- ENABLE_OTEL=true
- ENABLE_OTEL_METRICS=true
- OTEL_EXPORTER_OTLP_INSECURE=true # Use insecure connection for OTLP, remove in production
- OTEL_EXPORTER_OTLP_ENDPOINT=http://grafana:4317
- OTEL_SERVICE_NAME=open-webui
extra_hosts:
- host.docker.internal:host-gateway
restart: unless-stopped
volumes:
open-webui: {}

View File

@ -17,7 +17,7 @@ class CustomBuildHook(BuildHookInterface):
"NodeJS `npm` is required for building Open Webui but it was not found" "NodeJS `npm` is required for building Open Webui but it was not found"
) )
stderr.write("### npm install\n") stderr.write("### npm install\n")
subprocess.run([npm, "install"], check=True) # noqa: S603 subprocess.run([npm, "install", "--force"], check=True) # noqa: S603
stderr.write("\n### npm run build\n") stderr.write("\n### npm run build\n")
os.environ["APP_BUILD_HASH"] = version os.environ["APP_BUILD_HASH"] = version
subprocess.run([npm, "run", "build"], check=True) # noqa: S603 subprocess.run([npm, "run", "build"], check=True) # noqa: S603

999
package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,6 @@
{ {
"name": "open-webui", "name": "open-webui",
"version": "0.6.15", "version": "0.6.18",
"private": true, "private": true,
"scripts": { "scripts": {
"dev": "npm run pyodide:fetch && vite dev --host", "dev": "npm run pyodide:fetch && vite dev --host",
@ -57,25 +57,32 @@
"@codemirror/lang-python": "^6.1.6", "@codemirror/lang-python": "^6.1.6",
"@codemirror/language-data": "^6.5.1", "@codemirror/language-data": "^6.5.1",
"@codemirror/theme-one-dark": "^6.1.2", "@codemirror/theme-one-dark": "^6.1.2",
"@floating-ui/dom": "^1.7.2",
"@huggingface/transformers": "^3.0.0", "@huggingface/transformers": "^3.0.0",
"@mediapipe/tasks-vision": "^0.10.17", "@mediapipe/tasks-vision": "^0.10.17",
"@pyscript/core": "^0.4.32", "@pyscript/core": "^0.4.32",
"@sveltejs/adapter-node": "^2.0.0", "@sveltejs/adapter-node": "^2.0.0",
"@sveltejs/svelte-virtual-list": "^3.0.1", "@sveltejs/svelte-virtual-list": "^3.0.1",
"@tiptap/core": "^2.11.9", "@tiptap/core": "^3.0.7",
"@tiptap/extension-code-block-lowlight": "^2.11.9", "@tiptap/extension-bubble-menu": "^2.26.1",
"@tiptap/extension-highlight": "^2.10.0", "@tiptap/extension-code-block-lowlight": "^3.0.7",
"@tiptap/extension-placeholder": "^2.10.0", "@tiptap/extension-drag-handle": "^3.0.7",
"@tiptap/extension-table": "^2.12.0", "@tiptap/extension-file-handler": "^3.0.7",
"@tiptap/extension-table-cell": "^2.12.0", "@tiptap/extension-floating-menu": "^2.26.1",
"@tiptap/extension-table-header": "^2.12.0", "@tiptap/extension-highlight": "^3.0.7",
"@tiptap/extension-table-row": "^2.12.0", "@tiptap/extension-image": "^3.0.7",
"@tiptap/extension-typography": "^2.10.0", "@tiptap/extension-link": "^3.0.7",
"@tiptap/pm": "^2.11.7", "@tiptap/extension-list": "^3.0.7",
"@tiptap/starter-kit": "^2.10.0", "@tiptap/extension-table": "^3.0.7",
"@tiptap/extension-typography": "^3.0.7",
"@tiptap/extension-youtube": "^3.0.7",
"@tiptap/extensions": "^3.0.7",
"@tiptap/pm": "^3.0.7",
"@tiptap/starter-kit": "^3.0.7",
"@xyflow/svelte": "^0.1.19", "@xyflow/svelte": "^0.1.19",
"async": "^3.2.5", "async": "^3.2.5",
"bits-ui": "^0.21.15", "bits-ui": "^0.21.15",
"chart.js": "^4.5.0",
"codemirror": "^6.0.1", "codemirror": "^6.0.1",
"codemirror-lang-elixir": "^4.0.0", "codemirror-lang-elixir": "^4.0.0",
"codemirror-lang-hcl": "^0.1.0", "codemirror-lang-hcl": "^0.1.0",
@ -86,9 +93,10 @@
"file-saver": "^2.0.5", "file-saver": "^2.0.5",
"focus-trap": "^7.6.4", "focus-trap": "^7.6.4",
"fuse.js": "^7.0.0", "fuse.js": "^7.0.0",
"heic2any": "^0.0.4",
"highlight.js": "^11.9.0", "highlight.js": "^11.9.0",
"html-entities": "^2.5.3", "html-entities": "^2.5.3",
"html2canvas-pro": "^1.5.8", "html2canvas-pro": "^1.5.11",
"i18next": "^23.10.0", "i18next": "^23.10.0",
"i18next-browser-languagedetector": "^7.2.0", "i18next-browser-languagedetector": "^7.2.0",
"i18next-resources-to-backend": "^1.2.0", "i18next-resources-to-backend": "^1.2.0",
@ -103,6 +111,8 @@
"mermaid": "^11.6.0", "mermaid": "^11.6.0",
"paneforge": "^0.0.6", "paneforge": "^0.0.6",
"panzoom": "^9.4.3", "panzoom": "^9.4.3",
"pdfjs-dist": "^5.3.93",
"prosemirror-collab": "^1.3.1",
"prosemirror-commands": "^1.6.0", "prosemirror-commands": "^1.6.0",
"prosemirror-example-setup": "^1.2.3", "prosemirror-example-setup": "^1.2.3",
"prosemirror-history": "^1.4.1", "prosemirror-history": "^1.4.1",
@ -116,7 +126,7 @@
"prosemirror-view": "^1.34.3", "prosemirror-view": "^1.34.3",
"pyodide": "^0.27.3", "pyodide": "^0.27.3",
"socket.io-client": "^4.2.0", "socket.io-client": "^4.2.0",
"sortablejs": "^1.15.2", "sortablejs": "^1.15.6",
"svelte-sonner": "^0.3.19", "svelte-sonner": "^0.3.19",
"tippy.js": "^6.3.7", "tippy.js": "^6.3.7",
"turndown": "^7.2.0", "turndown": "^7.2.0",
@ -124,7 +134,9 @@
"undici": "^7.3.0", "undici": "^7.3.0",
"uuid": "^9.0.1", "uuid": "^9.0.1",
"vite-plugin-static-copy": "^2.2.0", "vite-plugin-static-copy": "^2.2.0",
"yaml": "^2.7.1" "y-prosemirror": "^1.3.7",
"yaml": "^2.7.1",
"yjs": "^13.6.27"
}, },
"engines": { "engines": {
"node": ">=18.13.0 <=22.x.x", "node": ">=18.13.0 <=22.x.x",

View File

@ -8,7 +8,7 @@ license = { file = "LICENSE" }
dependencies = [ dependencies = [
"fastapi==0.115.7", "fastapi==0.115.7",
"uvicorn[standard]==0.34.2", "uvicorn[standard]==0.34.2",
"pydantic==2.10.6", "pydantic==2.11.7",
"python-multipart==0.0.20", "python-multipart==0.0.20",
"python-socketio==5.13.0", "python-socketio==5.13.0",
"python-jose==3.4.0", "python-jose==3.4.0",
@ -48,7 +48,7 @@ dependencies = [
"fake-useragent==2.1.0", "fake-useragent==2.1.0",
"chromadb==0.6.3", "chromadb==0.6.3",
"pymilvus==2.5.0", "pymilvus==2.5.0",
"qdrant-client~=1.12.0", "qdrant-client==1.14.3",
"opensearch-py==2.8.0", "opensearch-py==2.8.0",
"playwright==1.49.1", "playwright==1.49.1",
"elasticsearch==9.0.1", "elasticsearch==9.0.1",
@ -90,7 +90,6 @@ dependencies = [
"youtube-transcript-api==1.1.0", "youtube-transcript-api==1.1.0",
"pytube==15.0.0", "pytube==15.0.0",
"pydub", "pydub",
"duckduckgo-search==8.0.2",
"ddgs==9.0.0", "ddgs==9.0.0",
"google-api-python-client", "google-api-python-client",
"google-auth-httplib2", "google-auth-httplib2",
@ -115,7 +114,7 @@ requires-python = ">= 3.11, < 3.13.0a1"
dynamic = ["version"] dynamic = ["version"]
classifiers = [ classifiers = [
"Development Status :: 4 - Beta", "Development Status :: 4 - Beta",
"License :: OSI Approved :: MIT License", "License :: Other/Proprietary License",
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.12",
@ -164,3 +163,8 @@ skip = '.git*,*.svg,package-lock.json,i18n,*.lock,*.css,*-bundle.js,locales,exam
check-hidden = true check-hidden = true
# ignore-regex = '' # ignore-regex = ''
ignore-words-list = 'ans' ignore-words-list = 'ans'
[dependency-groups]
dev = [
"pytest-asyncio>=1.0.0",
]

View File

@ -13,7 +13,8 @@ const packages = [
'tiktoken', 'tiktoken',
'seaborn', 'seaborn',
'pytz', 'pytz',
'black' 'black',
'openai'
]; ];
import { loadPyodide } from 'pyodide'; import { loadPyodide } from 'pyodide';
@ -74,8 +75,8 @@ async function downloadPackages() {
console.log('Pyodide version mismatch, removing static/pyodide directory'); console.log('Pyodide version mismatch, removing static/pyodide directory');
await rmdir('static/pyodide', { recursive: true }); await rmdir('static/pyodide', { recursive: true });
} }
} catch (e) { } catch (err) {
console.log('Pyodide package not found, proceeding with download.'); console.log('Pyodide package not found, proceeding with download.', err);
} }
try { try {

View File

@ -40,6 +40,11 @@ code {
width: auto; width: auto;
} }
.editor-selection {
background: rgba(180, 213, 255, 0.5);
border-radius: 2px;
}
.font-secondary { .font-secondary {
font-family: 'InstrumentSerif', sans-serif; font-family: 'InstrumentSerif', sans-serif;
} }
@ -65,19 +70,23 @@ textarea::placeholder {
} }
.input-prose { .input-prose {
@apply prose dark:prose-invert prose-headings:font-semibold prose-hr:my-4 prose-hr:border-gray-100 prose-hr:dark:border-gray-800 prose-p:my-0 prose-img:my-1 prose-headings:my-1 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-ul:-my-0 prose-ol:-my-0 prose-li:-my-0 whitespace-pre-line; @apply prose dark:prose-invert prose-headings:font-semibold prose-hr:my-4 prose-hr:border-gray-100 prose-hr:dark:border-gray-800 prose-p:my-1 prose-img:my-1 prose-headings:my-2 prose-pre:my-0 prose-table:my-1 prose-blockquote:my-0 prose-ul:my-1 prose-ol:my-1 prose-li:my-0.5 whitespace-pre-line;
} }
.input-prose-sm { .input-prose-sm {
@apply prose dark:prose-invert prose-headings:font-semibold prose-hr:my-4 prose-hr:border-gray-100 prose-hr:dark:border-gray-800 prose-p:my-0 prose-img:my-1 prose-headings:my-1 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-ul:-my-0 prose-ol:-my-0 prose-li:-my-0 whitespace-pre-line text-sm; @apply prose dark:prose-invert prose-headings:font-medium prose-h1:text-2xl prose-h2:text-xl prose-h3:text-lg prose-hr:my-4 prose-hr:border-gray-100 prose-hr:dark:border-gray-800 prose-p:my-1 prose-img:my-1 prose-headings:my-2 prose-pre:my-0 prose-table:my-1 prose-blockquote:my-0 prose-ul:my-1 prose-ol:my-1 prose-li:my-1 whitespace-pre-line text-sm;
} }
.markdown-prose { .markdown-prose {
@apply prose dark:prose-invert prose-blockquote:border-s-gray-100 prose-blockquote:dark:border-gray-800 prose-blockquote:border-s-2 prose-blockquote:not-italic prose-blockquote:font-normal prose-headings:font-semibold prose-hr:my-4 prose-hr:border-gray-100 prose-hr:dark:border-gray-800 prose-p:my-0 prose-img:my-1 prose-headings:my-1 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-ul:-my-0 prose-ol:-my-0 prose-li:-my-0 whitespace-pre-line; @apply prose dark:prose-invert prose-blockquote:border-s-gray-100 prose-blockquote:dark:border-gray-800 prose-blockquote:border-s-2 prose-blockquote:not-italic prose-blockquote:font-normal prose-headings:font-semibold prose-hr:my-4 prose-hr:border-gray-100 prose-hr:dark:border-gray-800 prose-p:my-0 prose-img:my-1 prose-headings:my-1 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-ul:-my-0 prose-ol:-my-0 prose-li:-my-0 whitespace-pre-line;
} }
.markdown-prose-sm {
@apply text-sm prose dark:prose-invert prose-blockquote:border-s-gray-100 prose-blockquote:dark:border-gray-800 prose-blockquote:border-s-2 prose-blockquote:not-italic prose-blockquote:font-normal prose-headings:font-semibold prose-hr:my-2 prose-hr:border-gray-100 prose-hr:dark:border-gray-800 prose-p:my-0 prose-img:my-1 prose-headings:my-1 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-ul:-my-0 prose-ol:-my-0 prose-li:-my-0 whitespace-pre-line;
}
.markdown-prose-xs { .markdown-prose-xs {
@apply text-xs prose dark:prose-invert prose-blockquote:border-s-gray-100 prose-blockquote:dark:border-gray-800 prose-blockquote:border-s-2 prose-blockquote:not-italic prose-blockquote:font-normal prose-headings:font-semibold prose-hr:my-0 prose-hr:border-gray-100 prose-hr:dark:border-gray-800 prose-p:my-0 prose-img:my-1 prose-headings:my-1 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-ul:-my-0 prose-ol:-my-0 prose-li:-my-0 whitespace-pre-line; @apply text-xs prose dark:prose-invert prose-blockquote:border-s-gray-100 prose-blockquote:dark:border-gray-800 prose-blockquote:border-s-2 prose-blockquote:not-italic prose-blockquote:font-normal prose-headings:font-semibold prose-hr:my-0.5 prose-hr:border-gray-100 prose-hr:dark:border-gray-800 prose-p:my-0 prose-img:my-1 prose-headings:my-1 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-ul:-my-0 prose-ol:-my-0 prose-li:-my-0 whitespace-pre-line;
} }
.markdown a { .markdown a {
@ -326,6 +335,138 @@ input[type='number'] {
@apply line-clamp-1 absolute; @apply line-clamp-1 absolute;
} }
.tiptap ul[data-type='taskList'] {
list-style: none;
margin-left: 0;
padding: 0;
li {
align-items: start;
display: flex;
> label {
flex: 0 0 auto;
margin-right: 0.5rem;
margin-top: 0.2rem;
user-select: none;
display: flex;
}
> div {
flex: 1 1 auto;
align-items: center;
}
}
/* checked data-checked="true" */
li[data-checked='true'] {
> div {
opacity: 0.5;
text-decoration: line-through;
}
}
input[type='checkbox'] {
cursor: pointer;
}
ul[data-type='taskList'] {
margin: 0;
}
/* Reset nested regular ul elements to default styling */
ul:not([data-type='taskList']) {
list-style: disc;
padding-left: 1rem;
li {
align-items: initial;
display: list-item;
label {
flex: initial;
margin-right: initial;
margin-top: initial;
user-select: initial;
display: initial;
}
div {
flex: initial;
align-items: initial;
}
}
}
}
.input-prose .tiptap ul[data-type='taskList'] {
list-style: none;
margin-left: 0;
padding: 0;
li {
align-items: start;
display: flex;
> label {
flex: 0 0 auto;
margin-right: 0.5rem;
margin-top: 0.4rem;
user-select: none;
display: flex;
}
> div {
flex: 1 1 auto;
align-items: center;
}
}
/* checked data-checked="true" */
li[data-checked='true'] {
> div {
opacity: 0.5;
text-decoration: line-through;
}
}
input[type='checkbox'] {
cursor: pointer;
}
ul[data-type='taskList'] {
margin: 0;
}
/* Reset nested regular ul elements to default styling */
ul:not([data-type='taskList']) {
list-style: disc;
padding-left: 1rem;
li {
align-items: initial;
display: list-item;
label {
flex: initial;
margin-right: initial;
margin-top: initial;
user-select: initial;
display: initial;
}
div {
flex: initial;
align-items: initial;
}
}
}
}
@media (prefers-color-scheme: dark) { @media (prefers-color-scheme: dark) {
.ProseMirror p.is-editor-empty:first-child::before { .ProseMirror p.is-editor-empty:first-child::before {
color: #757575; color: #757575;
@ -339,21 +480,21 @@ input[type='number'] {
pointer-events: none; pointer-events: none;
} }
.tiptap > pre > code { .tiptap pre > code {
border-radius: 0.4rem; border-radius: 0.4rem;
font-size: 0.85rem; font-size: 0.85rem;
padding: 0.25em 0.3em; padding: 0.25em 0.3em;
@apply dark:bg-gray-800 bg-gray-100; @apply dark:bg-gray-800 bg-gray-50;
} }
.tiptap > pre { .tiptap pre {
border-radius: 0.5rem; border-radius: 0.5rem;
font-family: 'JetBrainsMono', monospace; font-family: 'JetBrainsMono', monospace;
margin: 1.5rem 0; margin: 1.5rem 0;
padding: 0.75rem 1rem; padding: 0.75rem 1rem;
@apply dark:bg-gray-800 bg-gray-100; @apply dark:bg-gray-800 bg-gray-50;
} }
.tiptap p code { .tiptap p code {
@ -362,7 +503,7 @@ input[type='number'] {
padding: 3px 8px; padding: 3px 8px;
font-size: 0.8em; font-size: 0.8em;
font-weight: 600; font-weight: 600;
@apply rounded-md dark:bg-gray-800 bg-gray-100 mx-0.5; @apply rounded-md dark:bg-gray-800 bg-gray-50 mx-0.5;
} }
/* Code styling */ /* Code styling */
@ -442,3 +583,36 @@ input[type='number'] {
.tiptap tr { .tiptap tr {
@apply bg-white dark:bg-gray-900 dark:border-gray-850 text-xs; @apply bg-white dark:bg-gray-900 dark:border-gray-850 text-xs;
} }
.tippy-box[data-theme~='transparent'] {
@apply bg-transparent p-0 m-0;
}
/* this is a rough fix for the first cursor position when the first paragraph is empty */
.ProseMirror > .ProseMirror-yjs-cursor:first-child {
margin-top: 16px;
}
/* This gives the remote user caret. The colors are automatically overwritten*/
.ProseMirror-yjs-cursor {
position: relative;
margin-left: -1px;
margin-right: -1px;
border-left: 1px solid black;
border-right: 1px solid black;
border-color: orange;
word-break: normal;
pointer-events: none;
}
/* This renders the username above the caret */
.ProseMirror-yjs-cursor > div {
position: absolute;
top: -1.05em;
left: -1px;
font-size: 13px;
background-color: rgb(250, 129, 0);
user-select: none;
color: white;
padding-left: 2px;
padding-right: 2px;
white-space: nowrap;
}

View File

@ -56,7 +56,6 @@
document.documentElement.classList.add('light'); document.documentElement.classList.add('light');
metaThemeColorTag.setAttribute('content', '#ffffff'); metaThemeColorTag.setAttribute('content', '#ffffff');
} else if (localStorage.theme === 'her') { } else if (localStorage.theme === 'her') {
document.documentElement.classList.add('dark');
document.documentElement.classList.add('her'); document.documentElement.classList.add('her');
metaThemeColorTag.setAttribute('content', '#983724'); metaThemeColorTag.setAttribute('content', '#983724');
} else { } else {
@ -77,28 +76,18 @@
} }
} }
}); });
function setSplashImage() {
const logo = document.getElementById('logo');
const isDarkMode = document.documentElement.classList.contains('dark'); const isDarkMode = document.documentElement.classList.contains('dark');
if (isDarkMode) { const logo = document.createElement('img');
const darkImage = new Image(); logo.id = 'logo';
darkImage.src = '/static/splash-dark.png'; logo.style =
'position: absolute; width: auto; height: 6rem; top: 44%; left: 50%; transform: translateX(-50%); display:block;';
logo.src = isDarkMode ? '/static/splash-dark.png' : '/static/splash.png';
darkImage.onload = () => { document.addEventListener('DOMContentLoaded', function () {
logo.src = '/static/splash-dark.png'; const splash = document.getElementById('splash-screen');
logo.style.filter = ''; // Ensure no inversion is applied if splash-dark.png exists if (splash) splash.prepend(logo);
}; });
darkImage.onerror = () => {
logo.style.filter = 'invert(1)'; // Invert image if splash-dark.png is missing
};
}
}
// Runs after classes are assigned
window.onload = setSplashImage;
})(); })();
</script> </script>
@ -120,19 +109,6 @@
} }
</style> </style>
<img
id="logo"
style="
position: absolute;
width: auto;
height: 6rem;
top: 44%;
left: 50%;
transform: translateX(-50%);
"
src="/static/splash.png"
/>
<div <div
style=" style="
position: absolute; position: absolute;

View File

@ -347,6 +347,8 @@ export const userSignOut = async () => {
if (error) { if (error) {
throw error; throw error;
} }
sessionStorage.clear();
return res; return res;
}; };

View File

@ -1,7 +1,7 @@
import { WEBUI_API_BASE_URL } from '$lib/constants'; import { WEBUI_API_BASE_URL } from '$lib/constants';
import { getTimeRange } from '$lib/utils'; import { getTimeRange } from '$lib/utils';
export const createNewChat = async (token: string, chat: object) => { export const createNewChat = async (token: string, chat: object, folderId: string | null) => {
let error = null; let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/chats/new`, { const res = await fetch(`${WEBUI_API_BASE_URL}/chats/new`, {
@ -12,7 +12,8 @@ export const createNewChat = async (token: string, chat: object) => {
authorization: `Bearer ${token}` authorization: `Bearer ${token}`
}, },
body: JSON.stringify({ body: JSON.stringify({
chat: chat chat: chat,
folder_id: folderId ?? null
}) })
}) })
.then(async (res) => { .then(async (res) => {
@ -37,7 +38,9 @@ export const importChat = async (
chat: object, chat: object,
meta: object | null, meta: object | null,
pinned?: boolean, pinned?: boolean,
folderId?: string | null folderId?: string | null,
createdAt: number | null = null,
updatedAt: number | null = null
) => { ) => {
let error = null; let error = null;
@ -52,7 +55,9 @@ export const importChat = async (
chat: chat, chat: chat,
meta: meta ?? {}, meta: meta ?? {},
pinned: pinned, pinned: pinned,
folder_id: folderId folder_id: folderId,
created_at: createdAt ?? null,
updated_at: updatedAt ?? null
}) })
}) })
.then(async (res) => { .then(async (res) => {

View File

@ -58,10 +58,10 @@ export const exportConfig = async (token: string) => {
return res; return res;
}; };
export const getDirectConnectionsConfig = async (token: string) => { export const getConnectionsConfig = async (token: string) => {
let error = null; let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/configs/direct_connections`, { const res = await fetch(`${WEBUI_API_BASE_URL}/configs/connections`, {
method: 'GET', method: 'GET',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
@ -85,10 +85,10 @@ export const getDirectConnectionsConfig = async (token: string) => {
return res; return res;
}; };
export const setDirectConnectionsConfig = async (token: string, config: object) => { export const setConnectionsConfig = async (token: string, config: object) => {
let error = null; let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/configs/direct_connections`, { const res = await fetch(`${WEBUI_API_BASE_URL}/configs/connections`, {
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',

View File

@ -1,6 +1,11 @@
import { WEBUI_API_BASE_URL } from '$lib/constants'; import { WEBUI_API_BASE_URL } from '$lib/constants';
export const createNewFolder = async (token: string, name: string) => { type FolderForm = {
name: string;
data?: Record<string, any>;
};
export const createNewFolder = async (token: string, folderForm: FolderForm) => {
let error = null; let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/folders/`, { const res = await fetch(`${WEBUI_API_BASE_URL}/folders/`, {
@ -10,9 +15,7 @@ export const createNewFolder = async (token: string, name: string) => {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
authorization: `Bearer ${token}` authorization: `Bearer ${token}`
}, },
body: JSON.stringify({ body: JSON.stringify(folderForm)
name: name
})
}) })
.then(async (res) => { .then(async (res) => {
if (!res.ok) throw await res.json(); if (!res.ok) throw await res.json();
@ -92,7 +95,7 @@ export const getFolderById = async (token: string, id: string) => {
return res; return res;
}; };
export const updateFolderNameById = async (token: string, id: string, name: string) => { export const updateFolderById = async (token: string, id: string, folderForm: FolderForm) => {
let error = null; let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/folders/${id}/update`, { const res = await fetch(`${WEBUI_API_BASE_URL}/folders/${id}/update`, {
@ -102,9 +105,7 @@ export const updateFolderNameById = async (token: string, id: string, name: stri
'Content-Type': 'application/json', 'Content-Type': 'application/json',
authorization: `Bearer ${token}` authorization: `Bearer ${token}`
}, },
body: JSON.stringify({ body: JSON.stringify(folderForm)
name: name
})
}) })
.then(async (res) => { .then(async (res) => {
if (!res.ok) throw await res.json(); if (!res.ok) throw await res.json();

View File

@ -8,17 +8,26 @@ import { toast } from 'svelte-sonner';
export const getModels = async ( export const getModels = async (
token: string = '', token: string = '',
connections: object | null = null, connections: object | null = null,
base: boolean = false base: boolean = false,
refresh: boolean = false
) => { ) => {
const searchParams = new URLSearchParams();
if (refresh) {
searchParams.append('refresh', 'true');
}
let error = null; let error = null;
const res = await fetch(`${WEBUI_BASE_URL}/api/models${base ? '/base' : ''}`, { const res = await fetch(
`${WEBUI_BASE_URL}/api/models${base ? '/base' : ''}?${searchParams.toString()}`,
{
method: 'GET', method: 'GET',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',
'Content-Type': 'application/json', 'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` }) ...(token && { authorization: `Bearer ${token}` })
} }
}) }
)
.then(async (res) => { .then(async (res) => {
if (!res.ok) throw await res.json(); if (!res.ok) throw await res.json();
return res.json(); return res.json();
@ -1587,6 +1596,7 @@ export interface ModelConfig {
} }
export interface ModelMeta { export interface ModelMeta {
toolIds: never[];
description?: string; description?: string;
capabilities?: object; capabilities?: object;
profile_image_url?: string; profile_image_url?: string;

View File

@ -39,7 +39,7 @@ export const createNewNote = async (token: string, note: NoteItem) => {
return res; return res;
}; };
export const getNotes = async (token: string = '') => { export const getNotes = async (token: string = '', raw: boolean = false) => {
let error = null; let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/notes/`, { const res = await fetch(`${WEBUI_API_BASE_URL}/notes/`, {
@ -67,6 +67,10 @@ export const getNotes = async (token: string = '') => {
throw error; throw error;
} }
if (raw) {
return res; // Return raw response if requested
}
if (!Array.isArray(res)) { if (!Array.isArray(res)) {
return {}; // or throw new Error("Notes response is not an array") return {}; // or throw new Error("Notes response is not an array")
} }
@ -87,6 +91,37 @@ export const getNotes = async (token: string = '') => {
return grouped; return grouped;
}; };
export const getNoteList = async (token: string = '') => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/notes/list`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.error(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getNoteById = async (token: string, id: string) => { export const getNoteById = async (token: string, id: string) => {
let error = null; let error = null;

View File

@ -366,7 +366,7 @@ export const unloadModel = async (token: string, tagName: string) => {
Authorization: `Bearer ${token}` Authorization: `Bearer ${token}`
}, },
body: JSON.stringify({ body: JSON.stringify({
name: tagName model: tagName
}) })
}).catch((err) => { }).catch((err) => {
error = err; error = err;
@ -419,7 +419,7 @@ export const deleteModel = async (token: string, tagName: string, urlIdx: string
Authorization: `Bearer ${token}` Authorization: `Bearer ${token}`
}, },
body: JSON.stringify({ body: JSON.stringify({
name: tagName model: tagName
}) })
} }
) )

View File

@ -403,6 +403,7 @@ export const deleteUserById = async (token: string, userId: string) => {
}; };
type UserUpdateForm = { type UserUpdateForm = {
role: string;
profile_image_url: string; profile_image_url: string;
email: string; email: string;
name: string; name: string;

View File

@ -15,6 +15,8 @@
import Tooltip from '$lib/components/common/Tooltip.svelte'; import Tooltip from '$lib/components/common/Tooltip.svelte';
import Switch from '$lib/components/common/Switch.svelte'; import Switch from '$lib/components/common/Switch.svelte';
import Tags from './common/Tags.svelte'; import Tags from './common/Tags.svelte';
import Spinner from '$lib/components/common/Spinner.svelte';
import XMark from '$lib/components/icons/XMark.svelte';
export let onSubmit: Function = () => {}; export let onSubmit: Function = () => {};
export let onDelete: Function = () => {}; export let onDelete: Function = () => {};
@ -33,9 +35,7 @@
let connectionType = 'external'; let connectionType = 'external';
let azure = false; let azure = false;
$: azure = $: azure =
(url.includes('azure.com') || url.includes('cognitive.microsoft.com')) && !direct (url.includes('azure.') || url.includes('cognitive.microsoft.com')) && !direct ? true : false;
? true
: false;
let prefixId = ''; let prefixId = '';
let enable = true; let enable = true;
@ -208,17 +208,7 @@
show = false; show = false;
}} }}
> >
<svg <XMark className={'size-5'} />
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 20 20"
fill="currentColor"
aria-hidden="true"
class="w-5 h-5"
>
<path
d="M6.28 5.22a.75.75 0 00-1.06 1.06L8.94 10l-3.72 3.72a.75.75 0 101.06 1.06L10 11.06l3.72 3.72a.75.75 0 101.06-1.06L11.06 10l3.72-3.72a.75.75 0 00-1.06-1.06L10 8.94 6.28 5.22z"
/>
</svg>
</button> </button>
</div> </div>
@ -524,29 +514,7 @@
{#if loading} {#if loading}
<div class="ml-2 self-center"> <div class="ml-2 self-center">
<svg <Spinner />
class=" w-4 h-4"
viewBox="0 0 24 24"
fill="currentColor"
xmlns="http://www.w3.org/2000/svg"
><style>
.spinner_ajPY {
transform-origin: center;
animation: spinner_AtaB 0.75s infinite linear;
}
@keyframes spinner_AtaB {
100% {
transform: rotate(360deg);
}
}
</style><path
d="M12,1A11,11,0,1,0,23,12,11,11,0,0,0,12,1Zm0,19a8,8,0,1,1,8-8A8,8,0,0,1,12,20Z"
opacity=".25"
/><path
d="M10.14,1.16a11,11,0,0,0-9,8.92A1.59,1.59,0,0,0,2.46,12,1.52,1.52,0,0,0,4.11,10.7a8,8,0,0,1,6.66-6.61A1.42,1.42,0,0,0,12,2.69h0A1.57,1.57,0,0,0,10.14,1.16Z"
class="spinner_ajPY"
/></svg
>
</div> </div>
{/if} {/if}
</button> </button>

View File

@ -3,6 +3,7 @@
import { getContext, onMount } from 'svelte'; import { getContext, onMount } from 'svelte';
const i18n = getContext('i18n'); const i18n = getContext('i18n');
import { settings } from '$lib/stores';
import Modal from '$lib/components/common/Modal.svelte'; import Modal from '$lib/components/common/Modal.svelte';
import Plus from '$lib/components/icons/Plus.svelte'; import Plus from '$lib/components/icons/Plus.svelte';
import Minus from '$lib/components/icons/Minus.svelte'; import Minus from '$lib/components/icons/Minus.svelte';
@ -14,6 +15,8 @@
import { getToolServerData } from '$lib/apis'; import { getToolServerData } from '$lib/apis';
import { verifyToolServerConnection } from '$lib/apis/configs'; import { verifyToolServerConnection } from '$lib/apis/configs';
import AccessControl from './workspace/common/AccessControl.svelte'; import AccessControl from './workspace/common/AccessControl.svelte';
import Spinner from '$lib/components/common/Spinner.svelte';
import XMark from '$lib/components/icons/XMark.svelte';
export let onSubmit: Function = () => {}; export let onSubmit: Function = () => {};
export let onDelete: Function = () => {}; export let onDelete: Function = () => {};
@ -153,29 +156,21 @@
<Modal size="sm" bind:show> <Modal size="sm" bind:show>
<div> <div>
<div class=" flex justify-between dark:text-gray-100 px-5 pt-4 pb-2"> <div class=" flex justify-between dark:text-gray-100 px-5 pt-4 pb-2">
<div class=" text-lg font-medium self-center font-primary"> <h1 class=" text-lg font-medium self-center font-primary">
{#if edit} {#if edit}
{$i18n.t('Edit Connection')} {$i18n.t('Edit Connection')}
{:else} {:else}
{$i18n.t('Add Connection')} {$i18n.t('Add Connection')}
{/if} {/if}
</div> </h1>
<button <button
class="self-center" class="self-center"
aria-label={$i18n.t('Close Configure Connection Modal')}
on:click={() => { on:click={() => {
show = false; show = false;
}} }}
> >
<svg <XMark className={'size-5'} />
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 20 20"
fill="currentColor"
class="w-5 h-5"
>
<path
d="M6.28 5.22a.75.75 0 00-1.06 1.06L8.94 10l-3.72 3.72a.75.75 0 101.06 1.06L10 11.06l3.72 3.72a.75.75 0 101.06-1.06L11.06 10l3.72-3.72a.75.75 0 00-1.06-1.06L10 8.94 6.28 5.22z"
/>
</svg>
</button> </button>
</div> </div>
@ -192,12 +187,17 @@
<div class="flex gap-2"> <div class="flex gap-2">
<div class="flex flex-col w-full"> <div class="flex flex-col w-full">
<div class="flex justify-between mb-0.5"> <div class="flex justify-between mb-0.5">
<div class=" text-xs text-gray-500">{$i18n.t('URL')}</div> <label
for="api-base-url"
class={`text-xs ${($settings?.highContrastMode ?? false) ? 'text-gray-800 dark:text-gray-100' : 'text-gray-500'}`}
>{$i18n.t('URL')}</label
>
</div> </div>
<div class="flex flex-1 items-center"> <div class="flex flex-1 items-center">
<input <input
class="w-full flex-1 text-sm bg-transparent placeholder:text-gray-300 dark:placeholder:text-gray-700 outline-hidden" id="api-base-url"
class={`w-full flex-1 text-sm bg-transparent ${($settings?.highContrastMode ?? false) ? 'placeholder:text-gray-700 dark:placeholder:text-gray-100' : 'outline-hidden placeholder:text-gray-300 dark:placeholder:text-gray-700'}`}
type="text" type="text"
bind:value={url} bind:value={url}
placeholder={$i18n.t('API Base URL')} placeholder={$i18n.t('API Base URL')}
@ -214,6 +214,7 @@
on:click={() => { on:click={() => {
verifyHandler(); verifyHandler();
}} }}
aria-label={$i18n.t('Verify Connection')}
type="button" type="button"
> >
<svg <svg
@ -221,6 +222,7 @@
viewBox="0 0 20 20" viewBox="0 0 20 20"
fill="currentColor" fill="currentColor"
class="w-4 h-4" class="w-4 h-4"
aria-hidden="true"
> >
<path <path
fill-rule="evenodd" fill-rule="evenodd"
@ -237,9 +239,13 @@
</div> </div>
<div class="flex-1 flex items-center"> <div class="flex-1 flex items-center">
<label for="url-or-path" class="sr-only"
>{$i18n.t('openapi.json URL or Path')}</label
>
<input <input
class="w-full text-sm bg-transparent placeholder:text-gray-300 dark:placeholder:text-gray-700 outline-hidden" class={`w-full text-sm bg-transparent ${($settings?.highContrastMode ?? false) ? 'placeholder:text-gray-700 dark:placeholder:text-gray-100' : 'outline-hidden placeholder:text-gray-300 dark:placeholder:text-gray-700'}`}
type="text" type="text"
id="url-or-path"
bind:value={path} bind:value={path}
placeholder={$i18n.t('openapi.json URL or Path')} placeholder={$i18n.t('openapi.json URL or Path')}
autocomplete="off" autocomplete="off"
@ -249,7 +255,9 @@
</div> </div>
</div> </div>
<div class="text-xs text-gray-500 mt-1"> <div
class={`text-xs mt-1 ${($settings?.highContrastMode ?? false) ? 'text-gray-800 dark:text-gray-100' : 'text-gray-500'}`}
>
{$i18n.t(`WebUI will make requests to "{{url}}"`, { {$i18n.t(`WebUI will make requests to "{{url}}"`, {
url: path.includes('://') ? path : `${url}${path.startsWith('/') ? '' : '/'}${path}` url: path.includes('://') ? path : `${url}${path.startsWith('/') ? '' : '/'}${path}`
})} })}
@ -257,12 +265,17 @@
<div class="flex gap-2 mt-2"> <div class="flex gap-2 mt-2">
<div class="flex flex-col w-full"> <div class="flex flex-col w-full">
<div class=" text-xs text-gray-500">{$i18n.t('Auth')}</div> <label
for="select-bearer-or-session"
class={`text-xs ${($settings?.highContrastMode ?? false) ? 'text-gray-800 dark:text-gray-100' : 'text-gray-500'}`}
>{$i18n.t('Auth')}</label
>
<div class="flex gap-2"> <div class="flex gap-2">
<div class="flex-shrink-0 self-start"> <div class="flex-shrink-0 self-start">
<select <select
class="w-full text-sm bg-transparent dark:bg-gray-900 placeholder:text-gray-300 dark:placeholder:text-gray-700 outline-hidden pr-5" id="select-bearer-or-session"
class={`w-full text-sm bg-transparent pr-5 ${($settings?.highContrastMode ?? false) ? 'placeholder:text-gray-700 dark:placeholder:text-gray-100' : 'outline-hidden placeholder:text-gray-300 dark:placeholder:text-gray-700'}`}
bind:value={auth_type} bind:value={auth_type}
> >
<option value="bearer">Bearer</option> <option value="bearer">Bearer</option>
@ -273,13 +286,14 @@
<div class="flex flex-1 items-center"> <div class="flex flex-1 items-center">
{#if auth_type === 'bearer'} {#if auth_type === 'bearer'}
<SensitiveInput <SensitiveInput
className="w-full text-sm bg-transparent placeholder:text-gray-300 dark:placeholder:text-gray-700 outline-hidden"
bind:value={key} bind:value={key}
placeholder={$i18n.t('API Key')} placeholder={$i18n.t('API Key')}
required={false} required={false}
/> />
{:else if auth_type === 'session'} {:else if auth_type === 'session'}
<div class="text-xs text-gray-500 self-center translate-y-[1px]"> <div
class={`text-xs self-center translate-y-[1px] ${($settings?.highContrastMode ?? false) ? 'text-gray-800 dark:text-gray-100' : 'text-gray-500'}`}
>
{$i18n.t('Forwards system user session credentials to authenticate')} {$i18n.t('Forwards system user session credentials to authenticate')}
</div> </div>
{/if} {/if}
@ -293,11 +307,16 @@
<div class="flex gap-2"> <div class="flex gap-2">
<div class="flex flex-col w-full"> <div class="flex flex-col w-full">
<div class=" mb-0.5 text-xs text-gray-500">{$i18n.t('Name')}</div> <label
for="enter-name"
class={`mb-0.5 text-xs" ${($settings?.highContrastMode ?? false) ? 'text-gray-800 dark:text-gray-100' : 'text-gray-500'}`}
>{$i18n.t('Name')}</label
>
<div class="flex-1"> <div class="flex-1">
<input <input
class="w-full text-sm bg-transparent placeholder:text-gray-300 dark:placeholder:text-gray-700 outline-hidden" id="enter-name"
class={`w-full text-sm bg-transparent ${($settings?.highContrastMode ?? false) ? 'placeholder:text-gray-700 dark:placeholder:text-gray-100' : 'outline-hidden placeholder:text-gray-300 dark:placeholder:text-gray-700'}`}
type="text" type="text"
bind:value={name} bind:value={name}
placeholder={$i18n.t('Enter name')} placeholder={$i18n.t('Enter name')}
@ -309,11 +328,16 @@
</div> </div>
<div class="flex flex-col w-full mt-2"> <div class="flex flex-col w-full mt-2">
<div class=" mb-1 text-xs text-gray-500">{$i18n.t('Description')}</div> <label
for="description"
class={`mb-1 text-xs ${($settings?.highContrastMode ?? false) ? 'text-gray-800 dark:text-gray-100 placeholder:text-gray-700 dark:placeholder:text-gray-100' : 'outline-hidden placeholder:text-gray-300 dark:placeholder:text-gray-700 text-gray-500'}`}
>{$i18n.t('Description')}</label
>
<div class="flex-1"> <div class="flex-1">
<input <input
class="w-full text-sm bg-transparent placeholder:text-gray-300 dark:placeholder:text-gray-700 outline-hidden" id="description"
class={`w-full text-sm bg-transparent ${($settings?.highContrastMode ?? false) ? 'placeholder:text-gray-700 dark:placeholder:text-gray-100' : 'outline-hidden placeholder:text-gray-300 dark:placeholder:text-gray-700'}`}
type="text" type="text"
bind:value={description} bind:value={description}
placeholder={$i18n.t('Enter description')} placeholder={$i18n.t('Enter description')}
@ -357,29 +381,7 @@
{#if loading} {#if loading}
<div class="ml-2 self-center"> <div class="ml-2 self-center">
<svg <Spinner />
class=" w-4 h-4"
viewBox="0 0 24 24"
fill="currentColor"
xmlns="http://www.w3.org/2000/svg"
><style>
.spinner_ajPY {
transform-origin: center;
animation: spinner_AtaB 0.75s infinite linear;
}
@keyframes spinner_AtaB {
100% {
transform: rotate(360deg);
}
}
</style><path
d="M12,1A11,11,0,1,0,23,12,11,11,0,0,0,12,1Zm0,19a8,8,0,1,1,8-8A8,8,0,0,1,12,20Z"
opacity=".25"
/><path
d="M10.14,1.16a11,11,0,0,0-9,8.92A1.59,1.59,0,0,0,2.46,12,1.52,1.52,0,0,0,4.11,10.7a8,8,0,0,1,6.66-6.61A1.42,1.42,0,0,0,12,2.69h0A1.57,1.57,0,0,0,10.14,1.16Z"
class="spinner_ajPY"
/></svg
>
</div> </div>
{/if} {/if}
</button> </button>

View File

@ -9,6 +9,7 @@
import Modal from './common/Modal.svelte'; import Modal from './common/Modal.svelte';
import { updateUserSettings } from '$lib/apis/users'; import { updateUserSettings } from '$lib/apis/users';
import XMark from '$lib/components/icons/XMark.svelte';
const i18n = getContext('i18n'); const i18n = getContext('i18n');
@ -36,18 +37,11 @@
localStorage.version = $config.version; localStorage.version = $config.version;
show = false; show = false;
}} }}
aria-label={$i18n.t('Close')}
> >
<svg <XMark className={'size-5'}>
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 20 20"
fill="currentColor"
class="w-5 h-5"
>
<p class="sr-only">{$i18n.t('Close')}</p> <p class="sr-only">{$i18n.t('Close')}</p>
<path </XMark>
d="M6.28 5.22a.75.75 0 00-1.06 1.06L8.94 10l-3.72 3.72a.75.75 0 101.06 1.06L10 11.06l3.72 3.72a.75.75 0 101.06-1.06L11.06 10l3.72-3.72a.75.75 0 00-1.06-1.06L10 8.94 6.28 5.22z"
/>
</svg>
</button> </button>
</div> </div>
<div class="flex items-center mt-1"> <div class="flex items-center mt-1">

View File

@ -3,7 +3,9 @@
import { getContext, onMount } from 'svelte'; import { getContext, onMount } from 'svelte';
const i18n = getContext('i18n'); const i18n = getContext('i18n');
import Spinner from '$lib/components/common/Spinner.svelte';
import Modal from '$lib/components/common/Modal.svelte'; import Modal from '$lib/components/common/Modal.svelte';
import XMark from '$lib/components/icons/XMark.svelte';
import { extractFrontmatter } from '$lib/utils'; import { extractFrontmatter } from '$lib/utils';
export let show = false; export let show = false;
@ -69,16 +71,7 @@
show = false; show = false;
}} }}
> >
<svg <XMark className={'size-5'} />
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 20 20"
fill="currentColor"
class="w-5 h-5"
>
<path
d="M6.28 5.22a.75.75 0 00-1.06 1.06L8.94 10l-3.72 3.72a.75.75 0 101.06 1.06L10 11.06l3.72 3.72a.75.75 0 101.06-1.06L11.06 10l3.72-3.72a.75.75 0 00-1.06-1.06L10 8.94 6.28 5.22z"
/>
</svg>
</button> </button>
</div> </div>
@ -120,29 +113,7 @@
{#if loading} {#if loading}
<div class="ml-2 self-center"> <div class="ml-2 self-center">
<svg <Spinner />
class=" w-4 h-4"
viewBox="0 0 24 24"
fill="currentColor"
xmlns="http://www.w3.org/2000/svg"
><style>
.spinner_ajPY {
transform-origin: center;
animation: spinner_AtaB 0.75s infinite linear;
}
@keyframes spinner_AtaB {
100% {
transform: rotate(360deg);
}
}
</style><path
d="M12,1A11,11,0,1,0,23,12,11,11,0,0,0,12,1Zm0,19a8,8,0,1,1,8-8A8,8,0,0,1,12,20Z"
opacity=".25"
/><path
d="M10.14,1.16a11,11,0,0,0-9,8.92A1.59,1.59,0,0,0,2.46,12,1.52,1.52,0,0,0,4.11,10.7a8,8,0,0,1,6.66-6.61A1.42,1.42,0,0,0,12,2.69h0A1.57,1.57,0,0,0,10.14,1.16Z"
class="spinner_ajPY"
/></svg
>
</div> </div>
{/if} {/if}
</button> </button>

View File

@ -1,4 +1,5 @@
<script lang="ts"> <script lang="ts">
import { WEBUI_BASE_URL } from '$lib/constants';
import { settings, playingNotificationSound, isLastActiveTab } from '$lib/stores'; import { settings, playingNotificationSound, isLastActiveTab } from '$lib/stores';
import DOMPurify from 'dompurify'; import DOMPurify from 'dompurify';
@ -38,7 +39,7 @@
}} }}
> >
<div class="shrink-0 self-top -translate-y-0.5"> <div class="shrink-0 self-top -translate-y-0.5">
<img src={'/static/favicon.png'} alt="favicon" class="size-7 rounded-full" /> <img src="{WEBUI_BASE_URL}/static/favicon.png" alt="favicon" class="size-7 rounded-full" />
</div> </div>
<div> <div>

View File

@ -19,10 +19,10 @@
if (isDarkMode) { if (isDarkMode) {
const darkImage = new Image(); const darkImage = new Image();
darkImage.src = '/static/favicon-dark.png'; darkImage.src = `${WEBUI_BASE_URL}/static/favicon-dark.png`;
darkImage.onload = () => { darkImage.onload = () => {
logo.src = '/static/favicon-dark.png'; logo.src = `${WEBUI_BASE_URL}/static/favicon-dark.png`;
logo.style.filter = ''; // Ensure no inversion is applied if splash-dark.png exists logo.style.filter = ''; // Ensure no inversion is applied if splash-dark.png exists
}; };

View File

@ -13,7 +13,7 @@
import GarbageBin from '$lib/components/icons/GarbageBin.svelte'; import GarbageBin from '$lib/components/icons/GarbageBin.svelte';
import Pencil from '$lib/components/icons/Pencil.svelte'; import Pencil from '$lib/components/icons/Pencil.svelte';
import Tooltip from '$lib/components/common/Tooltip.svelte'; import Tooltip from '$lib/components/common/Tooltip.svelte';
import Download from '$lib/components/icons/Download.svelte'; import Download from '$lib/components/icons/ArrowDownTray.svelte';
let show = false; let show = false;
</script> </script>

View File

@ -2,16 +2,41 @@
import Modal from '$lib/components/common/Modal.svelte'; import Modal from '$lib/components/common/Modal.svelte';
import { getContext } from 'svelte'; import { getContext } from 'svelte';
const i18n = getContext('i18n'); const i18n = getContext('i18n');
import XMark from '$lib/components/icons/XMark.svelte';
import { getFeedbackById } from '$lib/apis/evaluations';
import { toast } from 'svelte-sonner';
import Spinner from '$lib/components/common/Spinner.svelte';
export let show = false; export let show = false;
export let selectedFeedback = null; export let selectedFeedback = null;
export let onClose: () => void = () => {}; export let onClose: () => void = () => {};
let loaded = false;
let feedbackData = null;
const close = () => { const close = () => {
show = false; show = false;
onClose(); onClose();
}; };
const init = async () => {
loaded = false;
feedbackData = null;
if (selectedFeedback) {
feedbackData = await getFeedbackById(localStorage.token, selectedFeedback.id).catch((err) => {
return null;
});
console.log('Feedback Data:', selectedFeedback, feedbackData);
}
loaded = true;
};
$: if (show) {
init();
}
</script> </script>
<Modal size="sm" bind:show> <Modal size="sm" bind:show>
@ -22,49 +47,75 @@
{$i18n.t('Feedback Details')} {$i18n.t('Feedback Details')}
</div> </div>
<button class="self-center" on:click={close} aria-label="Close"> <button class="self-center" on:click={close} aria-label="Close">
<svg <XMark className={'size-5'} />
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 20 20"
fill="currentColor"
class="w-5 h-5"
>
<path
d="M6.28 5.22a.75.75 0 00-1.06 1.06L8.94 10l-3.72 3.72a.75.75 0 101.06 1.06L10 11.06l3.72 3.72a.75.75 0 101.06-1.06L11.06 10l3.72-3.72a.75.75 0 00-1.06-1.06L10 8.94 6.28 5.22z"
/>
</svg>
</button> </button>
</div> </div>
<div class="flex flex-col md:flex-row w-full px-5 pb-4 md:space-x-4 dark:text-gray-200"> <div class="flex flex-col md:flex-row w-full px-5 pb-4 md:space-x-4 dark:text-gray-200">
{#if loaded}
<div class="flex flex-col w-full"> <div class="flex flex-col w-full">
{#if feedbackData}
{@const messageId = feedbackData?.meta?.message_id}
{@const messages = feedbackData?.snapshot?.chat?.chat?.history.messages}
{#if messages[messages[messageId]?.parentId]}
<div class="flex flex-col w-full mb-2">
<div class="mb-1 text-xs text-gray-500">{$i18n.t('Prompt')}</div>
<div class="flex-1 text-xs whitespace-pre-line break-words">
<span>{messages[messages[messageId]?.parentId]?.content || '-'}</span>
</div>
</div>
{/if}
{#if messages[messageId]}
<div class="flex flex-col w-full mb-2">
<div class="mb-1 text-xs text-gray-500">{$i18n.t('Response')}</div>
<div
class="flex-1 text-xs whitespace-pre-line break-words max-h-32 overflow-y-auto"
>
<span>{messages[messageId]?.content || '-'}</span>
</div>
</div>
{/if}
{/if}
<div class="flex flex-col w-full mb-2"> <div class="flex flex-col w-full mb-2">
<div class=" mb-1 text-xs text-gray-500">{$i18n.t('Rating')}</div> <div class=" mb-1 text-xs text-gray-500">{$i18n.t('Rating')}</div>
<div class="flex-1"> <div class="flex-1 text-xs">
<span>{selectedFeedback?.data?.details?.rating ?? '-'}</span> <span>{selectedFeedback?.data?.details?.rating ?? '-'}</span>
</div> </div>
</div> </div>
<div class="flex flex-col w-full mb-2"> <div class="flex flex-col w-full mb-2">
<div class=" mb-1 text-xs text-gray-500">{$i18n.t('Reason')}</div> <div class=" mb-1 text-xs text-gray-500">{$i18n.t('Reason')}</div>
<div class="flex-1"> <div class="flex-1 text-xs">
<span>{selectedFeedback?.data?.reason || '-'}</span> <span>{selectedFeedback?.data?.reason || '-'}</span>
</div> </div>
</div> </div>
<div class="mb-2"> <div class="flex flex-col w-full mb-2">
<div class=" mb-1 text-xs text-gray-500">{$i18n.t('Comment')}</div>
<div class="flex-1 text-xs">
<span>{selectedFeedback?.data?.comment || '-'}</span>
</div>
</div>
{#if selectedFeedback?.data?.tags && selectedFeedback?.data?.tags.length} {#if selectedFeedback?.data?.tags && selectedFeedback?.data?.tags.length}
<div class="mb-2 -mx-1">
<div class="flex flex-wrap gap-1 mt-1"> <div class="flex flex-wrap gap-1 mt-1">
{#each selectedFeedback?.data?.tags as tag} {#each selectedFeedback?.data?.tags as tag}
<span class="px-2 py-0.5 rounded bg-gray-100 dark:bg-gray-800 text-xs">{tag}</span <span class="px-2 py-0.5 rounded-full bg-gray-100 dark:bg-gray-850 text-[9px]"
>{tag}</span
> >
{/each} {/each}
</div> </div>
{:else}
<span>-</span>
{/if}
</div> </div>
<div class="flex justify-end pt-3"> {/if}
<div class="flex justify-end pt-2">
<button <button
class="px-3.5 py-1.5 text-sm font-medium bg-black hover:bg-gray-900 text-white dark:bg-white dark:text-black dark:hover:bg-gray-100 transition rounded-full" class="px-3.5 py-1.5 text-sm font-medium bg-black hover:bg-gray-900 text-white dark:bg-white dark:text-black dark:hover:bg-gray-100 transition rounded-full"
type="button" type="button"
@ -74,6 +125,11 @@
</button> </button>
</div> </div>
</div> </div>
{:else}
<div class="flex items-center justify-center w-full h-32">
<Spinner className={'size-5'} />
</div>
{/if}
</div> </div>
</div> </div>
{/if} {/if}

View File

@ -23,6 +23,8 @@
import ChevronUp from '$lib/components/icons/ChevronUp.svelte'; import ChevronUp from '$lib/components/icons/ChevronUp.svelte';
import ChevronDown from '$lib/components/icons/ChevronDown.svelte'; import ChevronDown from '$lib/components/icons/ChevronDown.svelte';
import { WEBUI_BASE_URL } from '$lib/constants';
import { config } from '$lib/stores';
export let feedbacks = []; export let feedbacks = [];
@ -305,7 +307,7 @@
<tbody class=""> <tbody class="">
{#each paginatedFeedbacks as feedback (feedback.id)} {#each paginatedFeedbacks as feedback (feedback.id)}
<tr <tr
class="bg-white dark:bg-gray-900 dark:border-gray-850 text-xs cursor-pointer hover:bg-gray-50 dark:hover:bg-gray-800 transition" class="bg-white dark:bg-gray-900 dark:border-gray-850 text-xs cursor-pointer hover:bg-gray-50 dark:hover:bg-gray-850/50 transition"
on:click={() => openFeedbackModal(feedback)} on:click={() => openFeedbackModal(feedback)}
> >
<td class=" py-0.5 text-right font-semibold"> <td class=" py-0.5 text-right font-semibold">
@ -313,7 +315,7 @@
<Tooltip content={feedback?.user?.name}> <Tooltip content={feedback?.user?.name}>
<div class="shrink-0"> <div class="shrink-0">
<img <img
src={feedback?.user?.profile_image_url ?? '/user.png'} src={feedback?.user?.profile_image_url ?? `${WEBUI_BASE_URL}/user.png`}
alt={feedback?.user?.name} alt={feedback?.user?.name}
class="size-5 rounded-full object-cover shrink-0" class="size-5 rounded-full object-cover shrink-0"
/> />
@ -353,23 +355,26 @@
</div> </div>
</div> </div>
</td> </td>
{#if feedback?.data?.rating}
<td class="px-3 py-1 text-right font-medium text-gray-900 dark:text-white w-max"> <td class="px-3 py-1 text-right font-medium text-gray-900 dark:text-white w-max">
<div class=" flex justify-end"> <div class=" flex justify-end">
{#if feedback.data.rating.toString() === '1'} {#if feedback?.data?.rating.toString() === '1'}
<Badge type="info" content={$i18n.t('Won')} /> <Badge type="info" content={$i18n.t('Won')} />
{:else if feedback.data.rating.toString() === '0'} {:else if feedback?.data?.rating.toString() === '0'}
<Badge type="muted" content={$i18n.t('Draw')} /> <Badge type="muted" content={$i18n.t('Draw')} />
{:else if feedback.data.rating.toString() === '-1'} {:else if feedback?.data?.rating.toString() === '-1'}
<Badge type="error" content={$i18n.t('Lost')} /> <Badge type="error" content={$i18n.t('Lost')} />
{/if} {/if}
</div> </div>
</td> </td>
{/if}
<td class=" px-3 py-1 text-right font-medium"> <td class=" px-3 py-1 text-right font-medium">
{dayjs(feedback.updated_at * 1000).fromNow()} {dayjs(feedback.updated_at * 1000).fromNow()}
</td> </td>
<td class=" px-3 py-1 text-right font-semibold"> <td class=" px-3 py-1 text-right font-semibold" on:click={(e) => e.stopPropagation()}>
<FeedbackMenu <FeedbackMenu
on:delete={(e) => { on:delete={(e) => {
deleteFeedbackHandler(feedback.id); deleteFeedbackHandler(feedback.id);
@ -389,7 +394,7 @@
{/if} {/if}
</div> </div>
{#if feedbacks.length > 0} {#if feedbacks.length > 0 && $config?.features?.enable_community_sharing}
<div class=" flex flex-col justify-end w-full text-right gap-1"> <div class=" flex flex-col justify-end w-full text-right gap-1">
<div class="line-clamp-1 text-gray-500 text-xs"> <div class="line-clamp-1 text-gray-500 text-xs">
{$i18n.t('Help us create the best community leaderboard by sharing your feedback history!')} {$i18n.t('Help us create the best community leaderboard by sharing your feedback history!')}

View File

@ -11,10 +11,11 @@
import Spinner from '$lib/components/common/Spinner.svelte'; import Spinner from '$lib/components/common/Spinner.svelte';
import Tooltip from '$lib/components/common/Tooltip.svelte'; import Tooltip from '$lib/components/common/Tooltip.svelte';
import MagnifyingGlass from '$lib/components/icons/MagnifyingGlass.svelte'; import Search from '$lib/components/icons/Search.svelte';
import ChevronUp from '$lib/components/icons/ChevronUp.svelte'; import ChevronUp from '$lib/components/icons/ChevronUp.svelte';
import ChevronDown from '$lib/components/icons/ChevronDown.svelte'; import ChevronDown from '$lib/components/icons/ChevronDown.svelte';
import { WEBUI_BASE_URL } from '$lib/constants';
const i18n = getContext('i18n'); const i18n = getContext('i18n');
@ -77,7 +78,7 @@
let showLeaderboardModal = false; let showLeaderboardModal = false;
let selectedModel = null; let selectedModel = null;
const openFeedbackModal = (model) => { const openLeaderboardModelModal = (model) => {
showLeaderboardModal = true; showLeaderboardModal = true;
selectedModel = model; selectedModel = model;
}; };
@ -150,6 +151,8 @@
} }
feedbacks.forEach((feedback) => { feedbacks.forEach((feedback) => {
if (!feedback?.data?.model_id || !feedback?.data?.rating) return;
const modelA = feedback.data.model_id; const modelA = feedback.data.model_id;
const statsA = getOrDefaultStats(modelA); const statsA = getOrDefaultStats(modelA);
let outcome: number; let outcome: number;
@ -350,7 +353,7 @@
<Tooltip content={$i18n.t('Re-rank models by topic similarity')}> <Tooltip content={$i18n.t('Re-rank models by topic similarity')}>
<div class="flex flex-1"> <div class="flex flex-1">
<div class=" self-center ml-1 mr-3"> <div class=" self-center ml-1 mr-3">
<MagnifyingGlass className="size-3" /> <Search className="size-3" />
</div> </div>
<input <input
class=" w-full text-sm pr-4 py-1 rounded-r-xl outline-hidden bg-transparent" class=" w-full text-sm pr-4 py-1 rounded-r-xl outline-hidden bg-transparent"
@ -371,7 +374,7 @@
{#if loadingLeaderboard} {#if loadingLeaderboard}
<div class=" absolute top-0 bottom-0 left-0 right-0 flex"> <div class=" absolute top-0 bottom-0 left-0 right-0 flex">
<div class="m-auto"> <div class="m-auto">
<Spinner /> <Spinner className="size-5" />
</div> </div>
</div> </div>
{/if} {/if}
@ -504,8 +507,8 @@
<tbody class=""> <tbody class="">
{#each sortedModels as model, modelIdx (model.id)} {#each sortedModels as model, modelIdx (model.id)}
<tr <tr
class="bg-white dark:bg-gray-900 dark:border-gray-850 text-xs group cursor-pointer hover:bg-gray-50 dark:hover:bg-gray-800 transition" class="bg-white dark:bg-gray-900 dark:border-gray-850 text-xs group cursor-pointer hover:bg-gray-50 dark:hover:bg-gray-850/50 transition"
on:click={() => openFeedbackModal(model)} on:click={() => openLeaderboardModelModal(model)}
> >
<td class="px-3 py-1.5 text-left font-medium text-gray-900 dark:text-white w-fit"> <td class="px-3 py-1.5 text-left font-medium text-gray-900 dark:text-white w-fit">
<div class=" line-clamp-1"> <div class=" line-clamp-1">
@ -516,7 +519,7 @@
<div class="flex items-center gap-2"> <div class="flex items-center gap-2">
<div class="shrink-0"> <div class="shrink-0">
<img <img
src={model?.info?.meta?.profile_image_url ?? '/favicon.png'} src={model?.info?.meta?.profile_image_url ?? `${WEBUI_BASE_URL}/favicon.png`}
alt={model.name} alt={model.name}
class="size-5 rounded-full object-cover shrink-0" class="size-5 rounded-full object-cover shrink-0"
/> />

View File

@ -6,6 +6,7 @@
export let feedbacks = []; export let feedbacks = [];
export let onClose: () => void = () => {}; export let onClose: () => void = () => {};
const i18n = getContext('i18n'); const i18n = getContext('i18n');
import XMark from '$lib/components/icons/XMark.svelte';
const close = () => { const close = () => {
show = false; show = false;
@ -37,25 +38,16 @@
{model.name} {model.name}
</div> </div>
<button class="self-center" on:click={close} aria-label="Close"> <button class="self-center" on:click={close} aria-label="Close">
<svg <XMark className={'size-5'} />
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 20 20"
fill="currentColor"
class="w-5 h-5"
>
<path
d="M6.28 5.22a.75.75 0 00-1.06 1.06L8.94 10l-3.72 3.72a.75.75 0 101.06 1.06L10 11.06l3.72 3.72a.75.75 0 101.06-1.06L11.06 10l3.72-3.72a.75.75 0 00-1.06-1.06L10 8.94 6.28 5.22z"
/>
</svg>
</button> </button>
</div> </div>
<div class="px-5 pb-4 dark:text-gray-200"> <div class="px-5 pb-4 dark:text-gray-200">
<div class="mb-2"> <div class="mb-2">
{#if topTags.length} {#if topTags.length}
<div class="flex flex-wrap gap-1 mt-1"> <div class="flex flex-wrap gap-1 mt-1 -mx-1">
{#each topTags as tagInfo} {#each topTags as tagInfo}
<span class="px-2 py-0.5 rounded bg-gray-100 dark:bg-gray-800 text-xs"> <span class="px-2 py-0.5 rounded-full bg-gray-100 dark:bg-gray-850 text-xs">
{tagInfo.tag} <span class="text-gray-500">({tagInfo.count})</span> {tagInfo.tag} <span class="text-gray-500 font-medium">{tagInfo.count}</span>
</span> </span>
{/each} {/each}
</div> </div>
@ -63,7 +55,7 @@
<span>-</span> <span>-</span>
{/if} {/if}
</div> </div>
<div class="flex justify-end pt-3"> <div class="flex justify-end pt-2">
<button <button
class="px-3.5 py-1.5 text-sm font-medium bg-black hover:bg-gray-900 text-white dark:bg-white dark:text-black dark:hover:bg-gray-100 transition rounded-full" class="px-3.5 py-1.5 text-sm font-medium bg-black hover:bg-gray-900 text-white dark:bg-white dark:text-black dark:hover:bg-gray-100 transition rounded-full"
type="button" type="button"

View File

@ -49,6 +49,8 @@
let showConfirm = false; let showConfirm = false;
let query = ''; let query = '';
let selectedType = 'all';
let showManifestModal = false; let showManifestModal = false;
let showValvesModal = false; let showValvesModal = false;
let selectedFunction = null; let selectedFunction = null;
@ -59,9 +61,10 @@
$: filteredItems = $functions $: filteredItems = $functions
.filter( .filter(
(f) => (f) =>
query === '' || (selectedType !== 'all' ? f.type === selectedType : true) &&
(query === '' ||
f.name.toLowerCase().includes(query.toLowerCase()) || f.name.toLowerCase().includes(query.toLowerCase()) ||
f.id.toLowerCase().includes(query.toLowerCase()) f.id.toLowerCase().includes(query.toLowerCase()))
) )
.sort((a, b) => a.type.localeCompare(b.type) || a.name.localeCompare(b.name)); .sort((a, b) => a.type.localeCompare(b.type) || a.name.localeCompare(b.name));
@ -135,7 +138,9 @@
models.set( models.set(
await getModels( await getModels(
localStorage.token, localStorage.token,
$config?.features?.enable_direct_connections && ($settings?.directConnections ?? null) $config?.features?.enable_direct_connections && ($settings?.directConnections ?? null),
false,
true
) )
); );
} }
@ -161,7 +166,9 @@
models.set( models.set(
await getModels( await getModels(
localStorage.token, localStorage.token,
$config?.features?.enable_direct_connections && ($settings?.directConnections ?? null) $config?.features?.enable_direct_connections && ($settings?.directConnections ?? null),
false,
true
) )
); );
} }
@ -215,8 +222,8 @@
}} }}
/> />
<div class="flex flex-col gap-1 mt-1.5 mb-2"> <div class="flex flex-col mt-1.5 mb-0.5">
<div class="flex justify-between items-center"> <div class="flex justify-between items-center mb-1">
<div class="flex md:self-center text-xl items-center font-medium px-0.5"> <div class="flex md:self-center text-xl items-center font-medium px-0.5">
{$i18n.t('Functions')} {$i18n.t('Functions')}
<div class="flex self-center w-[1px] h-6 mx-2.5 bg-gray-50 dark:bg-gray-850" /> <div class="flex self-center w-[1px] h-6 mx-2.5 bg-gray-50 dark:bg-gray-850" />
@ -266,12 +273,54 @@
</AddFunctionMenu> </AddFunctionMenu>
</div> </div>
</div> </div>
<div class=" flex w-full">
<div
class="flex gap-1 scrollbar-none overflow-x-auto w-fit text-center text-sm font-medium rounded-full bg-transparent"
>
<button
class="min-w-fit p-1.5 {selectedType === 'all'
? ''
: 'text-gray-300 dark:text-gray-600 hover:text-gray-700 dark:hover:text-white'} transition"
on:click={() => {
selectedType = 'all';
}}>{$i18n.t('All')}</button
>
<button
class="min-w-fit p-1.5 {selectedType === 'pipe'
? ''
: 'text-gray-300 dark:text-gray-600 hover:text-gray-700 dark:hover:text-white'} transition"
on:click={() => {
selectedType = 'pipe';
}}>{$i18n.t('Pipe')}</button
>
<button
class="min-w-fit p-1.5 {selectedType === 'filter'
? ''
: 'text-gray-300 dark:text-gray-600 hover:text-gray-700 dark:hover:text-white'} transition"
on:click={() => {
selectedType = 'filter';
}}>{$i18n.t('Filter')}</button
>
<button
class="min-w-fit p-1.5 {selectedType === 'action'
? ''
: 'text-gray-300 dark:text-gray-600 hover:text-gray-700 dark:hover:text-white'} transition"
on:click={() => {
selectedType = 'action';
}}>{$i18n.t('Action')}</button
>
</div>
</div>
</div> </div>
<div class="mb-5"> <div class="mb-5">
{#each filteredItems as func (func.id)} {#each filteredItems as func (func.id)}
<div <div
class=" flex space-x-4 cursor-pointer w-full px-3 py-2 dark:hover:bg-white/5 hover:bg-black/5 rounded-xl" class=" flex space-x-4 cursor-pointer w-full px-2 py-2 dark:hover:bg-white/5 hover:bg-black/5 rounded-xl"
> >
<a <a
class=" flex flex-1 space-x-3.5 cursor-pointer w-full" class=" flex flex-1 space-x-3.5 cursor-pointer w-full"
@ -413,7 +462,9 @@
await getModels( await getModels(
localStorage.token, localStorage.token,
$config?.features?.enable_direct_connections && $config?.features?.enable_direct_connections &&
($settings?.directConnections ?? null) ($settings?.directConnections ?? null),
false,
true
) )
); );
}} }}
@ -518,7 +569,7 @@
<a <a
class=" flex cursor-pointer items-center justify-between hover:bg-gray-50 dark:hover:bg-gray-850 w-full mb-2 px-3.5 py-1.5 rounded-xl transition" class=" flex cursor-pointer items-center justify-between hover:bg-gray-50 dark:hover:bg-gray-850 w-full mb-2 px-3.5 py-1.5 rounded-xl transition"
href="https://openwebui.com/#open-webui-community" href="https://openwebui.com/functions"
target="_blank" target="_blank"
> >
<div class=" self-center"> <div class=" self-center">
@ -559,7 +610,9 @@
models.set( models.set(
await getModels( await getModels(
localStorage.token, localStorage.token,
$config?.features?.enable_direct_connections && ($settings?.directConnections ?? null) $config?.features?.enable_direct_connections && ($settings?.directConnections ?? null),
false,
true
) )
); );
}} }}
@ -585,7 +638,9 @@
models.set( models.set(
await getModels( await getModels(
localStorage.token, localStorage.token,
$config?.features?.enable_direct_connections && ($settings?.directConnections ?? null) $config?.features?.enable_direct_connections && ($settings?.directConnections ?? null),
false,
true
) )
); );
}; };

View File

@ -12,6 +12,7 @@
} from '$lib/apis/audio'; } from '$lib/apis/audio';
import { config, settings } from '$lib/stores'; import { config, settings } from '$lib/stores';
import Spinner from '$lib/components/common/Spinner.svelte';
import SensitiveInput from '$lib/components/common/SensitiveInput.svelte'; import SensitiveInput from '$lib/components/common/SensitiveInput.svelte';
import { TTS_RESPONSE_SPLIT } from '$lib/types'; import { TTS_RESPONSE_SPLIT } from '$lib/types';
@ -199,7 +200,9 @@
<input <input
class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-hidden" class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-hidden"
bind:value={STT_SUPPORTED_CONTENT_TYPES} bind:value={STT_SUPPORTED_CONTENT_TYPES}
placeholder={$i18n.t('e.g., audio/wav,audio/mpeg (leave blank for defaults)')} placeholder={$i18n.t(
'e.g., audio/wav,audio/mpeg,video/* (leave blank for defaults)'
)}
/> />
</div> </div>
</div> </div>
@ -373,33 +376,7 @@
> >
{#if STT_WHISPER_MODEL_LOADING} {#if STT_WHISPER_MODEL_LOADING}
<div class="self-center"> <div class="self-center">
<svg <Spinner />
class=" w-4 h-4"
viewBox="0 0 24 24"
fill="currentColor"
xmlns="http://www.w3.org/2000/svg"
>
<style>
.spinner_ajPY {
transform-origin: center;
animation: spinner_AtaB 0.75s infinite linear;
}
@keyframes spinner_AtaB {
100% {
transform: rotate(360deg);
}
}
</style>
<path
d="M12,1A11,11,0,1,0,23,12,11,11,0,0,0,12,1Zm0,19a8,8,0,1,1,8-8A8,8,0,0,1,12,20Z"
opacity=".25"
/>
<path
d="M10.14,1.16a11,11,0,0,0-9,8.92A1.59,1.59,0,0,0,2.46,12,1.52,1.52,0,0,0,4.11,10.7a8,8,0,0,1,6.66-6.61A1.42,1.42,0,0,0,12,2.69h0A1.57,1.57,0,0,0,10.14,1.16Z"
class="spinner_ajPY"
/>
</svg>
</div> </div>
{:else} {:else}
<svg <svg

View File

@ -7,7 +7,7 @@
import { getOllamaConfig, updateOllamaConfig } from '$lib/apis/ollama'; import { getOllamaConfig, updateOllamaConfig } from '$lib/apis/ollama';
import { getOpenAIConfig, updateOpenAIConfig, getOpenAIModels } from '$lib/apis/openai'; import { getOpenAIConfig, updateOpenAIConfig, getOpenAIModels } from '$lib/apis/openai';
import { getModels as _getModels } from '$lib/apis'; import { getModels as _getModels } from '$lib/apis';
import { getDirectConnectionsConfig, setDirectConnectionsConfig } from '$lib/apis/configs'; import { getConnectionsConfig, setConnectionsConfig } from '$lib/apis/configs';
import { config, models, settings, user } from '$lib/stores'; import { config, models, settings, user } from '$lib/stores';
@ -25,7 +25,9 @@
const getModels = async () => { const getModels = async () => {
const models = await _getModels( const models = await _getModels(
localStorage.token, localStorage.token,
$config?.features?.enable_direct_connections && ($settings?.directConnections ?? null) $config?.features?.enable_direct_connections && ($settings?.directConnections ?? null),
false,
true
); );
return models; return models;
}; };
@ -41,7 +43,7 @@
let ENABLE_OPENAI_API: null | boolean = null; let ENABLE_OPENAI_API: null | boolean = null;
let ENABLE_OLLAMA_API: null | boolean = null; let ENABLE_OLLAMA_API: null | boolean = null;
let directConnectionsConfig = null; let connectionsConfig = null;
let pipelineUrls = {}; let pipelineUrls = {};
let showAddOpenAIConnectionModal = false; let showAddOpenAIConnectionModal = false;
@ -104,15 +106,13 @@
} }
}; };
const updateDirectConnectionsHandler = async () => { const updateConnectionsHandler = async () => {
const res = await setDirectConnectionsConfig(localStorage.token, directConnectionsConfig).catch( const res = await setConnectionsConfig(localStorage.token, connectionsConfig).catch((error) => {
(error) => {
toast.error(`${error}`); toast.error(`${error}`);
} });
);
if (res) { if (res) {
toast.success($i18n.t('Direct Connections settings updated')); toast.success($i18n.t('Connections settings updated'));
await models.set(await getModels()); await models.set(await getModels());
} }
}; };
@ -148,7 +148,7 @@
openaiConfig = await getOpenAIConfig(localStorage.token); openaiConfig = await getOpenAIConfig(localStorage.token);
})(), })(),
(async () => { (async () => {
directConnectionsConfig = await getDirectConnectionsConfig(localStorage.token); connectionsConfig = await getConnectionsConfig(localStorage.token);
})() })()
]); ]);
@ -196,7 +196,6 @@
const submitHandler = async () => { const submitHandler = async () => {
updateOpenAIHandler(); updateOpenAIHandler();
updateOllamaHandler(); updateOllamaHandler();
updateDirectConnectionsHandler();
dispatch('save'); dispatch('save');
}; };
@ -215,9 +214,14 @@
<form class="flex flex-col h-full justify-between text-sm" on:submit|preventDefault={submitHandler}> <form class="flex flex-col h-full justify-between text-sm" on:submit|preventDefault={submitHandler}>
<div class=" overflow-y-scroll scrollbar-hidden h-full"> <div class=" overflow-y-scroll scrollbar-hidden h-full">
{#if ENABLE_OPENAI_API !== null && ENABLE_OLLAMA_API !== null && directConnectionsConfig !== null} {#if ENABLE_OPENAI_API !== null && ENABLE_OLLAMA_API !== null && connectionsConfig !== null}
<div class="mb-3.5">
<div class=" mb-2.5 text-base font-medium">{$i18n.t('General')}</div>
<hr class=" border-gray-100 dark:border-gray-850 my-2" />
<div class="my-2"> <div class="my-2">
<div class="mt-2 space-y-2 pr-1.5"> <div class="mt-2 space-y-2">
<div class="flex justify-between items-center text-sm"> <div class="flex justify-between items-center text-sm">
<div class=" font-medium">{$i18n.t('OpenAI API')}</div> <div class=" font-medium">{$i18n.t('OpenAI API')}</div>
@ -234,11 +238,9 @@
</div> </div>
{#if ENABLE_OPENAI_API} {#if ENABLE_OPENAI_API}
<hr class=" border-gray-100 dark:border-gray-850" />
<div class=""> <div class="">
<div class="flex justify-between items-center"> <div class="flex justify-between items-center">
<div class="font-medium">{$i18n.t('Manage OpenAI API Connections')}</div> <div class="font-medium text-xs">{$i18n.t('Manage OpenAI API Connections')}</div>
<Tooltip content={$i18n.t(`Add Connection`)}> <Tooltip content={$i18n.t(`Add Connection`)}>
<button <button
@ -271,7 +273,8 @@
let newConfig = {}; let newConfig = {};
OPENAI_API_BASE_URLS.forEach((url, newIdx) => { OPENAI_API_BASE_URLS.forEach((url, newIdx) => {
newConfig[newIdx] = OPENAI_API_CONFIGS[newIdx < idx ? newIdx : newIdx + 1]; newConfig[newIdx] =
OPENAI_API_CONFIGS[newIdx < idx ? newIdx : newIdx + 1];
}); });
OPENAI_API_CONFIGS = newConfig; OPENAI_API_CONFIGS = newConfig;
updateOpenAIHandler(); updateOpenAIHandler();
@ -284,9 +287,7 @@
</div> </div>
</div> </div>
<hr class=" border-gray-100 dark:border-gray-850" /> <div class=" my-2">
<div class="pr-1.5 my-2">
<div class="flex justify-between items-center text-sm mb-2"> <div class="flex justify-between items-center text-sm mb-2">
<div class=" font-medium">{$i18n.t('Ollama API')}</div> <div class=" font-medium">{$i18n.t('Ollama API')}</div>
@ -301,11 +302,9 @@
</div> </div>
{#if ENABLE_OLLAMA_API} {#if ENABLE_OLLAMA_API}
<hr class=" border-gray-100 dark:border-gray-850 my-2" />
<div class=""> <div class="">
<div class="flex justify-between items-center"> <div class="flex justify-between items-center">
<div class="font-medium">{$i18n.t('Manage Ollama API Connections')}</div> <div class="font-medium text-xs">{$i18n.t('Manage Ollama API Connections')}</div>
<Tooltip content={$i18n.t(`Add Connection`)}> <Tooltip content={$i18n.t(`Add Connection`)}>
<button <button
@ -335,7 +334,8 @@
let newConfig = {}; let newConfig = {};
OLLAMA_BASE_URLS.forEach((url, newIdx) => { OLLAMA_BASE_URLS.forEach((url, newIdx) => {
newConfig[newIdx] = OLLAMA_API_CONFIGS[newIdx < idx ? newIdx : newIdx + 1]; newConfig[newIdx] =
OLLAMA_API_CONFIGS[newIdx < idx ? newIdx : newIdx + 1];
}); });
OLLAMA_API_CONFIGS = newConfig; OLLAMA_API_CONFIGS = newConfig;
}} }}
@ -358,31 +358,53 @@
{/if} {/if}
</div> </div>
<hr class=" border-gray-100 dark:border-gray-850" /> <div class="my-2">
<div class="pr-1.5 my-2">
<div class="flex justify-between items-center text-sm"> <div class="flex justify-between items-center text-sm">
<div class=" font-medium">{$i18n.t('Direct Connections')}</div> <div class=" font-medium">{$i18n.t('Direct Connections')}</div>
<div class="flex items-center"> <div class="flex items-center">
<div class=""> <div class="">
<Switch <Switch
bind:state={directConnectionsConfig.ENABLE_DIRECT_CONNECTIONS} bind:state={connectionsConfig.ENABLE_DIRECT_CONNECTIONS}
on:change={async () => { on:change={async () => {
updateDirectConnectionsHandler(); updateConnectionsHandler();
}} }}
/> />
</div> </div>
</div> </div>
</div> </div>
<div class="mt-1.5"> <div class="mt-1 text-xs text-gray-400 dark:text-gray-500">
<div class="text-xs text-gray-500">
{$i18n.t( {$i18n.t(
'Direct Connections allow users to connect to their own OpenAI compatible API endpoints.' 'Direct Connections allow users to connect to their own OpenAI compatible API endpoints.'
)} )}
</div> </div>
</div> </div>
<hr class=" border-gray-100 dark:border-gray-850 my-2" />
<div class="my-2">
<div class="flex justify-between items-center text-sm">
<div class=" text-xs font-medium">{$i18n.t('Cache Base Model List')}</div>
<div class="flex items-center">
<div class="">
<Switch
bind:state={connectionsConfig.ENABLE_BASE_MODELS_CACHE}
on:change={async () => {
updateConnectionsHandler();
}}
/>
</div>
</div>
</div>
<div class="mt-1 text-xs text-gray-400 dark:text-gray-500">
{$i18n.t(
'Base Model List Cache speeds up access by fetching base models only at startup or on settings save—faster, but may not show recent base model changes.'
)}
</div>
</div>
</div> </div>
{:else} {:else}
<div class="flex h-full justify-center"> <div class="flex h-full justify-center">

View File

@ -5,6 +5,7 @@
import Modal from '$lib/components/common/Modal.svelte'; import Modal from '$lib/components/common/Modal.svelte';
import ManageOllama from '../Models/Manage/ManageOllama.svelte'; import ManageOllama from '../Models/Manage/ManageOllama.svelte';
import XMark from '$lib/components/icons/XMark.svelte';
export let show = false; export let show = false;
export let urlIdx: number | null = null; export let urlIdx: number | null = null;
@ -26,16 +27,7 @@
show = false; show = false;
}} }}
> >
<svg <XMark className={'size-5'} />
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 20 20"
fill="currentColor"
class="w-5 h-5"
>
<path
d="M6.28 5.22a.75.75 0 00-1.06 1.06L8.94 10l-3.72 3.72a.75.75 0 101.06 1.06L10 11.06l3.72 3.72a.75.75 0 101.06-1.06L11.06 10l3.72-3.72a.75.75 0 00-1.06-1.06L10 8.94 6.28 5.22z"
/>
</svg>
</button> </button>
</div> </div>

Some files were not shown because too many files have changed in this diff Show More