feat: merge with main

This commit is contained in:
Fabio Polito 2025-03-05 22:04:34 +00:00
commit 9aa407dbd2
372 changed files with 26027 additions and 10944 deletions

View File

@ -1,80 +0,0 @@
---
name: Bug report
about: Create a report to help us improve
title: ''
labels: ''
assignees: ''
---
# Bug Report
## Important Notes
- **Before submitting a bug report**: Please check the Issues or Discussions section to see if a similar issue or feature request has already been posted. It's likely we're already tracking it! If youre unsure, start a discussion post first. This will help us efficiently focus on improving the project.
- **Collaborate respectfully**: We value a constructive attitude, so please be mindful of your communication. If negativity is part of your approach, our capacity to engage may be limited. Were here to help if youre open to learning and communicating positively. Remember, Open WebUI is a volunteer-driven project managed by a single maintainer and supported by contributors who also have full-time jobs. We appreciate your time and ask that you respect ours.
- **Contributing**: If you encounter an issue, we highly encourage you to submit a pull request or fork the project. We actively work to prevent contributor burnout to maintain the quality and continuity of Open WebUI.
- **Bug reproducibility**: If a bug cannot be reproduced with a `:main` or `:dev` Docker setup, or a pip install with Python 3.11, it may require additional help from the community. In such cases, we will move it to the "issues" Discussions section due to our limited resources. We encourage the community to assist with these issues. Remember, its not that the issue doesnt exist; we need your help!
Note: Please remove the notes above when submitting your post. Thank you for your understanding and support!
---
## Installation Method
[Describe the method you used to install the project, e.g., git clone, Docker, pip, etc.]
## Environment
- **Open WebUI Version:** [e.g., v0.3.11]
- **Ollama (if applicable):** [e.g., v0.2.0, v0.1.32-rc1]
- **Operating System:** [e.g., Windows 10, macOS Big Sur, Ubuntu 20.04]
- **Browser (if applicable):** [e.g., Chrome 100.0, Firefox 98.0]
**Confirmation:**
- [ ] I have read and followed all the instructions provided in the README.md.
- [ ] I am on the latest version of both Open WebUI and Ollama.
- [ ] I have included the browser console logs.
- [ ] I have included the Docker container logs.
- [ ] I have provided the exact steps to reproduce the bug in the "Steps to Reproduce" section below.
## Expected Behavior:
[Describe what you expected to happen.]
## Actual Behavior:
[Describe what actually happened.]
## Description
**Bug Summary:**
[Provide a brief but clear summary of the bug]
## Reproduction Details
**Steps to Reproduce:**
[Outline the steps to reproduce the bug. Be as detailed as possible.]
## Logs and Screenshots
**Browser Console Logs:**
[Include relevant browser console logs, if applicable]
**Docker Container Logs:**
[Include relevant Docker container logs, if applicable]
**Screenshots/Screen Recordings (if applicable):**
[Attach any relevant screenshots to help illustrate the issue]
## Additional Information
[Include any additional details that may help in understanding and reproducing the issue. This could include specific configurations, error messages, or anything else relevant to the bug.]
## Note
If the bug report is incomplete or does not follow the provided instructions, it may not be addressed. Please ensure that you have followed the steps outlined in the README.md and troubleshooting.md documents, and provide all necessary information for us to reproduce and address the issue. Thank you!

144
.github/ISSUE_TEMPLATE/bug_report.yaml vendored Normal file
View File

@ -0,0 +1,144 @@
name: Bug Report
description: Create a detailed bug report to help us improve Open WebUI.
title: 'issue: '
labels: ['bug', 'triage']
assignees: []
body:
- type: markdown
attributes:
value: |
# Bug Report
## Important Notes
- **Before submitting a bug report**: Please check the [Issues](https://github.com/open-webui/open-webui/issues) or [Discussions](https://github.com/open-webui/open-webui/discussions) sections to see if a similar issue has already been reported. If unsure, start a discussion first, as this helps us efficiently focus on improving the project.
- **Respectful collaboration**: Open WebUI is a volunteer-driven project with a single maintainer and contributors who also have full-time jobs. Please be constructive and respectful in your communication.
- **Contributing**: If you encounter an issue, consider submitting a pull request or forking the project. We prioritize preventing contributor burnout to maintain Open WebUI's quality.
- **Bug Reproducibility**: If a bug cannot be reproduced using a `:main` or `:dev` Docker setup or with `pip install` on Python 3.11, community assistance may be required. In such cases, we will move it to the "[Issues](https://github.com/open-webui/open-webui/discussions/categories/issues)" Discussions section. Your help is appreciated!
- type: checkboxes
id: issue-check
attributes:
label: Check Existing Issues
description: Confirm that youve checked for existing reports before submitting a new one.
options:
- label: I have searched the existing issues and discussions.
required: true
- type: dropdown
id: installation-method
attributes:
label: Installation Method
description: How did you install Open WebUI?
options:
- Git Clone
- Pip Install
- Docker
- Other
validations:
required: true
- type: input
id: open-webui-version
attributes:
label: Open WebUI Version
description: Specify the version (e.g., v0.3.11)
validations:
required: true
- type: input
id: ollama-version
attributes:
label: Ollama Version (if applicable)
description: Specify the version (e.g., v0.2.0, or v0.1.32-rc1)
validations:
required: false
- type: input
id: operating-system
attributes:
label: Operating System
description: Specify the OS (e.g., Windows 10, macOS Sonoma, Ubuntu 22.04)
validations:
required: true
- type: input
id: browser
attributes:
label: Browser (if applicable)
description: Specify the browser/version (e.g., Chrome 100.0, Firefox 98.0)
validations:
required: false
- type: checkboxes
id: confirmation
attributes:
label: Confirmation
description: Ensure the following prerequisites have been met.
options:
- label: I have read and followed all instructions in `README.md`.
required: true
- label: I am using the latest version of **both** Open WebUI and Ollama.
required: true
- label: I have checked the browser console logs.
required: true
- label: I have checked the Docker container logs.
required: true
- label: I have listed steps to reproduce the bug in detail.
required: true
- type: textarea
id: expected-behavior
attributes:
label: Expected Behavior
description: Describe what should have happened.
validations:
required: true
- type: textarea
id: actual-behavior
attributes:
label: Actual Behavior
description: Describe what actually happened.
validations:
required: true
- type: textarea
id: reproduction-steps
attributes:
label: Steps to Reproduce
description: Provide step-by-step instructions to reproduce the issue.
placeholder: |
1. Go to '...'
2. Click on '...'
3. Scroll down to '...'
4. See the error message '...'
validations:
required: true
- type: textarea
id: logs-screenshots
attributes:
label: Logs & Screenshots
description: Include relevant logs, errors, or screenshots to help diagnose the issue.
placeholder: 'Attach logs from the browser console, Docker logs, or error messages.'
validations:
required: true
- type: textarea
id: additional-info
attributes:
label: Additional Information
description: Provide any extra details that may assist in understanding the issue.
validations:
required: false
- type: markdown
attributes:
value: |
## Note
If the bug report is incomplete or does not follow instructions, it may not be addressed. Ensure that you've followed all the **README.md** and **troubleshooting.md** guidelines, and provide all necessary information for us to reproduce the issue.
Thank you for contributing to Open WebUI!

View File

@ -1,35 +0,0 @@
---
name: Feature request
about: Suggest an idea for this project
title: ''
labels: ''
assignees: ''
---
# Feature Request
## Important Notes
- **Before submitting a report**: Please check the Issues or Discussions section to see if a similar issue or feature request has already been posted. It's likely we're already tracking it! If youre unsure, start a discussion post first. This will help us efficiently focus on improving the project.
- **Collaborate respectfully**: We value a constructive attitude, so please be mindful of your communication. If negativity is part of your approach, our capacity to engage may be limited. Were here to help if youre open to learning and communicating positively. Remember, Open WebUI is a volunteer-driven project managed by a single maintainer and supported by contributors who also have full-time jobs. We appreciate your time and ask that you respect ours.
- **Contributing**: If you encounter an issue, we highly encourage you to submit a pull request or fork the project. We actively work to prevent contributor burnout to maintain the quality and continuity of Open WebUI.
- **Bug reproducibility**: If a bug cannot be reproduced with a `:main` or `:dev` Docker setup, or a pip install with Python 3.11, it may require additional help from the community. In such cases, we will move it to the "issues" Discussions section due to our limited resources. We encourage the community to assist with these issues. Remember, its not that the issue doesnt exist; we need your help!
Note: Please remove the notes above when submitting your post. Thank you for your understanding and support!
---
**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
**Describe the solution you'd like**
A clear and concise description of what you want to happen.
**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.
**Additional context**
Add any other context or screenshots about the feature request here.

View File

@ -0,0 +1,64 @@
name: Feature Request
description: Suggest an idea for this project
title: 'feat: '
labels: ['triage']
body:
- type: markdown
attributes:
value: |
## Important Notes
### Before submitting
Please check the [Issues](https://github.com/open-webui/open-webui/issues) or [Discussions](https://github.com/open-webui/open-webui/discussions) to see if a similar request has been posted.
It's likely we're already tracking it! If youre unsure, start a discussion post first.
This will help us efficiently focus on improving the project.
### Collaborate respectfully
We value a **constructive attitude**, so please be mindful of your communication. If negativity is part of your approach, our capacity to engage may be limited. We're here to help if you're **open to learning** and **communicating positively**.
Remember:
- Open WebUI is a **volunteer-driven project**
- It's managed by a **single maintainer**
- It's supported by contributors who also have **full-time jobs**
We appreciate your time and ask that you **respect ours**.
### Contributing
If you encounter an issue, we highly encourage you to submit a pull request or fork the project. We actively work to prevent contributor burnout to maintain the quality and continuity of Open WebUI.
### Bug reproducibility
If a bug cannot be reproduced with a `:main` or `:dev` Docker setup, or a `pip install` with Python 3.11, it may require additional help from the community. In such cases, we will move it to the "[issues](https://github.com/open-webui/open-webui/discussions/categories/issues)" Discussions section due to our limited resources. We encourage the community to assist with these issues. Remember, its not that the issue doesnt exist; we need your help!
- type: checkboxes
id: existing-issue
attributes:
label: Check Existing Issues
description: Please confirm that you've checked for existing similar requests
options:
- label: I have searched the existing issues and discussions.
required: true
- type: textarea
id: problem-description
attributes:
label: Problem Description
description: Is your feature request related to a problem? Please provide a clear and concise description of what the problem is.
placeholder: "Ex. I'm always frustrated when..."
validations:
required: true
- type: textarea
id: solution-description
attributes:
label: Desired Solution you'd like
description: Clearly describe what you want to happen.
validations:
required: true
- type: textarea
id: alternatives-considered
attributes:
label: Alternatives Considered
description: A clear and concise description of any alternative solutions or features you've considered.
- type: textarea
id: additional-context
attributes:
label: Additional Context
description: Add any other context or screenshots about the feature request here.

View File

@ -14,7 +14,7 @@ env:
jobs:
build-main-image:
runs-on: ubuntu-latest
runs-on: ${{ matrix.platform == 'linux/arm64' && 'ubuntu-24.04-arm' || 'ubuntu-latest' }}
permissions:
contents: read
packages: write
@ -111,7 +111,7 @@ jobs:
retention-days: 1
build-cuda-image:
runs-on: ubuntu-latest
runs-on: ${{ matrix.platform == 'linux/arm64' && 'ubuntu-24.04-arm' || 'ubuntu-latest' }}
permissions:
contents: read
packages: write
@ -211,7 +211,7 @@ jobs:
retention-days: 1
build-ollama-image:
runs-on: ubuntu-latest
runs-on: ${{ matrix.platform == 'linux/arm64' && 'ubuntu-24.04-arm' || 'ubuntu-latest' }}
permissions:
contents: read
packages: write

View File

@ -19,7 +19,7 @@ jobs:
uses: actions/checkout@v4
- uses: actions/setup-node@v4
with:
node-version: 18
node-version: 22
- uses: actions/setup-python@v5
with:
python-version: 3.11

View File

@ -5,6 +5,185 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.5.19] - 2024-03-04
### Added
- **📊 Logit Bias Parameter Support**: Fine-tune conversation dynamics by adjusting the Logit Bias parameter directly in chat settings, giving you more control over model responses.
- **⌨️ Customizable Enter Behavior**: You can now configure Enter to send messages only when combined with Ctrl (Ctrl+Enter) via Settings > Interface, preventing accidental message sends.
- **📝 Collapsible Code Blocks**: Easily collapse long code blocks to declutter your chat, making it easier to focus on important details.
- **🏷️ Tag Selector in Model Selector**: Quickly find and categorize models with the new tag filtering system in the Model Selector, streamlining model discovery.
- **📈 Experimental Elasticsearch Vector DB Support**: Now supports Elasticsearch as a vector database, offering more flexibility for data retrieval in Retrieval-Augmented Generation (RAG) workflows.
- **⚙️ General Reliability Enhancements**: Various stability improvements across the WebUI, ensuring a smoother, more consistent experience.
- **🌍 Updated Translations**: Refined multilingual support for better localization and accuracy across various languages.
### Fixed
- **🔄 "Stream" Hook Activation**: Fixed an issue where the "Stream" hook only worked when globally enabled, ensuring reliable real-time filtering.
- **📧 LDAP Email Case Sensitivity**: Resolved an issue where LDAP login failed due to email case sensitivity mismatches, improving authentication reliability.
- **💬 WebSocket Chat Event Registration**: Fixed a bug preventing chat event listeners from being registered upon sign-in, ensuring real-time updates work properly.
## [0.5.18] - 2025-02-27
### Fixed
- **🌐 Open WebUI Now Works Over LAN in Insecure Context**: Resolved an issue preventing Open WebUI from functioning when accessed over a local network in an insecure context, ensuring seamless connectivity.
- **🔄 UI Now Reflects Deleted Connections Instantly**: Fixed an issue where deleting a connection did not update the UI in real time, ensuring accurate system state visibility.
- **🛠️ Models Now Display Correctly with ENABLE_FORWARD_USER_INFO_HEADERS**: Addressed a bug where models were not visible when ENABLE_FORWARD_USER_INFO_HEADERS was set, restoring proper model listing.
## [0.5.17] - 2025-02-27
### Added
- **🚀 Instant Document Upload with Bypass Embedding & Retrieval**: Admins can now enable "Bypass Embedding & Retrieval" in Admin Settings > Documents, significantly speeding up document uploads and ensuring full document context is retained without chunking.
- **🔎 "Stream" Hook for Real-Time Filtering**: The new "stream" hook allows dynamic real-time message filtering. Learn more in our documentation (https://docs.openwebui.com/features/plugin/functions/filter).
- **☁️ OneDrive Integration**: Early support for OneDrive storage integration has been introduced, expanding file import options.
- **📈 Enhanced Logging with Loguru**: Backend logging has been improved with Loguru, making debugging and issue tracking far more efficient.
- **⚙️ General Stability Enhancements**: Backend and frontend refactoring improves performance, ensuring a smoother and more reliable user experience.
- **🌍 Updated Translations**: Refined multilingual support for better localization and accuracy across various languages.
### Fixed
- **🔄 Reliable Model Imports from the Community Platform**: Resolved import failures, allowing seamless integration of community-shared models without errors.
- **📊 OpenAI Usage Statistics Restored**: Fixed an issue where OpenAI usage metrics were not displaying correctly, ensuring accurate tracking of usage data.
- **🗂️ Deduplication for Retrieved Documents**: Documents retrieved during searches are now intelligently deduplicated, meaning no more redundant results—helping to keep information concise and relevant.
### Changed
- **📝 "Full Context Mode" Renamed for Clarity**: The "Full Context Mode" toggle in Web Search settings is now labeled "Bypass Embedding & Retrieval" for consistency across the UI.
## [0.5.16] - 2025-02-20
### Fixed
- **🔍 Web Search Retrieval Restored**: Resolved a critical issue that broke web search retrieval by reverting deduplication changes, ensuring complete and accurate search results once again.
## [0.5.15] - 2025-02-20
### Added
- **📄 Full Context Mode for Local Document Search (RAG)**: Toggle full context mode from Admin Settings > Documents to inject entire document content into context, improving accuracy for models with large context windows—ideal for deep context understanding.
- **🌍 Smarter Web Search with Agentic Workflows**: Web searches now intelligently gather and refine multiple relevant terms, similar to RAG handling, delivering significantly better search results for more accurate information retrieval.
- **🔎 Experimental Playwright Support for Web Loader**: Web content retrieval is taken to the next level with Playwright-powered scraping for enhanced accuracy in extracted web data.
- **☁️ Experimental Azure Storage Provider**: Early-stage support for Azure Storage allows more cloud storage flexibility directly within Open WebUI.
- **📊 Improved Jupyter Code Execution with Plots**: Interactive coding now properly displays inline plots, making data visualization more seamless inside chat interactions.
- **⏳ Adjustable Execution Timeout for Jupyter Interpreter**: Customize execution timeout (default: 60s) for Jupyter-based code execution, allowing longer or more constrained execution based on your needs.
- **▶️ "Running..." Indicator for Jupyter Code Execution**: A visual indicator now appears while code execution is in progress, providing real-time status updates on ongoing computations.
- **⚙️ General Backend & Frontend Stability Enhancements**: Extensive refactoring improves reliability, performance, and overall user experience for a more seamless Open WebUI.
- **🌍 Translation Updates**: Various international translation refinements ensure better localization and a more natural user interface experience.
### Fixed
- **📱 Mobile Hover Issue Resolved**: Users can now edit responses smoothly on mobile without interference, fixing a longstanding hover issue.
- **🔄 Temporary Chat Message Duplication Fixed**: Eliminated buggy behavior where messages were being unnecessarily repeated in temporary chat mode, ensuring a smooth and consistent conversation flow.
## [0.5.14] - 2025-02-17
### Fixed
- **🔧 Critical Import Error Resolved**: Fixed a circular import issue preventing 'override_static' from being correctly imported in 'open_webui.config', ensuring smooth system initialization and stability.
## [0.5.13] - 2025-02-17
### Added
- **🌐 Full Context Mode for Web Search**: Enable highly accurate web searches by utilizing full context mode—ideal for models with large context windows, ensuring more precise and insightful results.
- **⚡ Optimized Asynchronous Web Search**: Web searches now load significantly faster with optimized async support, providing users with quicker, more efficient information retrieval.
- **🔄 Auto Text Direction for RTL Languages**: Automatic text alignment based on language input, ensuring seamless conversation flow for Arabic, Hebrew, and other right-to-left scripts.
- **🚀 Jupyter Notebook Support for Code Execution**: The "Run" button in code blocks can now use Jupyter for execution, offering a powerful, dynamic coding experience directly in the chat.
- **🗑️ Message Delete Confirmation Dialog**: Prevent accidental deletions with a new confirmation prompt before removing messages, adding an additional layer of security to your chat history.
- **📥 Download Button for SVG Diagrams**: SVG diagrams generated within chat can now be downloaded instantly, making it easier to save and share complex visual data.
- **✨ General UI/UX Improvements and Backend Stability**: A refined interface with smoother interactions, improved layouts, and backend stability enhancements for a more reliable, polished experience.
### Fixed
- **🛠️ Temporary Chat Message Continue Button Fixed**: The "Continue Response" button for temporary chats now works as expected, ensuring an uninterrupted conversation flow.
### Changed
- **📝 Prompt Variable Update**: Deprecated square bracket '[]' indicators for prompt variables; now requires double curly brackets '{{}}' for consistency and clarity.
- **🔧 Stability Enhancements**: Error handling improved in chat history, ensuring smoother operations when reviewing previous messages.
## [0.5.12] - 2025-02-13
### Added
- **🛠️ Multiple Tool Calls Support for Native Function Mode**: Functions now can call multiple tools within a single response, unlocking better automation and workflow flexibility when using native function calling.
### Fixed
- **📝 Playground Text Completion Restored**: Addressed an issue where text completion in the Playground was not functioning.
- **🔗 Direct Connections Now Work for Regular Users**: Fixed a bug where users with the 'user' role couldn't establish direct API connections, enabling seamless model usage for all user tiers.
- **⚡ Landing Page Input No Longer Lags with Long Text**: Improved input responsiveness on the landing page, ensuring fast and smooth typing experiences even when entering long messages.
- **🔧 Parameter in Functions Fixed**: Fixed an issue where the reserved parameters wasnt recognized within functions, restoring full functionality for advanced task-based automation.
## [0.5.11] - 2025-02-13
### Added
- **🎤 Kokoro-JS TTS Support**: A new on-device, high-quality text-to-speech engine has been integrated, vastly improving voice generation quality—everything runs directly in your browser.
- **🐍 Jupyter Notebook Support in Code Interpreter**: Now, you can configure Code Interpreter to run Python code not only via Pyodide but also through Jupyter, offering a more robust coding environment for AI-driven computations and analysis.
- **🔗 Direct API Connections for Private & Local Inference**: You can now connect Open WebUI to your private or localhost API inference endpoints. CORS must be enabled, but this unlocks direct, on-device AI infrastructure support.
- **🔍 Advanced Domain Filtering for Web Search**: You can now specify which domains should be included or excluded from web searches, refining results for more relevant information retrieval.
- **🚀 Improved Image Generation Metadata Handling**: Generated images now retain metadata for better organization and future retrieval.
- **📂 S3 Key Prefix Support**: Fine-grained control over S3 storage file structuring with configurable key prefixes.
- **📸 Support for Image-Only Messages**: Send messages containing only images, facilitating more visual-centric interactions.
- **🌍 Updated Translations**: German, Spanish, Traditional Chinese, and Catalan translations updated for better multilingual support.
### Fixed
- **🔧 OAuth Debug Logs & Username Claim Fixes**: Debug logs have been added for OAuth role and group management, with fixes ensuring proper OAuth username retrieval and claim handling.
- **📌 Citations Formatting & Toggle Fixes**: Inline citation toggles now function correctly, and citations with more than three sources are now fully visible when expanded.
- **📸 ComfyUI Maximum Seed Value Constraint Fixed**: The maximum allowed seed value for ComfyUI has been corrected, preventing unintended behavior.
- **🔑 Connection Settings Stability**: Addressed connection settings issues that were causing instability when saving configurations.
- **📂 GGUF Model Upload Stability**: Fixed upload inconsistencies for GGUF models, ensuring reliable local model handling.
- **🔧 Web Search Configuration Bug**: Fixed issues where web search filters and settings weren't correctly applied.
- **💾 User Settings Persistence Fix**: Ensured user-specific settings are correctly saved and applied across sessions.
- **🔄 OpenID Username Retrieval Enhancement**: Usernames are now correctly picked up and assigned for OpenID Connect (OIDC) logins.
## [0.5.10] - 2025-02-05
### Fixed
- **⚙️ System Prompts Now Properly Templated via API**: Resolved an issue where system prompts were not being correctly processed when used through the API, ensuring template variables now function as expected.
- **📝 '<thinking>' Tag Display Issue Fixed**: Fixed a bug where the 'thinking' tag was disrupting content rendering, ensuring clean and accurate text display.
- **💻 Code Interpreter Stability with Custom Functions**: Addressed failures when using the Code Interpreter with certain custom functions like Anthropic, ensuring smoother execution and better compatibility.
## [0.5.9] - 2025-02-05
### Fixed
- **💡 "Think" Tag Display Issue**: Resolved a bug where the "Think" tag was not functioning correctly, ensuring proper visualization of the model's reasoning process before delivering responses.
## [0.5.8] - 2025-02-05
### Added
- **🖥️ Code Interpreter**: Models can now execute code in real time to refine their answers dynamically, running securely within a sandboxed browser environment using Pyodide. Perfect for calculations, data analysis, and AI-assisted coding tasks!
- **💬 Redesigned Chat Input UI**: Enjoy a sleeker and more intuitive message input with improved feature selection, making it easier than ever to toggle tools, enable search, and interact with AI seamlessly.
- **🛠️ Native Tool Calling Support (Experimental)**: Supported models can now call tools natively, reducing query latency and improving contextual responses. More enhancements coming soon!
- **🔗 Exa Search Engine Integration**: A new search provider has been added, allowing users to retrieve up-to-date and relevant information without leaving the chat interface.
- **🌍 Localized Dates & Times**: Date and time formats now match your system locale, ensuring a more natural, region-specific experience.
- **📎 User Headers for External Embedding APIs**: API calls to external embedding services now include user-related headers.
- **🌍 "Always On" Web Search Toggle**: A new option under Settings > Interface allows users to enable Web Search by default—transform Open WebUI into your go-to search engine, ensuring AI-powered results with every query.
- **🚀 General Performance & Stability**: Significant improvements across the platform for a faster, more reliable experience.
- **🖼️ UI/UX Enhancements**: Numerous design refinements improving readability, responsiveness, and accessibility.
- **🌍 Improved Translations**: Chinese, Korean, French, Ukrainian and Serbian translations have been updated with refined terminologies for better clarity.
### Fixed
- **🔄 OAuth Name Field Fallback**: Resolves OAuth login failures by using the email field as a fallback when a name is missing.
- **🔑 Google Drive Credentials Restriction**: Ensures only authenticated users can access Google Drive credentials for enhanced security.
- **🌐 DuckDuckGo Search Rate Limit Handling**: Fixes issues where users would encounter 202 errors due to rate limits when using DuckDuckGo for web search.
- **📁 File Upload Permission Indicator**: Users are now notified when they lack permission to upload files, improving clarity on system restrictions.
- **🔧 Max Tokens Issue**: Fixes cases where 'max_tokens' were not applied correctly, ensuring proper model behavior.
- **🔍 Validation for RAG Web Search URLs**: Filters out invalid or unsupported URLs when using web-based retrieval augmentation.
- **🖋️ Title Generation Bug**: Fixes inconsistencies in title generation, ensuring proper chat organization.
### Removed
- **⚡ Deprecated Non-Web Worker Pyodide Execution**: Moves entirely to browser sandboxing for better performance and security.
## [0.5.7] - 2025-01-23
### Added

View File

@ -27,10 +27,15 @@ git push origin main
**Open WebUI is an [extensible](https://docs.openwebui.com/features/plugin/), feature-rich, and user-friendly self-hosted AI platform designed to operate entirely offline.** It supports various LLM runners like **Ollama** and **OpenAI-compatible APIs**, with **built-in inference engine** for RAG, making it a **powerful AI deployment solution**.
For more information, be sure to check out our [Open WebUI Documentation](https://docs.openwebui.com/).
![Open WebUI Demo](./demo.gif)
> [!TIP]
> **Looking for an [Enterprise Plan](https://docs.openwebui.com/enterprise)?** **[Speak with Our Sales Team Today!](mailto:sales@openwebui.com)**
>
> Get **enhanced capabilities**, including **custom theming and branding**, **Service Level Agreement (SLA) support**, **Long-Term Support (LTS) versions**, and **more!**
For more information, be sure to check out our [Open WebUI Documentation](https://docs.openwebui.com/).
## Key Features of Open WebUI ⭐
- 🚀 **Effortless Setup**: Install seamlessly using Docker or Kubernetes (kubectl, kustomize or helm) for a hassle-free experience with support for both `:ollama` and `:cuda` tagged images.
@ -188,7 +193,7 @@ docker run --rm --volume /var/run/docker.sock:/var/run/docker.sock containrrr/wa
In the last part of the command, replace `open-webui` with your container name if it is different.
Check our Migration Guide available in our [Open WebUI Documentation](https://docs.openwebui.com/tutorials/migration/).
Check our Updating Guide available in our [Open WebUI Documentation](https://docs.openwebui.com/getting-started/updating).
### Using the Dev Branch 🌙

View File

@ -2,12 +2,13 @@ import json
import logging
import os
import shutil
import base64
from datetime import datetime
from pathlib import Path
from typing import Generic, Optional, TypeVar
from urllib.parse import urlparse
import chromadb
import requests
from pydantic import BaseModel
from sqlalchemy import JSON, Column, DateTime, Integer, func
@ -42,7 +43,7 @@ logging.getLogger("uvicorn.access").addFilter(EndpointFilter())
# Function to run the alembic migrations
def run_migrations():
print("Running migrations")
log.info("Running migrations")
try:
from alembic import command
from alembic.config import Config
@ -55,7 +56,7 @@ def run_migrations():
command.upgrade(alembic_cfg, "head")
except Exception as e:
print(f"Error: {e}")
log.exception(f"Error running migrations: {e}")
run_migrations()
@ -586,6 +587,14 @@ load_oauth_providers()
STATIC_DIR = Path(os.getenv("STATIC_DIR", OPEN_WEBUI_DIR / "static")).resolve()
for file_path in (FRONTEND_BUILD_DIR / "static").glob("**/*"):
if file_path.is_file():
target_path = STATIC_DIR / file_path.relative_to(
(FRONTEND_BUILD_DIR / "static")
)
target_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copyfile(file_path, target_path)
frontend_favicon = FRONTEND_BUILD_DIR / "static" / "favicon.png"
if frontend_favicon.exists():
@ -593,8 +602,6 @@ if frontend_favicon.exists():
shutil.copyfile(frontend_favicon, STATIC_DIR / "favicon.png")
except Exception as e:
logging.error(f"An error occurred: {e}")
else:
logging.warning(f"Frontend favicon not found at {frontend_favicon}")
frontend_splash = FRONTEND_BUILD_DIR / "static" / "splash.png"
@ -603,12 +610,18 @@ if frontend_splash.exists():
shutil.copyfile(frontend_splash, STATIC_DIR / "splash.png")
except Exception as e:
logging.error(f"An error occurred: {e}")
else:
logging.warning(f"Frontend splash not found at {frontend_splash}")
frontend_loader = FRONTEND_BUILD_DIR / "static" / "loader.js"
if frontend_loader.exists():
try:
shutil.copyfile(frontend_loader, STATIC_DIR / "loader.js")
except Exception as e:
logging.error(f"An error occurred: {e}")
####################################
# CUSTOM_NAME
# CUSTOM_NAME (Legacy)
####################################
CUSTOM_NAME = os.environ.get("CUSTOM_NAME", "")
@ -650,6 +663,12 @@ if CUSTOM_NAME:
pass
####################################
# LICENSE_KEY
####################################
LICENSE_KEY = os.environ.get("LICENSE_KEY", "")
####################################
# STORAGE PROVIDER
####################################
@ -660,27 +679,47 @@ S3_ACCESS_KEY_ID = os.environ.get("S3_ACCESS_KEY_ID", None)
S3_SECRET_ACCESS_KEY = os.environ.get("S3_SECRET_ACCESS_KEY", None)
S3_REGION_NAME = os.environ.get("S3_REGION_NAME", None)
S3_BUCKET_NAME = os.environ.get("S3_BUCKET_NAME", None)
S3_KEY_PREFIX = os.environ.get("S3_KEY_PREFIX", None)
S3_ENDPOINT_URL = os.environ.get("S3_ENDPOINT_URL", None)
S3_USE_ACCELERATE_ENDPOINT = (
os.environ.get("S3_USE_ACCELERATE_ENDPOINT", "False").lower() == "true"
)
S3_ADDRESSING_STYLE = os.environ.get("S3_ADDRESSING_STYLE", None)
GCS_BUCKET_NAME = os.environ.get("GCS_BUCKET_NAME", None)
GOOGLE_APPLICATION_CREDENTIALS_JSON = os.environ.get(
"GOOGLE_APPLICATION_CREDENTIALS_JSON", None
)
AZURE_STORAGE_ENDPOINT = os.environ.get("AZURE_STORAGE_ENDPOINT", None)
AZURE_STORAGE_CONTAINER_NAME = os.environ.get("AZURE_STORAGE_CONTAINER_NAME", None)
AZURE_STORAGE_KEY = os.environ.get("AZURE_STORAGE_KEY", None)
####################################
# File Upload DIR
####################################
UPLOAD_DIR = f"{DATA_DIR}/uploads"
Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True)
UPLOAD_DIR = DATA_DIR / "uploads"
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
####################################
# Cache DIR
####################################
CACHE_DIR = f"{DATA_DIR}/cache"
Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
CACHE_DIR = DATA_DIR / "cache"
CACHE_DIR.mkdir(parents=True, exist_ok=True)
####################################
# DIRECT CONNECTIONS
####################################
ENABLE_DIRECT_CONNECTIONS = PersistentConfig(
"ENABLE_DIRECT_CONNECTIONS",
"direct.enable",
os.environ.get("ENABLE_DIRECT_CONNECTIONS", "True").lower() == "true",
)
####################################
# OLLAMA_BASE_URL
@ -755,6 +794,9 @@ ENABLE_OPENAI_API = PersistentConfig(
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "")
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", "")
GEMINI_API_BASE_URL = os.environ.get("GEMINI_API_BASE_URL", "")
if OPENAI_API_BASE_URL == "":
OPENAI_API_BASE_URL = "https://api.openai.com/v1"
@ -927,6 +969,12 @@ USER_PERMISSIONS_FEATURES_IMAGE_GENERATION = (
== "true"
)
USER_PERMISSIONS_FEATURES_CODE_INTERPRETER = (
os.environ.get("USER_PERMISSIONS_FEATURES_CODE_INTERPRETER", "True").lower()
== "true"
)
DEFAULT_USER_PERMISSIONS = {
"workspace": {
"models": USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS,
@ -944,6 +992,7 @@ DEFAULT_USER_PERMISSIONS = {
"features": {
"web_search": USER_PERMISSIONS_FEATURES_WEB_SEARCH,
"image_generation": USER_PERMISSIONS_FEATURES_IMAGE_GENERATION,
"code_interpreter": USER_PERMISSIONS_FEATURES_CODE_INTERPRETER,
},
}
@ -1052,7 +1101,7 @@ try:
banners = json.loads(os.environ.get("WEBUI_BANNERS", "[]"))
banners = [BannerModel(**banner) for banner in banners]
except Exception as e:
print(f"Error loading WEBUI_BANNERS: {e}")
log.exception(f"Error loading WEBUI_BANNERS: {e}")
banners = []
WEBUI_BANNERS = PersistentConfig("WEBUI_BANNERS", "ui.banners", banners)
@ -1094,21 +1143,27 @@ TITLE_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
os.environ.get("TITLE_GENERATION_PROMPT_TEMPLATE", ""),
)
DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE = """Create a concise, 3-5 word title with an emoji as a title for the chat history, in the given language. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT.
Examples of titles:
📉 Stock Market Trends
🍪 Perfect Chocolate Chip Recipe
Evolution of Music Streaming
Remote Work Productivity Tips
Artificial Intelligence in Healthcare
🎮 Video Game Development Insights
DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE = """### Task:
Generate a concise, 3-5 word title with an emoji summarizing the chat history.
### Guidelines:
- The title should clearly represent the main theme or subject of the conversation.
- Use emojis that enhance understanding of the topic, but avoid quotation marks or special formatting.
- Write the title in the chat's primary language; default to English if multilingual.
- Prioritize accuracy over excessive creativity; keep it clear and simple.
### Output:
JSON format: { "title": "your concise title here" }
### Examples:
- { "title": "📉 Stock Market Trends" },
- { "title": "🍪 Perfect Chocolate Chip Recipe" },
- { "title": "Evolution of Music Streaming" },
- { "title": "Remote Work Productivity Tips" },
- { "title": "Artificial Intelligence in Healthcare" },
- { "title": "🎮 Video Game Development Insights" }
### Chat History:
<chat_history>
{{MESSAGES:END:2}}
</chat_history>"""
TAGS_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
"TAGS_GENERATION_PROMPT_TEMPLATE",
"task.tags.prompt_template",
@ -1165,6 +1220,12 @@ ENABLE_TAGS_GENERATION = PersistentConfig(
os.environ.get("ENABLE_TAGS_GENERATION", "True").lower() == "true",
)
ENABLE_TITLE_GENERATION = PersistentConfig(
"ENABLE_TITLE_GENERATION",
"task.title.enable",
os.environ.get("ENABLE_TITLE_GENERATION", "True").lower() == "true",
)
ENABLE_SEARCH_QUERY_GENERATION = PersistentConfig(
"ENABLE_SEARCH_QUERY_GENERATION",
@ -1277,7 +1338,28 @@ TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
)
DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = """Available Tools: {{TOOLS}}\nReturn an empty string if no tools match the query. If a function tool matches, construct and return a JSON object in the format {\"name\": \"functionName\", \"parameters\": {\"requiredFunctionParamKey\": \"requiredFunctionParamValue\"}} using the appropriate tool and its parameters. Only return the object and limit the response to the JSON object without additional text."""
DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = """Available Tools: {{TOOLS}}
Your task is to choose and return the correct tool(s) from the list of available tools based on the query. Follow these guidelines:
- Return only the JSON object, without any additional text or explanation.
- If no tools match the query, return an empty array:
{
"tool_calls": []
}
- If one or more tools match the query, construct a JSON response containing a "tool_calls" array with objects that include:
- "name": The tool's name.
- "parameters": A dictionary of required parameters and their corresponding values.
The format for the JSON response is strictly:
{
"tool_calls": [
{"name": "toolName1", "parameters": {"key1": "value1"}},
{"name": "toolName2", "parameters": {"key2": "value2"}}
]
}"""
DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE = """Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱).
@ -1290,6 +1372,131 @@ Your task is to synthesize these responses into a single, high-quality response.
Responses from models: {{responses}}"""
####################################
# Code Interpreter
####################################
CODE_EXECUTION_ENGINE = PersistentConfig(
"CODE_EXECUTION_ENGINE",
"code_execution.engine",
os.environ.get("CODE_EXECUTION_ENGINE", "pyodide"),
)
CODE_EXECUTION_JUPYTER_URL = PersistentConfig(
"CODE_EXECUTION_JUPYTER_URL",
"code_execution.jupyter.url",
os.environ.get("CODE_EXECUTION_JUPYTER_URL", ""),
)
CODE_EXECUTION_JUPYTER_AUTH = PersistentConfig(
"CODE_EXECUTION_JUPYTER_AUTH",
"code_execution.jupyter.auth",
os.environ.get("CODE_EXECUTION_JUPYTER_AUTH", ""),
)
CODE_EXECUTION_JUPYTER_AUTH_TOKEN = PersistentConfig(
"CODE_EXECUTION_JUPYTER_AUTH_TOKEN",
"code_execution.jupyter.auth_token",
os.environ.get("CODE_EXECUTION_JUPYTER_AUTH_TOKEN", ""),
)
CODE_EXECUTION_JUPYTER_AUTH_PASSWORD = PersistentConfig(
"CODE_EXECUTION_JUPYTER_AUTH_PASSWORD",
"code_execution.jupyter.auth_password",
os.environ.get("CODE_EXECUTION_JUPYTER_AUTH_PASSWORD", ""),
)
CODE_EXECUTION_JUPYTER_TIMEOUT = PersistentConfig(
"CODE_EXECUTION_JUPYTER_TIMEOUT",
"code_execution.jupyter.timeout",
int(os.environ.get("CODE_EXECUTION_JUPYTER_TIMEOUT", "60")),
)
ENABLE_CODE_INTERPRETER = PersistentConfig(
"ENABLE_CODE_INTERPRETER",
"code_interpreter.enable",
os.environ.get("ENABLE_CODE_INTERPRETER", "True").lower() == "true",
)
CODE_INTERPRETER_ENGINE = PersistentConfig(
"CODE_INTERPRETER_ENGINE",
"code_interpreter.engine",
os.environ.get("CODE_INTERPRETER_ENGINE", "pyodide"),
)
CODE_INTERPRETER_PROMPT_TEMPLATE = PersistentConfig(
"CODE_INTERPRETER_PROMPT_TEMPLATE",
"code_interpreter.prompt_template",
os.environ.get("CODE_INTERPRETER_PROMPT_TEMPLATE", ""),
)
CODE_INTERPRETER_JUPYTER_URL = PersistentConfig(
"CODE_INTERPRETER_JUPYTER_URL",
"code_interpreter.jupyter.url",
os.environ.get(
"CODE_INTERPRETER_JUPYTER_URL", os.environ.get("CODE_EXECUTION_JUPYTER_URL", "")
),
)
CODE_INTERPRETER_JUPYTER_AUTH = PersistentConfig(
"CODE_INTERPRETER_JUPYTER_AUTH",
"code_interpreter.jupyter.auth",
os.environ.get(
"CODE_INTERPRETER_JUPYTER_AUTH",
os.environ.get("CODE_EXECUTION_JUPYTER_AUTH", ""),
),
)
CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = PersistentConfig(
"CODE_INTERPRETER_JUPYTER_AUTH_TOKEN",
"code_interpreter.jupyter.auth_token",
os.environ.get(
"CODE_INTERPRETER_JUPYTER_AUTH_TOKEN",
os.environ.get("CODE_EXECUTION_JUPYTER_AUTH_TOKEN", ""),
),
)
CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = PersistentConfig(
"CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD",
"code_interpreter.jupyter.auth_password",
os.environ.get(
"CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD",
os.environ.get("CODE_EXECUTION_JUPYTER_AUTH_PASSWORD", ""),
),
)
CODE_INTERPRETER_JUPYTER_TIMEOUT = PersistentConfig(
"CODE_INTERPRETER_JUPYTER_TIMEOUT",
"code_interpreter.jupyter.timeout",
int(
os.environ.get(
"CODE_INTERPRETER_JUPYTER_TIMEOUT",
os.environ.get("CODE_EXECUTION_JUPYTER_TIMEOUT", "60"),
)
),
)
DEFAULT_CODE_INTERPRETER_PROMPT = """
#### Tools Available
1. **Code Interpreter**: `<code_interpreter type="code" lang="python"></code_interpreter>`
- You have access to a Python shell that runs directly in the user's browser, enabling fast execution of code for analysis, calculations, or problem-solving. Use it in this response.
- The Python code you write can incorporate a wide array of libraries, handle data manipulation or visualization, perform API calls for web-related tasks, or tackle virtually any computational challenge. Use this flexibility to **think outside the box, craft elegant solutions, and harness Python's full potential**.
- To use it, **you must enclose your code within `<code_interpreter type="code" lang="python">` XML tags** and stop right away. If you don't, the code won't execute. Do NOT use triple backticks.
- When coding, **always aim to print meaningful outputs** (e.g., results, tables, summaries, or visuals) to better interpret and verify the findings. Avoid relying on implicit outputs; prioritize explicit and clear print statements so the results are effectively communicated to the user.
- After obtaining the printed output, **always provide a concise analysis, interpretation, or next steps to help the user understand the findings or refine the outcome further.**
- If the results are unclear, unexpected, or require validation, refine the code and execute it again as needed. Always aim to deliver meaningful insights from the results, iterating if necessary.
- **If a link to an image, audio, or any file is provided in markdown format in the output, ALWAYS regurgitate word for word, explicitly display it as part of the response to ensure the user can access it easily, do NOT change the link.**
- All responses should be communicated in the chat's primary language, ensuring seamless understanding. If the chat is multilingual, default to English for clarity.
Ensure that the tools are effectively utilized to achieve the highest-quality analysis for the user."""
####################################
# Vector Database
####################################
@ -1298,27 +1505,34 @@ VECTOR_DB = os.environ.get("VECTOR_DB", "chroma")
# Chroma
CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
CHROMA_TENANT = os.environ.get("CHROMA_TENANT", chromadb.DEFAULT_TENANT)
CHROMA_DATABASE = os.environ.get("CHROMA_DATABASE", chromadb.DEFAULT_DATABASE)
CHROMA_HTTP_HOST = os.environ.get("CHROMA_HTTP_HOST", "")
CHROMA_HTTP_PORT = int(os.environ.get("CHROMA_HTTP_PORT", "8000"))
CHROMA_CLIENT_AUTH_PROVIDER = os.environ.get("CHROMA_CLIENT_AUTH_PROVIDER", "")
CHROMA_CLIENT_AUTH_CREDENTIALS = os.environ.get("CHROMA_CLIENT_AUTH_CREDENTIALS", "")
# Comma-separated list of header=value pairs
CHROMA_HTTP_HEADERS = os.environ.get("CHROMA_HTTP_HEADERS", "")
if CHROMA_HTTP_HEADERS:
CHROMA_HTTP_HEADERS = dict(
[pair.split("=") for pair in CHROMA_HTTP_HEADERS.split(",")]
if VECTOR_DB == "chroma":
import chromadb
CHROMA_TENANT = os.environ.get("CHROMA_TENANT", chromadb.DEFAULT_TENANT)
CHROMA_DATABASE = os.environ.get("CHROMA_DATABASE", chromadb.DEFAULT_DATABASE)
CHROMA_HTTP_HOST = os.environ.get("CHROMA_HTTP_HOST", "")
CHROMA_HTTP_PORT = int(os.environ.get("CHROMA_HTTP_PORT", "8000"))
CHROMA_CLIENT_AUTH_PROVIDER = os.environ.get("CHROMA_CLIENT_AUTH_PROVIDER", "")
CHROMA_CLIENT_AUTH_CREDENTIALS = os.environ.get(
"CHROMA_CLIENT_AUTH_CREDENTIALS", ""
)
else:
CHROMA_HTTP_HEADERS = None
CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true"
# Comma-separated list of header=value pairs
CHROMA_HTTP_HEADERS = os.environ.get("CHROMA_HTTP_HEADERS", "")
if CHROMA_HTTP_HEADERS:
CHROMA_HTTP_HEADERS = dict(
[pair.split("=") for pair in CHROMA_HTTP_HEADERS.split(",")]
)
else:
CHROMA_HTTP_HEADERS = None
CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true"
# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2)
# Milvus
MILVUS_URI = os.environ.get("MILVUS_URI", f"{DATA_DIR}/vector_db/milvus.db")
MILVUS_DB = os.environ.get("MILVUS_DB", "default")
MILVUS_TOKEN = os.environ.get("MILVUS_TOKEN", None)
# Qdrant
QDRANT_URI = os.environ.get("QDRANT_URI", None)
@ -1331,6 +1545,15 @@ OPENSEARCH_CERT_VERIFY = os.environ.get("OPENSEARCH_CERT_VERIFY", False)
OPENSEARCH_USERNAME = os.environ.get("OPENSEARCH_USERNAME", None)
OPENSEARCH_PASSWORD = os.environ.get("OPENSEARCH_PASSWORD", None)
# ElasticSearch
ELASTICSEARCH_URL = os.environ.get("ELASTICSEARCH_URL", "https://localhost:9200")
ELASTICSEARCH_CA_CERTS = os.environ.get("ELASTICSEARCH_CA_CERTS", None)
ELASTICSEARCH_API_KEY = os.environ.get("ELASTICSEARCH_API_KEY", None)
ELASTICSEARCH_USERNAME = os.environ.get("ELASTICSEARCH_USERNAME", None)
ELASTICSEARCH_PASSWORD = os.environ.get("ELASTICSEARCH_PASSWORD", None)
ELASTICSEARCH_CLOUD_ID = os.environ.get("ELASTICSEARCH_CLOUD_ID", None)
SSL_ASSERT_FINGERPRINT = os.environ.get("SSL_ASSERT_FINGERPRINT", None)
# Pgvector
PGVECTOR_DB_URL = os.environ.get("PGVECTOR_DB_URL", DATABASE_URL)
if VECTOR_DB == "pgvector" and not PGVECTOR_DB_URL.startswith("postgres"):
@ -1365,6 +1588,18 @@ GOOGLE_DRIVE_API_KEY = PersistentConfig(
os.environ.get("GOOGLE_DRIVE_API_KEY", ""),
)
ENABLE_ONEDRIVE_INTEGRATION = PersistentConfig(
"ENABLE_ONEDRIVE_INTEGRATION",
"onedrive.enable",
os.getenv("ENABLE_ONEDRIVE_INTEGRATION", "False").lower() == "true",
)
ONEDRIVE_CLIENT_ID = PersistentConfig(
"ONEDRIVE_CLIENT_ID",
"onedrive.client_id",
os.environ.get("ONEDRIVE_CLIENT_ID", ""),
)
# RAG Content Extraction
CONTENT_EXTRACTION_ENGINE = PersistentConfig(
"CONTENT_EXTRACTION_ENGINE",
@ -1384,6 +1619,26 @@ DOCLING_SERVER_URL = PersistentConfig(
os.getenv("DOCLING_SERVER_URL", "http://docling:5001"),
)
DOCUMENT_INTELLIGENCE_ENDPOINT = PersistentConfig(
"DOCUMENT_INTELLIGENCE_ENDPOINT",
"rag.document_intelligence_endpoint",
os.getenv("DOCUMENT_INTELLIGENCE_ENDPOINT", ""),
)
DOCUMENT_INTELLIGENCE_KEY = PersistentConfig(
"DOCUMENT_INTELLIGENCE_KEY",
"rag.document_intelligence_key",
os.getenv("DOCUMENT_INTELLIGENCE_KEY", ""),
)
BYPASS_EMBEDDING_AND_RETRIEVAL = PersistentConfig(
"BYPASS_EMBEDDING_AND_RETRIEVAL",
"rag.bypass_embedding_and_retrieval",
os.environ.get("BYPASS_EMBEDDING_AND_RETRIEVAL", "False").lower() == "true",
)
RAG_TOP_K = PersistentConfig(
"RAG_TOP_K", "rag.top_k", int(os.environ.get("RAG_TOP_K", "3"))
)
@ -1399,6 +1654,12 @@ ENABLE_RAG_HYBRID_SEARCH = PersistentConfig(
os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true",
)
RAG_FULL_CONTEXT = PersistentConfig(
"RAG_FULL_CONTEXT",
"rag.full_context",
os.getenv("RAG_FULL_CONTEXT", "False").lower() == "true",
)
RAG_FILE_MAX_COUNT = PersistentConfig(
"RAG_FILE_MAX_COUNT",
"rag.file.max_count",
@ -1513,7 +1774,7 @@ Respond to the user query using the provided context, incorporating inline citat
- Respond in the same language as the user's query.
- If the context is unreadable or of poor quality, inform the user and provide the best possible answer.
- If the answer isn't present in the context but you possess the knowledge, explain this to the user and provide the answer using your own understanding.
- **Only include inline citations using [source_id] when a <source_id> tag is explicitly provided in the context.**
- **Only include inline citations using [source_id] (e.g., [1], [2]) when a `<source_id>` tag is explicitly provided in the context.**
- Do not cite if the <source_id> tag is not provided in the context.
- Do not use XML tags in your response.
- Ensure citations are concise and directly related to the information provided.
@ -1594,11 +1855,17 @@ RAG_WEB_SEARCH_ENGINE = PersistentConfig(
os.getenv("RAG_WEB_SEARCH_ENGINE", ""),
)
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = PersistentConfig(
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL",
"rag.web.search.bypass_embedding_and_retrieval",
os.getenv("BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL", "False").lower() == "true",
)
# You can provide a list of your own websites to filter after performing a web search.
# This ensures the highest level of safety and reliability of the information sources.
RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = PersistentConfig(
"RAG_WEB_SEARCH_DOMAIN_FILTER_LIST",
"rag.rag.web.search.domain.filter_list",
"rag.web.search.domain.filter_list",
[
# "wikipedia.com",
# "wikimedia.org",
@ -1643,6 +1910,12 @@ MOJEEK_SEARCH_API_KEY = PersistentConfig(
os.getenv("MOJEEK_SEARCH_API_KEY", ""),
)
BOCHA_SEARCH_API_KEY = PersistentConfig(
"BOCHA_SEARCH_API_KEY",
"rag.web.search.bocha_search_api_key",
os.getenv("BOCHA_SEARCH_API_KEY", ""),
)
SERPSTACK_API_KEY = PersistentConfig(
"SERPSTACK_API_KEY",
"rag.web.search.serpstack_api_key",
@ -1691,6 +1964,18 @@ SEARCHAPI_ENGINE = PersistentConfig(
os.getenv("SEARCHAPI_ENGINE", ""),
)
SERPAPI_API_KEY = PersistentConfig(
"SERPAPI_API_KEY",
"rag.web.search.serpapi_api_key",
os.getenv("SERPAPI_API_KEY", ""),
)
SERPAPI_ENGINE = PersistentConfig(
"SERPAPI_ENGINE",
"rag.web.search.serpapi_engine",
os.getenv("SERPAPI_ENGINE", ""),
)
BING_SEARCH_V7_ENDPOINT = PersistentConfig(
"BING_SEARCH_V7_ENDPOINT",
"rag.web.search.bing_search_v7_endpoint",
@ -1705,6 +1990,17 @@ BING_SEARCH_V7_SUBSCRIPTION_KEY = PersistentConfig(
os.environ.get("BING_SEARCH_V7_SUBSCRIPTION_KEY", ""),
)
EXA_API_KEY = PersistentConfig(
"EXA_API_KEY",
"rag.web.search.exa_api_key",
os.getenv("EXA_API_KEY", ""),
)
PERPLEXITY_API_KEY = PersistentConfig(
"PERPLEXITY_API_KEY",
"rag.web.search.perplexity_api_key",
os.getenv("PERPLEXITY_API_KEY", ""),
)
RAG_WEB_SEARCH_RESULT_COUNT = PersistentConfig(
"RAG_WEB_SEARCH_RESULT_COUNT",
@ -1718,6 +2014,35 @@ RAG_WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig(
int(os.getenv("RAG_WEB_SEARCH_CONCURRENT_REQUESTS", "10")),
)
RAG_WEB_LOADER_ENGINE = PersistentConfig(
"RAG_WEB_LOADER_ENGINE",
"rag.web.loader.engine",
os.environ.get("RAG_WEB_LOADER_ENGINE", "safe_web"),
)
RAG_WEB_SEARCH_TRUST_ENV = PersistentConfig(
"RAG_WEB_SEARCH_TRUST_ENV",
"rag.web.search.trust_env",
os.getenv("RAG_WEB_SEARCH_TRUST_ENV", "False").lower() == "true",
)
PLAYWRIGHT_WS_URI = PersistentConfig(
"PLAYWRIGHT_WS_URI",
"rag.web.loader.engine.playwright.ws.uri",
os.environ.get("PLAYWRIGHT_WS_URI", None),
)
FIRECRAWL_API_KEY = PersistentConfig(
"FIRECRAWL_API_KEY",
"firecrawl.api_key",
os.environ.get("FIRECRAWL_API_KEY", ""),
)
FIRECRAWL_API_BASE_URL = PersistentConfig(
"FIRECRAWL_API_BASE_URL",
"firecrawl.api_url",
os.environ.get("FIRECRAWL_API_BASE_URL", "https://api.firecrawl.dev"),
)
####################################
# Images
@ -1929,6 +2254,17 @@ IMAGES_OPENAI_API_KEY = PersistentConfig(
os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY),
)
IMAGES_GEMINI_API_BASE_URL = PersistentConfig(
"IMAGES_GEMINI_API_BASE_URL",
"image_generation.gemini.api_base_url",
os.getenv("IMAGES_GEMINI_API_BASE_URL", GEMINI_API_BASE_URL),
)
IMAGES_GEMINI_API_KEY = PersistentConfig(
"IMAGES_GEMINI_API_KEY",
"image_generation.gemini.api_key",
os.getenv("IMAGES_GEMINI_API_KEY", GEMINI_API_KEY),
)
IMAGE_SIZE = PersistentConfig(
"IMAGE_SIZE", "image_generation.size", os.getenv("IMAGE_SIZE", "512x512")
)
@ -1960,6 +2296,12 @@ WHISPER_MODEL_AUTO_UPDATE = (
and os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
)
# Add Deepgram configuration
DEEPGRAM_API_KEY = PersistentConfig(
"DEEPGRAM_API_KEY",
"audio.stt.deepgram.api_key",
os.getenv("DEEPGRAM_API_KEY", ""),
)
AUDIO_STT_OPENAI_API_BASE_URL = PersistentConfig(
"AUDIO_STT_OPENAI_API_BASE_URL",
@ -2099,7 +2441,7 @@ LDAP_SEARCH_BASE = PersistentConfig(
LDAP_SEARCH_FILTERS = PersistentConfig(
"LDAP_SEARCH_FILTER",
"ldap.server.search_filter",
os.environ.get("LDAP_SEARCH_FILTER", ""),
os.environ.get("LDAP_SEARCH_FILTER", os.environ.get("LDAP_SEARCH_FILTERS", "")),
)
LDAP_USE_TLS = PersistentConfig(

View File

@ -57,7 +57,7 @@ class ERROR_MESSAGES(str, Enum):
)
FILE_NOT_SENT = "FILE_NOT_SENT"
FILE_NOT_SUPPORTED = "Oops! It seems like the file format you're trying to upload is not supported. Please upload a file with a supported format (e.g., JPG, PNG, PDF, TXT) and try again."
FILE_NOT_SUPPORTED = "Oops! It seems like the file format you're trying to upload is not supported. Please upload a file with a supported format and try again."
NOT_FOUND = "We could not find what you're looking for :/"
USER_NOT_FOUND = "We could not find what you're looking for :/"

View File

@ -65,10 +65,8 @@ except Exception:
# LOGGING
####################################
log_levels = ["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"]
GLOBAL_LOG_LEVEL = os.environ.get("GLOBAL_LOG_LEVEL", "").upper()
if GLOBAL_LOG_LEVEL in log_levels:
if GLOBAL_LOG_LEVEL in logging.getLevelNamesMapping():
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL, force=True)
else:
GLOBAL_LOG_LEVEL = "INFO"
@ -78,6 +76,7 @@ log.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}")
if "cuda_error" in locals():
log.exception(cuda_error)
del cuda_error
log_sources = [
"AUDIO",
@ -92,6 +91,7 @@ log_sources = [
"RAG",
"WEBHOOK",
"SOCKET",
"OAUTH",
]
SRC_LOG_LEVELS = {}
@ -99,7 +99,7 @@ SRC_LOG_LEVELS = {}
for source in log_sources:
log_env_var = source + "_LOG_LEVEL"
SRC_LOG_LEVELS[source] = os.environ.get(log_env_var, "").upper()
if SRC_LOG_LEVELS[source] not in log_levels:
if SRC_LOG_LEVELS[source] not in logging.getLevelNamesMapping():
SRC_LOG_LEVELS[source] = GLOBAL_LOG_LEVEL
log.info(f"{log_env_var}: {SRC_LOG_LEVELS[source]}")
@ -112,6 +112,7 @@ if WEBUI_NAME != "Open WebUI":
WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png"
TRUSTED_SIGNATURE_KEY = os.environ.get("TRUSTED_SIGNATURE_KEY", "")
####################################
# ENV (dev,test,prod)
@ -356,14 +357,22 @@ WEBUI_SECRET_KEY = os.environ.get(
), # DEPRECATED: remove at next major version
)
WEBUI_SESSION_COOKIE_SAME_SITE = os.environ.get(
"WEBUI_SESSION_COOKIE_SAME_SITE",
os.environ.get("WEBUI_SESSION_COOKIE_SAME_SITE", "lax"),
WEBUI_SESSION_COOKIE_SAME_SITE = os.environ.get("WEBUI_SESSION_COOKIE_SAME_SITE", "lax")
WEBUI_SESSION_COOKIE_SECURE = (
os.environ.get("WEBUI_SESSION_COOKIE_SECURE", "false").lower() == "true"
)
WEBUI_SESSION_COOKIE_SECURE = os.environ.get(
"WEBUI_SESSION_COOKIE_SECURE",
os.environ.get("WEBUI_SESSION_COOKIE_SECURE", "false").lower() == "true",
WEBUI_AUTH_COOKIE_SAME_SITE = os.environ.get(
"WEBUI_AUTH_COOKIE_SAME_SITE", WEBUI_SESSION_COOKIE_SAME_SITE
)
WEBUI_AUTH_COOKIE_SECURE = (
os.environ.get(
"WEBUI_AUTH_COOKIE_SECURE",
os.environ.get("WEBUI_SESSION_COOKIE_SECURE", "false"),
).lower()
== "true"
)
if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
@ -376,6 +385,7 @@ ENABLE_WEBSOCKET_SUPPORT = (
WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "")
WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL)
WEBSOCKET_REDIS_LOCK_TIMEOUT = os.environ.get("WEBSOCKET_REDIS_LOCK_TIMEOUT", 60)
AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "")
@ -387,19 +397,20 @@ else:
except Exception:
AIOHTTP_CLIENT_TIMEOUT = 300
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = os.environ.get(
"AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", ""
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = os.environ.get(
"AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST",
os.environ.get("AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", ""),
)
if AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST == "":
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = None
if AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST == "":
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = None
else:
try:
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = int(
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST
)
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = int(AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
except Exception:
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = 5
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = 5
####################################
# OFFLINE_MODE
@ -409,3 +420,25 @@ OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true"
if OFFLINE_MODE:
os.environ["HF_HUB_OFFLINE"] = "1"
####################################
# AUDIT LOGGING
####################################
ENABLE_AUDIT_LOGS = os.getenv("ENABLE_AUDIT_LOGS", "false").lower() == "true"
# Where to store log file
AUDIT_LOGS_FILE_PATH = f"{DATA_DIR}/audit.log"
# Maximum size of a file before rotating into a new log file
AUDIT_LOG_FILE_ROTATION_SIZE = os.getenv("AUDIT_LOG_FILE_ROTATION_SIZE", "10MB")
# METADATA | REQUEST | REQUEST_RESPONSE
AUDIT_LOG_LEVEL = os.getenv("AUDIT_LOG_LEVEL", "REQUEST_RESPONSE").upper()
try:
MAX_BODY_LOG_SIZE = int(os.environ.get("MAX_BODY_LOG_SIZE") or 2048)
except ValueError:
MAX_BODY_LOG_SIZE = 2048
# Comma separated list for urls to exclude from audit
AUDIT_EXCLUDED_PATHS = os.getenv("AUDIT_EXCLUDED_PATHS", "/chats,/chat,/folders").split(
","
)
AUDIT_EXCLUDED_PATHS = [path.strip() for path in AUDIT_EXCLUDED_PATHS]
AUDIT_EXCLUDED_PATHS = [path.lstrip("/") for path in AUDIT_EXCLUDED_PATHS]

View File

@ -2,6 +2,7 @@ import logging
import sys
import inspect
import json
import asyncio
from pydantic import BaseModel
from typing import AsyncGenerator, Generator, Iterator
@ -76,11 +77,13 @@ async def get_function_models(request):
if hasattr(function_module, "pipes"):
sub_pipes = []
# Check if pipes is a function or a list
# Handle pipes being a list, sync function, or async function
try:
if callable(function_module.pipes):
sub_pipes = function_module.pipes()
if asyncio.iscoroutinefunction(function_module.pipes):
sub_pipes = await function_module.pipes()
else:
sub_pipes = function_module.pipes()
else:
sub_pipes = function_module.pipes
except Exception as e:
@ -250,7 +253,7 @@ async def generate_function_chat_completion(
params = model_info.params.model_dump()
form_data = apply_model_params_to_body_openai(params, form_data)
form_data = apply_model_system_prompt_to_body(params, form_data, user)
form_data = apply_model_system_prompt_to_body(params, form_data, metadata, user)
pipe_id = get_pipe_id(form_data)
function_module = get_function_module_by_id(request, pipe_id)

View File

@ -45,6 +45,9 @@ from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import Response, StreamingResponse
from open_webui.utils import logger
from open_webui.utils.audit import AuditLevel, AuditLoggingMiddleware
from open_webui.utils.logger import start_logger
from open_webui.socket.main import (
app as socket_app,
periodic_usage_pool_cleanup,
@ -88,15 +91,34 @@ from open_webui.models.models import Models
from open_webui.models.users import UserModel, Users
from open_webui.config import (
LICENSE_KEY,
# Ollama
ENABLE_OLLAMA_API,
OLLAMA_BASE_URLS,
OLLAMA_API_CONFIGS,
# OpenAI
ENABLE_OPENAI_API,
ONEDRIVE_CLIENT_ID,
OPENAI_API_BASE_URLS,
OPENAI_API_KEYS,
OPENAI_API_CONFIGS,
# Direct Connections
ENABLE_DIRECT_CONNECTIONS,
# Code Execution
CODE_EXECUTION_ENGINE,
CODE_EXECUTION_JUPYTER_URL,
CODE_EXECUTION_JUPYTER_AUTH,
CODE_EXECUTION_JUPYTER_AUTH_TOKEN,
CODE_EXECUTION_JUPYTER_AUTH_PASSWORD,
CODE_EXECUTION_JUPYTER_TIMEOUT,
ENABLE_CODE_INTERPRETER,
CODE_INTERPRETER_ENGINE,
CODE_INTERPRETER_PROMPT_TEMPLATE,
CODE_INTERPRETER_JUPYTER_URL,
CODE_INTERPRETER_JUPYTER_AUTH,
CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
CODE_INTERPRETER_JUPYTER_TIMEOUT,
# Image
AUTOMATIC1111_API_AUTH,
AUTOMATIC1111_BASE_URL,
@ -115,6 +137,8 @@ from open_webui.config import (
IMAGE_STEPS,
IMAGES_OPENAI_API_BASE_URL,
IMAGES_OPENAI_API_KEY,
IMAGES_GEMINI_API_BASE_URL,
IMAGES_GEMINI_API_KEY,
# Audio
AUDIO_STT_ENGINE,
AUDIO_STT_MODEL,
@ -129,12 +153,19 @@ from open_webui.config import (
AUDIO_TTS_VOICE,
AUDIO_TTS_AZURE_SPEECH_REGION,
AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT,
PLAYWRIGHT_WS_URI,
FIRECRAWL_API_BASE_URL,
FIRECRAWL_API_KEY,
RAG_WEB_LOADER_ENGINE,
WHISPER_MODEL,
DEEPGRAM_API_KEY,
WHISPER_MODEL_AUTO_UPDATE,
WHISPER_MODEL_DIR,
# Retrieval
RAG_TEMPLATE,
DEFAULT_RAG_TEMPLATE,
RAG_FULL_CONTEXT,
BYPASS_EMBEDDING_AND_RETRIEVAL,
RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
@ -155,6 +186,8 @@ from open_webui.config import (
CONTENT_EXTRACTION_ENGINE,
TIKA_SERVER_URL,
DOCLING_SERVER_URL,
DOCUMENT_INTELLIGENCE_ENDPOINT,
DOCUMENT_INTELLIGENCE_KEY,
RAG_TOP_K,
RAG_TEXT_SPLITTER,
TIKTOKEN_ENCODING_NAME,
@ -163,12 +196,16 @@ from open_webui.config import (
YOUTUBE_LOADER_PROXY_URL,
# Retrieval (Web Search)
RAG_WEB_SEARCH_ENGINE,
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
RAG_WEB_SEARCH_RESULT_COUNT,
RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
RAG_WEB_SEARCH_TRUST_ENV,
RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
JINA_API_KEY,
SEARCHAPI_API_KEY,
SEARCHAPI_ENGINE,
SERPAPI_API_KEY,
SERPAPI_ENGINE,
SEARXNG_QUERY_URL,
SERPER_API_KEY,
SERPLY_API_KEY,
@ -178,17 +215,22 @@ from open_webui.config import (
BING_SEARCH_V7_ENDPOINT,
BING_SEARCH_V7_SUBSCRIPTION_KEY,
BRAVE_SEARCH_API_KEY,
EXA_API_KEY,
PERPLEXITY_API_KEY,
KAGI_SEARCH_API_KEY,
MOJEEK_SEARCH_API_KEY,
BOCHA_SEARCH_API_KEY,
GOOGLE_PSE_API_KEY,
GOOGLE_PSE_ENGINE_ID,
GOOGLE_DRIVE_CLIENT_ID,
GOOGLE_DRIVE_API_KEY,
ONEDRIVE_CLIENT_ID,
ENABLE_RAG_HYBRID_SEARCH,
ENABLE_RAG_LOCAL_WEB_FETCH,
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
ENABLE_RAG_WEB_SEARCH,
ENABLE_GOOGLE_DRIVE_INTEGRATION,
ENABLE_ONEDRIVE_INTEGRATION,
UPLOAD_DIR,
# WebUI
WEBUI_AUTH,
@ -252,6 +294,7 @@ from open_webui.config import (
TASK_MODEL,
TASK_MODEL_EXTERNAL,
ENABLE_TAGS_GENERATION,
ENABLE_TITLE_GENERATION,
ENABLE_SEARCH_QUERY_GENERATION,
ENABLE_RETRIEVAL_QUERY_GENERATION,
ENABLE_AUTOCOMPLETE_GENERATION,
@ -266,8 +309,11 @@ from open_webui.config import (
reset_config,
)
from open_webui.env import (
AUDIT_EXCLUDED_PATHS,
AUDIT_LOG_LEVEL,
CHANGELOG,
GLOBAL_LOG_LEVEL,
MAX_BODY_LOG_SIZE,
SAFE_MODE,
SRC_LOG_LEVELS,
VERSION,
@ -298,15 +344,17 @@ from open_webui.utils.middleware import process_chat_payload, process_chat_respo
from open_webui.utils.access_control import has_access
from open_webui.utils.auth import (
get_license_data,
decode_token,
get_admin_user,
get_verified_user,
)
from open_webui.utils.oauth import oauth_manager
from open_webui.utils.oauth import OAuthManager
from open_webui.utils.security_headers import SecurityHeadersMiddleware
from open_webui.tasks import stop_task, list_tasks # Import from tasks.py
if SAFE_MODE:
print("SAFE MODE ENABLED")
Functions.deactivate_all_functions()
@ -322,19 +370,23 @@ class SPAStaticFiles(StaticFiles):
return await super().get_response(path, scope)
except (HTTPException, StarletteHTTPException) as ex:
if ex.status_code == 404:
return await super().get_response("index.html", scope)
if path.endswith(".js"):
# Return 404 for javascript files
raise ex
else:
return await super().get_response("index.html", scope)
else:
raise ex
print(
rf"""
___ __ __ _ _ _ ___
/ _ \ _ __ ___ _ __ \ \ / /__| |__ | | | |_ _|
| | | | '_ \ / _ \ '_ \ \ \ /\ / / _ \ '_ \| | | || |
| |_| | |_) | __/ | | | \ V V / __/ |_) | |_| || |
\___/| .__/ \___|_| |_| \_/\_/ \___|_.__/ \___/|___|
|_|
v{VERSION} - building the best open-source AI user interface.
@ -346,9 +398,13 @@ https://github.com/open-webui/open-webui
@asynccontextmanager
async def lifespan(app: FastAPI):
start_logger()
if RESET_CONFIG_ON_START:
reset_config()
if LICENSE_KEY:
get_license_data(app, LICENSE_KEY)
asyncio.create_task(periodic_usage_pool_cleanup())
yield
@ -360,8 +416,12 @@ app = FastAPI(
lifespan=lifespan,
)
oauth_manager = OAuthManager(app)
app.state.config = AppConfig()
app.state.WEBUI_NAME = WEBUI_NAME
app.state.LICENSE_METADATA = None
########################################
#
@ -389,6 +449,14 @@ app.state.config.OPENAI_API_CONFIGS = OPENAI_API_CONFIGS
app.state.OPENAI_MODELS = {}
########################################
#
# DIRECT CONNECTIONS
#
########################################
app.state.config.ENABLE_DIRECT_CONNECTIONS = ENABLE_DIRECT_CONNECTIONS
########################################
#
# WEBUI
@ -455,10 +523,10 @@ app.state.config.LDAP_CIPHERS = LDAP_CIPHERS
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER
app.state.USER_COUNT = None
app.state.TOOLS = {}
app.state.FUNCTIONS = {}
########################################
#
# RETRIEVAL
@ -471,6 +539,9 @@ app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE
app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT
app.state.config.RAG_FULL_CONTEXT = RAG_FULL_CONTEXT
app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL = BYPASS_EMBEDDING_AND_RETRIEVAL
app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
@ -479,6 +550,8 @@ app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL
app.state.config.DOCLING_SERVER_URL = DOCLING_SERVER_URL
app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = DOCUMENT_INTELLIGENCE_ENDPOINT
app.state.config.DOCUMENT_INTELLIGENCE_KEY = DOCUMENT_INTELLIGENCE_KEY
app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER
app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME
@ -506,15 +579,20 @@ app.state.config.YOUTUBE_LOADER_PROXY_URL = YOUTUBE_LOADER_PROXY_URL
app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE
app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = (
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
)
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = RAG_WEB_SEARCH_DOMAIN_FILTER_LIST
app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = ENABLE_GOOGLE_DRIVE_INTEGRATION
app.state.config.ENABLE_ONEDRIVE_INTEGRATION = ENABLE_ONEDRIVE_INTEGRATION
app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID
app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY
app.state.config.KAGI_SEARCH_API_KEY = KAGI_SEARCH_API_KEY
app.state.config.MOJEEK_SEARCH_API_KEY = MOJEEK_SEARCH_API_KEY
app.state.config.BOCHA_SEARCH_API_KEY = BOCHA_SEARCH_API_KEY
app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY
app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS
app.state.config.SERPER_API_KEY = SERPER_API_KEY
@ -522,12 +600,21 @@ app.state.config.SERPLY_API_KEY = SERPLY_API_KEY
app.state.config.TAVILY_API_KEY = TAVILY_API_KEY
app.state.config.SEARCHAPI_API_KEY = SEARCHAPI_API_KEY
app.state.config.SEARCHAPI_ENGINE = SEARCHAPI_ENGINE
app.state.config.SERPAPI_API_KEY = SERPAPI_API_KEY
app.state.config.SERPAPI_ENGINE = SERPAPI_ENGINE
app.state.config.JINA_API_KEY = JINA_API_KEY
app.state.config.BING_SEARCH_V7_ENDPOINT = BING_SEARCH_V7_ENDPOINT
app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = BING_SEARCH_V7_SUBSCRIPTION_KEY
app.state.config.EXA_API_KEY = EXA_API_KEY
app.state.config.PERPLEXITY_API_KEY = PERPLEXITY_API_KEY
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
app.state.config.RAG_WEB_LOADER_ENGINE = RAG_WEB_LOADER_ENGINE
app.state.config.RAG_WEB_SEARCH_TRUST_ENV = RAG_WEB_SEARCH_TRUST_ENV
app.state.config.PLAYWRIGHT_WS_URI = PLAYWRIGHT_WS_URI
app.state.config.FIRECRAWL_API_BASE_URL = FIRECRAWL_API_BASE_URL
app.state.config.FIRECRAWL_API_KEY = FIRECRAWL_API_KEY
app.state.EMBEDDING_FUNCTION = None
app.state.ef = None
@ -569,6 +656,34 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
)
########################################
#
# CODE EXECUTION
#
########################################
app.state.config.CODE_EXECUTION_ENGINE = CODE_EXECUTION_ENGINE
app.state.config.CODE_EXECUTION_JUPYTER_URL = CODE_EXECUTION_JUPYTER_URL
app.state.config.CODE_EXECUTION_JUPYTER_AUTH = CODE_EXECUTION_JUPYTER_AUTH
app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN = CODE_EXECUTION_JUPYTER_AUTH_TOKEN
app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD = (
CODE_EXECUTION_JUPYTER_AUTH_PASSWORD
)
app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT = CODE_EXECUTION_JUPYTER_TIMEOUT
app.state.config.ENABLE_CODE_INTERPRETER = ENABLE_CODE_INTERPRETER
app.state.config.CODE_INTERPRETER_ENGINE = CODE_INTERPRETER_ENGINE
app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE = CODE_INTERPRETER_PROMPT_TEMPLATE
app.state.config.CODE_INTERPRETER_JUPYTER_URL = CODE_INTERPRETER_JUPYTER_URL
app.state.config.CODE_INTERPRETER_JUPYTER_AUTH = CODE_INTERPRETER_JUPYTER_AUTH
app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = (
CODE_INTERPRETER_JUPYTER_AUTH_TOKEN
)
app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = (
CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD
)
app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT = CODE_INTERPRETER_JUPYTER_TIMEOUT
########################################
#
@ -583,6 +698,9 @@ app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = ENABLE_IMAGE_PROMPT_GENERATION
app.state.config.IMAGES_OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
app.state.config.IMAGES_OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
app.state.config.IMAGES_GEMINI_API_BASE_URL = IMAGES_GEMINI_API_BASE_URL
app.state.config.IMAGES_GEMINI_API_KEY = IMAGES_GEMINI_API_KEY
app.state.config.IMAGE_GENERATION_MODEL = IMAGE_GENERATION_MODEL
app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
@ -611,6 +729,7 @@ app.state.config.STT_ENGINE = AUDIO_STT_ENGINE
app.state.config.STT_MODEL = AUDIO_STT_MODEL
app.state.config.WHISPER_MODEL = WHISPER_MODEL
app.state.config.DEEPGRAM_API_KEY = DEEPGRAM_API_KEY
app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL
app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY
@ -645,6 +764,7 @@ app.state.config.ENABLE_SEARCH_QUERY_GENERATION = ENABLE_SEARCH_QUERY_GENERATION
app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ENABLE_RETRIEVAL_QUERY_GENERATION
app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = ENABLE_AUTOCOMPLETE_GENERATION
app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION
app.state.config.ENABLE_TITLE_GENERATION = ENABLE_TITLE_GENERATION
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
@ -753,6 +873,7 @@ app.include_router(openai.router, prefix="/openai", tags=["openai"])
app.include_router(pipelines.router, prefix="/api/v1/pipelines", tags=["pipelines"])
app.include_router(tasks.router, prefix="/api/v1/tasks", tags=["tasks"])
app.include_router(images.router, prefix="/api/v1/images", tags=["images"])
app.include_router(audio.router, prefix="/api/v1/audio", tags=["audio"])
app.include_router(retrieval.router, prefix="/api/v1/retrieval", tags=["retrieval"])
@ -781,6 +902,19 @@ app.include_router(
app.include_router(utils.router, prefix="/api/v1/utils", tags=["utils"])
try:
audit_level = AuditLevel(AUDIT_LOG_LEVEL)
except ValueError as e:
logger.error(f"Invalid audit level: {AUDIT_LOG_LEVEL}. Error: {e}")
audit_level = AuditLevel.NONE
if audit_level != AuditLevel.NONE:
app.add_middleware(
AuditLoggingMiddleware,
audit_level=audit_level,
excluded_paths=AUDIT_EXCLUDED_PATHS,
max_body_size=MAX_BODY_LOG_SIZE,
)
##################################
#
# Chat Endpoints
@ -813,7 +947,7 @@ async def get_models(request: Request, user=Depends(get_verified_user)):
return filtered_models
models = await get_all_models(request)
models = await get_all_models(request, user=user)
# Filter out filter pipelines
models = [
@ -842,7 +976,7 @@ async def get_models(request: Request, user=Depends(get_verified_user)):
@app.get("/api/models/base")
async def get_base_models(request: Request, user=Depends(get_admin_user)):
models = await get_all_base_models(request)
models = await get_all_base_models(request, user=user)
return {"data": models}
@ -853,21 +987,32 @@ async def chat_completion(
user=Depends(get_verified_user),
):
if not request.app.state.MODELS:
await get_all_models(request)
await get_all_models(request, user=user)
model_item = form_data.pop("model_item", {})
tasks = form_data.pop("background_tasks", None)
try:
model_id = form_data.get("model", None)
if model_id not in request.app.state.MODELS:
raise Exception("Model not found")
model = request.app.state.MODELS[model_id]
# Check if user has access to the model
if not BYPASS_MODEL_ACCESS_CONTROL and user.role == "user":
try:
check_model_access(user, model)
except Exception as e:
raise e
try:
if not model_item.get("direct", False):
model_id = form_data.get("model", None)
if model_id not in request.app.state.MODELS:
raise Exception("Model not found")
model = request.app.state.MODELS[model_id]
model_info = Models.get_model_by_id(model_id)
# Check if user has access to the model
if not BYPASS_MODEL_ACCESS_CONTROL and user.role == "user":
try:
check_model_access(user, model)
except Exception as e:
raise e
else:
model = model_item
model_info = None
request.state.direct = True
request.state.model = model
metadata = {
"user_id": user.id,
@ -877,13 +1022,30 @@ async def chat_completion(
"tool_ids": form_data.get("tool_ids", None),
"files": form_data.get("files", None),
"features": form_data.get("features", None),
"variables": form_data.get("variables", None),
"model": model,
"direct": model_item.get("direct", False),
**(
{"function_calling": "native"}
if form_data.get("params", {}).get("function_calling") == "native"
or (
model_info
and model_info.params.model_dump().get("function_calling")
== "native"
)
else {}
),
}
request.state.metadata = metadata
form_data["metadata"] = metadata
form_data, events = await process_chat_payload(
request, form_data, metadata, user, model
form_data, metadata, events = await process_chat_payload(
request, form_data, user, metadata, model
)
except Exception as e:
log.debug(f"Error processing chat payload: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
@ -891,8 +1053,9 @@ async def chat_completion(
try:
response = await chat_completion_handler(request, form_data, user)
return await process_chat_response(
request, response, form_data, user, events, metadata, tasks
request, response, form_data, user, metadata, model, events, tasks
)
except Exception as e:
raise HTTPException(
@ -911,6 +1074,12 @@ async def chat_completed(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
try:
model_item = form_data.pop("model_item", {})
if model_item.get("direct", False):
request.state.direct = True
request.state.model = model_item
return await chat_completed_handler(request, form_data, user)
except Exception as e:
raise HTTPException(
@ -924,6 +1093,12 @@ async def chat_action(
request: Request, action_id: str, form_data: dict, user=Depends(get_verified_user)
):
try:
model_item = form_data.pop("model_item", {})
if model_item.get("direct", False):
request.state.direct = True
request.state.model = model_item
return await chat_action_handler(request, action_id, form_data, user)
except Exception as e:
raise HTTPException(
@ -969,15 +1144,16 @@ async def get_app_config(request: Request):
if data is not None and "id" in data:
user = Users.get_user_by_id(data["id"])
user_count = Users.get_num_users()
onboarding = False
if user is None:
user_count = Users.get_num_users()
onboarding = user_count == 0
return {
**({"onboarding": True} if onboarding else {}),
"status": True,
"name": WEBUI_NAME,
"name": app.state.WEBUI_NAME,
"version": VERSION,
"default_locale": str(DEFAULT_LOCALE),
"oauth": {
@ -996,27 +1172,31 @@ async def get_app_config(request: Request):
"enable_websocket": ENABLE_WEBSOCKET_SUPPORT,
**(
{
"enable_direct_connections": app.state.config.ENABLE_DIRECT_CONNECTIONS,
"enable_channels": app.state.config.ENABLE_CHANNELS,
"enable_web_search": app.state.config.ENABLE_RAG_WEB_SEARCH,
"enable_google_drive_integration": app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
"enable_code_interpreter": app.state.config.ENABLE_CODE_INTERPRETER,
"enable_image_generation": app.state.config.ENABLE_IMAGE_GENERATION,
"enable_autocomplete_generation": app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
"enable_community_sharing": app.state.config.ENABLE_COMMUNITY_SHARING,
"enable_message_rating": app.state.config.ENABLE_MESSAGE_RATING,
"enable_admin_export": ENABLE_ADMIN_EXPORT,
"enable_admin_chat_access": ENABLE_ADMIN_CHAT_ACCESS,
"enable_google_drive_integration": app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
"enable_onedrive_integration": app.state.config.ENABLE_ONEDRIVE_INTEGRATION,
}
if user is not None
else {}
),
},
"google_drive": {
"client_id": GOOGLE_DRIVE_CLIENT_ID.value,
"api_key": GOOGLE_DRIVE_API_KEY.value,
},
**(
{
"default_models": app.state.config.DEFAULT_MODELS,
"default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
"user_count": user_count,
"code": {
"engine": app.state.config.CODE_EXECUTION_ENGINE,
},
"audio": {
"tts": {
"engine": app.state.config.TTS_ENGINE,
@ -1032,6 +1212,19 @@ async def get_app_config(request: Request):
"max_count": app.state.config.FILE_MAX_COUNT,
},
"permissions": {**app.state.config.USER_PERMISSIONS},
"google_drive": {
"client_id": GOOGLE_DRIVE_CLIENT_ID.value,
"api_key": GOOGLE_DRIVE_API_KEY.value,
},
"onedrive": {"client_id": ONEDRIVE_CLIENT_ID.value},
"license_metadata": app.state.LICENSE_METADATA,
**(
{
"active_entries": app.state.USER_COUNT,
}
if user.role == "admin"
else {}
),
}
if user is not None
else {}
@ -1065,7 +1258,7 @@ async def get_app_version():
@app.get("/api/version/updates")
async def get_app_latest_release_version():
async def get_app_latest_release_version(user=Depends(get_verified_user)):
if OFFLINE_MODE:
log.debug(
f"Offline mode is enabled, returning current version as latest version"
@ -1109,7 +1302,7 @@ if len(OAUTH_PROVIDERS) > 0:
@app.get("/oauth/{provider}/login")
async def oauth_login(provider: str, request: Request):
return await oauth_manager.handle_login(provider, request)
return await oauth_manager.handle_login(request, provider)
# OAuth login logic is as follows:
@ -1120,14 +1313,14 @@ async def oauth_login(provider: str, request: Request):
# - Email addresses are considered unique, so we fail registration if the email address is already taken
@app.get("/oauth/{provider}/callback")
async def oauth_callback(provider: str, request: Request, response: Response):
return await oauth_manager.handle_callback(provider, request, response)
return await oauth_manager.handle_callback(request, provider, response)
@app.get("/manifest.json")
async def get_manifest_json():
return {
"name": WEBUI_NAME,
"short_name": WEBUI_NAME,
"name": app.state.WEBUI_NAME,
"short_name": app.state.WEBUI_NAME,
"description": "Open WebUI is an open, extensible, user-friendly interface for AI that adapts to your workflow.",
"start_url": "/",
"display": "standalone",
@ -1154,8 +1347,8 @@ async def get_manifest_json():
async def get_opensearch_xml():
xml_content = rf"""
<OpenSearchDescription xmlns="http://a9.com/-/spec/opensearch/1.1/" xmlns:moz="http://www.mozilla.org/2006/browser/search/">
<ShortName>{WEBUI_NAME}</ShortName>
<Description>Search {WEBUI_NAME}</Description>
<ShortName>{app.state.WEBUI_NAME}</ShortName>
<Description>Search {app.state.WEBUI_NAME}</Description>
<InputEncoding>UTF-8</InputEncoding>
<Image width="16" height="16" type="image/x-icon">{app.state.config.WEBUI_URL}/static/favicon.png</Image>
<Url type="text/html" method="get" template="{app.state.config.WEBUI_URL}/?q={"{searchTerms}"}"/>

View File

@ -1,3 +1,4 @@
import logging
import json
import time
import uuid
@ -5,7 +6,7 @@ from typing import Optional
from open_webui.internal.db import Base, get_db
from open_webui.models.tags import TagModel, Tag, Tags
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
@ -16,6 +17,9 @@ from sqlalchemy.sql import exists
# Chat DB Schema
####################
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
class Chat(Base):
__tablename__ = "chat"
@ -470,7 +474,7 @@ class ChatTable:
try:
with get_db() as db:
# it is possible that the shared link was deleted. hence,
# we check if the chat is still shared by checkng if a chat with the share_id exists
# we check if the chat is still shared by checking if a chat with the share_id exists
chat = db.query(Chat).filter_by(share_id=id).first()
if chat:
@ -670,7 +674,7 @@ class ChatTable:
# Perform pagination at the SQL level
all_chats = query.offset(skip).limit(limit).all()
print(len(all_chats))
log.info(f"The number of chats: {len(all_chats)}")
# Validate and return chats
return [ChatModel.model_validate(chat) for chat in all_chats]
@ -731,7 +735,7 @@ class ChatTable:
query = db.query(Chat).filter_by(user_id=user_id)
tag_id = tag_name.replace(" ", "_").lower()
print(db.bind.dialect.name)
log.info(f"DB dialect name: {db.bind.dialect.name}")
if db.bind.dialect.name == "sqlite":
# SQLite JSON1 querying for tags within the meta JSON field
query = query.filter(
@ -752,7 +756,7 @@ class ChatTable:
)
all_chats = query.all()
print("all_chats", all_chats)
log.debug(f"all_chats: {all_chats}")
return [ChatModel.model_validate(chat) for chat in all_chats]
def add_chat_tag_by_id_and_user_id_and_tag_name(
@ -810,7 +814,7 @@ class ChatTable:
count = query.count()
# Debugging output for inspection
print(f"Count of chats for tag '{tag_name}':", count)
log.info(f"Count of chats for tag '{tag_name}': {count}")
return count

View File

@ -118,7 +118,7 @@ class FeedbackTable:
else:
return None
except Exception as e:
print(e)
log.exception(f"Error creating a new feedback: {e}")
return None
def get_feedback_by_id(self, id: str) -> Optional[FeedbackModel]:

View File

@ -119,7 +119,7 @@ class FilesTable:
else:
return None
except Exception as e:
print(f"Error creating tool: {e}")
log.exception(f"Error inserting a new file: {e}")
return None
def get_file_by_id(self, id: str) -> Optional[FileModel]:

View File

@ -82,7 +82,7 @@ class FolderTable:
else:
return None
except Exception as e:
print(e)
log.exception(f"Error inserting a new folder: {e}")
return None
def get_folder_by_id_and_user_id(

View File

@ -105,7 +105,7 @@ class FunctionsTable:
else:
return None
except Exception as e:
print(f"Error creating tool: {e}")
log.exception(f"Error creating a new function: {e}")
return None
def get_function_by_id(self, id: str) -> Optional[FunctionModel]:
@ -170,7 +170,7 @@ class FunctionsTable:
function = db.get(Function, id)
return function.valves if function.valves else {}
except Exception as e:
print(f"An error occurred: {e}")
log.exception(f"Error getting function valves by id {id}: {e}")
return None
def update_function_valves_by_id(
@ -202,7 +202,9 @@ class FunctionsTable:
return user_settings["functions"]["valves"].get(id, {})
except Exception as e:
print(f"An error occurred: {e}")
log.exception(
f"Error getting user values by id {id} and user id {user_id}: {e}"
)
return None
def update_user_valves_by_id_and_user_id(
@ -225,7 +227,9 @@ class FunctionsTable:
return user_settings["functions"]["valves"][id]
except Exception as e:
print(f"An error occurred: {e}")
log.exception(
f"Error updating user valves by id {id} and user_id {user_id}: {e}"
)
return None
def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:

5
backend/open_webui/models/models.py Normal file → Executable file
View File

@ -166,7 +166,7 @@ class ModelsTable:
else:
return None
except Exception as e:
print(e)
log.exception(f"Failed to insert a new model: {e}")
return None
def get_all_models(self) -> list[ModelModel]:
@ -246,8 +246,7 @@ class ModelsTable:
db.refresh(model)
return ModelModel.model_validate(model)
except Exception as e:
print(e)
log.exception(f"Failed to update the model by id {id}: {e}")
return None
def delete_model_by_id(self, id: str) -> bool:

View File

@ -61,7 +61,7 @@ class TagTable:
else:
return None
except Exception as e:
print(e)
log.exception(f"Error inserting a new tag: {e}")
return None
def get_tag_by_name_and_user_id(

View File

@ -131,7 +131,7 @@ class ToolsTable:
else:
return None
except Exception as e:
print(f"Error creating tool: {e}")
log.exception(f"Error creating a new tool: {e}")
return None
def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
@ -175,7 +175,7 @@ class ToolsTable:
tool = db.get(Tool, id)
return tool.valves if tool.valves else {}
except Exception as e:
print(f"An error occurred: {e}")
log.exception(f"Error getting tool valves by id {id}: {e}")
return None
def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]:
@ -204,7 +204,9 @@ class ToolsTable:
return user_settings["tools"]["valves"].get(id, {})
except Exception as e:
print(f"An error occurred: {e}")
log.exception(
f"Error getting user values by id {id} and user_id {user_id}: {e}"
)
return None
def update_user_valves_by_id_and_user_id(
@ -227,7 +229,9 @@ class ToolsTable:
return user_settings["tools"]["valves"][id]
except Exception as e:
print(f"An error occurred: {e}")
log.exception(
f"Error updating user valves by id {id} and user_id {user_id}: {e}"
)
return None
def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:

View File

@ -271,6 +271,24 @@ class UsersTable:
except Exception:
return None
def update_user_settings_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
try:
with get_db() as db:
user_settings = db.query(User).filter_by(id=id).first().settings
if user_settings is None:
user_settings = {}
user_settings.update(updated)
db.query(User).filter_by(id=id).update({"settings": user_settings})
db.commit()
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user)
except Exception:
return None
def delete_user_by_id(self, id: str) -> bool:
try:
# Remove User from Groups

View File

@ -4,6 +4,7 @@ import ftfy
import sys
from langchain_community.document_loaders import (
AzureAIDocumentIntelligenceLoader,
BSHTMLLoader,
CSVLoader,
Docx2txtLoader,
@ -76,6 +77,7 @@ known_source_ext = [
"jsx",
"hs",
"lhs",
"json",
]
@ -221,12 +223,33 @@ class Loader:
file_path=file_path,
mime_type=file_content_type,
)
elif self.engine == "docling":
elif self.engine == "docling" and self.kwargs.get("DOCLING_SERVER_URL"):
loader = DoclingLoader(
url=self.kwargs.get("DOCLING_SERVER_URL"),
file_path=file_path,
mime_type=file_content_type,
)
elif (
self.engine == "document_intelligence"
and self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT") != ""
and self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY") != ""
and (
file_ext in ["pdf", "xls", "xlsx", "docx", "ppt", "pptx"]
or file_content_type
in [
"application/vnd.ms-excel",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"application/vnd.ms-powerpoint",
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
]
)
):
loader = AzureAIDocumentIntelligenceLoader(
file_path=file_path,
api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"),
api_key=self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY"),
)
else:
if file_ext == "pdf":
loader = PyPDFLoader(

View File

@ -1,13 +1,19 @@
import os
import logging
import torch
import numpy as np
from colbert.infra import ColBERTConfig
from colbert.modeling.checkpoint import Checkpoint
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class ColBERT:
def __init__(self, name, **kwargs) -> None:
print("ColBERT: Loading model", name)
log.info("ColBERT: Loading model", name)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
DOCKER = kwargs.get("env") == "docker"

View File

@ -5,6 +5,7 @@ from typing import Optional, Union
import asyncio
import requests
import hashlib
from huggingface_hub import snapshot_download
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
@ -14,9 +15,16 @@ from langchain_core.documents import Document
from open_webui.config import VECTOR_DB
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.utils.misc import get_last_user_message
from open_webui.utils.misc import get_last_user_message, calculate_sha256_string
from open_webui.env import SRC_LOG_LEVELS, OFFLINE_MODE
from open_webui.models.users import UserModel
from open_webui.models.files import Files
from open_webui.env import (
SRC_LOG_LEVELS,
OFFLINE_MODE,
ENABLE_FORWARD_USER_INFO_HEADERS,
)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
@ -61,9 +69,7 @@ class VectorSearchRetriever(BaseRetriever):
def query_doc(
collection_name: str,
query_embedding: list[float],
k: int,
collection_name: str, query_embedding: list[float], k: int, user: UserModel = None
):
try:
result = VECTOR_DB_CLIENT.search(
@ -77,7 +83,20 @@ def query_doc(
return result
except Exception as e:
print(e)
log.exception(f"Error querying doc {collection_name} with limit {k}: {e}")
raise e
def get_doc(collection_name: str, user: UserModel = None):
try:
result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
if result:
log.info(f"query_doc:result {result.ids} {result.metadatas}")
return result
except Exception as e:
log.exception(f"Error getting doc {collection_name}: {e}")
raise e
@ -134,47 +153,80 @@ def query_doc_with_hybrid_search(
raise e
def merge_and_sort_query_results(
query_results: list[dict], k: int, reverse: bool = False
) -> list[dict]:
def merge_get_results(get_results: list[dict]) -> dict:
# Initialize lists to store combined data
combined_distances = []
combined_documents = []
combined_metadatas = []
combined_ids = []
for data in query_results:
combined_distances.extend(data["distances"][0])
for data in get_results:
combined_documents.extend(data["documents"][0])
combined_metadatas.extend(data["metadatas"][0])
combined_ids.extend(data["ids"][0])
# Create a list of tuples (distance, document, metadata)
combined = list(zip(combined_distances, combined_documents, combined_metadatas))
# Create the output dictionary
result = {
"documents": [combined_documents],
"metadatas": [combined_metadatas],
"ids": [combined_ids],
}
return result
def merge_and_sort_query_results(
query_results: list[dict], k: int, reverse: bool = False
) -> dict:
# Initialize lists to store combined data
combined = []
seen_hashes = set() # To store unique document hashes
for data in query_results:
distances = data["distances"][0]
documents = data["documents"][0]
metadatas = data["metadatas"][0]
for distance, document, metadata in zip(distances, documents, metadatas):
if isinstance(document, str):
doc_hash = hashlib.md5(
document.encode()
).hexdigest() # Compute a hash for uniqueness
if doc_hash not in seen_hashes:
seen_hashes.add(doc_hash)
combined.append((distance, document, metadata))
# Sort the list based on distances
combined.sort(key=lambda x: x[0], reverse=reverse)
# We don't have anything :-(
if not combined:
sorted_distances = []
sorted_documents = []
sorted_metadatas = []
else:
# Unzip the sorted list
sorted_distances, sorted_documents, sorted_metadatas = zip(*combined)
# Slice to keep only the top k elements
sorted_distances, sorted_documents, sorted_metadatas = (
zip(*combined[:k]) if combined else ([], [], [])
)
# Slicing the lists to include only k elements
sorted_distances = list(sorted_distances)[:k]
sorted_documents = list(sorted_documents)[:k]
sorted_metadatas = list(sorted_metadatas)[:k]
# Create the output dictionary
result = {
"distances": [sorted_distances],
"documents": [sorted_documents],
"metadatas": [sorted_metadatas],
# Create and return the output dictionary
return {
"distances": [list(sorted_distances)],
"documents": [list(sorted_documents)],
"metadatas": [list(sorted_metadatas)],
}
return result
def get_all_items_from_collections(collection_names: list[str]) -> dict:
results = []
for collection_name in collection_names:
if collection_name:
try:
result = get_doc(collection_name=collection_name)
if result is not None:
results.append(result.model_dump())
except Exception as e:
log.exception(f"Error when querying the collection: {e}")
else:
pass
return merge_get_results(results)
def query_collection(
@ -259,29 +311,35 @@ def get_embedding_function(
embedding_batch_size,
):
if embedding_engine == "":
return lambda query: embedding_function.encode(query).tolist()
return lambda query, user=None: embedding_function.encode(query).tolist()
elif embedding_engine in ["ollama", "openai"]:
func = lambda query: generate_embeddings(
func = lambda query, user=None: generate_embeddings(
engine=embedding_engine,
model=embedding_model,
text=query,
url=url,
key=key,
user=user,
)
def generate_multiple(query, func):
def generate_multiple(query, user, func):
if isinstance(query, list):
embeddings = []
for i in range(0, len(query), embedding_batch_size):
embeddings.extend(func(query[i : i + embedding_batch_size]))
embeddings.extend(
func(query[i : i + embedding_batch_size], user=user)
)
return embeddings
else:
return func(query)
return func(query, user)
return lambda query: generate_multiple(query, func)
return lambda query, user=None: generate_multiple(query, user, func)
else:
raise ValueError(f"Unknown embedding engine: {embedding_engine}")
def get_sources_from_files(
request,
files,
queries,
embedding_function,
@ -289,21 +347,81 @@ def get_sources_from_files(
reranking_function,
r,
hybrid_search,
full_context=False,
):
log.debug(f"files: {files} {queries} {embedding_function} {reranking_function}")
log.debug(
f"files: {files} {queries} {embedding_function} {reranking_function} {full_context}"
)
extracted_collections = []
relevant_contexts = []
for file in files:
if file.get("context") == "full":
context = None
if file.get("docs"):
# BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
context = {
"documents": [[doc.get("content") for doc in file.get("docs")]],
"metadatas": [[doc.get("metadata") for doc in file.get("docs")]],
}
elif file.get("context") == "full":
# Manual Full Mode Toggle
context = {
"documents": [[file.get("file").get("data", {}).get("content")]],
"metadatas": [[{"file_id": file.get("id"), "name": file.get("name")}]],
}
else:
context = None
elif (
file.get("type") != "web_search"
and request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
):
# BYPASS_EMBEDDING_AND_RETRIEVAL
if file.get("type") == "collection":
file_ids = file.get("data", {}).get("file_ids", [])
documents = []
metadatas = []
for file_id in file_ids:
file_object = Files.get_file_by_id(file_id)
if file_object:
documents.append(file_object.data.get("content", ""))
metadatas.append(
{
"file_id": file_id,
"name": file_object.filename,
"source": file_object.filename,
}
)
context = {
"documents": [documents],
"metadatas": [metadatas],
}
elif file.get("id"):
file_object = Files.get_file_by_id(file.get("id"))
if file_object:
context = {
"documents": [[file_object.data.get("content", "")]],
"metadatas": [
[
{
"file_id": file.get("id"),
"name": file_object.filename,
"source": file_object.filename,
}
]
],
}
elif file.get("file").get("data"):
context = {
"documents": [[file.get("file").get("data", {}).get("content")]],
"metadatas": [
[file.get("file").get("data", {}).get("metadata", {})]
],
}
else:
collection_names = []
if file.get("type") == "collection":
if file.get("legacy"):
@ -323,42 +441,50 @@ def get_sources_from_files(
log.debug(f"skipping {file} as it has already been extracted")
continue
try:
context = None
if file.get("type") == "text":
context = file["content"]
else:
if hybrid_search:
try:
context = query_collection_with_hybrid_search(
if full_context:
try:
context = get_all_items_from_collections(collection_names)
except Exception as e:
log.exception(e)
else:
try:
context = None
if file.get("type") == "text":
context = file["content"]
else:
if hybrid_search:
try:
context = query_collection_with_hybrid_search(
collection_names=collection_names,
queries=queries,
embedding_function=embedding_function,
k=k,
reranking_function=reranking_function,
r=r,
)
except Exception as e:
log.debug(
"Error when using hybrid search, using"
" non hybrid search as fallback."
)
if (not hybrid_search) or (context is None):
context = query_collection(
collection_names=collection_names,
queries=queries,
embedding_function=embedding_function,
k=k,
reranking_function=reranking_function,
r=r,
)
except Exception as e:
log.debug(
"Error when using hybrid search, using"
" non hybrid search as fallback."
)
if (not hybrid_search) or (context is None):
context = query_collection(
collection_names=collection_names,
queries=queries,
embedding_function=embedding_function,
k=k,
)
except Exception as e:
log.exception(e)
except Exception as e:
log.exception(e)
extracted_collections.extend(collection_names)
if context:
if "data" in file:
del file["data"]
relevant_contexts.append({**context, "file": file})
sources = []
@ -423,7 +549,11 @@ def get_model_path(model: str, update_model: bool = False):
def generate_openai_batch_embeddings(
model: str, texts: list[str], url: str = "https://api.openai.com/v1", key: str = ""
model: str,
texts: list[str],
url: str = "https://api.openai.com/v1",
key: str = "",
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
r = requests.post(
@ -431,6 +561,16 @@ def generate_openai_batch_embeddings(
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
json={"input": texts, "model": model},
)
@ -441,12 +581,12 @@ def generate_openai_batch_embeddings(
else:
raise "Something went wrong :/"
except Exception as e:
print(e)
log.exception(f"Error generating openai batch embeddings: {e}")
return None
def generate_ollama_batch_embeddings(
model: str, texts: list[str], url: str, key: str = ""
model: str, texts: list[str], url: str, key: str = "", user: UserModel = None
) -> Optional[list[list[float]]]:
try:
r = requests.post(
@ -454,6 +594,16 @@ def generate_ollama_batch_embeddings(
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
},
json={"input": texts, "model": model},
)
@ -465,29 +615,36 @@ def generate_ollama_batch_embeddings(
else:
raise "Something went wrong :/"
except Exception as e:
print(e)
log.exception(f"Error generating ollama batch embeddings: {e}")
return None
def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
url = kwargs.get("url", "")
key = kwargs.get("key", "")
user = kwargs.get("user")
if engine == "ollama":
if isinstance(text, list):
embeddings = generate_ollama_batch_embeddings(
**{"model": model, "texts": text, "url": url, "key": key}
**{"model": model, "texts": text, "url": url, "key": key, "user": user}
)
else:
embeddings = generate_ollama_batch_embeddings(
**{"model": model, "texts": [text], "url": url, "key": key}
**{
"model": model,
"texts": [text],
"url": url,
"key": key,
"user": user,
}
)
return embeddings[0] if isinstance(text, str) else embeddings
elif engine == "openai":
if isinstance(text, list):
embeddings = generate_openai_batch_embeddings(model, text, url, key)
embeddings = generate_openai_batch_embeddings(model, text, url, key, user)
else:
embeddings = generate_openai_batch_embeddings(model, [text], url, key)
embeddings = generate_openai_batch_embeddings(model, [text], url, key, user)
return embeddings[0] if isinstance(text, str) else embeddings

View File

@ -16,6 +16,10 @@ elif VECTOR_DB == "pgvector":
from open_webui.retrieval.vector.dbs.pgvector import PgvectorClient
VECTOR_DB_CLIENT = PgvectorClient()
elif VECTOR_DB == "elasticsearch":
from open_webui.retrieval.vector.dbs.elasticsearch import ElasticsearchClient
VECTOR_DB_CLIENT = ElasticsearchClient()
else:
from open_webui.retrieval.vector.dbs.chroma import ChromaClient

8
backend/open_webui/retrieval/vector/dbs/chroma.py Normal file → Executable file
View File

@ -1,4 +1,5 @@
import chromadb
import logging
from chromadb import Settings
from chromadb.utils.batch_utils import create_batches
@ -16,6 +17,10 @@ from open_webui.config import (
CHROMA_CLIENT_AUTH_PROVIDER,
CHROMA_CLIENT_AUTH_CREDENTIALS,
)
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class ChromaClient:
@ -102,8 +107,7 @@ class ChromaClient:
}
)
return None
except Exception as e:
print(e)
except:
return None
def get(self, collection_name: str) -> Optional[GetResult]:

View File

@ -0,0 +1,274 @@
from elasticsearch import Elasticsearch, BadRequestError
from typing import Optional
import ssl
from elasticsearch.helpers import bulk, scan
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.config import (
ELASTICSEARCH_URL,
ELASTICSEARCH_CA_CERTS,
ELASTICSEARCH_API_KEY,
ELASTICSEARCH_USERNAME,
ELASTICSEARCH_PASSWORD,
ELASTICSEARCH_CLOUD_ID,
SSL_ASSERT_FINGERPRINT,
)
class ElasticsearchClient:
"""
Important:
in order to reduce the number of indexes and since the embedding vector length is fixed, we avoid creating
an index for each file but store it as a text field, while seperating to different index
baesd on the embedding length.
"""
def __init__(self):
self.index_prefix = "open_webui_collections"
self.client = Elasticsearch(
hosts=[ELASTICSEARCH_URL],
ca_certs=ELASTICSEARCH_CA_CERTS,
api_key=ELASTICSEARCH_API_KEY,
cloud_id=ELASTICSEARCH_CLOUD_ID,
basic_auth=(
(ELASTICSEARCH_USERNAME, ELASTICSEARCH_PASSWORD)
if ELASTICSEARCH_USERNAME and ELASTICSEARCH_PASSWORD
else None
),
ssl_assert_fingerprint=SSL_ASSERT_FINGERPRINT,
)
# Status: works
def _get_index_name(self, dimension: int) -> str:
return f"{self.index_prefix}_d{str(dimension)}"
# Status: works
def _scan_result_to_get_result(self, result) -> GetResult:
if not result:
return None
ids = []
documents = []
metadatas = []
for hit in result:
ids.append(hit["_id"])
documents.append(hit["_source"].get("text"))
metadatas.append(hit["_source"].get("metadata"))
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
# Status: works
def _result_to_get_result(self, result) -> GetResult:
if not result["hits"]["hits"]:
return None
ids = []
documents = []
metadatas = []
for hit in result["hits"]["hits"]:
ids.append(hit["_id"])
documents.append(hit["_source"].get("text"))
metadatas.append(hit["_source"].get("metadata"))
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
# Status: works
def _result_to_search_result(self, result) -> SearchResult:
ids = []
distances = []
documents = []
metadatas = []
for hit in result["hits"]["hits"]:
ids.append(hit["_id"])
distances.append(hit["_score"])
documents.append(hit["_source"].get("text"))
metadatas.append(hit["_source"].get("metadata"))
return SearchResult(
ids=[ids],
distances=[distances],
documents=[documents],
metadatas=[metadatas],
)
# Status: works
def _create_index(self, dimension: int):
body = {
"mappings": {
"properties": {
"collection": {"type": "keyword"},
"id": {"type": "keyword"},
"vector": {
"type": "dense_vector",
"dims": dimension, # Adjust based on your vector dimensions
"index": True,
"similarity": "cosine",
},
"text": {"type": "text"},
"metadata": {"type": "object"},
}
}
}
self.client.indices.create(index=self._get_index_name(dimension), body=body)
# Status: works
def _create_batches(self, items: list[VectorItem], batch_size=100):
for i in range(0, len(items), batch_size):
yield items[i : min(i + batch_size, len(items))]
# Status: works
def has_collection(self, collection_name) -> bool:
query_body = {"query": {"bool": {"filter": []}}}
query_body["query"]["bool"]["filter"].append(
{"term": {"collection": collection_name}}
)
try:
result = self.client.count(index=f"{self.index_prefix}*", body=query_body)
return result.body["count"] > 0
except Exception as e:
return None
# @TODO: Make this delete a collection and not an index
def delete_colleciton(self, collection_name: str):
# TODO: fix this to include the dimension or a * prefix
# delete_collection here means delete a bunch of documents for an index.
# We are simply adapting to the norms of the other DBs.
self.client.indices.delete(index=self._get_collection_name(collection_name))
# Status: works
def search(
self, collection_name: str, vectors: list[list[float]], limit: int
) -> Optional[SearchResult]:
query = {
"size": limit,
"_source": ["text", "metadata"],
"query": {
"script_score": {
"query": {
"bool": {"filter": [{"term": {"collection": collection_name}}]}
},
"script": {
"source": "cosineSimilarity(params.vector, 'vector') + 1.0",
"params": {
"vector": vectors[0]
}, # Assuming single query vector
},
}
},
}
result = self.client.search(
index=self._get_index_name(len(vectors[0])), body=query
)
return self._result_to_search_result(result)
# Status: only tested halfwat
def query(
self, collection_name: str, filter: dict, limit: Optional[int] = None
) -> Optional[GetResult]:
if not self.has_collection(collection_name):
return None
query_body = {
"query": {"bool": {"filter": []}},
"_source": ["text", "metadata"],
}
for field, value in filter.items():
query_body["query"]["bool"]["filter"].append({"term": {field: value}})
query_body["query"]["bool"]["filter"].append(
{"term": {"collection": collection_name}}
)
size = limit if limit else 10
try:
result = self.client.search(
index=f"{self.index_prefix}*",
body=query_body,
size=size,
)
return self._result_to_get_result(result)
except Exception as e:
return None
# Status: works
def _has_index(self, dimension: int):
return self.client.indices.exists(
index=self._get_index_name(dimension=dimension)
)
def get_or_create_index(self, dimension: int):
if not self._has_index(dimension=dimension):
self._create_index(dimension=dimension)
# Status: works
def get(self, collection_name: str) -> Optional[GetResult]:
# Get all the items in the collection.
query = {
"query": {"bool": {"filter": [{"term": {"collection": collection_name}}]}},
"_source": ["text", "metadata"],
}
results = list(scan(self.client, index=f"{self.index_prefix}*", query=query))
return self._scan_result_to_get_result(results)
# Status: works
def insert(self, collection_name: str, items: list[VectorItem]):
if not self._has_index(dimension=len(items[0]["vector"])):
self._create_index(dimension=len(items[0]["vector"]))
for batch in self._create_batches(items):
actions = [
{
"_index": self._get_index_name(dimension=len(items[0]["vector"])),
"_id": item["id"],
"_source": {
"collection": collection_name,
"vector": item["vector"],
"text": item["text"],
"metadata": item["metadata"],
},
}
for item in batch
]
bulk(self.client, actions)
# Status: should work
def upsert(self, collection_name: str, items: list[VectorItem]):
if not self._has_index(dimension=len(items[0]["vector"])):
self._create_index(collection_name, dimension=len(items[0]["vector"]))
for batch in self._create_batches(items):
actions = [
{
"_index": self._get_index_name(dimension=len(items[0]["vector"])),
"_id": item["id"],
"_source": {
"vector": item["vector"],
"text": item["text"],
"metadata": item["metadata"],
},
}
for item in batch
]
self.client.bulk(actions)
# TODO: This currently deletes by * which is not always supported in ElasticSearch.
# Need to read a bit before changing. Also, need to delete from a specific collection
def delete(self, collection_name: str, ids: list[str]):
# Assuming ID is unique across collections and indexes
actions = [
{"delete": {"_index": f"{self.index_prefix}*", "_id": id}} for id in ids
]
self.client.bulk(body=actions)
def reset(self):
indices = self.client.indices.get(index=f"{self.index_prefix}*")
for index in indices:
self.client.indices.delete(index=index)

View File

@ -1,20 +1,28 @@
from pymilvus import MilvusClient as Client
from pymilvus import FieldSchema, DataType
import json
import logging
from typing import Optional
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.config import (
MILVUS_URI,
MILVUS_DB,
MILVUS_TOKEN,
)
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class MilvusClient:
def __init__(self):
self.collection_prefix = "open_webui"
self.client = Client(uri=MILVUS_URI, database=MILVUS_DB)
if MILVUS_TOKEN is None:
self.client = Client(uri=MILVUS_URI, db_name=MILVUS_DB)
else:
self.client = Client(uri=MILVUS_URI, db_name=MILVUS_DB, token=MILVUS_TOKEN)
def _result_to_get_result(self, result) -> GetResult:
ids = []
@ -164,7 +172,7 @@ class MilvusClient:
try:
# Loop until there are no more items to fetch or the desired limit is reached
while remaining > 0:
print("remaining", remaining)
log.info(f"remaining: {remaining}")
current_fetch = min(
max_limit, remaining
) # Determine how many items to fetch in this iteration
@ -191,10 +199,12 @@ class MilvusClient:
if results_count < current_fetch:
break
print(all_results)
log.debug(all_results)
return self._result_to_get_result([all_results])
except Exception as e:
print(e)
log.exception(
f"Error querying collection {collection_name} with limit {limit}: {e}"
)
return None
def get(self, collection_name: str) -> Optional[GetResult]:

View File

@ -49,7 +49,7 @@ class OpenSearchClient:
ids=ids, distances=distances, documents=documents, metadatas=metadatas
)
def _create_index(self, index_name: str, dimension: int):
def _create_index(self, collection_name: str, dimension: int):
body = {
"mappings": {
"properties": {
@ -72,24 +72,28 @@ class OpenSearchClient:
}
}
}
self.client.indices.create(index=f"{self.index_prefix}_{index_name}", body=body)
self.client.indices.create(
index=f"{self.index_prefix}_{collection_name}", body=body
)
def _create_batches(self, items: list[VectorItem], batch_size=100):
for i in range(0, len(items), batch_size):
yield items[i : i + batch_size]
def has_collection(self, index_name: str) -> bool:
def has_collection(self, collection_name: str) -> bool:
# has_collection here means has index.
# We are simply adapting to the norms of the other DBs.
return self.client.indices.exists(index=f"{self.index_prefix}_{index_name}")
return self.client.indices.exists(
index=f"{self.index_prefix}_{collection_name}"
)
def delete_colleciton(self, index_name: str):
def delete_colleciton(self, collection_name: str):
# delete_collection here means delete index.
# We are simply adapting to the norms of the other DBs.
self.client.indices.delete(index=f"{self.index_prefix}_{index_name}")
self.client.indices.delete(index=f"{self.index_prefix}_{collection_name}")
def search(
self, index_name: str, vectors: list[list[float]], limit: int
self, collection_name: str, vectors: list[list[float]], limit: int
) -> Optional[SearchResult]:
query = {
"size": limit,
@ -108,26 +112,55 @@ class OpenSearchClient:
}
result = self.client.search(
index=f"{self.index_prefix}_{index_name}", body=query
index=f"{self.index_prefix}_{collection_name}", body=query
)
return self._result_to_search_result(result)
def get_or_create_index(self, index_name: str, dimension: int):
if not self.has_index(index_name):
self._create_index(index_name, dimension)
def query(
self, collection_name: str, filter: dict, limit: Optional[int] = None
) -> Optional[GetResult]:
if not self.has_collection(collection_name):
return None
def get(self, index_name: str) -> Optional[GetResult]:
query_body = {
"query": {"bool": {"filter": []}},
"_source": ["text", "metadata"],
}
for field, value in filter.items():
query_body["query"]["bool"]["filter"].append({"term": {field: value}})
size = limit if limit else 10
try:
result = self.client.search(
index=f"{self.index_prefix}_{collection_name}",
body=query_body,
size=size,
)
return self._result_to_get_result(result)
except Exception as e:
return None
def _create_index_if_not_exists(self, collection_name: str, dimension: int):
if not self.has_index(collection_name):
self._create_index(collection_name, dimension)
def get(self, collection_name: str) -> Optional[GetResult]:
query = {"query": {"match_all": {}}, "_source": ["text", "metadata"]}
result = self.client.search(
index=f"{self.index_prefix}_{index_name}", body=query
index=f"{self.index_prefix}_{collection_name}", body=query
)
return self._result_to_get_result(result)
def insert(self, index_name: str, items: list[VectorItem]):
if not self.has_index(index_name):
self._create_index(index_name, dimension=len(items[0]["vector"]))
def insert(self, collection_name: str, items: list[VectorItem]):
self._create_index_if_not_exists(
collection_name=collection_name, dimension=len(items[0]["vector"])
)
for batch in self._create_batches(items):
actions = [
@ -145,15 +178,17 @@ class OpenSearchClient:
]
self.client.bulk(actions)
def upsert(self, index_name: str, items: list[VectorItem]):
if not self.has_index(index_name):
self._create_index(index_name, dimension=len(items[0]["vector"]))
def upsert(self, collection_name: str, items: list[VectorItem]):
self._create_index_if_not_exists(
collection_name=collection_name, dimension=len(items[0]["vector"])
)
for batch in self._create_batches(items):
actions = [
{
"index": {
"_id": item["id"],
"_index": f"{self.index_prefix}_{collection_name}",
"_source": {
"vector": item["vector"],
"text": item["text"],
@ -165,9 +200,9 @@ class OpenSearchClient:
]
self.client.bulk(actions)
def delete(self, index_name: str, ids: list[str]):
def delete(self, collection_name: str, ids: list[str]):
actions = [
{"delete": {"_index": f"{self.index_prefix}_{index_name}", "_id": id}}
{"delete": {"_index": f"{self.index_prefix}_{collection_name}", "_id": id}}
for id in ids
]
self.client.bulk(body=actions)

View File

@ -1,4 +1,5 @@
from typing import Optional, List, Dict, Any
import logging
from sqlalchemy import (
cast,
column,
@ -24,9 +25,14 @@ from sqlalchemy.exc import NoSuchTableError
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
from open_webui.env import SRC_LOG_LEVELS
VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
Base = declarative_base()
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class DocumentChunk(Base):
__tablename__ = "document_chunk"
@ -82,10 +88,10 @@ class PgvectorClient:
)
)
self.session.commit()
print("Initialization complete.")
log.info("Initialization complete.")
except Exception as e:
self.session.rollback()
print(f"Error during initialization: {e}")
log.exception(f"Error during initialization: {e}")
raise
def check_vector_length(self) -> None:
@ -150,12 +156,12 @@ class PgvectorClient:
new_items.append(new_chunk)
self.session.bulk_save_objects(new_items)
self.session.commit()
print(
log.info(
f"Inserted {len(new_items)} items into collection '{collection_name}'."
)
except Exception as e:
self.session.rollback()
print(f"Error during insert: {e}")
log.exception(f"Error during insert: {e}")
raise
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
@ -184,10 +190,12 @@ class PgvectorClient:
)
self.session.add(new_chunk)
self.session.commit()
print(f"Upserted {len(items)} items into collection '{collection_name}'.")
log.info(
f"Upserted {len(items)} items into collection '{collection_name}'."
)
except Exception as e:
self.session.rollback()
print(f"Error during upsert: {e}")
log.exception(f"Error during upsert: {e}")
raise
def search(
@ -278,7 +286,7 @@ class PgvectorClient:
ids=ids, distances=distances, documents=documents, metadatas=metadatas
)
except Exception as e:
print(f"Error during search: {e}")
log.exception(f"Error during search: {e}")
return None
def query(
@ -310,7 +318,7 @@ class PgvectorClient:
metadatas=metadatas,
)
except Exception as e:
print(f"Error during query: {e}")
log.exception(f"Error during query: {e}")
return None
def get(
@ -334,7 +342,7 @@ class PgvectorClient:
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
except Exception as e:
print(f"Error during get: {e}")
log.exception(f"Error during get: {e}")
return None
def delete(
@ -356,22 +364,22 @@ class PgvectorClient:
)
deleted = query.delete(synchronize_session=False)
self.session.commit()
print(f"Deleted {deleted} items from collection '{collection_name}'.")
log.info(f"Deleted {deleted} items from collection '{collection_name}'.")
except Exception as e:
self.session.rollback()
print(f"Error during delete: {e}")
log.exception(f"Error during delete: {e}")
raise
def reset(self) -> None:
try:
deleted = self.session.query(DocumentChunk).delete()
self.session.commit()
print(
log.info(
f"Reset complete. Deleted {deleted} items from 'document_chunk' table."
)
except Exception as e:
self.session.rollback()
print(f"Error during reset: {e}")
log.exception(f"Error during reset: {e}")
raise
def close(self) -> None:
@ -387,9 +395,9 @@ class PgvectorClient:
)
return exists
except Exception as e:
print(f"Error checking collection existence: {e}")
log.exception(f"Error checking collection existence: {e}")
return False
def delete_collection(self, collection_name: str) -> None:
self.delete(collection_name)
print(f"Collection '{collection_name}' deleted.")
log.info(f"Collection '{collection_name}' deleted.")

View File

@ -1,4 +1,5 @@
from typing import Optional
import logging
from qdrant_client import QdrantClient as Qclient
from qdrant_client.http.models import PointStruct
@ -6,9 +7,13 @@ from qdrant_client.models import models
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.config import QDRANT_URI, QDRANT_API_KEY
from open_webui.env import SRC_LOG_LEVELS
NO_LIMIT = 999999999
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class QdrantClient:
def __init__(self):
@ -49,7 +54,7 @@ class QdrantClient:
),
)
print(f"collection {collection_name_with_prefix} successfully created!")
log.info(f"collection {collection_name_with_prefix} successfully created!")
def _create_collection_if_not_exists(self, collection_name, dimension):
if not self.has_collection(collection_name=collection_name):
@ -120,7 +125,7 @@ class QdrantClient:
)
return self._result_to_get_result(points.points)
except Exception as e:
print(e)
log.exception(f"Error querying a collection '{collection_name}': {e}")
return None
def get(self, collection_name: str) -> Optional[GetResult]:

View File

@ -0,0 +1,65 @@
import logging
from typing import Optional
import requests
import json
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def _parse_response(response):
result = {}
if "data" in response:
data = response["data"]
if "webPages" in data:
webPages = data["webPages"]
if "value" in webPages:
result["webpage"] = [
{
"id": item.get("id", ""),
"name": item.get("name", ""),
"url": item.get("url", ""),
"snippet": item.get("snippet", ""),
"summary": item.get("summary", ""),
"siteName": item.get("siteName", ""),
"siteIcon": item.get("siteIcon", ""),
"datePublished": item.get("datePublished", "")
or item.get("dateLastCrawled", ""),
}
for item in webPages["value"]
]
return result
def search_bocha(
api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None
) -> list[SearchResult]:
"""Search using Bocha's Search API and return the results as a list of SearchResult objects.
Args:
api_key (str): A Bocha Search API key
query (str): The query to search for
"""
url = "https://api.bochaai.com/v1/web-search?utm_source=ollama"
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
payload = json.dumps(
{"query": query, "summary": True, "freshness": "noLimit", "count": count}
)
response = requests.post(url, headers=headers, data=payload, timeout=5)
response.raise_for_status()
results = _parse_response(response.json())
print(results)
if filter_list:
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["url"], title=result.get("name"), snippet=result.get("summary")
)
for result in results.get("webpage", [])[:count]
]

View File

@ -32,19 +32,15 @@ def search_duckduckgo(
# Convert the search results into a list
search_results = [r for r in ddgs_gen]
# Create an empty list to store the SearchResult objects
results = []
# Iterate over each search result
for result in search_results:
# Create a SearchResult object and append it to the results list
results.append(
SearchResult(
link=result["href"],
title=result.get("title"),
snippet=result.get("body"),
)
)
if filter_list:
results = get_filtered_results(results, filter_list)
search_results = get_filtered_results(search_results, filter_list)
# Return the list of search results
return results
return [
SearchResult(
link=result["href"],
title=result.get("title"),
snippet=result.get("body"),
)
for result in search_results
]

View File

@ -0,0 +1,76 @@
import logging
from dataclasses import dataclass
from typing import Optional
import requests
from open_webui.env import SRC_LOG_LEVELS
from open_webui.retrieval.web.main import SearchResult
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
EXA_API_BASE = "https://api.exa.ai"
@dataclass
class ExaResult:
url: str
title: str
text: str
def search_exa(
api_key: str,
query: str,
count: int,
filter_list: Optional[list[str]] = None,
) -> list[SearchResult]:
"""Search using Exa Search API and return the results as a list of SearchResult objects.
Args:
api_key (str): A Exa Search API key
query (str): The query to search for
count (int): Number of results to return
filter_list (Optional[list[str]]): List of domains to filter results by
"""
log.info(f"Searching with Exa for query: {query}")
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
payload = {
"query": query,
"numResults": count or 5,
"includeDomains": filter_list,
"contents": {"text": True, "highlights": True},
"type": "auto", # Use the auto search type (keyword or neural)
}
try:
response = requests.post(
f"{EXA_API_BASE}/search", headers=headers, json=payload
)
response.raise_for_status()
data = response.json()
results = []
for result in data["results"]:
results.append(
ExaResult(
url=result["url"],
title=result["title"],
text=result["text"],
)
)
log.info(f"Found {len(results)} results")
return [
SearchResult(
link=result.url,
title=result.title,
snippet=result.text,
)
for result in results
]
except Exception as e:
log.error(f"Error searching Exa: {e}")
return []

View File

@ -17,34 +17,53 @@ def search_google_pse(
filter_list: Optional[list[str]] = None,
) -> list[SearchResult]:
"""Search using Google's Programmable Search Engine API and return the results as a list of SearchResult objects.
Handles pagination for counts greater than 10.
Args:
api_key (str): A Programmable Search Engine API key
search_engine_id (str): A Programmable Search Engine ID
query (str): The query to search for
count (int): The number of results to return (max 100, as PSE max results per query is 10 and max page is 10)
filter_list (Optional[list[str]], optional): A list of keywords to filter out from results. Defaults to None.
Returns:
list[SearchResult]: A list of SearchResult objects.
"""
url = "https://www.googleapis.com/customsearch/v1"
headers = {"Content-Type": "application/json"}
params = {
"cx": search_engine_id,
"q": query,
"key": api_key,
"num": count,
}
all_results = []
start_index = 1 # Google PSE start parameter is 1-based
response = requests.request("GET", url, headers=headers, params=params)
response.raise_for_status()
while count > 0:
num_results_this_page = min(count, 10) # Google PSE max results per page is 10
params = {
"cx": search_engine_id,
"q": query,
"key": api_key,
"num": num_results_this_page,
"start": start_index,
}
response = requests.request("GET", url, headers=headers, params=params)
response.raise_for_status()
json_response = response.json()
results = json_response.get("items", [])
if results: # check if results are returned. If not, no more pages to fetch.
all_results.extend(results)
count -= len(
results
) # Decrement count by the number of results fetched in this page.
start_index += 10 # Increment start index for the next page
else:
break # No more results from Google PSE, break the loop
json_response = response.json()
results = json_response.get("items", [])
if filter_list:
results = get_filtered_results(results, filter_list)
all_results = get_filtered_results(all_results, filter_list)
return [
SearchResult(
link=result["link"],
title=result.get("title"),
snippet=result.get("snippet"),
)
for result in results
for result in all_results
]

View File

@ -20,14 +20,23 @@ def search_jina(api_key: str, query: str, count: int) -> list[SearchResult]:
list[SearchResult]: A list of search results
"""
jina_search_endpoint = "https://s.jina.ai/"
headers = {"Accept": "application/json", "Authorization": f"Bearer {api_key}"}
url = str(URL(jina_search_endpoint + query))
response = requests.get(url, headers=headers)
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
"Authorization": api_key,
"X-Retain-Images": "none",
}
payload = {"q": query, "count": count if count <= 10 else 10}
url = str(URL(jina_search_endpoint))
response = requests.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
results = []
for result in data["data"][:count]:
for result in data["data"]:
results.append(
SearchResult(
link=result["url"],

View File

@ -1,3 +1,5 @@
import validators
from typing import Optional
from urllib.parse import urlparse
@ -10,6 +12,8 @@ def get_filtered_results(results, filter_list):
filtered_results = []
for result in results:
url = result.get("url") or result.get("link", "")
if not validators.url(url):
continue
domain = urlparse(url).netloc
if any(domain.endswith(filtered_domain) for filtered_domain in filter_list):
filtered_results.append(result)

View File

@ -0,0 +1,87 @@
import logging
from typing import Optional, List
import requests
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_perplexity(
api_key: str,
query: str,
count: int,
filter_list: Optional[list[str]] = None,
) -> list[SearchResult]:
"""Search using Perplexity API and return the results as a list of SearchResult objects.
Args:
api_key (str): A Perplexity API key
query (str): The query to search for
count (int): Maximum number of results to return
"""
# Handle PersistentConfig object
if hasattr(api_key, "__str__"):
api_key = str(api_key)
try:
url = "https://api.perplexity.ai/chat/completions"
# Create payload for the API call
payload = {
"model": "sonar",
"messages": [
{
"role": "system",
"content": "You are a search assistant. Provide factual information with citations.",
},
{"role": "user", "content": query},
],
"temperature": 0.2, # Lower temperature for more factual responses
"stream": False,
}
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
# Make the API request
response = requests.request("POST", url, json=payload, headers=headers)
# Parse the JSON response
json_response = response.json()
# Extract citations from the response
citations = json_response.get("citations", [])
# Create search results from citations
results = []
for i, citation in enumerate(citations[:count]):
# Extract content from the response to use as snippet
content = ""
if "choices" in json_response and json_response["choices"]:
if i == 0:
content = json_response["choices"][0]["message"]["content"]
result = {"link": citation, "title": f"Source {i+1}", "snippet": content}
results.append(result)
if filter_list:
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["link"], title=result["title"], snippet=result["snippet"]
)
for result in results[:count]
]
except Exception as e:
log.error(f"Error searching with Perplexity API: {e}")
return []

View File

@ -0,0 +1,48 @@
import logging
from typing import Optional
from urllib.parse import urlencode
import requests
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_serpapi(
api_key: str,
engine: str,
query: str,
count: int,
filter_list: Optional[list[str]] = None,
) -> list[SearchResult]:
"""Search using serpapi.com's API and return the results as a list of SearchResult objects.
Args:
api_key (str): A serpapi.com API key
query (str): The query to search for
"""
url = "https://serpapi.com/search"
engine = engine or "google"
payload = {"engine": engine, "q": query, "api_key": api_key}
url = f"{url}?{urlencode(payload)}"
response = requests.request("GET", url)
json_response = response.json()
log.info(f"results from serpapi search: {json_response}")
results = sorted(
json_response.get("organic_results", []), key=lambda x: x.get("position", 0)
)
if filter_list:
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["link"], title=result["title"], snippet=result["snippet"]
)
for result in results[:count]
]

View File

@ -1,4 +1,5 @@
import logging
from typing import Optional
import requests
from open_webui.retrieval.web.main import SearchResult
@ -8,7 +9,13 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_tavily(api_key: str, query: str, count: int) -> list[SearchResult]:
def search_tavily(
api_key: str,
query: str,
count: int,
filter_list: Optional[list[str]] = None,
# **kwargs,
) -> list[SearchResult]:
"""Search using Tavily's Search API and return the results as a list of SearchResult objects.
Args:
@ -20,7 +27,6 @@ def search_tavily(api_key: str, query: str, count: int) -> list[SearchResult]:
"""
url = "https://api.tavily.com/search"
data = {"query": query, "api_key": api_key}
response = requests.post(url, json=data)
response.raise_for_status()

View File

@ -1,19 +1,38 @@
import socket
import urllib.parse
import validators
from typing import Union, Sequence, Iterator
from langchain_community.document_loaders import (
WebBaseLoader,
)
from langchain_core.documents import Document
from open_webui.constants import ERROR_MESSAGES
from open_webui.config import ENABLE_RAG_LOCAL_WEB_FETCH
from open_webui.env import SRC_LOG_LEVELS
import asyncio
import logging
import socket
import ssl
import urllib.parse
import urllib.request
from collections import defaultdict
from datetime import datetime, time, timedelta
from typing import (
Any,
AsyncIterator,
Dict,
Iterator,
List,
Optional,
Sequence,
Union,
Literal,
)
import aiohttp
import certifi
import validators
from langchain_community.document_loaders import PlaywrightURLLoader, WebBaseLoader
from langchain_community.document_loaders.firecrawl import FireCrawlLoader
from langchain_community.document_loaders.base import BaseLoader
from langchain_core.documents import Document
from open_webui.constants import ERROR_MESSAGES
from open_webui.config import (
ENABLE_RAG_LOCAL_WEB_FETCH,
PLAYWRIGHT_WS_URI,
RAG_WEB_LOADER_ENGINE,
FIRECRAWL_API_BASE_URL,
FIRECRAWL_API_KEY,
)
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
@ -43,6 +62,17 @@ def validate_url(url: Union[str, Sequence[str]]):
return False
def safe_validate_urls(url: Sequence[str]) -> Sequence[str]:
valid_urls = []
for u in url:
try:
if validate_url(u):
valid_urls.append(u)
except ValueError:
continue
return valid_urls
def resolve_hostname(hostname):
# Get address information
addr_info = socket.getaddrinfo(hostname, None)
@ -54,9 +84,381 @@ def resolve_hostname(hostname):
return ipv4_addresses, ipv6_addresses
def extract_metadata(soup, url):
metadata = {"source": url}
if title := soup.find("title"):
metadata["title"] = title.get_text()
if description := soup.find("meta", attrs={"name": "description"}):
metadata["description"] = description.get("content", "No description found.")
if html := soup.find("html"):
metadata["language"] = html.get("lang", "No language found.")
return metadata
def verify_ssl_cert(url: str) -> bool:
"""Verify SSL certificate for the given URL."""
if not url.startswith("https://"):
return True
try:
hostname = url.split("://")[-1].split("/")[0]
context = ssl.create_default_context(cafile=certifi.where())
with context.wrap_socket(ssl.socket(), server_hostname=hostname) as s:
s.connect((hostname, 443))
return True
except ssl.SSLError:
return False
except Exception as e:
log.warning(f"SSL verification failed for {url}: {str(e)}")
return False
class SafeFireCrawlLoader(BaseLoader):
def __init__(
self,
web_paths,
verify_ssl: bool = True,
trust_env: bool = False,
requests_per_second: Optional[float] = None,
continue_on_failure: bool = True,
api_key: Optional[str] = None,
api_url: Optional[str] = None,
mode: Literal["crawl", "scrape", "map"] = "crawl",
proxy: Optional[Dict[str, str]] = None,
params: Optional[Dict] = None,
):
"""Concurrent document loader for FireCrawl operations.
Executes multiple FireCrawlLoader instances concurrently using thread pooling
to improve bulk processing efficiency.
Args:
web_paths: List of URLs/paths to process.
verify_ssl: If True, verify SSL certificates.
trust_env: If True, use proxy settings from environment variables.
requests_per_second: Number of requests per second to limit to.
continue_on_failure (bool): If True, continue loading other URLs on failure.
api_key: API key for FireCrawl service. Defaults to None
(uses FIRE_CRAWL_API_KEY environment variable if not provided).
api_url: Base URL for FireCrawl API. Defaults to official API endpoint.
mode: Operation mode selection:
- 'crawl': Website crawling mode (default)
- 'scrape': Direct page scraping
- 'map': Site map generation
proxy: Proxy override settings for the FireCrawl API.
params: The parameters to pass to the Firecrawl API.
Examples include crawlerOptions.
For more details, visit: https://github.com/mendableai/firecrawl-py
"""
proxy_server = proxy.get("server") if proxy else None
if trust_env and not proxy_server:
env_proxies = urllib.request.getproxies()
env_proxy_server = env_proxies.get("https") or env_proxies.get("http")
if env_proxy_server:
if proxy:
proxy["server"] = env_proxy_server
else:
proxy = {"server": env_proxy_server}
self.web_paths = web_paths
self.verify_ssl = verify_ssl
self.requests_per_second = requests_per_second
self.last_request_time = None
self.trust_env = trust_env
self.continue_on_failure = continue_on_failure
self.api_key = api_key
self.api_url = api_url
self.mode = mode
self.params = params
def lazy_load(self) -> Iterator[Document]:
"""Load documents concurrently using FireCrawl."""
for url in self.web_paths:
try:
self._safe_process_url_sync(url)
loader = FireCrawlLoader(
url=url,
api_key=self.api_key,
api_url=self.api_url,
mode=self.mode,
params=self.params,
)
yield from loader.lazy_load()
except Exception as e:
if self.continue_on_failure:
log.exception(e, "Error loading %s", url)
continue
raise e
async def alazy_load(self):
"""Async version of lazy_load."""
for url in self.web_paths:
try:
await self._safe_process_url(url)
loader = FireCrawlLoader(
url=url,
api_key=self.api_key,
api_url=self.api_url,
mode=self.mode,
params=self.params,
)
async for document in loader.alazy_load():
yield document
except Exception as e:
if self.continue_on_failure:
log.exception(e, "Error loading %s", url)
continue
raise e
def _verify_ssl_cert(self, url: str) -> bool:
return verify_ssl_cert(url)
async def _wait_for_rate_limit(self):
"""Wait to respect the rate limit if specified."""
if self.requests_per_second and self.last_request_time:
min_interval = timedelta(seconds=1.0 / self.requests_per_second)
time_since_last = datetime.now() - self.last_request_time
if time_since_last < min_interval:
await asyncio.sleep((min_interval - time_since_last).total_seconds())
self.last_request_time = datetime.now()
def _sync_wait_for_rate_limit(self):
"""Synchronous version of rate limit wait."""
if self.requests_per_second and self.last_request_time:
min_interval = timedelta(seconds=1.0 / self.requests_per_second)
time_since_last = datetime.now() - self.last_request_time
if time_since_last < min_interval:
time.sleep((min_interval - time_since_last).total_seconds())
self.last_request_time = datetime.now()
async def _safe_process_url(self, url: str) -> bool:
"""Perform safety checks before processing a URL."""
if self.verify_ssl and not self._verify_ssl_cert(url):
raise ValueError(f"SSL certificate verification failed for {url}")
await self._wait_for_rate_limit()
return True
def _safe_process_url_sync(self, url: str) -> bool:
"""Synchronous version of safety checks."""
if self.verify_ssl and not self._verify_ssl_cert(url):
raise ValueError(f"SSL certificate verification failed for {url}")
self._sync_wait_for_rate_limit()
return True
class SafePlaywrightURLLoader(PlaywrightURLLoader):
"""Load HTML pages safely with Playwright, supporting SSL verification, rate limiting, and remote browser connection.
Attributes:
web_paths (List[str]): List of URLs to load.
verify_ssl (bool): If True, verify SSL certificates.
trust_env (bool): If True, use proxy settings from environment variables.
requests_per_second (Optional[float]): Number of requests per second to limit to.
continue_on_failure (bool): If True, continue loading other URLs on failure.
headless (bool): If True, the browser will run in headless mode.
proxy (dict): Proxy override settings for the Playwright session.
playwright_ws_url (Optional[str]): WebSocket endpoint URI for remote browser connection.
"""
def __init__(
self,
web_paths: List[str],
verify_ssl: bool = True,
trust_env: bool = False,
requests_per_second: Optional[float] = None,
continue_on_failure: bool = True,
headless: bool = True,
remove_selectors: Optional[List[str]] = None,
proxy: Optional[Dict[str, str]] = None,
playwright_ws_url: Optional[str] = None,
):
"""Initialize with additional safety parameters and remote browser support."""
proxy_server = proxy.get("server") if proxy else None
if trust_env and not proxy_server:
env_proxies = urllib.request.getproxies()
env_proxy_server = env_proxies.get("https") or env_proxies.get("http")
if env_proxy_server:
if proxy:
proxy["server"] = env_proxy_server
else:
proxy = {"server": env_proxy_server}
# We'll set headless to False if using playwright_ws_url since it's handled by the remote browser
super().__init__(
urls=web_paths,
continue_on_failure=continue_on_failure,
headless=headless if playwright_ws_url is None else False,
remove_selectors=remove_selectors,
proxy=proxy,
)
self.verify_ssl = verify_ssl
self.requests_per_second = requests_per_second
self.last_request_time = None
self.playwright_ws_url = playwright_ws_url
self.trust_env = trust_env
def lazy_load(self) -> Iterator[Document]:
"""Safely load URLs synchronously with support for remote browser."""
from playwright.sync_api import sync_playwright
with sync_playwright() as p:
# Use remote browser if ws_endpoint is provided, otherwise use local browser
if self.playwright_ws_url:
browser = p.chromium.connect(self.playwright_ws_url)
else:
browser = p.chromium.launch(headless=self.headless, proxy=self.proxy)
for url in self.urls:
try:
self._safe_process_url_sync(url)
page = browser.new_page()
response = page.goto(url)
if response is None:
raise ValueError(f"page.goto() returned None for url {url}")
text = self.evaluator.evaluate(page, browser, response)
metadata = {"source": url}
yield Document(page_content=text, metadata=metadata)
except Exception as e:
if self.continue_on_failure:
log.exception(e, "Error loading %s", url)
continue
raise e
browser.close()
async def alazy_load(self) -> AsyncIterator[Document]:
"""Safely load URLs asynchronously with support for remote browser."""
from playwright.async_api import async_playwright
async with async_playwright() as p:
# Use remote browser if ws_endpoint is provided, otherwise use local browser
if self.playwright_ws_url:
browser = await p.chromium.connect(self.playwright_ws_url)
else:
browser = await p.chromium.launch(
headless=self.headless, proxy=self.proxy
)
for url in self.urls:
try:
await self._safe_process_url(url)
page = await browser.new_page()
response = await page.goto(url)
if response is None:
raise ValueError(f"page.goto() returned None for url {url}")
text = await self.evaluator.evaluate_async(page, browser, response)
metadata = {"source": url}
yield Document(page_content=text, metadata=metadata)
except Exception as e:
if self.continue_on_failure:
log.exception(e, "Error loading %s", url)
continue
raise e
await browser.close()
def _verify_ssl_cert(self, url: str) -> bool:
return verify_ssl_cert(url)
async def _wait_for_rate_limit(self):
"""Wait to respect the rate limit if specified."""
if self.requests_per_second and self.last_request_time:
min_interval = timedelta(seconds=1.0 / self.requests_per_second)
time_since_last = datetime.now() - self.last_request_time
if time_since_last < min_interval:
await asyncio.sleep((min_interval - time_since_last).total_seconds())
self.last_request_time = datetime.now()
def _sync_wait_for_rate_limit(self):
"""Synchronous version of rate limit wait."""
if self.requests_per_second and self.last_request_time:
min_interval = timedelta(seconds=1.0 / self.requests_per_second)
time_since_last = datetime.now() - self.last_request_time
if time_since_last < min_interval:
time.sleep((min_interval - time_since_last).total_seconds())
self.last_request_time = datetime.now()
async def _safe_process_url(self, url: str) -> bool:
"""Perform safety checks before processing a URL."""
if self.verify_ssl and not self._verify_ssl_cert(url):
raise ValueError(f"SSL certificate verification failed for {url}")
await self._wait_for_rate_limit()
return True
def _safe_process_url_sync(self, url: str) -> bool:
"""Synchronous version of safety checks."""
if self.verify_ssl and not self._verify_ssl_cert(url):
raise ValueError(f"SSL certificate verification failed for {url}")
self._sync_wait_for_rate_limit()
return True
class SafeWebBaseLoader(WebBaseLoader):
"""WebBaseLoader with enhanced error handling for URLs."""
def __init__(self, trust_env: bool = False, *args, **kwargs):
"""Initialize SafeWebBaseLoader
Args:
trust_env (bool, optional): set to True if using proxy to make web requests, for example
using http(s)_proxy environment variables. Defaults to False.
"""
super().__init__(*args, **kwargs)
self.trust_env = trust_env
async def _fetch(
self, url: str, retries: int = 3, cooldown: int = 2, backoff: float = 1.5
) -> str:
async with aiohttp.ClientSession(trust_env=self.trust_env) as session:
for i in range(retries):
try:
kwargs: Dict = dict(
headers=self.session.headers,
cookies=self.session.cookies.get_dict(),
)
if not self.session.verify:
kwargs["ssl"] = False
async with session.get(
url, **(self.requests_kwargs | kwargs)
) as response:
if self.raise_for_status:
response.raise_for_status()
return await response.text()
except aiohttp.ClientConnectionError as e:
if i == retries - 1:
raise
else:
log.warning(
f"Error fetching {url} with attempt "
f"{i + 1}/{retries}: {e}. Retrying..."
)
await asyncio.sleep(cooldown * backoff**i)
raise ValueError("retry count exceeded")
def _unpack_fetch_results(
self, results: Any, urls: List[str], parser: Union[str, None] = None
) -> List[Any]:
"""Unpack fetch results into BeautifulSoup objects."""
from bs4 import BeautifulSoup
final_results = []
for i, result in enumerate(results):
url = urls[i]
if parser is None:
if url.endswith(".xml"):
parser = "xml"
else:
parser = self.default_parser
self._check_parser(parser)
final_results.append(BeautifulSoup(result, parser, **self.bs_kwargs))
return final_results
async def ascrape_all(
self, urls: List[str], parser: Union[str, None] = None
) -> List[Any]:
"""Async fetch all urls, then return soups for all results."""
results = await self.fetch_all(urls)
return self._unpack_fetch_results(results, urls, parser=parser)
def lazy_load(self) -> Iterator[Document]:
"""Lazy load text from the url(s) in web_path with error handling."""
for path in self.web_paths:
@ -65,33 +467,72 @@ class SafeWebBaseLoader(WebBaseLoader):
text = soup.get_text(**self.bs_get_text_kwargs)
# Build metadata
metadata = {"source": path}
if title := soup.find("title"):
metadata["title"] = title.get_text()
if description := soup.find("meta", attrs={"name": "description"}):
metadata["description"] = description.get(
"content", "No description found."
)
if html := soup.find("html"):
metadata["language"] = html.get("lang", "No language found.")
metadata = extract_metadata(soup, path)
yield Document(page_content=text, metadata=metadata)
except Exception as e:
# Log the error and continue with the next URL
log.error(f"Error loading {path}: {e}")
log.exception(e, "Error loading %s", path)
async def alazy_load(self) -> AsyncIterator[Document]:
"""Async lazy load text from the url(s) in web_path."""
results = await self.ascrape_all(self.web_paths)
for path, soup in zip(self.web_paths, results):
text = soup.get_text(**self.bs_get_text_kwargs)
metadata = {"source": path}
if title := soup.find("title"):
metadata["title"] = title.get_text()
if description := soup.find("meta", attrs={"name": "description"}):
metadata["description"] = description.get(
"content", "No description found."
)
if html := soup.find("html"):
metadata["language"] = html.get("lang", "No language found.")
yield Document(page_content=text, metadata=metadata)
async def aload(self) -> list[Document]:
"""Load data into Document objects."""
return [document async for document in self.alazy_load()]
RAG_WEB_LOADER_ENGINES = defaultdict(lambda: SafeWebBaseLoader)
RAG_WEB_LOADER_ENGINES["playwright"] = SafePlaywrightURLLoader
RAG_WEB_LOADER_ENGINES["safe_web"] = SafeWebBaseLoader
RAG_WEB_LOADER_ENGINES["firecrawl"] = SafeFireCrawlLoader
def get_web_loader(
urls: Union[str, Sequence[str]],
verify_ssl: bool = True,
requests_per_second: int = 2,
trust_env: bool = False,
):
# Check if the URL is valid
if not validate_url(urls):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
return SafeWebBaseLoader(
urls,
verify_ssl=verify_ssl,
requests_per_second=requests_per_second,
continue_on_failure=True,
# Check if the URLs are valid
safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls)
web_loader_args = {
"web_paths": safe_urls,
"verify_ssl": verify_ssl,
"requests_per_second": requests_per_second,
"continue_on_failure": True,
"trust_env": trust_env,
}
if PLAYWRIGHT_WS_URI.value:
web_loader_args["playwright_ws_url"] = PLAYWRIGHT_WS_URI.value
if RAG_WEB_LOADER_ENGINE.value == "firecrawl":
web_loader_args["api_key"] = FIRECRAWL_API_KEY.value
web_loader_args["api_url"] = FIRECRAWL_API_BASE_URL.value
# Create the appropriate WebLoader based on the configuration
WebLoaderClass = RAG_WEB_LOADER_ENGINES[RAG_WEB_LOADER_ENGINE.value]
web_loader = WebLoaderClass(**web_loader_args)
log.debug(
"Using RAG_WEB_LOADER_ENGINE %s for %s URLs",
web_loader.__class__.__name__,
len(safe_urls),
)
return web_loader

View File

@ -11,6 +11,7 @@ from pydub.silence import split_on_silence
import aiohttp
import aiofiles
import requests
import mimetypes
from fastapi import (
Depends,
@ -36,6 +37,7 @@ from open_webui.config import (
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import (
AIOHTTP_CLIENT_TIMEOUT,
ENV,
SRC_LOG_LEVELS,
DEVICE_TYPE,
@ -52,7 +54,7 @@ MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["AUDIO"])
SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
SPEECH_CACHE_DIR = CACHE_DIR / "audio" / "speech"
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
@ -69,7 +71,7 @@ from pydub.utils import mediainfo
def is_mp4_audio(file_path):
"""Check if the given file is an MP4 audio file."""
if not os.path.isfile(file_path):
print(f"File not found: {file_path}")
log.error(f"File not found: {file_path}")
return False
info = mediainfo(file_path)
@ -86,7 +88,7 @@ def convert_mp4_to_wav(file_path, output_path):
"""Convert MP4 audio file to WAV format."""
audio = AudioSegment.from_file(file_path, format="mp4")
audio.export(output_path, format="wav")
print(f"Converted {file_path} to {output_path}")
log.info(f"Converted {file_path} to {output_path}")
def set_faster_whisper_model(model: str, auto_update: bool = False):
@ -138,6 +140,7 @@ class STTConfigForm(BaseModel):
ENGINE: str
MODEL: str
WHISPER_MODEL: str
DEEPGRAM_API_KEY: str
class AudioConfigUpdateForm(BaseModel):
@ -165,6 +168,7 @@ async def get_audio_config(request: Request, user=Depends(get_admin_user)):
"ENGINE": request.app.state.config.STT_ENGINE,
"MODEL": request.app.state.config.STT_MODEL,
"WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
"DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY,
},
}
@ -190,6 +194,7 @@ async def update_audio_config(
request.app.state.config.STT_ENGINE = form_data.stt.ENGINE
request.app.state.config.STT_MODEL = form_data.stt.MODEL
request.app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
request.app.state.config.DEEPGRAM_API_KEY = form_data.stt.DEEPGRAM_API_KEY
if request.app.state.config.STT_ENGINE == "":
request.app.state.faster_whisper_model = set_faster_whisper_model(
@ -214,6 +219,7 @@ async def update_audio_config(
"ENGINE": request.app.state.config.STT_ENGINE,
"MODEL": request.app.state.config.STT_MODEL,
"WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
"DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY,
},
}
@ -260,8 +266,10 @@ async def speech(request: Request, user=Depends(get_verified_user)):
payload["model"] = request.app.state.config.TTS_MODEL
try:
# print(payload)
async with aiohttp.ClientSession() as session:
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
async with aiohttp.ClientSession(
timeout=timeout, trust_env=True
) as session:
async with session.post(
url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
json=payload,
@ -318,7 +326,10 @@ async def speech(request: Request, user=Depends(get_verified_user)):
)
try:
async with aiohttp.ClientSession() as session:
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
async with aiohttp.ClientSession(
timeout=timeout, trust_env=True
) as session:
async with session.post(
f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}",
json={
@ -375,7 +386,10 @@ async def speech(request: Request, user=Depends(get_verified_user)):
data = f"""<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="{locale}">
<voice name="{language}">{payload["input"]}</voice>
</speak>"""
async with aiohttp.ClientSession() as session:
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
async with aiohttp.ClientSession(
timeout=timeout, trust_env=True
) as session:
async with session.post(
f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1",
headers={
@ -453,7 +467,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
def transcribe(request: Request, file_path):
print("transcribe", file_path)
log.info(f"transcribe: {file_path}")
filename = os.path.basename(file_path)
file_dir = os.path.dirname(file_path)
id = filename.split(".")[0]
@ -521,6 +535,69 @@ def transcribe(request: Request, file_path):
raise Exception(detail if detail else "Open WebUI: Server Connection Error")
elif request.app.state.config.STT_ENGINE == "deepgram":
try:
# Determine the MIME type of the file
mime, _ = mimetypes.guess_type(file_path)
if not mime:
mime = "audio/wav" # fallback to wav if undetectable
# Read the audio file
with open(file_path, "rb") as f:
file_data = f.read()
# Build headers and parameters
headers = {
"Authorization": f"Token {request.app.state.config.DEEPGRAM_API_KEY}",
"Content-Type": mime,
}
# Add model if specified
params = {}
if request.app.state.config.STT_MODEL:
params["model"] = request.app.state.config.STT_MODEL
# Make request to Deepgram API
r = requests.post(
"https://api.deepgram.com/v1/listen",
headers=headers,
params=params,
data=file_data,
)
r.raise_for_status()
response_data = r.json()
# Extract transcript from Deepgram response
try:
transcript = response_data["results"]["channels"][0]["alternatives"][
0
].get("transcript", "")
except (KeyError, IndexError) as e:
log.error(f"Malformed response from Deepgram: {str(e)}")
raise Exception(
"Failed to parse Deepgram response - unexpected response format"
)
data = {"text": transcript.strip()}
# Save transcript
transcript_file = f"{file_dir}/{id}.json"
with open(transcript_file, "w") as f:
json.dump(data, f)
return data
except Exception as e:
log.exception(e)
detail = None
if r is not None:
try:
res = r.json()
if "error" in res:
detail = f"External: {res['error'].get('message', '')}"
except Exception:
detail = f"External: {e}"
raise Exception(detail if detail else "Open WebUI: Server Connection Error")
def compress_audio(file_path):
if os.path.getsize(file_path) > MAX_FILE_SIZE:
@ -602,7 +679,22 @@ def transcription(
def get_available_models(request: Request) -> list[dict]:
available_models = []
if request.app.state.config.TTS_ENGINE == "openai":
available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
# Use custom endpoint if not using the official OpenAI API URL
if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith(
"https://api.openai.com"
):
try:
response = requests.get(
f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/models"
)
response.raise_for_status()
data = response.json()
available_models = data.get("models", [])
except Exception as e:
log.error(f"Error fetching models from custom endpoint: {str(e)}")
available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
else:
available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
elif request.app.state.config.TTS_ENGINE == "elevenlabs":
try:
response = requests.get(
@ -633,14 +725,37 @@ def get_available_voices(request) -> dict:
"""Returns {voice_id: voice_name} dict"""
available_voices = {}
if request.app.state.config.TTS_ENGINE == "openai":
available_voices = {
"alloy": "alloy",
"echo": "echo",
"fable": "fable",
"onyx": "onyx",
"nova": "nova",
"shimmer": "shimmer",
}
# Use custom endpoint if not using the official OpenAI API URL
if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith(
"https://api.openai.com"
):
try:
response = requests.get(
f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/voices"
)
response.raise_for_status()
data = response.json()
voices_list = data.get("voices", [])
available_voices = {voice["id"]: voice["name"] for voice in voices_list}
except Exception as e:
log.error(f"Error fetching voices from custom endpoint: {str(e)}")
available_voices = {
"alloy": "alloy",
"echo": "echo",
"fable": "fable",
"onyx": "onyx",
"nova": "nova",
"shimmer": "shimmer",
}
else:
available_voices = {
"alloy": "alloy",
"echo": "echo",
"fable": "fable",
"onyx": "onyx",
"nova": "nova",
"shimmer": "shimmer",
}
elif request.app.state.config.TTS_ENGINE == "elevenlabs":
try:
available_voices = get_elevenlabs_voices(

View File

@ -25,16 +25,13 @@ from open_webui.env import (
WEBUI_AUTH,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
WEBUI_AUTH_TRUSTED_NAME_HEADER,
WEBUI_SESSION_COOKIE_SAME_SITE,
WEBUI_SESSION_COOKIE_SECURE,
WEBUI_AUTH_COOKIE_SAME_SITE,
WEBUI_AUTH_COOKIE_SECURE,
SRC_LOG_LEVELS,
)
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import RedirectResponse, Response
from open_webui.config import (
OPENID_PROVIDER_URL,
ENABLE_OAUTH_SIGNUP,
)
from open_webui.config import OPENID_PROVIDER_URL, ENABLE_OAUTH_SIGNUP, ENABLE_LDAP
from pydantic import BaseModel
from open_webui.utils.misc import parse_duration, validate_email_format
from open_webui.utils.auth import (
@ -51,8 +48,10 @@ from open_webui.utils.access_control import get_permissions
from typing import Optional, List
from ssl import CERT_REQUIRED, PROTOCOL_TLS
from ldap3 import Server, Connection, NONE, Tls
from ldap3.utils.conv import escape_filter_chars
if ENABLE_LDAP.value:
from ldap3 import Server, Connection, NONE, Tls
from ldap3.utils.conv import escape_filter_chars
router = APIRouter()
@ -95,8 +94,8 @@ async def get_session_user(
value=token,
expires=datetime_expires_at,
httponly=True, # Ensures the cookie is not accessible via JavaScript
samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
secure=WEBUI_SESSION_COOKIE_SECURE,
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
secure=WEBUI_AUTH_COOKIE_SECURE,
)
user_permissions = get_permissions(
@ -164,7 +163,7 @@ async def update_password(
############################
# LDAP Authentication
############################
@router.post("/ldap", response_model=SigninResponse)
@router.post("/ldap", response_model=SessionUserResponse)
async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
ENABLE_LDAP = request.app.state.config.ENABLE_LDAP
LDAP_SERVER_LABEL = request.app.state.config.LDAP_SERVER_LABEL
@ -231,9 +230,12 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
entry = connection_app.entries[0]
username = str(entry[f"{LDAP_ATTRIBUTE_FOR_USERNAME}"]).lower()
mail = str(entry[f"{LDAP_ATTRIBUTE_FOR_MAIL}"])
if not mail or mail == "" or mail == "[]":
raise HTTPException(400, f"User {form_data.user} does not have mail.")
email = str(entry[f"{LDAP_ATTRIBUTE_FOR_MAIL}"])
if not email or email == "" or email == "[]":
raise HTTPException(400, f"User {form_data.user} does not have email.")
else:
email = email.lower()
cn = str(entry["cn"])
user_dn = entry.entry_dn
@ -248,17 +250,22 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
if not connection_user.bind():
raise HTTPException(400, f"Authentication failed for {form_data.user}")
user = Users.get_user_by_email(mail)
user = Users.get_user_by_email(email)
if not user:
try:
user_count = Users.get_num_users()
role = (
"admin"
if Users.get_num_users() == 0
if user_count == 0
else request.app.state.config.DEFAULT_USER_ROLE
)
user = Auths.insert_new_auth(
email=mail, password=str(uuid.uuid4()), name=cn, role=role
email=email,
password=str(uuid.uuid4()),
name=cn,
role=role,
)
if not user:
@ -271,7 +278,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
except Exception as err:
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
user = Auths.authenticate_user_by_trusted_header(mail)
user = Auths.authenticate_user_by_trusted_header(email)
if user:
token = create_token(
@ -288,6 +295,10 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
httponly=True, # Ensures the cookie is not accessible via JavaScript
)
user_permissions = get_permissions(
user.id, request.app.state.config.USER_PERMISSIONS
)
return {
"token": token,
"token_type": "Bearer",
@ -296,6 +307,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
"name": user.name,
"role": user.role,
"profile_image_url": user.profile_image_url,
"permissions": user_permissions,
}
else:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
@ -378,8 +390,8 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
value=token,
expires=datetime_expires_at,
httponly=True, # Ensures the cookie is not accessible via JavaScript
samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
secure=WEBUI_SESSION_COOKIE_SECURE,
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
secure=WEBUI_AUTH_COOKIE_SECURE,
)
user_permissions = get_permissions(
@ -408,6 +420,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
@router.post("/signup", response_model=SessionUserResponse)
async def signup(request: Request, response: Response, form_data: SignupForm):
if WEBUI_AUTH:
if (
not request.app.state.config.ENABLE_SIGNUP
@ -422,6 +435,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
)
user_count = Users.get_num_users()
if not validate_email_format(form_data.email.lower()):
raise HTTPException(
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
@ -432,12 +446,10 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
try:
role = (
"admin"
if Users.get_num_users() == 0
else request.app.state.config.DEFAULT_USER_ROLE
"admin" if user_count == 0 else request.app.state.config.DEFAULT_USER_ROLE
)
if Users.get_num_users() == 0:
if user_count == 0:
# Disable signup after the first user is created
request.app.state.config.ENABLE_SIGNUP = False
@ -473,12 +485,13 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
value=token,
expires=datetime_expires_at,
httponly=True, # Ensures the cookie is not accessible via JavaScript
samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
secure=WEBUI_SESSION_COOKIE_SECURE,
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
secure=WEBUI_AUTH_COOKIE_SECURE,
)
if request.app.state.config.WEBHOOK_URL:
post_webhook(
request.app.state.WEBUI_NAME,
request.app.state.config.WEBHOOK_URL,
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
{
@ -525,7 +538,8 @@ async def signout(request: Request, response: Response):
if logout_url:
response.delete_cookie("oauth_id_token")
return RedirectResponse(
url=f"{logout_url}?id_token_hint={oauth_id_token}"
headers=response.headers,
url=f"{logout_url}?id_token_hint={oauth_id_token}",
)
else:
raise HTTPException(
@ -591,7 +605,7 @@ async def get_admin_details(request: Request, user=Depends(get_current_user)):
admin_email = request.app.state.config.ADMIN_EMAIL
admin_name = None
print(admin_email, admin_name)
log.info(f"Admin details - Email: {admin_email}, Name: {admin_name}")
if admin_email:
admin = Users.get_user_by_email(admin_email)

View File

@ -192,7 +192,7 @@ async def get_channel_messages(
############################
async def send_notification(webui_url, channel, message, active_user_ids):
async def send_notification(name, webui_url, channel, message, active_user_ids):
users = get_users_with_access("read", channel.access_control)
for user in users:
@ -206,6 +206,7 @@ async def send_notification(webui_url, channel, message, active_user_ids):
if webhook_url:
post_webhook(
name,
webhook_url,
f"#{channel.name} - {webui_url}/channels/{channel.id}\n\n{message.content}",
{
@ -302,6 +303,7 @@ async def post_new_message(
background_tasks.add_task(
send_notification,
request.app.state.WEBUI_NAME,
request.app.state.config.WEBUI_URL,
channel,
message,

View File

@ -444,15 +444,21 @@ async def pin_chat_by_id(id: str, user=Depends(get_verified_user)):
############################
class CloneForm(BaseModel):
title: Optional[str] = None
@router.post("/{id}/clone", response_model=Optional[ChatResponse])
async def clone_chat_by_id(id: str, user=Depends(get_verified_user)):
async def clone_chat_by_id(
form_data: CloneForm, id: str, user=Depends(get_verified_user)
):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
updated_chat = {
**chat.chat,
"originalChatId": chat.id,
"branchPointMessageId": chat.chat["history"]["currentId"],
"title": f"Clone of {chat.title}",
"title": form_data.title if form_data.title else f"Clone of {chat.title}",
}
chat = Chats.insert_new_chat(user.id, ChatForm(**{"chat": updated_chat}))

View File

@ -36,6 +36,140 @@ async def export_config(user=Depends(get_admin_user)):
return get_config()
############################
# Direct Connections Config
############################
class DirectConnectionsConfigForm(BaseModel):
ENABLE_DIRECT_CONNECTIONS: bool
@router.get("/direct_connections", response_model=DirectConnectionsConfigForm)
async def get_direct_connections_config(request: Request, user=Depends(get_admin_user)):
return {
"ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS,
}
@router.post("/direct_connections", response_model=DirectConnectionsConfigForm)
async def set_direct_connections_config(
request: Request,
form_data: DirectConnectionsConfigForm,
user=Depends(get_admin_user),
):
request.app.state.config.ENABLE_DIRECT_CONNECTIONS = (
form_data.ENABLE_DIRECT_CONNECTIONS
)
return {
"ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS,
}
############################
# CodeInterpreterConfig
############################
class CodeInterpreterConfigForm(BaseModel):
CODE_EXECUTION_ENGINE: str
CODE_EXECUTION_JUPYTER_URL: Optional[str]
CODE_EXECUTION_JUPYTER_AUTH: Optional[str]
CODE_EXECUTION_JUPYTER_AUTH_TOKEN: Optional[str]
CODE_EXECUTION_JUPYTER_AUTH_PASSWORD: Optional[str]
CODE_EXECUTION_JUPYTER_TIMEOUT: Optional[int]
ENABLE_CODE_INTERPRETER: bool
CODE_INTERPRETER_ENGINE: str
CODE_INTERPRETER_PROMPT_TEMPLATE: Optional[str]
CODE_INTERPRETER_JUPYTER_URL: Optional[str]
CODE_INTERPRETER_JUPYTER_AUTH: Optional[str]
CODE_INTERPRETER_JUPYTER_AUTH_TOKEN: Optional[str]
CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD: Optional[str]
CODE_INTERPRETER_JUPYTER_TIMEOUT: Optional[int]
@router.get("/code_execution", response_model=CodeInterpreterConfigForm)
async def get_code_execution_config(request: Request, user=Depends(get_admin_user)):
return {
"CODE_EXECUTION_ENGINE": request.app.state.config.CODE_EXECUTION_ENGINE,
"CODE_EXECUTION_JUPYTER_URL": request.app.state.config.CODE_EXECUTION_JUPYTER_URL,
"CODE_EXECUTION_JUPYTER_AUTH": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH,
"CODE_EXECUTION_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN,
"CODE_EXECUTION_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD,
"CODE_EXECUTION_JUPYTER_TIMEOUT": request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT,
"ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER,
"CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE,
"CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE,
"CODE_INTERPRETER_JUPYTER_URL": request.app.state.config.CODE_INTERPRETER_JUPYTER_URL,
"CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH,
"CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
"CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
"CODE_INTERPRETER_JUPYTER_TIMEOUT": request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT,
}
@router.post("/code_execution", response_model=CodeInterpreterConfigForm)
async def set_code_execution_config(
request: Request, form_data: CodeInterpreterConfigForm, user=Depends(get_admin_user)
):
request.app.state.config.CODE_EXECUTION_ENGINE = form_data.CODE_EXECUTION_ENGINE
request.app.state.config.CODE_EXECUTION_JUPYTER_URL = (
form_data.CODE_EXECUTION_JUPYTER_URL
)
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH = (
form_data.CODE_EXECUTION_JUPYTER_AUTH
)
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN = (
form_data.CODE_EXECUTION_JUPYTER_AUTH_TOKEN
)
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD = (
form_data.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD
)
request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT = (
form_data.CODE_EXECUTION_JUPYTER_TIMEOUT
)
request.app.state.config.ENABLE_CODE_INTERPRETER = form_data.ENABLE_CODE_INTERPRETER
request.app.state.config.CODE_INTERPRETER_ENGINE = form_data.CODE_INTERPRETER_ENGINE
request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE = (
form_data.CODE_INTERPRETER_PROMPT_TEMPLATE
)
request.app.state.config.CODE_INTERPRETER_JUPYTER_URL = (
form_data.CODE_INTERPRETER_JUPYTER_URL
)
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH = (
form_data.CODE_INTERPRETER_JUPYTER_AUTH
)
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = (
form_data.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN
)
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = (
form_data.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD
)
request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT = (
form_data.CODE_INTERPRETER_JUPYTER_TIMEOUT
)
return {
"CODE_EXECUTION_ENGINE": request.app.state.config.CODE_EXECUTION_ENGINE,
"CODE_EXECUTION_JUPYTER_URL": request.app.state.config.CODE_EXECUTION_JUPYTER_URL,
"CODE_EXECUTION_JUPYTER_AUTH": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH,
"CODE_EXECUTION_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN,
"CODE_EXECUTION_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD,
"CODE_EXECUTION_JUPYTER_TIMEOUT": request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT,
"ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER,
"CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE,
"CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE,
"CODE_INTERPRETER_JUPYTER_URL": request.app.state.config.CODE_INTERPRETER_JUPYTER_URL,
"CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH,
"CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
"CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
"CODE_INTERPRETER_JUPYTER_TIMEOUT": request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT,
}
############################
# SetDefaultModels
############################

View File

@ -3,30 +3,23 @@ import os
import uuid
from pathlib import Path
from typing import Optional
from pydantic import BaseModel
import mimetypes
from urllib.parse import quote
from open_webui.storage.provider import Storage
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status
from fastapi.responses import FileResponse, StreamingResponse
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import SRC_LOG_LEVELS
from open_webui.models.files import (
FileForm,
FileModel,
FileModelResponse,
Files,
)
from open_webui.routers.retrieval import process_file, ProcessFileForm
from open_webui.config import UPLOAD_DIR
from open_webui.env import SRC_LOG_LEVELS
from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status, Request
from fastapi.responses import FileResponse, StreamingResponse
from open_webui.routers.retrieval import ProcessFileForm, process_file
from open_webui.routers.audio import transcribe
from open_webui.storage.provider import Storage
from open_webui.utils.auth import get_admin_user, get_verified_user
from pydantic import BaseModel
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
@ -41,7 +34,10 @@ router = APIRouter()
@router.post("/", response_model=FileModelResponse)
def upload_file(
request: Request, file: UploadFile = File(...), user=Depends(get_verified_user)
request: Request,
file: UploadFile = File(...),
user=Depends(get_verified_user),
file_metadata: dict = {},
):
log.info(f"file.content_type: {file.content_type}")
try:
@ -65,13 +61,29 @@ def upload_file(
"name": name,
"content_type": file.content_type,
"size": len(contents),
"data": file_metadata,
},
}
),
)
try:
process_file(request, ProcessFileForm(file_id=id))
if file.content_type in [
"audio/mpeg",
"audio/wav",
"audio/ogg",
"audio/x-m4a",
]:
file_path = Storage.get_file(file_path)
result = transcribe(request, file_path)
process_file(
request,
ProcessFileForm(file_id=id, content=result.get("text", "")),
user=user,
)
else:
process_file(request, ProcessFileForm(file_id=id), user=user)
file_item = Files.get_file_by_id(id=id)
except Exception as e:
log.exception(e)
@ -126,7 +138,7 @@ async def delete_all_files(user=Depends(get_admin_user)):
Storage.delete_all_files()
except Exception as e:
log.exception(e)
log.error(f"Error deleting files")
log.error("Error deleting files")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error deleting files"),
@ -193,7 +205,9 @@ async def update_file_data_content_by_id(
if file and (file.user_id == user.id or user.role == "admin"):
try:
process_file(
request, ProcessFileForm(file_id=id, content=form_data.content)
request,
ProcessFileForm(file_id=id, content=form_data.content),
user=user,
)
file = Files.get_file_by_id(id=id)
except Exception as e:
@ -227,17 +241,24 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
filename = file.meta.get("name", file.filename)
encoded_filename = quote(filename) # RFC5987 encoding
content_type = file.meta.get("content_type")
filename = file.meta.get("name", file.filename)
encoded_filename = quote(filename)
headers = {}
if file.meta.get("content_type") not in [
"application/pdf",
"text/plain",
]:
headers = {
**headers,
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}",
}
return FileResponse(file_path, headers=headers)
if content_type == "application/pdf" or filename.lower().endswith(
".pdf"
):
headers["Content-Disposition"] = (
f"inline; filename*=UTF-8''{encoded_filename}"
)
content_type = "application/pdf"
elif content_type != "text/plain":
headers["Content-Disposition"] = (
f"attachment; filename*=UTF-8''{encoded_filename}"
)
return FileResponse(file_path, headers=headers, media_type=content_type)
else:
raise HTTPException(
@ -246,7 +267,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
)
except Exception as e:
log.exception(e)
log.error(f"Error getting file content")
log.error("Error getting file content")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error getting file content"),
@ -268,7 +289,7 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
# Check if the file already exists in the cache
if file_path.is_file():
print(f"file_path: {file_path}")
log.info(f"file_path: {file_path}")
return FileResponse(file_path)
else:
raise HTTPException(
@ -277,7 +298,7 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
)
except Exception as e:
log.exception(e)
log.error(f"Error getting file content")
log.error("Error getting file content")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error getting file content"),
@ -353,7 +374,7 @@ async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
Storage.delete_file(file.path)
except Exception as e:
log.exception(e)
log.error(f"Error deleting files")
log.error("Error deleting files")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error deleting files"),

View File

@ -1,4 +1,5 @@
import os
import logging
from pathlib import Path
from typing import Optional
@ -13,6 +14,11 @@ from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Request, status
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
router = APIRouter()
@ -68,7 +74,7 @@ async def create_new_function(
function = Functions.insert_new_function(user.id, function_type, form_data)
function_cache_dir = Path(CACHE_DIR) / "functions" / form_data.id
function_cache_dir = CACHE_DIR / "functions" / form_data.id
function_cache_dir.mkdir(parents=True, exist_ok=True)
if function:
@ -79,7 +85,7 @@ async def create_new_function(
detail=ERROR_MESSAGES.DEFAULT("Error creating function"),
)
except Exception as e:
print(e)
log.exception(f"Failed to create a new function: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
@ -183,7 +189,7 @@ async def update_function_by_id(
FUNCTIONS[id] = function_module
updated = {**form_data.model_dump(exclude={"id"}), "type": function_type}
print(updated)
log.debug(updated)
function = Functions.update_function_by_id(id, updated)
@ -299,7 +305,7 @@ async def update_function_valves_by_id(
Functions.update_function_valves_by_id(id, valves.model_dump())
return valves.model_dump()
except Exception as e:
print(e)
log.exception(f"Error updating function values by id {id}: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
@ -388,7 +394,7 @@ async def update_function_user_valves_by_id(
)
return user_valves.model_dump()
except Exception as e:
print(e)
log.exception(f"Error updating function user valves by id {id}: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),

16
backend/open_webui/routers/groups.py Normal file → Executable file
View File

@ -1,7 +1,7 @@
import os
from pathlib import Path
from typing import Optional
import logging
from open_webui.models.users import Users
from open_webui.models.groups import (
@ -14,7 +14,13 @@ from open_webui.models.groups import (
from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Request, status
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
router = APIRouter()
@ -37,7 +43,7 @@ async def get_groups(user=Depends(get_verified_user)):
@router.post("/create", response_model=Optional[GroupResponse])
async def create_new_function(form_data: GroupForm, user=Depends(get_admin_user)):
async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)):
try:
group = Groups.insert_new_group(user.id, form_data)
if group:
@ -48,7 +54,7 @@ async def create_new_function(form_data: GroupForm, user=Depends(get_admin_user)
detail=ERROR_MESSAGES.DEFAULT("Error creating group"),
)
except Exception as e:
print(e)
log.exception(f"Error creating a new group: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
@ -94,7 +100,7 @@ async def update_group_by_id(
detail=ERROR_MESSAGES.DEFAULT("Error updating group"),
)
except Exception as e:
print(e)
log.exception(f"Error updating group {id}: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
@ -118,7 +124,7 @@ async def delete_group_by_id(id: str, user=Depends(get_admin_user)):
detail=ERROR_MESSAGES.DEFAULT("Error deleting group"),
)
except Exception as e:
print(e)
log.exception(f"Error deleting group {id}: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),

View File

@ -1,37 +1,31 @@
import asyncio
import base64
import io
import json
import logging
import mimetypes
import re
import uuid
from pathlib import Path
from typing import Optional
import requests
from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS
from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS
from open_webui.routers.files import upload_file
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.images.comfyui import (
ComfyUIGenerateImageForm,
ComfyUIWorkflow,
comfyui_generate_image,
)
from pydantic import BaseModel
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["IMAGES"])
IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
IMAGE_CACHE_DIR = CACHE_DIR / "image" / "generations"
IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
@ -61,6 +55,10 @@ async def get_config(request: Request, user=Depends(get_admin_user)):
"COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
"COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
},
"gemini": {
"GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
"GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
},
}
@ -84,6 +82,11 @@ class ComfyUIConfigForm(BaseModel):
COMFYUI_WORKFLOW_NODES: list[dict]
class GeminiConfigForm(BaseModel):
GEMINI_API_BASE_URL: str
GEMINI_API_KEY: str
class ConfigForm(BaseModel):
enabled: bool
engine: str
@ -91,6 +94,7 @@ class ConfigForm(BaseModel):
openai: OpenAIConfigForm
automatic1111: Automatic1111ConfigForm
comfyui: ComfyUIConfigForm
gemini: GeminiConfigForm
@router.post("/config/update")
@ -109,6 +113,11 @@ async def update_config(
)
request.app.state.config.IMAGES_OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
request.app.state.config.IMAGES_GEMINI_API_BASE_URL = (
form_data.gemini.GEMINI_API_BASE_URL
)
request.app.state.config.IMAGES_GEMINI_API_KEY = form_data.gemini.GEMINI_API_KEY
request.app.state.config.AUTOMATIC1111_BASE_URL = (
form_data.automatic1111.AUTOMATIC1111_BASE_URL
)
@ -135,6 +144,8 @@ async def update_config(
request.app.state.config.COMFYUI_BASE_URL = (
form_data.comfyui.COMFYUI_BASE_URL.strip("/")
)
request.app.state.config.COMFYUI_API_KEY = form_data.comfyui.COMFYUI_API_KEY
request.app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW
request.app.state.config.COMFYUI_WORKFLOW_NODES = (
form_data.comfyui.COMFYUI_WORKFLOW_NODES
@ -161,6 +172,10 @@ async def update_config(
"COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
"COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
},
"gemini": {
"GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
"GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
},
}
@ -190,9 +205,17 @@ async def verify_url(request: Request, user=Depends(get_admin_user)):
request.app.state.config.ENABLE_IMAGE_GENERATION = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
headers = None
if request.app.state.config.COMFYUI_API_KEY:
headers = {
"Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
}
try:
r = requests.get(
url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info"
url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info",
headers=headers,
)
r.raise_for_status()
return True
@ -230,6 +253,12 @@ def get_image_model(request):
if request.app.state.config.IMAGE_GENERATION_MODEL
else "dall-e-2"
)
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
return (
request.app.state.config.IMAGE_GENERATION_MODEL
if request.app.state.config.IMAGE_GENERATION_MODEL
else "imagen-3.0-generate-002"
)
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
return (
request.app.state.config.IMAGE_GENERATION_MODEL
@ -271,7 +300,6 @@ async def get_image_config(request: Request, user=Depends(get_admin_user)):
async def update_image_config(
request: Request, form_data: ImageConfigForm, user=Depends(get_admin_user)
):
set_image_model(request, form_data.MODEL)
pattern = r"^\d+x\d+$"
@ -306,6 +334,10 @@ def get_models(request: Request, user=Depends(get_verified_user)):
{"id": "dall-e-2", "name": "DALL·E 2"},
{"id": "dall-e-3", "name": "DALL·E 3"},
]
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
return [
{"id": "imagen-3-0-generate-002", "name": "imagen-3.0 generate-002"},
]
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
# TODO - get models from comfyui
headers = {
@ -329,7 +361,7 @@ def get_models(request: Request, user=Depends(get_verified_user)):
if model_node_id:
model_list_key = None
print(workflow[model_node_id]["class_type"])
log.info(workflow[model_node_id]["class_type"])
for key in info[workflow[model_node_id]["class_type"]]["input"][
"required"
]:
@ -383,40 +415,22 @@ class GenerateImageForm(BaseModel):
negative_prompt: Optional[str] = None
def save_b64_image(b64_str):
def load_b64_image_data(b64_str):
try:
image_id = str(uuid.uuid4())
if "," in b64_str:
header, encoded = b64_str.split(",", 1)
mime_type = header.split(";")[0]
img_data = base64.b64decode(encoded)
image_format = mimetypes.guess_extension(mime_type)
image_filename = f"{image_id}{image_format}"
file_path = IMAGE_CACHE_DIR / f"{image_filename}"
with open(file_path, "wb") as f:
f.write(img_data)
return image_filename
else:
image_filename = f"{image_id}.png"
file_path = IMAGE_CACHE_DIR.joinpath(image_filename)
mime_type = "image/png"
img_data = base64.b64decode(b64_str)
# Write the image data to a file
with open(file_path, "wb") as f:
f.write(img_data)
return image_filename
return img_data, mime_type
except Exception as e:
log.exception(f"Error saving image: {e}")
log.exception(f"Error loading image data: {e}")
return None
def save_url_image(url, headers=None):
image_id = str(uuid.uuid4())
def load_url_image_data(url, headers=None):
try:
if headers:
r = requests.get(url, headers=headers)
@ -426,18 +440,7 @@ def save_url_image(url, headers=None):
r.raise_for_status()
if r.headers["content-type"].split("/")[0] == "image":
mime_type = r.headers["content-type"]
image_format = mimetypes.guess_extension(mime_type)
if not image_format:
raise ValueError("Could not determine image type from MIME type")
image_filename = f"{image_id}{image_format}"
file_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}")
with open(file_path, "wb") as image_file:
for chunk in r.iter_content(chunk_size=8192):
image_file.write(chunk)
return image_filename
return r.content, mime_type
else:
log.error("Url does not point to an image.")
return None
@ -447,6 +450,20 @@ def save_url_image(url, headers=None):
return None
def upload_image(request, image_metadata, image_data, content_type, user):
image_format = mimetypes.guess_extension(content_type)
file = UploadFile(
file=io.BytesIO(image_data),
filename=f"generated-image{image_format}", # will be converted to a unique ID on upload_file
headers={
"content-type": content_type,
},
)
file_item = upload_file(request, file, user, file_metadata=image_metadata)
url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
return url
@router.post("/generations")
async def image_generations(
request: Request,
@ -500,12 +517,49 @@ async def image_generations(
images = []
for image in res["data"]:
image_filename = save_b64_image(image["b64_json"])
images.append({"url": f"/cache/image/generations/{image_filename}"})
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
if "url" in image:
image_data, content_type = load_url_image_data(
image["url"], headers
)
else:
image_data, content_type = load_b64_image_data(image["b64_json"])
with open(file_body_path, "w") as f:
json.dump(data, f)
url = upload_image(request, data, image_data, content_type, user)
images.append({"url": url})
return images
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
headers = {}
headers["Content-Type"] = "application/json"
headers["x-goog-api-key"] = request.app.state.config.IMAGES_GEMINI_API_KEY
model = get_image_model(request)
data = {
"instances": {"prompt": form_data.prompt},
"parameters": {
"sampleCount": form_data.n,
"outputOptions": {"mimeType": "image/png"},
},
}
# Use asyncio.to_thread for the requests.post call
r = await asyncio.to_thread(
requests.post,
url=f"{request.app.state.config.IMAGES_GEMINI_API_BASE_URL}/models/{model}:predict",
json=data,
headers=headers,
)
r.raise_for_status()
res = r.json()
images = []
for image in res["predictions"]:
image_data, content_type = load_b64_image_data(
image["bytesBase64Encoded"]
)
url = upload_image(request, data, image_data, content_type, user)
images.append({"url": url})
return images
@ -552,14 +606,15 @@ async def image_generations(
"Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
}
image_filename = save_url_image(image["url"], headers)
images.append({"url": f"/cache/image/generations/{image_filename}"})
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
with open(file_body_path, "w") as f:
json.dump(form_data.model_dump(exclude_none=True), f)
log.debug(f"images: {images}")
image_data, content_type = load_url_image_data(image["url"], headers)
url = upload_image(
request,
form_data.model_dump(exclude_none=True),
image_data,
content_type,
user,
)
images.append({"url": url})
return images
elif (
request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
@ -604,13 +659,15 @@ async def image_generations(
images = []
for image in res["images"]:
image_filename = save_b64_image(image)
images.append({"url": f"/cache/image/generations/{image_filename}"})
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
with open(file_body_path, "w") as f:
json.dump({**data, "info": res["info"]}, f)
image_data, content_type = load_b64_image_data(image)
url = upload_image(
request,
{**data, "info": res["info"]},
image_data,
content_type,
user,
)
images.append({"url": url})
return images
except Exception as e:
error = e

View File

@ -264,7 +264,11 @@ def add_file_to_knowledge_by_id(
detail=ERROR_MESSAGES.NOT_FOUND,
)
if knowledge.user_id != user.id and user.role != "admin":
if (
knowledge.user_id != user.id
and not has_access(user.id, "write", knowledge.access_control)
and user.role != "admin"
):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
@ -285,7 +289,9 @@ def add_file_to_knowledge_by_id(
# Add content to the vector database
try:
process_file(
request, ProcessFileForm(file_id=form_data.file_id, collection_name=id)
request,
ProcessFileForm(file_id=form_data.file_id, collection_name=id),
user=user,
)
except Exception as e:
log.debug(e)
@ -342,7 +348,12 @@ def update_file_from_knowledge_by_id(
detail=ERROR_MESSAGES.NOT_FOUND,
)
if knowledge.user_id != user.id and user.role != "admin":
if (
knowledge.user_id != user.id
and not has_access(user.id, "write", knowledge.access_control)
and user.role != "admin"
):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
@ -363,7 +374,9 @@ def update_file_from_knowledge_by_id(
# Add content to the vector database
try:
process_file(
request, ProcessFileForm(file_id=form_data.file_id, collection_name=id)
request,
ProcessFileForm(file_id=form_data.file_id, collection_name=id),
user=user,
)
except Exception as e:
raise HTTPException(
@ -406,7 +419,11 @@ def remove_file_from_knowledge_by_id(
detail=ERROR_MESSAGES.NOT_FOUND,
)
if knowledge.user_id != user.id and user.role != "admin":
if (
knowledge.user_id != user.id
and not has_access(user.id, "write", knowledge.access_control)
and user.role != "admin"
):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
@ -429,10 +446,6 @@ def remove_file_from_knowledge_by_id(
if VECTOR_DB_CLIENT.has_collection(collection_name=file_collection):
VECTOR_DB_CLIENT.delete_collection(collection_name=file_collection)
# Delete physical file
if file.path:
Storage.delete_file(file.path)
# Delete file from database
Files.delete_file_by_id(form_data.file_id)
@ -484,7 +497,11 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)):
detail=ERROR_MESSAGES.NOT_FOUND,
)
if knowledge.user_id != user.id and user.role != "admin":
if (
knowledge.user_id != user.id
and not has_access(user.id, "write", knowledge.access_control)
and user.role != "admin"
):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
@ -543,7 +560,11 @@ async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)):
detail=ERROR_MESSAGES.NOT_FOUND,
)
if knowledge.user_id != user.id and user.role != "admin":
if (
knowledge.user_id != user.id
and not has_access(user.id, "write", knowledge.access_control)
and user.role != "admin"
):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
@ -582,14 +603,18 @@ def add_files_to_knowledge_batch(
detail=ERROR_MESSAGES.NOT_FOUND,
)
if knowledge.user_id != user.id and user.role != "admin":
if (
knowledge.user_id != user.id
and not has_access(user.id, "write", knowledge.access_control)
and user.role != "admin"
):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
# Get files content
print(f"files/batch/add - {len(form_data)} files")
log.info(f"files/batch/add - {len(form_data)} files")
files: List[FileModel] = []
for form in form_data:
file = Files.get_file_by_id(form.file_id)

View File

@ -57,7 +57,7 @@ async def add_memory(
{
"id": memory.id,
"text": memory.content,
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content),
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content, user),
"metadata": {"created_at": memory.created_at},
}
],
@ -82,7 +82,7 @@ async def query_memory(
):
results = VECTOR_DB_CLIENT.search(
collection_name=f"user-memory-{user.id}",
vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content)],
vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content, user)],
limit=form_data.k,
)
@ -105,7 +105,7 @@ async def reset_memory_from_vector_db(
{
"id": memory.id,
"text": memory.content,
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content),
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content, user),
"metadata": {
"created_at": memory.created_at,
"updated_at": memory.updated_at,
@ -160,7 +160,9 @@ async def update_memory_by_id(
{
"id": memory.id,
"text": memory.content,
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content),
"vector": request.app.state.EMBEDDING_FUNCTION(
memory.content, user
),
"metadata": {
"created_at": memory.created_at,
"updated_at": memory.updated_at,

View File

@ -183,7 +183,11 @@ async def delete_model_by_id(id: str, user=Depends(get_verified_user)):
detail=ERROR_MESSAGES.NOT_FOUND,
)
if model.user_id != user.id and user.role != "admin":
if (
user.role != "admin"
and model.user_id != user.id
and not has_access(user.id, "write", model.access_control)
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,

View File

@ -11,11 +11,14 @@ import re
import time
from typing import Optional, Union
from urllib.parse import urlparse
import aiohttp
from aiocache import cached
import requests
from open_webui.models.users import UserModel
from open_webui.env import (
ENABLE_FORWARD_USER_INFO_HEADERS,
)
from fastapi import (
Depends,
@ -28,7 +31,7 @@ from fastapi import (
)
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, validator
from starlette.background import BackgroundTask
@ -52,7 +55,7 @@ from open_webui.env import (
ENV,
SRC_LOG_LEVELS,
AIOHTTP_CLIENT_TIMEOUT,
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST,
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST,
BYPASS_MODEL_ACCESS_CONTROL,
)
from open_webui.constants import ERROR_MESSAGES
@ -68,12 +71,26 @@ log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
##########################################
async def send_get_request(url, key=None):
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
async def send_get_request(url, key=None, user: UserModel = None):
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
try:
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.get(
url, headers={**({"Authorization": f"Bearer {key}"} if key else {})}
url,
headers={
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
) as response:
return await response.json()
except Exception as e:
@ -98,6 +115,7 @@ async def send_post_request(
stream: bool = True,
key: Optional[str] = None,
content_type: Optional[str] = None,
user: UserModel = None,
):
r = None
@ -112,6 +130,16 @@ async def send_post_request(
headers={
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
)
r.raise_for_status()
@ -188,12 +216,24 @@ async def verify_connection(
key = form_data.key
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
) as session:
try:
async with session.get(
f"{url}/api/version",
headers={**({"Authorization": f"Bearer {key}"} if key else {})},
headers={
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
) as r:
if r.status != 200:
detail = f"HTTP Error: {r.status}"
@ -256,7 +296,7 @@ async def update_config(
@cached(ttl=3)
async def get_all_models(request: Request):
async def get_all_models(request: Request, user: UserModel = None):
log.info("get_all_models()")
if request.app.state.config.ENABLE_OLLAMA_API:
request_tasks = []
@ -264,7 +304,7 @@ async def get_all_models(request: Request):
if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and (
url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support
):
request_tasks.append(send_get_request(f"{url}/api/tags"))
request_tasks.append(send_get_request(f"{url}/api/tags", user=user))
else:
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
str(idx),
@ -277,7 +317,9 @@ async def get_all_models(request: Request):
key = api_config.get("key", None)
if enable:
request_tasks.append(send_get_request(f"{url}/api/tags", key))
request_tasks.append(
send_get_request(f"{url}/api/tags", key, user=user)
)
else:
request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
@ -362,7 +404,7 @@ async def get_ollama_tags(
models = []
if url_idx is None:
models = await get_all_models(request)
models = await get_all_models(request, user=user)
else:
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
@ -372,7 +414,19 @@ async def get_ollama_tags(
r = requests.request(
method="GET",
url=f"{url}/api/tags",
headers={**({"Authorization": f"Bearer {key}"} if key else {})},
headers={
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
)
r.raise_for_status()
@ -395,7 +449,7 @@ async def get_ollama_tags(
)
if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
models["models"] = get_filtered_models(models, user)
models["models"] = await get_filtered_models(models, user)
return models
@ -479,6 +533,7 @@ async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_u
url, {}
), # Legacy support
).get("key", None),
user=user,
)
for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS)
]
@ -511,6 +566,7 @@ async def pull_model(
url=f"{url}/api/pull",
payload=json.dumps(payload),
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
user=user,
)
@ -529,7 +585,7 @@ async def push_model(
user=Depends(get_admin_user),
):
if url_idx is None:
await get_all_models(request)
await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS
if form_data.name in models:
@ -547,6 +603,7 @@ async def push_model(
url=f"{url}/api/push",
payload=form_data.model_dump_json(exclude_none=True).encode(),
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
user=user,
)
@ -573,6 +630,7 @@ async def create_model(
url=f"{url}/api/create",
payload=form_data.model_dump_json(exclude_none=True).encode(),
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
user=user,
)
@ -590,7 +648,7 @@ async def copy_model(
user=Depends(get_admin_user),
):
if url_idx is None:
await get_all_models(request)
await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS
if form_data.source in models:
@ -611,6 +669,16 @@ async def copy_model(
headers={
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
data=form_data.model_dump_json(exclude_none=True).encode(),
)
@ -645,7 +713,7 @@ async def delete_model(
user=Depends(get_admin_user),
):
if url_idx is None:
await get_all_models(request)
await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS
if form_data.name in models:
@ -667,6 +735,16 @@ async def delete_model(
headers={
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
)
r.raise_for_status()
@ -695,7 +773,7 @@ async def delete_model(
async def show_model_info(
request: Request, form_data: ModelNameForm, user=Depends(get_verified_user)
):
await get_all_models(request)
await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS
if form_data.name not in models:
@ -716,6 +794,16 @@ async def show_model_info(
headers={
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
data=form_data.model_dump_json(exclude_none=True).encode(),
)
@ -759,7 +847,7 @@ async def embed(
log.info(f"generate_ollama_batch_embeddings {form_data}")
if url_idx is None:
await get_all_models(request)
await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS
model = form_data.model
@ -785,6 +873,16 @@ async def embed(
headers={
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
data=form_data.model_dump_json(exclude_none=True).encode(),
)
@ -828,7 +926,7 @@ async def embeddings(
log.info(f"generate_ollama_embeddings {form_data}")
if url_idx is None:
await get_all_models(request)
await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS
model = form_data.model
@ -854,6 +952,16 @@ async def embeddings(
headers={
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
data=form_data.model_dump_json(exclude_none=True).encode(),
)
@ -903,7 +1011,7 @@ async def generate_completion(
user=Depends(get_verified_user),
):
if url_idx is None:
await get_all_models(request)
await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS
model = form_data.model
@ -933,23 +1041,39 @@ async def generate_completion(
url=f"{url}/api/generate",
payload=form_data.model_dump_json(exclude_none=True).encode(),
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
user=user,
)
class ChatMessage(BaseModel):
role: str
content: str
content: Optional[str] = None
tool_calls: Optional[list[dict]] = None
images: Optional[list[str]] = None
@validator("content", pre=True)
@classmethod
def check_at_least_one_field(cls, field_value, values, **kwargs):
# Raise an error if both 'content' and 'tool_calls' are None
if field_value is None and (
"tool_calls" not in values or values["tool_calls"] is None
):
raise ValueError(
"At least one of 'content' or 'tool_calls' must be provided"
)
return field_value
class GenerateChatCompletionForm(BaseModel):
model: str
messages: list[ChatMessage]
format: Optional[dict] = None
format: Optional[Union[dict, str]] = None
options: Optional[dict] = None
template: Optional[str] = None
stream: Optional[bool] = True
keep_alive: Optional[Union[int, str]] = None
tools: Optional[list[dict]] = None
async def get_ollama_url(request: Request, model: str, url_idx: Optional[int] = None):
@ -977,6 +1101,7 @@ async def generate_chat_completion(
if BYPASS_MODEL_ACCESS_CONTROL:
bypass_filter = True
metadata = form_data.pop("metadata", None)
try:
form_data = GenerateChatCompletionForm(**form_data)
except Exception as e:
@ -1006,7 +1131,7 @@ async def generate_chat_completion(
payload["options"] = apply_model_params_to_body_ollama(
params, payload["options"]
)
payload = apply_model_system_prompt_to_body(params, payload, user)
payload = apply_model_system_prompt_to_body(params, payload, metadata, user)
# Check if user has access to the model
if not bypass_filter and user.role == "user":
@ -1046,6 +1171,7 @@ async def generate_chat_completion(
stream=form_data.stream,
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
content_type="application/x-ndjson",
user=user,
)
@ -1148,6 +1274,7 @@ async def generate_openai_completion(
payload=json.dumps(payload),
stream=payload.get("stream", False),
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
user=user,
)
@ -1159,6 +1286,8 @@ async def generate_openai_chat_completion(
url_idx: Optional[int] = None,
user=Depends(get_verified_user),
):
metadata = form_data.pop("metadata", None)
try:
completion_form = OpenAIChatCompletionForm(**form_data)
except Exception as e:
@ -1185,7 +1314,7 @@ async def generate_openai_chat_completion(
if params:
payload = apply_model_params_to_body_openai(params, payload)
payload = apply_model_system_prompt_to_body(params, payload, user)
payload = apply_model_system_prompt_to_body(params, payload, metadata, user)
# Check if user has access to the model
if user.role == "user":
@ -1224,6 +1353,7 @@ async def generate_openai_chat_completion(
payload=json.dumps(payload),
stream=payload.get("stream", False),
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
user=user,
)
@ -1237,7 +1367,7 @@ async def get_openai_models(
models = []
if url_idx is None:
model_list = await get_all_models(request)
model_list = await get_all_models(request, user=user)
models = [
{
"id": model["model"],
@ -1405,9 +1535,10 @@ async def download_model(
return None
# TODO: Progress bar does not reflect size & duration of upload.
@router.post("/models/upload")
@router.post("/models/upload/{url_idx}")
def upload_model(
async def upload_model(
request: Request,
file: UploadFile = File(...),
url_idx: Optional[int] = None,
@ -1416,59 +1547,85 @@ def upload_model(
if url_idx is None:
url_idx = 0
ollama_url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
file_path = os.path.join(UPLOAD_DIR, file.filename)
os.makedirs(UPLOAD_DIR, exist_ok=True)
file_path = f"{UPLOAD_DIR}/{file.filename}"
# --- P1: save file locally ---
chunk_size = 1024 * 1024 * 2 # 2 MB chunks
with open(file_path, "wb") as out_f:
while True:
chunk = file.file.read(chunk_size)
# log.info(f"Chunk: {str(chunk)}") # DEBUG
if not chunk:
break
out_f.write(chunk)
# Save file in chunks
with open(file_path, "wb+") as f:
for chunk in file.file:
f.write(chunk)
def file_process_stream():
async def file_process_stream():
nonlocal ollama_url
total_size = os.path.getsize(file_path)
chunk_size = 1024 * 1024
log.info(f"Total Model Size: {str(total_size)}") # DEBUG
# --- P2: SSE progress + calculate sha256 hash ---
file_hash = calculate_sha256(file_path, chunk_size)
log.info(f"Model Hash: {str(file_hash)}") # DEBUG
try:
with open(file_path, "rb") as f:
total = 0
done = False
while not done:
chunk = f.read(chunk_size)
if not chunk:
done = True
continue
total += len(chunk)
progress = round((total / total_size) * 100, 2)
res = {
bytes_read = 0
while chunk := f.read(chunk_size):
bytes_read += len(chunk)
progress = round(bytes_read / total_size * 100, 2)
data_msg = {
"progress": progress,
"total": total_size,
"completed": total,
"completed": bytes_read,
}
yield f"data: {json.dumps(res)}\n\n"
yield f"data: {json.dumps(data_msg)}\n\n"
if done:
f.seek(0)
hashed = calculate_sha256(f)
f.seek(0)
# --- P3: Upload to ollama /api/blobs ---
with open(file_path, "rb") as f:
url = f"{ollama_url}/api/blobs/sha256:{file_hash}"
response = requests.post(url, data=f)
url = f"{ollama_url}/api/blobs/sha256:{hashed}"
response = requests.post(url, data=f)
if response.ok:
log.info(f"Uploaded to /api/blobs") # DEBUG
# Remove local file
os.remove(file_path)
if response.ok:
res = {
"done": done,
"blob": f"sha256:{hashed}",
"name": file.filename,
}
os.remove(file_path)
yield f"data: {json.dumps(res)}\n\n"
else:
raise Exception(
"Ollama: Could not create blob, Please try again."
)
# Create model in ollama
model_name, ext = os.path.splitext(file.filename)
log.info(f"Created Model: {model_name}") # DEBUG
create_payload = {
"model": model_name,
# Reference the file by its original name => the uploaded blob's digest
"files": {file.filename: f"sha256:{file_hash}"},
}
log.info(f"Model Payload: {create_payload}") # DEBUG
# Call ollama /api/create
# https://github.com/ollama/ollama/blob/main/docs/api.md#create-a-model
create_resp = requests.post(
url=f"{ollama_url}/api/create",
headers={"Content-Type": "application/json"},
data=json.dumps(create_payload),
)
if create_resp.ok:
log.info(f"API SUCCESS!") # DEBUG
done_msg = {
"done": True,
"blob": f"sha256:{file_hash}",
"name": file.filename,
"model_created": model_name,
}
yield f"data: {json.dumps(done_msg)}\n\n"
else:
raise Exception(
f"Failed to create model in Ollama. {create_resp.text}"
)
else:
raise Exception("Ollama: Could not create blob, Please try again.")
except Exception as e:
res = {"error": str(e)}

View File

@ -22,10 +22,11 @@ from open_webui.config import (
)
from open_webui.env import (
AIOHTTP_CLIENT_TIMEOUT,
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST,
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST,
ENABLE_FORWARD_USER_INFO_HEADERS,
BYPASS_MODEL_ACCESS_CONTROL,
)
from open_webui.models.users import UserModel
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import ENV, SRC_LOG_LEVELS
@ -51,12 +52,25 @@ log.setLevel(SRC_LOG_LEVELS["OPENAI"])
##########################################
async def send_get_request(url, key=None):
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
async def send_get_request(url, key=None, user: UserModel = None):
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
try:
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.get(
url, headers={**({"Authorization": f"Bearer {key}"} if key else {})}
url,
headers={
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
) as response:
return await response.json()
except Exception as e:
@ -75,18 +89,24 @@ async def cleanup_response(
await session.close()
def openai_o1_handler(payload):
def openai_o1_o3_handler(payload):
"""
Handle O1 specific parameters
Handle o1, o3 specific parameters
"""
if "max_tokens" in payload:
# Remove "max_tokens" from the payload
payload["max_completion_tokens"] = payload["max_tokens"]
del payload["max_tokens"]
# Fix: O1 does not support the "system" parameter, Modify "system" to "user"
# Fix: o1 and o3 do not support the "system" role directly.
# For older models like "o1-mini" or "o1-preview", use role "user".
# For newer o1/o3 models, replace "system" with "developer".
if payload["messages"][0]["role"] == "system":
payload["messages"][0]["role"] = "user"
model_lower = payload["model"].lower()
if model_lower.startswith("o1-mini") or model_lower.startswith("o1-preview"):
payload["messages"][0]["role"] = "user"
else:
payload["messages"][0]["role"] = "developer"
return payload
@ -172,7 +192,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
body = await request.body()
name = hashlib.sha256(body).hexdigest()
SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
SPEECH_CACHE_DIR = CACHE_DIR / "audio" / "speech"
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
@ -247,7 +267,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
async def get_all_models_responses(request: Request) -> list:
async def get_all_models_responses(request: Request, user: UserModel) -> list:
if not request.app.state.config.ENABLE_OPENAI_API:
return []
@ -271,7 +291,9 @@ async def get_all_models_responses(request: Request) -> list:
):
request_tasks.append(
send_get_request(
f"{url}/models", request.app.state.config.OPENAI_API_KEYS[idx]
f"{url}/models",
request.app.state.config.OPENAI_API_KEYS[idx],
user=user,
)
)
else:
@ -291,6 +313,7 @@ async def get_all_models_responses(request: Request) -> list:
send_get_request(
f"{url}/models",
request.app.state.config.OPENAI_API_KEYS[idx],
user=user,
)
)
else:
@ -352,13 +375,13 @@ async def get_filtered_models(models, user):
@cached(ttl=3)
async def get_all_models(request: Request) -> dict[str, list]:
async def get_all_models(request: Request, user: UserModel) -> dict[str, list]:
log.info("get_all_models()")
if not request.app.state.config.ENABLE_OPENAI_API:
return {"data": []}
responses = await get_all_models_responses(request)
responses = await get_all_models_responses(request, user=user)
def extract_data(response):
if response and "data" in response:
@ -418,16 +441,14 @@ async def get_models(
}
if url_idx is None:
models = await get_all_models(request)
models = await get_all_models(request, user=user)
else:
url = request.app.state.config.OPENAI_API_BASE_URLS[url_idx]
key = request.app.state.config.OPENAI_API_KEYS[url_idx]
r = None
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(
total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST
)
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
) as session:
try:
async with session.get(
@ -489,7 +510,7 @@ async def get_models(
raise HTTPException(status_code=500, detail=error_detail)
if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
models["data"] = get_filtered_models(models, user)
models["data"] = await get_filtered_models(models, user)
return models
@ -507,7 +528,7 @@ async def verify_connection(
key = form_data.key
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
) as session:
try:
async with session.get(
@ -515,6 +536,16 @@ async def verify_connection(
headers={
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
},
) as r:
if r.status != 200:
@ -551,9 +582,9 @@ async def generate_chat_completion(
bypass_filter = True
idx = 0
payload = {**form_data}
if "metadata" in payload:
del payload["metadata"]
metadata = payload.pop("metadata", None)
model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id)
@ -566,7 +597,7 @@ async def generate_chat_completion(
params = model_info.params.model_dump()
payload = apply_model_params_to_body_openai(params, payload)
payload = apply_model_system_prompt_to_body(params, payload, user)
payload = apply_model_system_prompt_to_body(params, payload, metadata, user)
# Check if user has access to the model
if not bypass_filter and user.role == "user":
@ -587,7 +618,7 @@ async def generate_chat_completion(
detail="Model not found",
)
await get_all_models(request)
await get_all_models(request, user=user)
model = request.app.state.OPENAI_MODELS.get(model_id)
if model:
idx = model["urlIdx"]
@ -621,10 +652,10 @@ async def generate_chat_completion(
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
key = request.app.state.config.OPENAI_API_KEYS[idx]
# Fix: O1 does not support the "max_tokens" parameter, Modify "max_tokens" to "max_completion_tokens"
is_o1 = payload["model"].lower().startswith("o1-")
if is_o1:
payload = openai_o1_handler(payload)
# Fix: o1,o3 does not support the "max_tokens" parameter, Modify "max_tokens" to "max_completion_tokens"
is_o1_o3 = payload["model"].lower().startswith(("o1", "o3-"))
if is_o1_o3:
payload = openai_o1_o3_handler(payload)
elif "api.openai.com" not in url:
# Remove "max_completion_tokens" from the payload for backward compatibility
if "max_completion_tokens" in payload:
@ -777,7 +808,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
if r is not None:
try:
res = await r.json()
print(res)
log.error(res)
if "error" in res:
detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
except Exception:

View File

@ -9,6 +9,7 @@ from fastapi import (
status,
APIRouter,
)
import aiohttp
import os
import logging
import shutil
@ -56,96 +57,103 @@ def get_sorted_filters(model_id, models):
return sorted_filters
def process_pipeline_inlet_filter(request, payload, user, models):
async def process_pipeline_inlet_filter(request, payload, user, models):
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
model_id = payload["model"]
sorted_filters = get_sorted_filters(model_id, models)
model = models[model_id]
if "pipeline" in model:
sorted_filters.append(model)
for filter in sorted_filters:
r = None
try:
urlIdx = filter["urlIdx"]
async with aiohttp.ClientSession() as session:
for filter in sorted_filters:
urlIdx = filter.get("urlIdx")
if urlIdx is None:
continue
url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
if key == "":
if not key:
continue
headers = {"Authorization": f"Bearer {key}"}
r = requests.post(
f"{url}/{filter['id']}/filter/inlet",
headers=headers,
json={
"user": user,
"body": payload,
},
)
request_data = {
"user": user,
"body": payload,
}
r.raise_for_status()
payload = r.json()
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
if r is not None:
res = r.json()
try:
async with session.post(
f"{url}/{filter['id']}/filter/inlet",
headers=headers,
json=request_data,
) as response:
response.raise_for_status()
payload = await response.json()
except aiohttp.ClientResponseError as e:
res = (
await response.json()
if response.content_type == "application/json"
else {}
)
if "detail" in res:
raise Exception(r.status_code, res["detail"])
raise Exception(response.status, res["detail"])
except Exception as e:
log.exception(f"Connection error: {e}")
return payload
def process_pipeline_outlet_filter(request, payload, user, models):
async def process_pipeline_outlet_filter(request, payload, user, models):
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
model_id = payload["model"]
sorted_filters = get_sorted_filters(model_id, models)
model = models[model_id]
if "pipeline" in model:
sorted_filters = [model] + sorted_filters
for filter in sorted_filters:
r = None
try:
urlIdx = filter["urlIdx"]
async with aiohttp.ClientSession() as session:
for filter in sorted_filters:
urlIdx = filter.get("urlIdx")
if urlIdx is None:
continue
url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
if key != "":
r = requests.post(
if not key:
continue
headers = {"Authorization": f"Bearer {key}"}
request_data = {
"user": user,
"body": payload,
}
try:
async with session.post(
f"{url}/{filter['id']}/filter/outlet",
headers={"Authorization": f"Bearer {key}"},
json={
"user": user,
"body": payload,
},
)
r.raise_for_status()
data = r.json()
payload = data
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
if r is not None:
headers=headers,
json=request_data,
) as response:
response.raise_for_status()
payload = await response.json()
except aiohttp.ClientResponseError as e:
try:
res = r.json()
res = (
await response.json()
if "application/json" in response.content_type
else {}
)
if "detail" in res:
return Exception(r.status_code, res)
raise Exception(response.status, res)
except Exception:
pass
else:
pass
except Exception as e:
log.exception(f"Connection error: {e}")
return payload
@ -161,7 +169,7 @@ router = APIRouter()
@router.get("/list")
async def get_pipelines_list(request: Request, user=Depends(get_admin_user)):
responses = await get_all_models_responses(request)
responses = await get_all_models_responses(request, user)
log.debug(f"get_pipelines_list: get_openai_models_responses returned {responses}")
urlIdxs = [
@ -188,7 +196,7 @@ async def upload_pipeline(
file: UploadFile = File(...),
user=Depends(get_admin_user),
):
print("upload_pipeline", urlIdx, file.filename)
log.info(f"upload_pipeline: urlIdx={urlIdx}, filename={file.filename}")
# Check if the uploaded file is a python file
if not (file.filename and file.filename.endswith(".py")):
raise HTTPException(
@ -223,7 +231,7 @@ async def upload_pipeline(
return {**data}
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
log.exception(f"Connection error: {e}")
detail = None
status_code = status.HTTP_404_NOT_FOUND
@ -274,7 +282,7 @@ async def add_pipeline(
return {**data}
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
log.exception(f"Connection error: {e}")
detail = None
if r is not None:
@ -319,7 +327,7 @@ async def delete_pipeline(
return {**data}
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
log.exception(f"Connection error: {e}")
detail = None
if r is not None:
@ -353,7 +361,7 @@ async def get_pipelines(
return {**data}
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
log.exception(f"Connection error: {e}")
detail = None
if r is not None:
@ -392,7 +400,7 @@ async def get_pipeline_valves(
return {**data}
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
log.exception(f"Connection error: {e}")
detail = None
if r is not None:
@ -432,7 +440,7 @@ async def get_pipeline_valves_spec(
return {**data}
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
log.exception(f"Connection error: {e}")
detail = None
if r is not None:
@ -474,7 +482,7 @@ async def update_pipeline_valves(
return {**data}
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
log.exception(f"Connection error: {e}")
detail = None

View File

@ -147,7 +147,11 @@ async def delete_prompt_by_command(command: str, user=Depends(get_verified_user)
detail=ERROR_MESSAGES.NOT_FOUND,
)
if prompt.user_id != user.id and user.role != "admin":
if (
prompt.user_id != user.id
and not has_access(user.id, "write", prompt.access_control)
and user.role != "admin"
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,

View File

@ -21,6 +21,7 @@ from fastapi import (
APIRouter,
)
from fastapi.middleware.cors import CORSMiddleware
from fastapi.concurrency import run_in_threadpool
from pydantic import BaseModel
import tiktoken
@ -45,17 +46,20 @@ from open_webui.retrieval.web.utils import get_web_loader
from open_webui.retrieval.web.brave import search_brave
from open_webui.retrieval.web.kagi import search_kagi
from open_webui.retrieval.web.mojeek import search_mojeek
from open_webui.retrieval.web.bocha import search_bocha
from open_webui.retrieval.web.duckduckgo import search_duckduckgo
from open_webui.retrieval.web.google_pse import search_google_pse
from open_webui.retrieval.web.jina_search import search_jina
from open_webui.retrieval.web.searchapi import search_searchapi
from open_webui.retrieval.web.serpapi import search_serpapi
from open_webui.retrieval.web.searxng import search_searxng
from open_webui.retrieval.web.serper import search_serper
from open_webui.retrieval.web.serply import search_serply
from open_webui.retrieval.web.serpstack import search_serpstack
from open_webui.retrieval.web.tavily import search_tavily
from open_webui.retrieval.web.bing import search_bing
from open_webui.retrieval.web.exa import search_exa
from open_webui.retrieval.web.perplexity import search_perplexity
from open_webui.retrieval.utils import (
get_embedding_function,
@ -347,11 +351,18 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
return {
"status": True,
"pdf_extract_images": request.app.state.config.PDF_EXTRACT_IMAGES,
"RAG_FULL_CONTEXT": request.app.state.config.RAG_FULL_CONTEXT,
"BYPASS_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL,
"enable_google_drive_integration": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
"enable_onedrive_integration": request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION,
"content_extraction": {
"engine": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
"tika_server_url": request.app.state.config.TIKA_SERVER_URL,
"docling_server_url": request.app.state.config.DOCLING_SERVER_URL,
"document_intelligence_config": {
"endpoint": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
"key": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
},
},
"chunk": {
"text_splitter": request.app.state.config.TEXT_SPLITTER,
@ -368,10 +379,12 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
"proxy_url": request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
},
"web": {
"web_loader_ssl_verification": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
"search": {
"enabled": request.app.state.config.ENABLE_RAG_WEB_SEARCH,
"drive": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
"onedrive": request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION,
"engine": request.app.state.config.RAG_WEB_SEARCH_ENGINE,
"searxng_query_url": request.app.state.config.SEARXNG_QUERY_URL,
"google_pse_api_key": request.app.state.config.GOOGLE_PSE_API_KEY,
@ -379,6 +392,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
"brave_search_api_key": request.app.state.config.BRAVE_SEARCH_API_KEY,
"kagi_search_api_key": request.app.state.config.KAGI_SEARCH_API_KEY,
"mojeek_search_api_key": request.app.state.config.MOJEEK_SEARCH_API_KEY,
"bocha_search_api_key": request.app.state.config.BOCHA_SEARCH_API_KEY,
"serpstack_api_key": request.app.state.config.SERPSTACK_API_KEY,
"serpstack_https": request.app.state.config.SERPSTACK_HTTPS,
"serper_api_key": request.app.state.config.SERPER_API_KEY,
@ -386,11 +400,17 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
"tavily_api_key": request.app.state.config.TAVILY_API_KEY,
"searchapi_api_key": request.app.state.config.SEARCHAPI_API_KEY,
"searchapi_engine": request.app.state.config.SEARCHAPI_ENGINE,
"serpapi_api_key": request.app.state.config.SERPAPI_API_KEY,
"serpapi_engine": request.app.state.config.SERPAPI_ENGINE,
"jina_api_key": request.app.state.config.JINA_API_KEY,
"bing_search_v7_endpoint": request.app.state.config.BING_SEARCH_V7_ENDPOINT,
"bing_search_v7_subscription_key": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
"exa_api_key": request.app.state.config.EXA_API_KEY,
"perplexity_api_key": request.app.state.config.PERPLEXITY_API_KEY,
"result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
"trust_env": request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV,
"concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
"domain_filter_list": request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
},
},
}
@ -401,10 +421,16 @@ class FileConfig(BaseModel):
max_count: Optional[int] = None
class DocumentIntelligenceConfigForm(BaseModel):
endpoint: str
key: str
class ContentExtractionConfig(BaseModel):
engine: str = ""
tika_server_url: Optional[str] = None
docling_server_url: Optional[str] = None
document_intelligence_config: Optional[DocumentIntelligenceConfigForm] = None
class ChunkParamUpdateForm(BaseModel):
@ -428,6 +454,7 @@ class WebSearchConfig(BaseModel):
brave_search_api_key: Optional[str] = None
kagi_search_api_key: Optional[str] = None
mojeek_search_api_key: Optional[str] = None
bocha_search_api_key: Optional[str] = None
serpstack_api_key: Optional[str] = None
serpstack_https: Optional[bool] = None
serper_api_key: Optional[str] = None
@ -435,21 +462,31 @@ class WebSearchConfig(BaseModel):
tavily_api_key: Optional[str] = None
searchapi_api_key: Optional[str] = None
searchapi_engine: Optional[str] = None
serpapi_api_key: Optional[str] = None
serpapi_engine: Optional[str] = None
jina_api_key: Optional[str] = None
bing_search_v7_endpoint: Optional[str] = None
bing_search_v7_subscription_key: Optional[str] = None
exa_api_key: Optional[str] = None
perplexity_api_key: Optional[str] = None
result_count: Optional[int] = None
concurrent_requests: Optional[int] = None
trust_env: Optional[bool] = None
domain_filter_list: Optional[List[str]] = []
class WebConfig(BaseModel):
search: WebSearchConfig
web_loader_ssl_verification: Optional[bool] = None
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION: Optional[bool] = None
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL: Optional[bool] = None
class ConfigUpdateForm(BaseModel):
RAG_FULL_CONTEXT: Optional[bool] = None
BYPASS_EMBEDDING_AND_RETRIEVAL: Optional[bool] = None
pdf_extract_images: Optional[bool] = None
enable_google_drive_integration: Optional[bool] = None
enable_onedrive_integration: Optional[bool] = None
file: Optional[FileConfig] = None
content_extraction: Optional[ContentExtractionConfig] = None
chunk: Optional[ChunkParamUpdateForm] = None
@ -467,18 +504,38 @@ async def update_rag_config(
else request.app.state.config.PDF_EXTRACT_IMAGES
)
request.app.state.config.RAG_FULL_CONTEXT = (
form_data.RAG_FULL_CONTEXT
if form_data.RAG_FULL_CONTEXT is not None
else request.app.state.config.RAG_FULL_CONTEXT
)
request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL = (
form_data.BYPASS_EMBEDDING_AND_RETRIEVAL
if form_data.BYPASS_EMBEDDING_AND_RETRIEVAL is not None
else request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
)
request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = (
form_data.enable_google_drive_integration
if form_data.enable_google_drive_integration is not None
else request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION
)
request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION = (
form_data.enable_onedrive_integration
if form_data.enable_onedrive_integration is not None
else request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION
)
if form_data.file is not None:
request.app.state.config.FILE_MAX_SIZE = form_data.file.max_size
request.app.state.config.FILE_MAX_COUNT = form_data.file.max_count
if form_data.content_extraction is not None:
log.info(f"Updating text settings: {form_data.content_extraction}")
log.info(
f"Updating content extraction: {request.app.state.config.CONTENT_EXTRACTION_ENGINE} to {form_data.content_extraction.engine}"
)
request.app.state.config.CONTENT_EXTRACTION_ENGINE = (
form_data.content_extraction.engine
)
@ -488,6 +545,13 @@ async def update_rag_config(
request.app.state.config.DOCLING_SERVER_URL = (
form_data.content_extraction.docling_server_url
)
if form_data.content_extraction.document_intelligence_config is not None:
request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = (
form_data.content_extraction.document_intelligence_config.endpoint
)
request.app.state.config.DOCUMENT_INTELLIGENCE_KEY = (
form_data.content_extraction.document_intelligence_config.key
)
if form_data.chunk is not None:
request.app.state.config.TEXT_SPLITTER = form_data.chunk.text_splitter
@ -502,11 +566,16 @@ async def update_rag_config(
if form_data.web is not None:
request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
# Note: When UI "Bypass SSL verification for Websites"=True then ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION=False
form_data.web.web_loader_ssl_verification
form_data.web.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
request.app.state.config.ENABLE_RAG_WEB_SEARCH = form_data.web.search.enabled
request.app.state.config.RAG_WEB_SEARCH_ENGINE = form_data.web.search.engine
request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = (
form_data.web.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
)
request.app.state.config.SEARXNG_QUERY_URL = (
form_data.web.search.searxng_query_url
)
@ -525,6 +594,9 @@ async def update_rag_config(
request.app.state.config.MOJEEK_SEARCH_API_KEY = (
form_data.web.search.mojeek_search_api_key
)
request.app.state.config.BOCHA_SEARCH_API_KEY = (
form_data.web.search.bocha_search_api_key
)
request.app.state.config.SERPSTACK_API_KEY = (
form_data.web.search.serpstack_api_key
)
@ -539,6 +611,9 @@ async def update_rag_config(
form_data.web.search.searchapi_engine
)
request.app.state.config.SERPAPI_API_KEY = form_data.web.search.serpapi_api_key
request.app.state.config.SERPAPI_ENGINE = form_data.web.search.serpapi_engine
request.app.state.config.JINA_API_KEY = form_data.web.search.jina_api_key
request.app.state.config.BING_SEARCH_V7_ENDPOINT = (
form_data.web.search.bing_search_v7_endpoint
@ -547,16 +622,30 @@ async def update_rag_config(
form_data.web.search.bing_search_v7_subscription_key
)
request.app.state.config.EXA_API_KEY = form_data.web.search.exa_api_key
request.app.state.config.PERPLEXITY_API_KEY = (
form_data.web.search.perplexity_api_key
)
request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = (
form_data.web.search.result_count
)
request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
form_data.web.search.concurrent_requests
)
request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV = (
form_data.web.search.trust_env
)
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = (
form_data.web.search.domain_filter_list
)
return {
"status": True,
"pdf_extract_images": request.app.state.config.PDF_EXTRACT_IMAGES,
"RAG_FULL_CONTEXT": request.app.state.config.RAG_FULL_CONTEXT,
"BYPASS_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL,
"file": {
"max_size": request.app.state.config.FILE_MAX_SIZE,
"max_count": request.app.state.config.FILE_MAX_COUNT,
@ -565,6 +654,10 @@ async def update_rag_config(
"engine": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
"tika_server_url": request.app.state.config.TIKA_SERVER_URL,
"docling_server_url": request.app.state.config.DOCLING_SERVER_URL,
"document_intelligence_config": {
"endpoint": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
"key": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
},
},
"chunk": {
"text_splitter": request.app.state.config.TEXT_SPLITTER,
@ -577,7 +670,8 @@ async def update_rag_config(
"translation": request.app.state.YOUTUBE_LOADER_TRANSLATION,
},
"web": {
"web_loader_ssl_verification": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
"search": {
"enabled": request.app.state.config.ENABLE_RAG_WEB_SEARCH,
"engine": request.app.state.config.RAG_WEB_SEARCH_ENGINE,
@ -587,18 +681,25 @@ async def update_rag_config(
"brave_search_api_key": request.app.state.config.BRAVE_SEARCH_API_KEY,
"kagi_search_api_key": request.app.state.config.KAGI_SEARCH_API_KEY,
"mojeek_search_api_key": request.app.state.config.MOJEEK_SEARCH_API_KEY,
"bocha_search_api_key": request.app.state.config.BOCHA_SEARCH_API_KEY,
"serpstack_api_key": request.app.state.config.SERPSTACK_API_KEY,
"serpstack_https": request.app.state.config.SERPSTACK_HTTPS,
"serper_api_key": request.app.state.config.SERPER_API_KEY,
"serply_api_key": request.app.state.config.SERPLY_API_KEY,
"serachapi_api_key": request.app.state.config.SEARCHAPI_API_KEY,
"searchapi_engine": request.app.state.config.SEARCHAPI_ENGINE,
"serpapi_api_key": request.app.state.config.SERPAPI_API_KEY,
"serpapi_engine": request.app.state.config.SERPAPI_ENGINE,
"tavily_api_key": request.app.state.config.TAVILY_API_KEY,
"jina_api_key": request.app.state.config.JINA_API_KEY,
"bing_search_v7_endpoint": request.app.state.config.BING_SEARCH_V7_ENDPOINT,
"bing_search_v7_subscription_key": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
"exa_api_key": request.app.state.config.EXA_API_KEY,
"perplexity_api_key": request.app.state.config.PERPLEXITY_API_KEY,
"result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
"concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
"trust_env": request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV,
"domain_filter_list": request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
},
},
}
@ -666,6 +767,7 @@ def save_docs_to_vector_db(
overwrite: bool = False,
split: bool = True,
add: bool = False,
user=None,
) -> bool:
def _get_docs_info(docs: list[Document]) -> str:
docs_info = set()
@ -746,7 +848,11 @@ def save_docs_to_vector_db(
# for meta-data so convert them to string.
for metadata in metadatas:
for key, value in metadata.items():
if isinstance(value, datetime):
if (
isinstance(value, datetime)
or isinstance(value, list)
or isinstance(value, dict)
):
metadata[key] = str(value)
try:
@ -781,7 +887,7 @@ def save_docs_to_vector_db(
)
embeddings = embedding_function(
list(map(lambda x: x.replace("\n", " "), texts))
list(map(lambda x: x.replace("\n", " "), texts)), user=user
)
items = [
@ -829,7 +935,12 @@ def process_file(
# Update the content in the file
# Usage: /files/{file_id}/data/content/update
VECTOR_DB_CLIENT.delete_collection(collection_name=f"file-{file.id}")
try:
# /files/{file_id}/data/content/update
VECTOR_DB_CLIENT.delete_collection(collection_name=f"file-{file.id}")
except:
# Audio file upload pipeline
pass
docs = [
Document(
@ -887,6 +998,8 @@ def process_file(
TIKA_SERVER_URL=request.app.state.config.TIKA_SERVER_URL,
DOCLING_SERVER_URL=request.app.state.config.DOCLING_SERVER_URL,
PDF_EXTRACT_IMAGES=request.app.state.config.PDF_EXTRACT_IMAGES,
DOCUMENT_INTELLIGENCE_ENDPOINT=request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
DOCUMENT_INTELLIGENCE_KEY=request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
)
docs = loader.load(
file.filename, file.meta.get("content_type"), file_path
@ -929,35 +1042,45 @@ def process_file(
hash = calculate_sha256_string(text_content)
Files.update_file_hash_by_id(file.id, hash)
try:
result = save_docs_to_vector_db(
request,
docs=docs,
collection_name=collection_name,
metadata={
"file_id": file.id,
"name": file.filename,
"hash": hash,
},
add=(True if form_data.collection_name else False),
)
if result:
Files.update_file_metadata_by_id(
file.id,
{
"collection_name": collection_name,
if not request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL:
try:
result = save_docs_to_vector_db(
request,
docs=docs,
collection_name=collection_name,
metadata={
"file_id": file.id,
"name": file.filename,
"hash": hash,
},
add=(True if form_data.collection_name else False),
user=user,
)
return {
"status": True,
"collection_name": collection_name,
"filename": file.filename,
"content": text_content,
}
except Exception as e:
raise e
if result:
Files.update_file_metadata_by_id(
file.id,
{
"collection_name": collection_name,
},
)
return {
"status": True,
"collection_name": collection_name,
"filename": file.filename,
"content": text_content,
}
except Exception as e:
raise e
else:
return {
"status": True,
"collection_name": None,
"filename": file.filename,
"content": text_content,
}
except Exception as e:
log.exception(e)
if "No pandoc was found" in str(e):
@ -997,7 +1120,7 @@ def process_text(
text_content = form_data.content
log.debug(f"text_content: {text_content}")
result = save_docs_to_vector_db(request, docs, collection_name)
result = save_docs_to_vector_db(request, docs, collection_name, user=user)
if result:
return {
"status": True,
@ -1030,7 +1153,9 @@ def process_youtube_video(
content = " ".join([doc.page_content for doc in docs])
log.debug(f"text_content: {content}")
save_docs_to_vector_db(request, docs, collection_name, overwrite=True)
save_docs_to_vector_db(
request, docs, collection_name, overwrite=True, user=user
)
return {
"status": True,
@ -1071,7 +1196,13 @@ def process_web(
content = " ".join([doc.page_content for doc in docs])
log.debug(f"text_content: {content}")
save_docs_to_vector_db(request, docs, collection_name, overwrite=True)
if not request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL:
save_docs_to_vector_db(
request, docs, collection_name, overwrite=True, user=user
)
else:
collection_name = None
return {
"status": True,
@ -1083,6 +1214,7 @@ def process_web(
},
"meta": {
"name": form_data.url,
"source": form_data.url,
},
},
}
@ -1102,11 +1234,15 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
- BRAVE_SEARCH_API_KEY
- KAGI_SEARCH_API_KEY
- MOJEEK_SEARCH_API_KEY
- BOCHA_SEARCH_API_KEY
- SERPSTACK_API_KEY
- SERPER_API_KEY
- SERPLY_API_KEY
- TAVILY_API_KEY
- EXA_API_KEY
- PERPLEXITY_API_KEY
- SEARCHAPI_API_KEY + SEARCHAPI_ENGINE (by default `google`)
- SERPAPI_API_KEY + SERPAPI_ENGINE (by default `google`)
Args:
query (str): The query to search for
"""
@ -1168,6 +1304,16 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
)
else:
raise Exception("No MOJEEK_SEARCH_API_KEY found in environment variables")
elif engine == "bocha":
if request.app.state.config.BOCHA_SEARCH_API_KEY:
return search_bocha(
request.app.state.config.BOCHA_SEARCH_API_KEY,
query,
request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
)
else:
raise Exception("No BOCHA_SEARCH_API_KEY found in environment variables")
elif engine == "serpstack":
if request.app.state.config.SERPSTACK_API_KEY:
return search_serpstack(
@ -1211,6 +1357,7 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
request.app.state.config.TAVILY_API_KEY,
query,
request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
)
else:
raise Exception("No TAVILY_API_KEY found in environment variables")
@ -1225,6 +1372,17 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
)
else:
raise Exception("No SEARCHAPI_API_KEY found in environment variables")
elif engine == "serpapi":
if request.app.state.config.SERPAPI_API_KEY:
return search_serpapi(
request.app.state.config.SERPAPI_API_KEY,
request.app.state.config.SERPAPI_ENGINE,
query,
request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
)
else:
raise Exception("No SERPAPI_API_KEY found in environment variables")
elif engine == "jina":
return search_jina(
request.app.state.config.JINA_API_KEY,
@ -1240,12 +1398,26 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
)
elif engine == "exa":
return search_exa(
request.app.state.config.EXA_API_KEY,
query,
request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
)
elif engine == "perplexity":
return search_perplexity(
request.app.state.config.PERPLEXITY_API_KEY,
query,
request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
)
else:
raise Exception("No search engine API key found in environment variables")
@router.post("/process/web/search")
def process_web_search(
async def process_web_search(
request: Request, form_data: SearchForm, user=Depends(get_verified_user)
):
try:
@ -1277,15 +1449,40 @@ def process_web_search(
urls,
verify_ssl=request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
trust_env=request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV,
)
docs = loader.load()
save_docs_to_vector_db(request, docs, collection_name, overwrite=True)
docs = await loader.aload()
return {
"status": True,
"collection_name": collection_name,
"filenames": urls,
}
if request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL:
return {
"status": True,
"collection_name": None,
"filenames": urls,
"docs": [
{
"content": doc.page_content,
"metadata": doc.metadata,
}
for doc in docs
],
"loaded_count": len(docs),
}
else:
await run_in_threadpool(
save_docs_to_vector_db,
request,
docs,
collection_name,
overwrite=True,
user=user,
)
return {
"status": True,
"collection_name": collection_name,
"filenames": urls,
"loaded_count": len(docs),
}
except Exception as e:
log.exception(e)
raise HTTPException(
@ -1313,7 +1510,9 @@ def query_doc_handler(
return query_doc_with_hybrid_search(
collection_name=form_data.collection_name,
query=form_data.query,
embedding_function=request.app.state.EMBEDDING_FUNCTION,
embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(
query, user=user
),
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
reranking_function=request.app.state.rf,
r=(
@ -1321,12 +1520,16 @@ def query_doc_handler(
if form_data.r
else request.app.state.config.RELEVANCE_THRESHOLD
),
user=user,
)
else:
return query_doc(
collection_name=form_data.collection_name,
query_embedding=request.app.state.EMBEDDING_FUNCTION(form_data.query),
query_embedding=request.app.state.EMBEDDING_FUNCTION(
form_data.query, user=user
),
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
user=user,
)
except Exception as e:
log.exception(e)
@ -1355,7 +1558,9 @@ def query_collection_handler(
return query_collection_with_hybrid_search(
collection_names=form_data.collection_names,
queries=[form_data.query],
embedding_function=request.app.state.EMBEDDING_FUNCTION,
embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(
query, user=user
),
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
reranking_function=request.app.state.rf,
r=(
@ -1368,7 +1573,9 @@ def query_collection_handler(
return query_collection(
collection_names=form_data.collection_names,
queries=[form_data.query],
embedding_function=request.app.state.EMBEDDING_FUNCTION,
embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(
query, user=user
),
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
)
@ -1432,11 +1639,11 @@ def reset_upload_dir(user=Depends(get_admin_user)) -> bool:
elif os.path.isdir(file_path):
shutil.rmtree(file_path) # Remove the directory
except Exception as e:
print(f"Failed to delete {file_path}. Reason: {e}")
log.exception(f"Failed to delete {file_path}. Reason: {e}")
else:
print(f"The directory {folder} does not exist")
log.warning(f"The directory {folder} does not exist")
except Exception as e:
print(f"Failed to process the directory {folder}. Reason: {e}")
log.exception(f"Failed to process the directory {folder}. Reason: {e}")
return True
@ -1516,6 +1723,7 @@ def process_files_batch(
docs=all_docs,
collection_name=collection_name,
add=True,
user=user,
)
# Update all files with collection name

View File

@ -4,6 +4,7 @@ from fastapi.responses import JSONResponse, RedirectResponse
from pydantic import BaseModel
from typing import Optional
import logging
import re
from open_webui.utils.chat import generate_chat_completion
from open_webui.utils.task import (
@ -19,6 +20,10 @@ from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.constants import TASKS
from open_webui.routers.pipelines import process_pipeline_inlet_filter
from open_webui.utils.filter import (
get_sorted_filter_ids,
process_filter_functions,
)
from open_webui.utils.task import get_task_model_id
from open_webui.config import (
@ -57,6 +62,7 @@ async def get_task_config(request: Request, user=Depends(get_verified_user)):
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
"TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
"ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
"ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION,
"ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
"ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
"QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
@ -67,6 +73,7 @@ async def get_task_config(request: Request, user=Depends(get_verified_user)):
class TaskConfigForm(BaseModel):
TASK_MODEL: Optional[str]
TASK_MODEL_EXTERNAL: Optional[str]
ENABLE_TITLE_GENERATION: bool
TITLE_GENERATION_PROMPT_TEMPLATE: str
IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE: str
ENABLE_AUTOCOMPLETE_GENERATION: bool
@ -85,10 +92,15 @@ async def update_task_config(
):
request.app.state.config.TASK_MODEL = form_data.TASK_MODEL
request.app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL
request.app.state.config.ENABLE_TITLE_GENERATION = form_data.ENABLE_TITLE_GENERATION
request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = (
form_data.TITLE_GENERATION_PROMPT_TEMPLATE
)
request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = (
form_data.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
)
request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = (
form_data.ENABLE_AUTOCOMPLETE_GENERATION
)
@ -117,6 +129,7 @@ async def update_task_config(
return {
"TASK_MODEL": request.app.state.config.TASK_MODEL,
"TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
"ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION,
"TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
"IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
"ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
@ -134,7 +147,19 @@ async def update_task_config(
async def generate_title(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
models = request.app.state.MODELS
if not request.app.state.config.ENABLE_TITLE_GENERATION:
return JSONResponse(
status_code=status.HTTP_200_OK,
content={"detail": "Title generation is disabled"},
)
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
@ -161,9 +186,20 @@ async def generate_title(
else:
template = DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE
messages = form_data["messages"]
# Remove reasoning details from the messages
for message in messages:
message["content"] = re.sub(
r"<details\s+type=\"reasoning\"[^>]*>.*?<\/details>",
"",
message["content"],
flags=re.S,
).strip()
content = title_generation_template(
template,
form_data["messages"],
messages,
{
"name": user.name,
"location": user.info.get("location") if user.info else None,
@ -175,19 +211,26 @@ async def generate_title(
"messages": [{"role": "user", "content": content}],
"stream": False,
**(
{"max_tokens": 50}
if models[task_model_id]["owned_by"] == "ollama"
{"max_tokens": 1000}
if models[task_model_id].get("owned_by") == "ollama"
else {
"max_completion_tokens": 50,
"max_completion_tokens": 1000,
}
),
"metadata": {
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
"task": str(TASKS.TITLE_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
},
}
# Process the payload through the pipeline
try:
payload = await process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e:
raise e
try:
return await generate_chat_completion(request, form_data=payload, user=user)
except Exception as e:
@ -209,7 +252,12 @@ async def generate_chat_tags(
content={"detail": "Tags generation is disabled"},
)
models = request.app.state.MODELS
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
@ -245,12 +293,19 @@ async def generate_chat_tags(
"messages": [{"role": "user", "content": content}],
"stream": False,
"metadata": {
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
"task": str(TASKS.TAGS_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
},
}
# Process the payload through the pipeline
try:
payload = await process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e:
raise e
try:
return await generate_chat_completion(request, form_data=payload, user=user)
except Exception as e:
@ -265,7 +320,12 @@ async def generate_chat_tags(
async def generate_image_prompt(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
models = request.app.state.MODELS
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
@ -305,12 +365,19 @@ async def generate_image_prompt(
"messages": [{"role": "user", "content": content}],
"stream": False,
"metadata": {
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
"task": str(TASKS.IMAGE_PROMPT_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
},
}
# Process the payload through the pipeline
try:
payload = await process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e:
raise e
try:
return await generate_chat_completion(request, form_data=payload, user=user)
except Exception as e:
@ -340,7 +407,12 @@ async def generate_queries(
detail=f"Query generation is disabled",
)
models = request.app.state.MODELS
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
@ -376,12 +448,19 @@ async def generate_queries(
"messages": [{"role": "user", "content": content}],
"stream": False,
"metadata": {
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
"task": str(TASKS.QUERY_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
},
}
# Process the payload through the pipeline
try:
payload = await process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e:
raise e
try:
return await generate_chat_completion(request, form_data=payload, user=user)
except Exception as e:
@ -415,7 +494,12 @@ async def generate_autocompletion(
detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}",
)
models = request.app.state.MODELS
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
@ -451,12 +535,19 @@ async def generate_autocompletion(
"messages": [{"role": "user", "content": content}],
"stream": False,
"metadata": {
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
"task": str(TASKS.AUTOCOMPLETE_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
},
}
# Process the payload through the pipeline
try:
payload = await process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e:
raise e
try:
return await generate_chat_completion(request, form_data=payload, user=user)
except Exception as e:
@ -472,7 +563,12 @@ async def generate_emoji(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
models = request.app.state.MODELS
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
@ -509,15 +605,25 @@ async def generate_emoji(
"stream": False,
**(
{"max_tokens": 4}
if models[task_model_id]["owned_by"] == "ollama"
if models[task_model_id].get("owned_by") == "ollama"
else {
"max_completion_tokens": 4,
}
),
"chat_id": form_data.get("chat_id", None),
"metadata": {"task": str(TASKS.EMOJI_GENERATION), "task_body": form_data},
"metadata": {
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
"task": str(TASKS.EMOJI_GENERATION),
"task_body": form_data,
},
}
# Process the payload through the pipeline
try:
payload = await process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e:
raise e
try:
return await generate_chat_completion(request, form_data=payload, user=user)
except Exception as e:
@ -532,7 +638,13 @@ async def generate_moa_response(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
models = request.app.state.MODELS
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
@ -565,12 +677,19 @@ async def generate_moa_response(
"messages": [{"role": "user", "content": content}],
"stream": form_data.get("stream", False),
"metadata": {
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
"chat_id": form_data.get("chat_id", None),
"task": str(TASKS.MOA_RESPONSE_GENERATION),
"task_body": form_data,
},
}
# Process the payload through the pipeline
try:
payload = await process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e:
raise e
try:
return await generate_chat_completion(request, form_data=payload, user=user)
except Exception as e:

View File

@ -1,3 +1,4 @@
import logging
from pathlib import Path
from typing import Optional
@ -15,6 +16,10 @@ from fastapi import APIRouter, Depends, HTTPException, Request, status
from open_webui.utils.tools import get_tools_specs
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_access, has_permission
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
router = APIRouter()
@ -100,7 +105,7 @@ async def create_new_tools(
specs = get_tools_specs(TOOLS[form_data.id])
tools = Tools.insert_new_tool(user.id, form_data, specs)
tool_cache_dir = Path(CACHE_DIR) / "tools" / form_data.id
tool_cache_dir = CACHE_DIR / "tools" / form_data.id
tool_cache_dir.mkdir(parents=True, exist_ok=True)
if tools:
@ -111,7 +116,7 @@ async def create_new_tools(
detail=ERROR_MESSAGES.DEFAULT("Error creating tools"),
)
except Exception as e:
print(e)
log.exception(f"Failed to load the tool by id {form_data.id}: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(str(e)),
@ -193,7 +198,7 @@ async def update_tools_by_id(
"specs": specs,
}
print(updated)
log.debug(updated)
tools = Tools.update_tool_by_id(id, updated)
if tools:
@ -227,7 +232,11 @@ async def delete_tools_by_id(
detail=ERROR_MESSAGES.NOT_FOUND,
)
if tools.user_id != user.id and user.role != "admin":
if (
tools.user_id != user.id
and not has_access(user.id, "write", tools.access_control)
and user.role != "admin"
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
@ -339,7 +348,7 @@ async def update_tools_valves_by_id(
Tools.update_tool_valves_by_id(id, valves.model_dump())
return valves.model_dump()
except Exception as e:
print(e)
log.exception(f"Failed to update tool valves by id {id}: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(str(e)),
@ -417,7 +426,7 @@ async def update_tools_user_valves_by_id(
)
return user_valves.model_dump()
except Exception as e:
print(e)
log.exception(f"Failed to update user valves by id {id}: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(str(e)),

View File

@ -79,6 +79,7 @@ class ChatPermissions(BaseModel):
class FeaturesPermissions(BaseModel):
web_search: bool = True
image_generation: bool = True
code_interpreter: bool = True
class UserPermissions(BaseModel):
@ -152,7 +153,7 @@ async def get_user_settings_by_session_user(user=Depends(get_verified_user)):
async def update_user_settings_by_session_user(
form_data: UserSettings, user=Depends(get_verified_user)
):
user = Users.update_user_by_id(user.id, {"settings": form_data.model_dump()})
user = Users.update_user_settings_by_id(user.id, form_data.model_dump())
if user:
return user.settings
else:

View File

@ -1,48 +1,84 @@
import black
import logging
import markdown
from open_webui.models.chats import ChatTitleMessagesForm
from open_webui.config import DATA_DIR, ENABLE_ADMIN_EXPORT
from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Response, status
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from pydantic import BaseModel
from starlette.responses import FileResponse
from open_webui.utils.misc import get_gravatar_url
from open_webui.utils.pdf_generator import PDFGenerator
from open_webui.utils.auth import get_admin_user
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.code_interpreter import execute_code_jupyter
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
router = APIRouter()
@router.get("/gravatar")
async def get_gravatar(
email: str,
):
async def get_gravatar(email: str, user=Depends(get_verified_user)):
return get_gravatar_url(email)
class CodeFormatRequest(BaseModel):
class CodeForm(BaseModel):
code: str
@router.post("/code/format")
async def format_code(request: CodeFormatRequest):
async def format_code(form_data: CodeForm, user=Depends(get_verified_user)):
try:
formatted_code = black.format_str(request.code, mode=black.Mode())
formatted_code = black.format_str(form_data.code, mode=black.Mode())
return {"code": formatted_code}
except black.NothingChanged:
return {"code": request.code}
return {"code": form_data.code}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@router.post("/code/execute")
async def execute_code(
request: Request, form_data: CodeForm, user=Depends(get_verified_user)
):
if request.app.state.config.CODE_EXECUTION_ENGINE == "jupyter":
output = await execute_code_jupyter(
request.app.state.config.CODE_EXECUTION_JUPYTER_URL,
form_data.code,
(
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN
if request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH == "token"
else None
),
(
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD
if request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH == "password"
else None
),
request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT,
)
return output
else:
raise HTTPException(
status_code=400,
detail="Code execution engine not supported",
)
class MarkdownForm(BaseModel):
md: str
@router.post("/markdown")
async def get_html_from_markdown(
form_data: MarkdownForm,
form_data: MarkdownForm, user=Depends(get_verified_user)
):
return {"html": markdown.markdown(form_data.md)}
@ -54,7 +90,7 @@ class ChatForm(BaseModel):
@router.post("/pdf")
async def download_chat_as_pdf(
form_data: ChatTitleMessagesForm,
form_data: ChatTitleMessagesForm, user=Depends(get_verified_user)
):
try:
pdf_bytes = PDFGenerator(form_data).generate_chat_pdf()
@ -65,7 +101,7 @@ async def download_chat_as_pdf(
headers={"Content-Disposition": "attachment;filename=chat.pdf"},
)
except Exception as e:
print(e)
log.exception(f"Error generating PDF: {e}")
raise HTTPException(status_code=400, detail=str(e))

View File

@ -12,6 +12,7 @@ from open_webui.env import (
ENABLE_WEBSOCKET_SUPPORT,
WEBSOCKET_MANAGER,
WEBSOCKET_REDIS_URL,
WEBSOCKET_REDIS_LOCK_TIMEOUT,
)
from open_webui.utils.auth import decode_token
from open_webui.socket.utils import RedisDict, RedisLock
@ -61,7 +62,7 @@ if WEBSOCKET_MANAGER == "redis":
clean_up_lock = RedisLock(
redis_url=WEBSOCKET_REDIS_URL,
lock_name="usage_cleanup_lock",
timeout_secs=TIMEOUT_DURATION * 2,
timeout_secs=WEBSOCKET_REDIS_LOCK_TIMEOUT,
)
aquire_func = clean_up_lock.aquire_lock
renew_func = clean_up_lock.renew_lock
@ -279,8 +280,8 @@ def get_event_emitter(request_info):
await sio.emit(
"chat-events",
{
"chat_id": request_info["chat_id"],
"message_id": request_info["message_id"],
"chat_id": request_info.get("chat_id", None),
"message_id": request_info.get("message_id", None),
"data": event_data,
},
to=session_id,
@ -325,19 +326,22 @@ def get_event_emitter(request_info):
def get_event_call(request_info):
async def __event_call__(event_data):
async def __event_caller__(event_data):
response = await sio.call(
"chat-events",
{
"chat_id": request_info["chat_id"],
"message_id": request_info["message_id"],
"chat_id": request_info.get("chat_id", None),
"message_id": request_info.get("message_id", None),
"data": event_data,
},
to=request_info["session_id"],
)
return response
return __event_call__
return __event_caller__
get_event_caller = get_event_call
def get_user_id_from_session_pool(sid):

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 15 KiB

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 14 KiB

View File

View File

@ -0,0 +1,21 @@
{
"name": "Open WebUI",
"short_name": "WebUI",
"icons": [
{
"src": "/static/web-app-manifest-192x192.png",
"sizes": "192x192",
"type": "image/png",
"purpose": "maskable"
},
{
"src": "/static/web-app-manifest-512x512.png",
"sizes": "512x512",
"type": "image/png",
"purpose": "maskable"
}
],
"theme_color": "#ffffff",
"background_color": "#ffffff",
"display": "standalone"
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.3 KiB

View File

@ -9308,5 +9308,3 @@
.json-schema-2020-12__title:first-of-type {
font-size: 16px;
}
/*# sourceMappingURL=swagger-ui.css.map*/

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 29 KiB

View File

@ -1,25 +1,41 @@
import os
import shutil
import json
import logging
from abc import ABC, abstractmethod
from typing import BinaryIO, Tuple
import boto3
from botocore.config import Config
from botocore.exceptions import ClientError
from open_webui.config import (
S3_ACCESS_KEY_ID,
S3_BUCKET_NAME,
S3_ENDPOINT_URL,
S3_KEY_PREFIX,
S3_REGION_NAME,
S3_SECRET_ACCESS_KEY,
S3_USE_ACCELERATE_ENDPOINT,
S3_ADDRESSING_STYLE,
GCS_BUCKET_NAME,
GOOGLE_APPLICATION_CREDENTIALS_JSON,
AZURE_STORAGE_ENDPOINT,
AZURE_STORAGE_CONTAINER_NAME,
AZURE_STORAGE_KEY,
STORAGE_PROVIDER,
UPLOAD_DIR,
)
from google.cloud import storage
from google.cloud.exceptions import GoogleCloudError, NotFound
from open_webui.constants import ERROR_MESSAGES
from azure.identity import DefaultAzureCredential
from azure.storage.blob import BlobServiceClient
from azure.core.exceptions import ResourceNotFoundError
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
class StorageProvider(ABC):
@ -64,7 +80,7 @@ class LocalStorageProvider(StorageProvider):
if os.path.isfile(file_path):
os.remove(file_path)
else:
print(f"File {file_path} not found in local storage.")
log.warning(f"File {file_path} not found in local storage.")
@staticmethod
def delete_all_files() -> None:
@ -78,30 +94,52 @@ class LocalStorageProvider(StorageProvider):
elif os.path.isdir(file_path):
shutil.rmtree(file_path) # Remove the directory
except Exception as e:
print(f"Failed to delete {file_path}. Reason: {e}")
log.exception(f"Failed to delete {file_path}. Reason: {e}")
else:
print(f"Directory {UPLOAD_DIR} not found in local storage.")
log.warning(f"Directory {UPLOAD_DIR} not found in local storage.")
class S3StorageProvider(StorageProvider):
def __init__(self):
self.s3_client = boto3.client(
"s3",
region_name=S3_REGION_NAME,
endpoint_url=S3_ENDPOINT_URL,
aws_access_key_id=S3_ACCESS_KEY_ID,
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
config = Config(
s3={
"use_accelerate_endpoint": S3_USE_ACCELERATE_ENDPOINT,
"addressing_style": S3_ADDRESSING_STYLE,
},
)
# If access key and secret are provided, use them for authentication
if S3_ACCESS_KEY_ID and S3_SECRET_ACCESS_KEY:
self.s3_client = boto3.client(
"s3",
region_name=S3_REGION_NAME,
endpoint_url=S3_ENDPOINT_URL,
aws_access_key_id=S3_ACCESS_KEY_ID,
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
config=config,
)
else:
# If no explicit credentials are provided, fall back to default AWS credentials
# This supports workload identity (IAM roles for EC2, EKS, etc.)
self.s3_client = boto3.client(
"s3",
region_name=S3_REGION_NAME,
endpoint_url=S3_ENDPOINT_URL,
config=config,
)
self.bucket_name = S3_BUCKET_NAME
self.key_prefix = S3_KEY_PREFIX if S3_KEY_PREFIX else ""
def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
"""Handles uploading of the file to S3 storage."""
_, file_path = LocalStorageProvider.upload_file(file, filename)
try:
self.s3_client.upload_file(file_path, self.bucket_name, filename)
s3_key = os.path.join(self.key_prefix, filename)
self.s3_client.upload_file(file_path, self.bucket_name, s3_key)
return (
open(file_path, "rb").read(),
"s3://" + self.bucket_name + "/" + filename,
"s3://" + self.bucket_name + "/" + s3_key,
)
except ClientError as e:
raise RuntimeError(f"Error uploading file to S3: {e}")
@ -109,18 +147,18 @@ class S3StorageProvider(StorageProvider):
def get_file(self, file_path: str) -> str:
"""Handles downloading of the file from S3 storage."""
try:
bucket_name, key = file_path.split("//")[1].split("/")
local_file_path = f"{UPLOAD_DIR}/{key}"
self.s3_client.download_file(bucket_name, key, local_file_path)
s3_key = self._extract_s3_key(file_path)
local_file_path = self._get_local_file_path(s3_key)
self.s3_client.download_file(self.bucket_name, s3_key, local_file_path)
return local_file_path
except ClientError as e:
raise RuntimeError(f"Error downloading file from S3: {e}")
def delete_file(self, file_path: str) -> None:
"""Handles deletion of the file from S3 storage."""
filename = file_path.split("/")[-1]
try:
self.s3_client.delete_object(Bucket=self.bucket_name, Key=filename)
s3_key = self._extract_s3_key(file_path)
self.s3_client.delete_object(Bucket=self.bucket_name, Key=s3_key)
except ClientError as e:
raise RuntimeError(f"Error deleting file from S3: {e}")
@ -133,6 +171,10 @@ class S3StorageProvider(StorageProvider):
response = self.s3_client.list_objects_v2(Bucket=self.bucket_name)
if "Contents" in response:
for content in response["Contents"]:
# Skip objects that were not uploaded from open-webui in the first place
if not content["Key"].startswith(self.key_prefix):
continue
self.s3_client.delete_object(
Bucket=self.bucket_name, Key=content["Key"]
)
@ -142,6 +184,13 @@ class S3StorageProvider(StorageProvider):
# Always delete from local storage
LocalStorageProvider.delete_all_files()
# The s3 key is the name assigned to an object. It excludes the bucket name, but includes the internal path and the file name.
def _extract_s3_key(self, full_file_path: str) -> str:
return "/".join(full_file_path.split("//")[1].split("/")[1:])
def _get_local_file_path(self, s3_key: str) -> str:
return f"{UPLOAD_DIR}/{s3_key.split('/')[-1]}"
class GCSStorageProvider(StorageProvider):
def __init__(self):
@ -207,6 +256,74 @@ class GCSStorageProvider(StorageProvider):
LocalStorageProvider.delete_all_files()
class AzureStorageProvider(StorageProvider):
def __init__(self):
self.endpoint = AZURE_STORAGE_ENDPOINT
self.container_name = AZURE_STORAGE_CONTAINER_NAME
storage_key = AZURE_STORAGE_KEY
if storage_key:
# Configure using the Azure Storage Account Endpoint and Key
self.blob_service_client = BlobServiceClient(
account_url=self.endpoint, credential=storage_key
)
else:
# Configure using the Azure Storage Account Endpoint and DefaultAzureCredential
# If the key is not configured, then the DefaultAzureCredential will be used to support Managed Identity authentication
self.blob_service_client = BlobServiceClient(
account_url=self.endpoint, credential=DefaultAzureCredential()
)
self.container_client = self.blob_service_client.get_container_client(
self.container_name
)
def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
"""Handles uploading of the file to Azure Blob Storage."""
contents, file_path = LocalStorageProvider.upload_file(file, filename)
try:
blob_client = self.container_client.get_blob_client(filename)
blob_client.upload_blob(contents, overwrite=True)
return contents, f"{self.endpoint}/{self.container_name}/{filename}"
except Exception as e:
raise RuntimeError(f"Error uploading file to Azure Blob Storage: {e}")
def get_file(self, file_path: str) -> str:
"""Handles downloading of the file from Azure Blob Storage."""
try:
filename = file_path.split("/")[-1]
local_file_path = f"{UPLOAD_DIR}/{filename}"
blob_client = self.container_client.get_blob_client(filename)
with open(local_file_path, "wb") as download_file:
download_file.write(blob_client.download_blob().readall())
return local_file_path
except ResourceNotFoundError as e:
raise RuntimeError(f"Error downloading file from Azure Blob Storage: {e}")
def delete_file(self, file_path: str) -> None:
"""Handles deletion of the file from Azure Blob Storage."""
try:
filename = file_path.split("/")[-1]
blob_client = self.container_client.get_blob_client(filename)
blob_client.delete_blob()
except ResourceNotFoundError as e:
raise RuntimeError(f"Error deleting file from Azure Blob Storage: {e}")
# Always delete from local storage
LocalStorageProvider.delete_file(file_path)
def delete_all_files(self) -> None:
"""Handles deletion of all files from Azure Blob Storage."""
try:
blobs = self.container_client.list_blobs()
for blob in blobs:
self.container_client.delete_blob(blob.name)
except Exception as e:
raise RuntimeError(f"Error deleting all files from Azure Blob Storage: {e}")
# Always delete from local storage
LocalStorageProvider.delete_all_files()
def get_storage_provider(storage_provider: str):
if storage_provider == "local":
Storage = LocalStorageProvider()
@ -214,6 +331,8 @@ def get_storage_provider(storage_provider: str):
Storage = S3StorageProvider()
elif storage_provider == "gcs":
Storage = GCSStorageProvider()
elif storage_provider == "azure":
Storage = AzureStorageProvider()
else:
raise RuntimeError(f"Unsupported storage provider: {storage_provider}")
return Storage

View File

@ -7,6 +7,8 @@ from moto import mock_aws
from open_webui.storage import provider
from gcp_storage_emulator.server import create_server
from google.cloud import storage
from azure.storage.blob import BlobServiceClient, ContainerClient, BlobClient
from unittest.mock import MagicMock
def mock_upload_dir(monkeypatch, tmp_path):
@ -22,6 +24,7 @@ def test_imports():
provider.LocalStorageProvider
provider.S3StorageProvider
provider.GCSStorageProvider
provider.AzureStorageProvider
provider.Storage
@ -32,6 +35,8 @@ def test_get_storage_provider():
assert isinstance(Storage, provider.S3StorageProvider)
Storage = provider.get_storage_provider("gcs")
assert isinstance(Storage, provider.GCSStorageProvider)
Storage = provider.get_storage_provider("azure")
assert isinstance(Storage, provider.AzureStorageProvider)
with pytest.raises(RuntimeError):
provider.get_storage_provider("invalid")
@ -48,6 +53,7 @@ def test_class_instantiation():
provider.LocalStorageProvider()
provider.S3StorageProvider()
provider.GCSStorageProvider()
provider.AzureStorageProvider()
class TestLocalStorageProvider:
@ -181,6 +187,17 @@ class TestS3StorageProvider:
assert not (upload_dir / self.filename).exists()
assert not (upload_dir / self.filename_extra).exists()
def test_init_without_credentials(self, monkeypatch):
"""Test that S3StorageProvider can initialize without explicit credentials."""
# Temporarily unset the environment variables
monkeypatch.setattr(provider, "S3_ACCESS_KEY_ID", None)
monkeypatch.setattr(provider, "S3_SECRET_ACCESS_KEY", None)
# Should not raise an exception
storage = provider.S3StorageProvider()
assert storage.s3_client is not None
assert storage.bucket_name == provider.S3_BUCKET_NAME
class TestGCSStorageProvider:
Storage = provider.GCSStorageProvider()
@ -272,3 +289,147 @@ class TestGCSStorageProvider:
assert not (upload_dir / self.filename_extra).exists()
assert self.Storage.bucket.get_blob(self.filename) == None
assert self.Storage.bucket.get_blob(self.filename_extra) == None
class TestAzureStorageProvider:
def __init__(self):
super().__init__()
@pytest.fixture(scope="class")
def setup_storage(self, monkeypatch):
# Create mock Blob Service Client and related clients
mock_blob_service_client = MagicMock()
mock_container_client = MagicMock()
mock_blob_client = MagicMock()
# Set up return values for the mock
mock_blob_service_client.get_container_client.return_value = (
mock_container_client
)
mock_container_client.get_blob_client.return_value = mock_blob_client
# Monkeypatch the Azure classes to return our mocks
monkeypatch.setattr(
azure.storage.blob,
"BlobServiceClient",
lambda *args, **kwargs: mock_blob_service_client,
)
monkeypatch.setattr(
azure.storage.blob,
"ContainerClient",
lambda *args, **kwargs: mock_container_client,
)
monkeypatch.setattr(
azure.storage.blob, "BlobClient", lambda *args, **kwargs: mock_blob_client
)
self.Storage = provider.AzureStorageProvider()
self.Storage.endpoint = "https://myaccount.blob.core.windows.net"
self.Storage.container_name = "my-container"
self.file_content = b"test content"
self.filename = "test.txt"
self.filename_extra = "test_extra.txt"
self.file_bytesio_empty = io.BytesIO()
# Apply mocks to the Storage instance
self.Storage.blob_service_client = mock_blob_service_client
self.Storage.container_client = mock_container_client
def test_upload_file(self, monkeypatch, tmp_path):
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
# Simulate an error when container does not exist
self.Storage.container_client.get_blob_client.side_effect = Exception(
"Container does not exist"
)
with pytest.raises(Exception):
self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
# Reset side effect and create container
self.Storage.container_client.get_blob_client.side_effect = None
self.Storage.create_container()
contents, azure_file_path = self.Storage.upload_file(
io.BytesIO(self.file_content), self.filename
)
# Assertions
self.Storage.container_client.get_blob_client.assert_called_with(self.filename)
self.Storage.container_client.get_blob_client().upload_blob.assert_called_once_with(
self.file_content, overwrite=True
)
assert contents == self.file_content
assert (
azure_file_path
== f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}"
)
assert (upload_dir / self.filename).exists()
assert (upload_dir / self.filename).read_bytes() == self.file_content
with pytest.raises(ValueError):
self.Storage.upload_file(self.file_bytesio_empty, self.filename)
def test_get_file(self, monkeypatch, tmp_path):
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
self.Storage.create_container()
# Mock upload behavior
self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
# Mock blob download behavior
self.Storage.container_client.get_blob_client().download_blob().readall.return_value = (
self.file_content
)
file_url = f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}"
file_path = self.Storage.get_file(file_url)
assert file_path == str(upload_dir / self.filename)
assert (upload_dir / self.filename).exists()
assert (upload_dir / self.filename).read_bytes() == self.file_content
def test_delete_file(self, monkeypatch, tmp_path):
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
self.Storage.create_container()
# Mock file upload
self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
# Mock deletion
self.Storage.container_client.get_blob_client().delete_blob.return_value = None
file_url = f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}"
self.Storage.delete_file(file_url)
self.Storage.container_client.get_blob_client().delete_blob.assert_called_once()
assert not (upload_dir / self.filename).exists()
def test_delete_all_files(self, monkeypatch, tmp_path):
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
self.Storage.create_container()
# Mock file uploads
self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
self.Storage.upload_file(io.BytesIO(self.file_content), self.filename_extra)
# Mock listing and deletion behavior
self.Storage.container_client.list_blobs.return_value = [
{"name": self.filename},
{"name": self.filename_extra},
]
self.Storage.container_client.get_blob_client().delete_blob.return_value = None
self.Storage.delete_all_files()
self.Storage.container_client.list_blobs.assert_called_once()
self.Storage.container_client.get_blob_client().delete_blob.assert_any_call()
assert not (upload_dir / self.filename).exists()
assert not (upload_dir / self.filename_extra).exists()
def test_get_file_not_found(self, monkeypatch):
self.Storage.create_container()
file_url = f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}"
# Mock behavior to raise an error for missing blobs
self.Storage.container_client.get_blob_client().download_blob.side_effect = (
Exception("Blob not found")
)
with pytest.raises(Exception, match="Blob not found"):
self.Storage.get_file(file_url)

View File

@ -0,0 +1,249 @@
from contextlib import asynccontextmanager
from dataclasses import asdict, dataclass
from enum import Enum
import re
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Dict,
MutableMapping,
Optional,
cast,
)
import uuid
from asgiref.typing import (
ASGI3Application,
ASGIReceiveCallable,
ASGIReceiveEvent,
ASGISendCallable,
ASGISendEvent,
Scope as ASGIScope,
)
from loguru import logger
from starlette.requests import Request
from open_webui.env import AUDIT_LOG_LEVEL, MAX_BODY_LOG_SIZE
from open_webui.utils.auth import get_current_user, get_http_authorization_cred
from open_webui.models.users import UserModel
if TYPE_CHECKING:
from loguru import Logger
@dataclass(frozen=True)
class AuditLogEntry:
# `Metadata` audit level properties
id: str
user: dict[str, Any]
audit_level: str
verb: str
request_uri: str
user_agent: Optional[str] = None
source_ip: Optional[str] = None
# `Request` audit level properties
request_object: Any = None
# `Request Response` level
response_object: Any = None
response_status_code: Optional[int] = None
class AuditLevel(str, Enum):
NONE = "NONE"
METADATA = "METADATA"
REQUEST = "REQUEST"
REQUEST_RESPONSE = "REQUEST_RESPONSE"
class AuditLogger:
"""
A helper class that encapsulates audit logging functionality. It uses Logurus logger with an auditable binding to ensure that audit log entries are filtered correctly.
Parameters:
logger (Logger): An instance of Logurus logger.
"""
def __init__(self, logger: "Logger"):
self.logger = logger.bind(auditable=True)
def write(
self,
audit_entry: AuditLogEntry,
*,
log_level: str = "INFO",
extra: Optional[dict] = None,
):
entry = asdict(audit_entry)
if extra:
entry["extra"] = extra
self.logger.log(
log_level,
"",
**entry,
)
class AuditContext:
"""
Captures and aggregates the HTTP request and response bodies during the processing of a request. It ensures that only a configurable maximum amount of data is stored to prevent excessive memory usage.
Attributes:
request_body (bytearray): Accumulated request payload.
response_body (bytearray): Accumulated response payload.
max_body_size (int): Maximum number of bytes to capture.
metadata (Dict[str, Any]): A dictionary to store additional audit metadata (user, http verb, user agent, etc.).
"""
def __init__(self, max_body_size: int = MAX_BODY_LOG_SIZE):
self.request_body = bytearray()
self.response_body = bytearray()
self.max_body_size = max_body_size
self.metadata: Dict[str, Any] = {}
def add_request_chunk(self, chunk: bytes):
if len(self.request_body) < self.max_body_size:
self.request_body.extend(
chunk[: self.max_body_size - len(self.request_body)]
)
def add_response_chunk(self, chunk: bytes):
if len(self.response_body) < self.max_body_size:
self.response_body.extend(
chunk[: self.max_body_size - len(self.response_body)]
)
class AuditLoggingMiddleware:
"""
ASGI middleware that intercepts HTTP requests and responses to perform audit logging. It captures request/response bodies (depending on audit level), headers, HTTP methods, and user information, then logs a structured audit entry at the end of the request cycle.
"""
AUDITED_METHODS = {"PUT", "PATCH", "DELETE", "POST"}
def __init__(
self,
app: ASGI3Application,
*,
excluded_paths: Optional[list[str]] = None,
max_body_size: int = MAX_BODY_LOG_SIZE,
audit_level: AuditLevel = AuditLevel.NONE,
) -> None:
self.app = app
self.audit_logger = AuditLogger(logger)
self.excluded_paths = excluded_paths or []
self.max_body_size = max_body_size
self.audit_level = audit_level
async def __call__(
self,
scope: ASGIScope,
receive: ASGIReceiveCallable,
send: ASGISendCallable,
) -> None:
if scope["type"] != "http":
return await self.app(scope, receive, send)
request = Request(scope=cast(MutableMapping, scope))
if self._should_skip_auditing(request):
return await self.app(scope, receive, send)
async with self._audit_context(request) as context:
async def send_wrapper(message: ASGISendEvent) -> None:
if self.audit_level == AuditLevel.REQUEST_RESPONSE:
await self._capture_response(message, context)
await send(message)
original_receive = receive
async def receive_wrapper() -> ASGIReceiveEvent:
nonlocal original_receive
message = await original_receive()
if self.audit_level in (
AuditLevel.REQUEST,
AuditLevel.REQUEST_RESPONSE,
):
await self._capture_request(message, context)
return message
await self.app(scope, receive_wrapper, send_wrapper)
@asynccontextmanager
async def _audit_context(
self, request: Request
) -> AsyncGenerator[AuditContext, None]:
"""
async context manager that ensures that an audit log entry is recorded after the request is processed.
"""
context = AuditContext()
try:
yield context
finally:
await self._log_audit_entry(request, context)
async def _get_authenticated_user(self, request: Request) -> UserModel:
auth_header = request.headers.get("Authorization")
assert auth_header
user = get_current_user(request, None, get_http_authorization_cred(auth_header))
return user
def _should_skip_auditing(self, request: Request) -> bool:
if (
request.method not in {"POST", "PUT", "PATCH", "DELETE"}
or AUDIT_LOG_LEVEL == "NONE"
or not request.headers.get("authorization")
):
return True
# match either /api/<resource>/...(for the endpoint /api/chat case) or /api/v1/<resource>/...
pattern = re.compile(
r"^/api(?:/v1)?/(" + "|".join(self.excluded_paths) + r")\b"
)
if pattern.match(request.url.path):
return True
return False
async def _capture_request(self, message: ASGIReceiveEvent, context: AuditContext):
if message["type"] == "http.request":
body = message.get("body", b"")
context.add_request_chunk(body)
async def _capture_response(self, message: ASGISendEvent, context: AuditContext):
if message["type"] == "http.response.start":
context.metadata["response_status_code"] = message["status"]
elif message["type"] == "http.response.body":
body = message.get("body", b"")
context.add_response_chunk(body)
async def _log_audit_entry(self, request: Request, context: AuditContext):
try:
user = await self._get_authenticated_user(request)
entry = AuditLogEntry(
id=str(uuid.uuid4()),
user=user.model_dump(include={"id", "name", "email", "role"}),
audit_level=self.audit_level.value,
verb=request.method,
request_uri=str(request.url),
response_status_code=context.metadata.get("response_status_code", None),
source_ip=request.client.host if request.client else None,
user_agent=request.headers.get("user-agent"),
request_object=context.request_body.decode("utf-8", errors="replace"),
response_object=context.response_body.decode("utf-8", errors="replace"),
)
self.audit_logger.write(entry)
except Exception as e:
logger.error(f"Failed to log audit entry: {str(e)}")

View File

@ -1,6 +1,12 @@
import logging
import uuid
import jwt
import base64
import hmac
import hashlib
import requests
import os
from datetime import UTC, datetime, timedelta
from typing import Optional, Union, List, Dict
@ -8,14 +14,22 @@ from typing import Optional, Union, List, Dict
from open_webui.models.users import Users
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import WEBUI_SECRET_KEY
from open_webui.env import (
WEBUI_SECRET_KEY,
TRUSTED_SIGNATURE_KEY,
STATIC_DIR,
SRC_LOG_LEVELS,
)
from fastapi import Depends, HTTPException, Request, Response, status
from fastapi import BackgroundTasks, Depends, HTTPException, Request, Response, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from passlib.context import CryptContext
logging.getLogger("passlib").setLevel(logging.ERROR)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OAUTH"])
SESSION_SECRET = WEBUI_SECRET_KEY
ALGORITHM = "HS256"
@ -24,6 +38,67 @@ ALGORITHM = "HS256"
# Auth Utils
##############
def verify_signature(payload: str, signature: str) -> bool:
"""
Verifies the HMAC signature of the received payload.
"""
try:
expected_signature = base64.b64encode(
hmac.new(TRUSTED_SIGNATURE_KEY, payload.encode(), hashlib.sha256).digest()
).decode()
# Compare securely to prevent timing attacks
return hmac.compare_digest(expected_signature, signature)
except Exception:
return False
def override_static(path: str, content: str):
# Ensure path is safe
if "/" in path or ".." in path:
log.error(f"Invalid path: {path}")
return
file_path = os.path.join(STATIC_DIR, path)
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "wb") as f:
f.write(base64.b64decode(content)) # Convert Base64 back to raw binary
def get_license_data(app, key):
if key:
try:
res = requests.post(
"https://api.openwebui.com/api/v1/license",
json={"key": key, "version": "1"},
timeout=5,
)
if getattr(res, "ok", False):
payload = getattr(res, "json", lambda: {})()
for k, v in payload.items():
if k == "resources":
for p, c in v.items():
globals().get("override_static", lambda a, b: None)(p, c)
elif k == "count":
setattr(app.state, "USER_COUNT", v)
elif k == "name":
setattr(app.state, "WEBUI_NAME", v)
elif k == "metadata":
setattr(app.state, "LICENSE_METADATA", v)
return True
else:
log.error(
f"License: retrieval issue: {getattr(res, 'text', 'unknown error')}"
)
except Exception as ex:
log.exception(f"License: Uncaught Exception: {ex}")
return False
bearer_security = HTTPBearer(auto_error=False)
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
@ -76,6 +151,7 @@ def get_http_authorization_cred(auth_header: str):
def get_current_user(
request: Request,
background_tasks: BackgroundTasks,
auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
):
token = None
@ -128,7 +204,10 @@ def get_current_user(
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
else:
Users.update_user_last_active_by_id(user.id)
# Refresh the user's last active timestamp asynchronously
# to prevent blocking the request
if background_tasks:
background_tasks.add_task(Users.update_user_last_active_by_id, user.id)
return user
else:
raise HTTPException(

View File

@ -7,14 +7,17 @@ from typing import Any, Optional
import random
import json
import inspect
import uuid
import asyncio
from fastapi import Request
from starlette.responses import Response, StreamingResponse
from fastapi import Request, status
from starlette.responses import Response, StreamingResponse, JSONResponse
from open_webui.models.users import UserModel
from open_webui.socket.main import (
sio,
get_event_call,
get_event_emitter,
)
@ -44,6 +47,10 @@ from open_webui.utils.response import (
convert_response_ollama_to_openai,
convert_streaming_response_ollama_to_openai,
)
from open_webui.utils.filter import (
get_sorted_filter_ids,
process_filter_functions,
)
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
@ -53,108 +60,224 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
async def generate_direct_chat_completion(
request: Request,
form_data: dict,
user: Any,
models: dict,
):
log.info("generate_direct_chat_completion")
metadata = form_data.pop("metadata", {})
user_id = metadata.get("user_id")
session_id = metadata.get("session_id")
request_id = str(uuid.uuid4()) # Generate a unique request ID
event_caller = get_event_call(metadata)
channel = f"{user_id}:{session_id}:{request_id}"
if form_data.get("stream"):
q = asyncio.Queue()
async def message_listener(sid, data):
"""
Handle received socket messages and push them into the queue.
"""
await q.put(data)
# Register the listener
sio.on(channel, message_listener)
# Start processing chat completion in background
res = await event_caller(
{
"type": "request:chat:completion",
"data": {
"form_data": form_data,
"model": models[form_data["model"]],
"channel": channel,
"session_id": session_id,
},
}
)
log.info(f"res: {res}")
if res.get("status", False):
# Define a generator to stream responses
async def event_generator():
nonlocal q
try:
while True:
data = await q.get() # Wait for new messages
if isinstance(data, dict):
if "done" in data and data["done"]:
break # Stop streaming when 'done' is received
yield f"data: {json.dumps(data)}\n\n"
elif isinstance(data, str):
yield data
except Exception as e:
log.debug(f"Error in event generator: {e}")
pass
# Define a background task to run the event generator
async def background():
try:
del sio.handlers["/"][channel]
except Exception as e:
pass
# Return the streaming response
return StreamingResponse(
event_generator(), media_type="text/event-stream", background=background
)
else:
raise Exception(str(res))
else:
res = await event_caller(
{
"type": "request:chat:completion",
"data": {
"form_data": form_data,
"model": models[form_data["model"]],
"channel": channel,
"session_id": session_id,
},
}
)
if "error" in res and res["error"]:
raise Exception(res["error"])
return res
async def generate_chat_completion(
request: Request,
form_data: dict,
user: Any,
bypass_filter: bool = False,
):
log.debug(f"generate_chat_completion: {form_data}")
if BYPASS_MODEL_ACCESS_CONTROL:
bypass_filter = True
models = request.app.state.MODELS
if hasattr(request.state, "metadata"):
if "metadata" not in form_data:
form_data["metadata"] = request.state.metadata
else:
form_data["metadata"] = {
**form_data["metadata"],
**request.state.metadata,
}
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = {
request.state.model["id"]: request.state.model,
}
log.debug(f"direct connection to model: {models}")
else:
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
raise Exception("Model not found")
# Process the form_data through the pipeline
try:
form_data = process_pipeline_inlet_filter(request, form_data, user, models)
except Exception as e:
raise e
model = models[model_id]
# Check if user has access to the model
if not bypass_filter and user.role == "user":
try:
check_model_access(user, model)
except Exception as e:
raise e
if model["owned_by"] == "arena":
model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
if model_ids and filter_mode == "exclude":
model_ids = [
model["id"]
for model in list(request.app.state.MODELS.values())
if model.get("owned_by") != "arena" and model["id"] not in model_ids
]
selected_model_id = None
if isinstance(model_ids, list) and model_ids:
selected_model_id = random.choice(model_ids)
else:
model_ids = [
model["id"]
for model in list(request.app.state.MODELS.values())
if model.get("owned_by") != "arena"
]
selected_model_id = random.choice(model_ids)
form_data["model"] = selected_model_id
if form_data.get("stream") == True:
async def stream_wrapper(stream):
yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n"
async for chunk in stream:
yield chunk
response = await generate_chat_completion(
request, form_data, user, bypass_filter=True
)
return StreamingResponse(
stream_wrapper(response.body_iterator),
media_type="text/event-stream",
background=response.background,
)
else:
return {
**(
await generate_chat_completion(
request, form_data, user, bypass_filter=True
)
),
"selected_model_id": selected_model_id,
}
if model.get("pipe"):
# Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter
return await generate_function_chat_completion(
if getattr(request.state, "direct", False):
return await generate_direct_chat_completion(
request, form_data, user=user, models=models
)
if model["owned_by"] == "ollama":
# Using /ollama/api/chat endpoint
form_data = convert_payload_openai_to_ollama(form_data)
response = await generate_ollama_chat_completion(
request=request, form_data=form_data, user=user, bypass_filter=bypass_filter
)
if form_data.get("stream"):
response.headers["content-type"] = "text/event-stream"
return StreamingResponse(
convert_streaming_response_ollama_to_openai(response),
headers=dict(response.headers),
background=response.background,
)
else:
return convert_response_ollama_to_openai(response)
else:
return await generate_openai_chat_completion(
request=request, form_data=form_data, user=user, bypass_filter=bypass_filter
)
# Check if user has access to the model
if not bypass_filter and user.role == "user":
try:
check_model_access(user, model)
except Exception as e:
raise e
if model.get("owned_by") == "arena":
model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
if model_ids and filter_mode == "exclude":
model_ids = [
model["id"]
for model in list(request.app.state.MODELS.values())
if model.get("owned_by") != "arena" and model["id"] not in model_ids
]
selected_model_id = None
if isinstance(model_ids, list) and model_ids:
selected_model_id = random.choice(model_ids)
else:
model_ids = [
model["id"]
for model in list(request.app.state.MODELS.values())
if model.get("owned_by") != "arena"
]
selected_model_id = random.choice(model_ids)
form_data["model"] = selected_model_id
if form_data.get("stream") == True:
async def stream_wrapper(stream):
yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n"
async for chunk in stream:
yield chunk
response = await generate_chat_completion(
request, form_data, user, bypass_filter=True
)
return StreamingResponse(
stream_wrapper(response.body_iterator),
media_type="text/event-stream",
background=response.background,
)
else:
return {
**(
await generate_chat_completion(
request, form_data, user, bypass_filter=True
)
),
"selected_model_id": selected_model_id,
}
if model.get("pipe"):
# Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter
return await generate_function_chat_completion(
request, form_data, user=user, models=models
)
if model.get("owned_by") == "ollama":
# Using /ollama/api/chat endpoint
form_data = convert_payload_openai_to_ollama(form_data)
response = await generate_ollama_chat_completion(
request=request,
form_data=form_data,
user=user,
bypass_filter=bypass_filter,
)
if form_data.get("stream"):
response.headers["content-type"] = "text/event-stream"
return StreamingResponse(
convert_streaming_response_ollama_to_openai(response),
headers=dict(response.headers),
background=response.background,
)
else:
return convert_response_ollama_to_openai(response)
else:
return await generate_openai_chat_completion(
request=request,
form_data=form_data,
user=user,
bypass_filter=bypass_filter,
)
chat_completion = generate_chat_completion
@ -162,8 +285,14 @@ chat_completion = generate_chat_completion
async def chat_completed(request: Request, form_data: dict, user: Any):
if not request.app.state.MODELS:
await get_all_models(request)
models = request.app.state.MODELS
await get_all_models(request, user=user)
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
data = form_data
model_id = data["model"]
@ -173,120 +302,47 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
model = models[model_id]
try:
data = process_pipeline_outlet_filter(request, data, user, models)
data = await process_pipeline_outlet_filter(request, data, user, models)
except Exception as e:
return Exception(f"Error: {e}")
__event_emitter__ = get_event_emitter(
{
"chat_id": data["chat_id"],
"message_id": data["id"],
"session_id": data["session_id"],
"user_id": user.id,
}
)
metadata = {
"chat_id": data["chat_id"],
"message_id": data["id"],
"session_id": data["session_id"],
"user_id": user.id,
}
__event_call__ = get_event_call(
{
"chat_id": data["chat_id"],
"message_id": data["id"],
"session_id": data["session_id"],
"user_id": user.id,
}
)
extra_params = {
"__event_emitter__": get_event_emitter(metadata),
"__event_call__": get_event_call(metadata),
"__user__": {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
"__metadata__": metadata,
"__request__": request,
"__model__": model,
}
def get_priority(function_id):
function = Functions.get_function_by_id(function_id)
if function is not None and hasattr(function, "valves"):
# TODO: Fix FunctionModel to include vavles
return (function.valves if function.valves else {}).get("priority", 0)
return 0
try:
filter_functions = [
Functions.get_function_by_id(filter_id)
for filter_id in get_sorted_filter_ids(model)
]
filter_ids = [function.id for function in Functions.get_global_filter_functions()]
if "info" in model and "meta" in model["info"]:
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
filter_ids = list(set(filter_ids))
enabled_filter_ids = [
function.id
for function in Functions.get_functions_by_type("filter", active_only=True)
]
filter_ids = [
filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
]
# Sort filter_ids by priority, using the get_priority function
filter_ids.sort(key=get_priority)
for filter_id in filter_ids:
filter = Functions.get_function_by_id(filter_id)
if not filter:
continue
if filter_id in request.app.state.FUNCTIONS:
function_module = request.app.state.FUNCTIONS[filter_id]
else:
function_module, _, _ = load_function_module_by_id(filter_id)
request.app.state.FUNCTIONS[filter_id] = function_module
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(filter_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
if not hasattr(function_module, "outlet"):
continue
try:
outlet = function_module.outlet
# Get the signature of the function
sig = inspect.signature(outlet)
params = {"body": data}
# Extra parameters to be passed to the function
extra_params = {
"__model__": model,
"__id__": filter_id,
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
"__request__": request,
}
# Add extra params in contained in function signature
for key, value in extra_params.items():
if key in sig.parameters:
params[key] = value
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
filter_id, user.id
)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if inspect.iscoroutinefunction(outlet):
data = await outlet(**params)
else:
data = outlet(**params)
except Exception as e:
return Exception(f"Error: {e}")
return data
result, _ = await process_filter_functions(
request=request,
filter_functions=filter_functions,
filter_type="outlet",
form_data=data,
extra_params=extra_params,
)
return result
except Exception as e:
return Exception(f"Error: {e}")
async def chat_action(request: Request, action_id: str, form_data: dict, user: Any):
@ -300,8 +356,14 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
raise Exception(f"Action not found: {action_id}")
if not request.app.state.MODELS:
await get_all_models(request)
models = request.app.state.MODELS
await get_all_models(request, user=user)
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
data = form_data
model_id = data["model"]
@ -375,7 +437,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
)
)
except Exception as e:
print(e)
log.exception(f"Failed to get user values: {e}")
params = {**params, "__user__": __user__}

View File

@ -0,0 +1,210 @@
import asyncio
import json
import logging
import uuid
from typing import Optional
import aiohttp
import websockets
from pydantic import BaseModel
from open_webui.env import SRC_LOG_LEVELS
logger = logging.getLogger(__name__)
logger.setLevel(SRC_LOG_LEVELS["MAIN"])
class ResultModel(BaseModel):
"""
Execute Code Result Model
"""
stdout: Optional[str] = ""
stderr: Optional[str] = ""
result: Optional[str] = ""
class JupyterCodeExecuter:
"""
Execute code in jupyter notebook
"""
def __init__(
self,
base_url: str,
code: str,
token: str = "",
password: str = "",
timeout: int = 60,
):
"""
:param base_url: Jupyter server URL (e.g., "http://localhost:8888")
:param code: Code to execute
:param token: Jupyter authentication token (optional)
:param password: Jupyter password (optional)
:param timeout: WebSocket timeout in seconds (default: 60s)
"""
self.base_url = base_url.rstrip("/")
self.code = code
self.token = token
self.password = password
self.timeout = timeout
self.kernel_id = ""
self.session = aiohttp.ClientSession(base_url=self.base_url)
self.params = {}
self.result = ResultModel()
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.kernel_id:
try:
async with self.session.delete(
f"/api/kernels/{self.kernel_id}", params=self.params
) as response:
response.raise_for_status()
except Exception as err:
logger.exception("close kernel failed, %s", err)
await self.session.close()
async def run(self) -> ResultModel:
try:
await self.sign_in()
await self.init_kernel()
await self.execute_code()
except Exception as err:
logger.exception("execute code failed, %s", err)
self.result.stderr = f"Error: {err}"
return self.result
async def sign_in(self) -> None:
# password authentication
if self.password and not self.token:
async with self.session.get("/login") as response:
response.raise_for_status()
xsrf_token = response.cookies["_xsrf"].value
if not xsrf_token:
raise ValueError("_xsrf token not found")
self.session.cookie_jar.update_cookies(response.cookies)
self.session.headers.update({"X-XSRFToken": xsrf_token})
async with self.session.post(
"/login",
data={"_xsrf": xsrf_token, "password": self.password},
allow_redirects=False,
) as response:
response.raise_for_status()
self.session.cookie_jar.update_cookies(response.cookies)
# token authentication
if self.token:
self.params.update({"token": self.token})
async def init_kernel(self) -> None:
async with self.session.post(
url="/api/kernels", params=self.params
) as response:
response.raise_for_status()
kernel_data = await response.json()
self.kernel_id = kernel_data["id"]
def init_ws(self) -> (str, dict):
ws_base = self.base_url.replace("http", "ws")
ws_params = "?" + "&".join([f"{key}={val}" for key, val in self.params.items()])
websocket_url = f"{ws_base}/api/kernels/{self.kernel_id}/channels{ws_params if len(ws_params) > 1 else ''}"
ws_headers = {}
if self.password and not self.token:
ws_headers = {
"Cookie": "; ".join(
[
f"{cookie.key}={cookie.value}"
for cookie in self.session.cookie_jar
]
),
**self.session.headers,
}
return websocket_url, ws_headers
async def execute_code(self) -> None:
# initialize ws
websocket_url, ws_headers = self.init_ws()
# execute
async with websockets.connect(
websocket_url, additional_headers=ws_headers
) as ws:
await self.execute_in_jupyter(ws)
async def execute_in_jupyter(self, ws) -> None:
# send message
msg_id = uuid.uuid4().hex
await ws.send(
json.dumps(
{
"header": {
"msg_id": msg_id,
"msg_type": "execute_request",
"username": "user",
"session": uuid.uuid4().hex,
"date": "",
"version": "5.3",
},
"parent_header": {},
"metadata": {},
"content": {
"code": self.code,
"silent": False,
"store_history": True,
"user_expressions": {},
"allow_stdin": False,
"stop_on_error": True,
},
"channel": "shell",
}
)
)
# parse message
stdout, stderr, result = "", "", []
while True:
try:
# wait for message
message = await asyncio.wait_for(ws.recv(), self.timeout)
message_data = json.loads(message)
# msg id not match, skip
if message_data.get("parent_header", {}).get("msg_id") != msg_id:
continue
# check message type
msg_type = message_data.get("msg_type")
match msg_type:
case "stream":
if message_data["content"]["name"] == "stdout":
stdout += message_data["content"]["text"]
elif message_data["content"]["name"] == "stderr":
stderr += message_data["content"]["text"]
case "execute_result" | "display_data":
data = message_data["content"]["data"]
if "image/png" in data:
result.append(f"data:image/png;base64,{data['image/png']}")
elif "text/plain" in data:
result.append(data["text/plain"])
case "error":
stderr += "\n".join(message_data["content"]["traceback"])
case "status":
if message_data["content"]["execution_state"] == "idle":
break
except asyncio.TimeoutError:
stderr += "\nExecution timed out."
break
self.result.stdout = stdout.strip()
self.result.stderr = stderr.strip()
self.result.result = "\n".join(result).strip() if result else ""
async def execute_code_jupyter(
base_url: str, code: str, token: str = "", password: str = "", timeout: int = 60
) -> dict:
async with JupyterCodeExecuter(
base_url, code, token, password, timeout
) as executor:
result = await executor.run()
return result.model_dump()

View File

@ -0,0 +1,111 @@
import inspect
import logging
from open_webui.utils.plugin import load_function_module_by_id
from open_webui.models.functions import Functions
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
def get_sorted_filter_ids(model: dict):
def get_priority(function_id):
function = Functions.get_function_by_id(function_id)
if function is not None and hasattr(function, "valves"):
# TODO: Fix FunctionModel to include vavles
return (function.valves if function.valves else {}).get("priority", 0)
return 0
filter_ids = [function.id for function in Functions.get_global_filter_functions()]
if "info" in model and "meta" in model["info"]:
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
filter_ids = list(set(filter_ids))
enabled_filter_ids = [
function.id
for function in Functions.get_functions_by_type("filter", active_only=True)
]
filter_ids = [fid for fid in filter_ids if fid in enabled_filter_ids]
filter_ids.sort(key=get_priority)
return filter_ids
async def process_filter_functions(
request, filter_functions, filter_type, form_data, extra_params
):
skip_files = None
for function in filter_functions:
filter = function
filter_id = function.id
if not filter:
continue
if filter_id in request.app.state.FUNCTIONS:
function_module = request.app.state.FUNCTIONS[filter_id]
else:
function_module, _, _ = load_function_module_by_id(filter_id)
request.app.state.FUNCTIONS[filter_id] = function_module
# Prepare handler function
handler = getattr(function_module, filter_type, None)
if not handler:
continue
# Check if the function has a file_handler variable
if filter_type == "inlet" and hasattr(function_module, "file_handler"):
skip_files = function_module.file_handler
# Apply valves to the function
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(filter_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
try:
# Prepare parameters
sig = inspect.signature(handler)
params = {"body": form_data}
if filter_type == "stream":
params = {"event": form_data}
params = params | {
k: v
for k, v in {
**extra_params,
"__id__": filter_id,
}.items()
if k in sig.parameters
}
# Handle user parameters
if "__user__" in sig.parameters:
if hasattr(function_module, "UserValves"):
try:
params["__user__"]["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
filter_id, params["__user__"]["id"]
)
)
except Exception as e:
log.exception(f"Failed to get user values: {e}")
# Execute handler
if inspect.iscoroutinefunction(handler):
form_data = await handler(**params)
else:
form_data = handler(**params)
except Exception as e:
log.exception(f"Error in {filter_type} handler {filter_id}: {e}")
raise e
# Handle file cleanup for inlet
if skip_files and "files" in form_data.get("metadata", {}):
del form_data["metadata"]["files"]
return form_data, {}

View File

@ -161,7 +161,7 @@ async def comfyui_generate_image(
seed = (
payload.seed
if payload.seed
else random.randint(0, 18446744073709551614)
else random.randint(0, 1125899906842624)
)
for node_id in node.node_ids:
workflow[node_id]["inputs"][node.key] = seed

View File

@ -0,0 +1,140 @@
import json
import logging
import sys
from typing import TYPE_CHECKING
from loguru import logger
from open_webui.env import (
AUDIT_LOG_FILE_ROTATION_SIZE,
AUDIT_LOG_LEVEL,
AUDIT_LOGS_FILE_PATH,
GLOBAL_LOG_LEVEL,
)
if TYPE_CHECKING:
from loguru import Record
def stdout_format(record: "Record") -> str:
"""
Generates a formatted string for log records that are output to the console. This format includes a timestamp, log level, source location (module, function, and line), the log message, and any extra data (serialized as JSON).
Parameters:
record (Record): A Loguru record that contains logging details including time, level, name, function, line, message, and any extra context.
Returns:
str: A formatted log string intended for stdout.
"""
record["extra"]["extra_json"] = json.dumps(record["extra"])
return (
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
"<level>{level: <8}</level> | "
"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - "
"<level>{message}</level> - {extra[extra_json]}"
"\n{exception}"
)
class InterceptHandler(logging.Handler):
"""
Intercepts log records from Python's standard logging module
and redirects them to Loguru's logger.
"""
def emit(self, record):
"""
Called by the standard logging module for each log event.
It transforms the standard `LogRecord` into a format compatible with Loguru
and passes it to Loguru's logger.
"""
try:
level = logger.level(record.levelname).name
except ValueError:
level = record.levelno
frame, depth = sys._getframe(6), 6
while frame and frame.f_code.co_filename == logging.__file__:
frame = frame.f_back
depth += 1
logger.opt(depth=depth, exception=record.exc_info).log(
level, record.getMessage()
)
def file_format(record: "Record"):
"""
Formats audit log records into a structured JSON string for file output.
Parameters:
record (Record): A Loguru record containing extra audit data.
Returns:
str: A JSON-formatted string representing the audit data.
"""
audit_data = {
"id": record["extra"].get("id", ""),
"timestamp": int(record["time"].timestamp()),
"user": record["extra"].get("user", dict()),
"audit_level": record["extra"].get("audit_level", ""),
"verb": record["extra"].get("verb", ""),
"request_uri": record["extra"].get("request_uri", ""),
"response_status_code": record["extra"].get("response_status_code", 0),
"source_ip": record["extra"].get("source_ip", ""),
"user_agent": record["extra"].get("user_agent", ""),
"request_object": record["extra"].get("request_object", b""),
"response_object": record["extra"].get("response_object", b""),
"extra": record["extra"].get("extra", {}),
}
record["extra"]["file_extra"] = json.dumps(audit_data, default=str)
return "{extra[file_extra]}\n"
def start_logger():
"""
Initializes and configures Loguru's logger with distinct handlers:
A console (stdout) handler for general log messages (excluding those marked as auditable).
An optional file handler for audit logs if audit logging is enabled.
Additionally, this function reconfigures Pythons standard logging to route through Loguru and adjusts logging levels for Uvicorn.
Parameters:
enable_audit_logging (bool): Determines whether audit-specific log entries should be recorded to file.
"""
logger.remove()
logger.add(
sys.stdout,
level=GLOBAL_LOG_LEVEL,
format=stdout_format,
filter=lambda record: "auditable" not in record["extra"],
)
if AUDIT_LOG_LEVEL != "NONE":
try:
logger.add(
AUDIT_LOGS_FILE_PATH,
level="INFO",
rotation=AUDIT_LOG_FILE_ROTATION_SIZE,
compression="zip",
format=file_format,
filter=lambda record: record["extra"].get("auditable") is True,
)
except Exception as e:
logger.error(f"Failed to initialize audit log file handler: {str(e)}")
logging.basicConfig(
handlers=[InterceptHandler()], level=GLOBAL_LOG_LEVEL, force=True
)
for uvicorn_logger_name in ["uvicorn", "uvicorn.error"]:
uvicorn_logger = logging.getLogger(uvicorn_logger_name)
uvicorn_logger.setLevel(GLOBAL_LOG_LEVEL)
uvicorn_logger.handlers = []
for uvicorn_logger_name in ["uvicorn.access"]:
uvicorn_logger = logging.getLogger(uvicorn_logger_name)
uvicorn_logger.setLevel(GLOBAL_LOG_LEVEL)
uvicorn_logger.handlers = [InterceptHandler()]
logger.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}")

File diff suppressed because it is too large Load Diff

View File

@ -2,9 +2,27 @@ import hashlib
import re
import time
import uuid
import logging
from datetime import timedelta
from pathlib import Path
from typing import Callable, Optional
import json
import collections.abc
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
def deep_update(d, u):
for k, v in u.items():
if isinstance(v, collections.abc.Mapping):
d[k] = deep_update(d.get(k, {}), v)
else:
d[k] = v
return d
def get_message_list(messages, message_id):
@ -20,7 +38,7 @@ def get_message_list(messages, message_id):
current_message = messages.get(message_id)
if not current_message:
return f"Message ID {message_id} not found in the history."
return None
# Reconstruct the chain by following the parentId links
message_list = []
@ -131,6 +149,44 @@ def add_or_update_system_message(content: str, messages: list[dict]):
return messages
def add_or_update_user_message(content: str, messages: list[dict]):
"""
Adds a new user message at the end of the messages list
or updates the existing user message at the end.
:param msg: The message to be added or appended.
:param messages: The list of message dictionaries.
:return: The updated list of message dictionaries.
"""
if messages and messages[-1].get("role") == "user":
messages[-1]["content"] = f"{messages[-1]['content']}\n{content}"
else:
# Insert at the end
messages.append({"role": "user", "content": content})
return messages
def append_or_update_assistant_message(content: str, messages: list[dict]):
"""
Adds a new assistant message at the end of the messages list
or updates the existing assistant message at the end.
:param msg: The message to be added or appended.
:param messages: The list of message dictionaries.
:return: The updated list of message dictionaries.
"""
if messages and messages[-1].get("role") == "assistant":
messages[-1]["content"] = f"{messages[-1]['content']}\n{content}"
else:
# Insert at the end
messages.append({"role": "assistant", "content": content})
return messages
def openai_chat_message_template(model: str):
return {
"id": f"{model}-{str(uuid.uuid4())}",
@ -141,13 +197,24 @@ def openai_chat_message_template(model: str):
def openai_chat_chunk_message_template(
model: str, message: Optional[str] = None, usage: Optional[dict] = None
model: str,
content: Optional[str] = None,
tool_calls: Optional[list[dict]] = None,
usage: Optional[dict] = None,
) -> dict:
template = openai_chat_message_template(model)
template["object"] = "chat.completion.chunk"
if message:
template["choices"][0]["delta"] = {"content": message}
else:
template["choices"][0]["index"] = 0
template["choices"][0]["delta"] = {}
if content:
template["choices"][0]["delta"]["content"] = content
if tool_calls:
template["choices"][0]["delta"]["tool_calls"] = tool_calls
if not content and not tool_calls:
template["choices"][0]["finish_reason"] = "stop"
if usage:
@ -156,12 +223,20 @@ def openai_chat_chunk_message_template(
def openai_chat_completion_message_template(
model: str, message: Optional[str] = None, usage: Optional[dict] = None
model: str,
message: Optional[str] = None,
tool_calls: Optional[list[dict]] = None,
usage: Optional[dict] = None,
) -> dict:
template = openai_chat_message_template(model)
template["object"] = "chat.completion"
if message is not None:
template["choices"][0]["message"] = {"content": message, "role": "assistant"}
template["choices"][0]["message"] = {
"content": message,
"role": "assistant",
**({"tool_calls": tool_calls} if tool_calls else {}),
}
template["choices"][0]["finish_reason"] = "stop"
if usage:
@ -183,11 +258,12 @@ def get_gravatar_url(email):
return f"https://www.gravatar.com/avatar/{hash_hex}?d=mp"
def calculate_sha256(file):
def calculate_sha256(file_path, chunk_size):
# Compute SHA-256 hash of a file efficiently in chunks
sha256 = hashlib.sha256()
# Read the file in chunks to efficiently handle large files
for chunk in iter(lambda: file.read(8192), b""):
sha256.update(chunk)
with open(file_path, "rb") as f:
while chunk := f.read(chunk_size):
sha256.update(chunk)
return sha256.hexdigest()
@ -342,7 +418,7 @@ def parse_ollama_modelfile(model_text):
elif param_type is bool:
value = value.lower() == "true"
except Exception as e:
print(e)
log.exception(f"Failed to parse parameter {param}: {e}")
continue
data["params"][param] = value
@ -375,3 +451,15 @@ def parse_ollama_modelfile(model_text):
data["params"]["messages"] = messages
return data
def convert_logit_bias_input_to_json(user_input):
logit_bias_pairs = user_input.split(",")
logit_bias_json = {}
for pair in logit_bias_pairs:
token, bias = pair.split(":")
token = str(token.strip())
bias = int(bias.strip())
bias = 100 if bias > 100 else -100 if bias < -100 else bias
logit_bias_json[token] = bias
return json.dumps(logit_bias_json)

View File

@ -22,6 +22,7 @@ from open_webui.config import (
)
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
from open_webui.models.users import UserModel
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
@ -29,17 +30,17 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
async def get_all_base_models(request: Request):
async def get_all_base_models(request: Request, user: UserModel = None):
function_models = []
openai_models = []
ollama_models = []
if request.app.state.config.ENABLE_OPENAI_API:
openai_models = await openai.get_all_models(request)
openai_models = await openai.get_all_models(request, user=user)
openai_models = openai_models["data"]
if request.app.state.config.ENABLE_OLLAMA_API:
ollama_models = await ollama.get_all_models(request)
ollama_models = await ollama.get_all_models(request, user=user)
ollama_models = [
{
"id": model["model"],
@ -58,8 +59,8 @@ async def get_all_base_models(request: Request):
return models
async def get_all_models(request):
models = await get_all_base_models(request)
async def get_all_models(request, user: UserModel = None):
models = await get_all_base_models(request, user=user)
# If there are no models, return an empty list
if len(models) == 0:
@ -142,7 +143,7 @@ async def get_all_models(request):
custom_model.base_model_id == model["id"]
or custom_model.base_model_id == model["id"].split(":")[0]
):
owned_by = model["owned_by"]
owned_by = model.get("owned_by", "unknown owner")
if "pipe" in model:
pipe = model["pipe"]
break

View File

@ -1,6 +1,7 @@
import base64
import logging
import mimetypes
import sys
import uuid
import aiohttp
@ -35,12 +36,20 @@ from open_webui.config import (
AppConfig,
)
from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from open_webui.env import WEBUI_SESSION_COOKIE_SAME_SITE, WEBUI_SESSION_COOKIE_SECURE
from open_webui.env import (
WEBUI_NAME,
WEBUI_AUTH_COOKIE_SAME_SITE,
WEBUI_AUTH_COOKIE_SECURE,
)
from open_webui.utils.misc import parse_duration
from open_webui.utils.auth import get_password_hash, create_token
from open_webui.utils.webhook import post_webhook
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OAUTH"])
auth_manager_config = AppConfig()
auth_manager_config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
@ -61,8 +70,9 @@ auth_manager_config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
class OAuthManager:
def __init__(self):
def __init__(self, app):
self.oauth = OAuth()
self.app = app
for _, provider_config in OAUTH_PROVIDERS.items():
provider_config["register"](self.oauth)
@ -72,17 +82,21 @@ class OAuthManager:
def get_user_role(self, user, user_data):
if user and Users.get_num_users() == 1:
# If the user is the only user, assign the role "admin" - actually repairs role for single user on login
log.debug("Assigning the only user the admin role")
return "admin"
if not user and Users.get_num_users() == 0:
# If there are no users, assign the role "admin", as the first user will be an admin
log.debug("Assigning the first user the admin role")
return "admin"
if auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT:
log.debug("Running OAUTH Role management")
oauth_claim = auth_manager_config.OAUTH_ROLES_CLAIM
oauth_allowed_roles = auth_manager_config.OAUTH_ALLOWED_ROLES
oauth_admin_roles = auth_manager_config.OAUTH_ADMIN_ROLES
oauth_roles = None
role = "pending" # Default/fallback role if no matching roles are found
# Default/fallback role if no matching roles are found
role = auth_manager_config.DEFAULT_USER_ROLE
# Next block extracts the roles from the user data, accepting nested claims of any depth
if oauth_claim and oauth_allowed_roles and oauth_admin_roles:
@ -92,17 +106,24 @@ class OAuthManager:
claim_data = claim_data.get(nested_claim, {})
oauth_roles = claim_data if isinstance(claim_data, list) else None
log.debug(f"Oauth Roles claim: {oauth_claim}")
log.debug(f"User roles from oauth: {oauth_roles}")
log.debug(f"Accepted user roles: {oauth_allowed_roles}")
log.debug(f"Accepted admin roles: {oauth_admin_roles}")
# If any roles are found, check if they match the allowed or admin roles
if oauth_roles:
# If role management is enabled, and matching roles are provided, use the roles
for allowed_role in oauth_allowed_roles:
# If the user has any of the allowed roles, assign the role "user"
if allowed_role in oauth_roles:
log.debug("Assigned user the user role")
role = "user"
break
for admin_role in oauth_admin_roles:
# If the user has any of the admin roles, assign the role "admin"
if admin_role in oauth_roles:
log.debug("Assigned user the admin role")
role = "admin"
break
else:
@ -116,16 +137,34 @@ class OAuthManager:
return role
def update_user_groups(self, user, user_data, default_permissions):
log.debug("Running OAUTH Group management")
oauth_claim = auth_manager_config.OAUTH_GROUPS_CLAIM
user_oauth_groups: list[str] = user_data.get(oauth_claim, list())
# Nested claim search for groups claim
if oauth_claim:
claim_data = user_data
nested_claims = oauth_claim.split(".")
for nested_claim in nested_claims:
claim_data = claim_data.get(nested_claim, {})
user_oauth_groups = claim_data if isinstance(claim_data, list) else []
user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id)
all_available_groups: list[GroupModel] = Groups.get_groups()
log.debug(f"Oauth Groups claim: {oauth_claim}")
log.debug(f"User oauth groups: {user_oauth_groups}")
log.debug(f"User's current groups: {[g.name for g in user_current_groups]}")
log.debug(
f"All groups available in OpenWebUI: {[g.name for g in all_available_groups]}"
)
# Remove groups that user is no longer a part of
for group_model in user_current_groups:
if group_model.name not in user_oauth_groups:
# Remove group from user
log.debug(
f"Removing user from group {group_model.name} as it is no longer in their oauth groups"
)
user_ids = group_model.user_ids
user_ids = [i for i in user_ids if i != user.id]
@ -151,6 +190,9 @@ class OAuthManager:
gm.name == group_model.name for gm in user_current_groups
):
# Add user to group
log.debug(
f"Adding user to group {group_model.name} as it was found in their oauth groups"
)
user_ids = group_model.user_ids
user_ids.append(user.id)
@ -170,7 +212,7 @@ class OAuthManager:
id=group_model.id, form_data=update_form, overwrite=False
)
async def handle_login(self, provider, request):
async def handle_login(self, request, provider):
if provider not in OAUTH_PROVIDERS:
raise HTTPException(404)
# If the provider has a custom redirect URL, use that, otherwise automatically generate one
@ -182,7 +224,7 @@ class OAuthManager:
raise HTTPException(404)
return await client.authorize_redirect(request, redirect_uri)
async def handle_callback(self, provider, request, response):
async def handle_callback(self, request, provider, response):
if provider not in OAUTH_PROVIDERS:
raise HTTPException(404)
client = self.get_client(provider)
@ -192,7 +234,7 @@ class OAuthManager:
log.warning(f"OAuth callback error: {e}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
user_data: UserInfo = token.get("userinfo")
if not user_data:
if not user_data or auth_manager_config.OAUTH_EMAIL_CLAIM not in user_data:
user_data: UserInfo = await client.userinfo(token=token)
if not user_data:
log.warning(f"OAuth callback failed, user data is missing: {token}")
@ -204,11 +246,46 @@ class OAuthManager:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
provider_sub = f"{provider}@{sub}"
email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM
email = user_data.get(email_claim, "").lower()
email = user_data.get(email_claim, "")
# We currently mandate that email addresses are provided
if not email:
log.warning(f"OAuth callback failed, email is missing: {user_data}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
# If the provider is GitHub,and public email is not provided, we can use the access token to fetch the user's email
if provider == "github":
try:
access_token = token.get("access_token")
headers = {"Authorization": f"Bearer {access_token}"}
async with aiohttp.ClientSession() as session:
async with session.get(
"https://api.github.com/user/emails", headers=headers
) as resp:
if resp.ok:
emails = await resp.json()
# use the primary email as the user's email
primary_email = next(
(e["email"] for e in emails if e.get("primary")),
None,
)
if primary_email:
email = primary_email
else:
log.warning(
"No primary email found in GitHub response"
)
raise HTTPException(
400, detail=ERROR_MESSAGES.INVALID_CRED
)
else:
log.warning("Failed to fetch GitHub email")
raise HTTPException(
400, detail=ERROR_MESSAGES.INVALID_CRED
)
except Exception as e:
log.warning(f"Error fetching GitHub email: {e}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
else:
log.warning(f"OAuth callback failed, email is missing: {user_data}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
email = email.lower()
if (
"*" not in auth_manager_config.OAUTH_ALLOWED_DOMAINS
and email.split("@")[-1] not in auth_manager_config.OAUTH_ALLOWED_DOMAINS
@ -236,12 +313,12 @@ class OAuthManager:
Users.update_user_role_by_id(user.id, determined_role)
if not user:
user_count = Users.get_num_users()
# If the user does not exist, check if signups are enabled
if auth_manager_config.ENABLE_OAUTH_SIGNUP:
# Check if an existing user with the same email already exists
existing_user = Users.get_user_by_email(
user_data.get("email", "").lower()
)
existing_user = Users.get_user_by_email(email)
if existing_user:
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
@ -260,24 +337,35 @@ class OAuthManager:
}
async with aiohttp.ClientSession() as session:
async with session.get(picture_url, **get_kwargs) as resp:
picture = await resp.read()
base64_encoded_picture = base64.b64encode(
picture
).decode("utf-8")
guessed_mime_type = mimetypes.guess_type(picture_url)[0]
if guessed_mime_type is None:
# assume JPG, browsers are tolerant enough of image formats
guessed_mime_type = "image/jpeg"
picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}"
if resp.ok:
picture = await resp.read()
base64_encoded_picture = base64.b64encode(
picture
).decode("utf-8")
guessed_mime_type = mimetypes.guess_type(
picture_url
)[0]
if guessed_mime_type is None:
# assume JPG, browsers are tolerant enough of image formats
guessed_mime_type = "image/jpeg"
picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}"
else:
picture_url = "/user.png"
except Exception as e:
log.error(
f"Error downloading profile image '{picture_url}': {e}"
)
picture_url = ""
picture_url = "/user.png"
if not picture_url:
picture_url = "/user.png"
username_claim = auth_manager_config.OAUTH_USERNAME_CLAIM
name = user_data.get(username_claim)
if not name:
log.warning("Username claim is missing, using email as name")
name = email
role = self.get_user_role(None, user_data)
user = Auths.insert_new_auth(
@ -285,7 +373,7 @@ class OAuthManager:
password=get_password_hash(
str(uuid.uuid4())
), # Random password, not used
name=user_data.get(username_claim, "User"),
name=name,
profile_image_url=picture_url,
role=role,
oauth_sub=provider_sub,
@ -293,6 +381,7 @@ class OAuthManager:
if auth_manager_config.WEBHOOK_URL:
post_webhook(
WEBUI_NAME,
auth_manager_config.WEBHOOK_URL,
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
{
@ -323,8 +412,8 @@ class OAuthManager:
key="token",
value=jwt_token,
httponly=True, # Ensures the cookie is not accessible via JavaScript
samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
secure=WEBUI_SESSION_COOKIE_SECURE,
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
secure=WEBUI_AUTH_COOKIE_SECURE,
)
if ENABLE_OAUTH_SIGNUP.value:
@ -333,12 +422,9 @@ class OAuthManager:
key="oauth_id_token",
value=oauth_id_token,
httponly=True,
samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
secure=WEBUI_SESSION_COOKIE_SECURE,
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
secure=WEBUI_AUTH_COOKIE_SECURE,
)
# Redirect back to the frontend with the JWT token
redirect_url = f"{request.base_url}auth#token={jwt_token}"
return RedirectResponse(url=redirect_url, headers=response.headers)
oauth_manager = OAuthManager()

View File

@ -1,17 +1,27 @@
from open_webui.utils.task import prompt_template
from open_webui.utils.task import prompt_template, prompt_variables_template
from open_webui.utils.misc import (
add_or_update_system_message,
)
from typing import Callable, Optional
import json
# inplace function: form_data is modified
def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> dict:
def apply_model_system_prompt_to_body(
params: dict, form_data: dict, metadata: Optional[dict] = None, user=None
) -> dict:
system = params.get("system", None)
if not system:
return form_data
# Metadata (WebUI Usage)
if metadata:
variables = metadata.get("variables", {})
if variables:
system = prompt_variables_template(system, variables)
# Legacy (API Usage)
if user:
template_params = {
"user_name": user.name,
@ -19,7 +29,9 @@ def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> di
}
else:
template_params = {}
system = prompt_template(system, **template_params)
form_data["messages"] = add_or_update_system_message(
system, form_data.get("messages", [])
)
@ -50,43 +62,55 @@ def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict:
"reasoning_effort": str,
"seed": lambda x: x,
"stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
"logit_bias": lambda x: x,
}
return apply_model_params_to_body(params, form_data, mappings)
def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict:
opts = [
"temperature",
"top_p",
"seed",
"mirostat",
"mirostat_eta",
"mirostat_tau",
"num_ctx",
"num_batch",
"num_keep",
"repeat_last_n",
"tfs_z",
"top_k",
"min_p",
"use_mmap",
"use_mlock",
"num_thread",
"num_gpu",
]
mappings = {i: lambda x: x for i in opts}
form_data = apply_model_params_to_body(params, form_data, mappings)
# Convert OpenAI parameter names to Ollama parameter names if needed.
name_differences = {
"max_tokens": "num_predict",
"frequency_penalty": "repeat_penalty",
}
for key, value in name_differences.items():
if (param := params.get(key, None)) is not None:
form_data[value] = param
# Copy the parameter to new name then delete it, to prevent Ollama warning of invalid option provided
params[value] = params[key]
del params[key]
return form_data
# See https://github.com/ollama/ollama/blob/main/docs/api.md#request-8
mappings = {
"temperature": float,
"top_p": float,
"seed": lambda x: x,
"mirostat": int,
"mirostat_eta": float,
"mirostat_tau": float,
"num_ctx": int,
"num_batch": int,
"num_keep": int,
"num_predict": int,
"repeat_last_n": int,
"top_k": int,
"min_p": float,
"typical_p": float,
"repeat_penalty": float,
"presence_penalty": float,
"frequency_penalty": float,
"penalize_newline": bool,
"stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
"numa": bool,
"num_gpu": int,
"main_gpu": int,
"low_vram": bool,
"vocab_only": bool,
"use_mmap": bool,
"use_mlock": bool,
"num_thread": int,
}
return apply_model_params_to_body(params, form_data, mappings)
def convert_messages_openai_to_ollama(messages: list[dict]) -> list[dict]:
@ -97,11 +121,38 @@ def convert_messages_openai_to_ollama(messages: list[dict]) -> list[dict]:
new_message = {"role": message["role"]}
content = message.get("content", [])
tool_calls = message.get("tool_calls", None)
tool_call_id = message.get("tool_call_id", None)
# Check if the content is a string (just a simple message)
if isinstance(content, str):
if isinstance(content, str) and not tool_calls:
# If the content is a string, it's pure text
new_message["content"] = content
# If message is a tool call, add the tool call id to the message
if tool_call_id:
new_message["tool_call_id"] = tool_call_id
elif tool_calls:
# If tool calls are present, add them to the message
ollama_tool_calls = []
for tool_call in tool_calls:
ollama_tool_call = {
"index": tool_call.get("index", 0),
"id": tool_call.get("id", None),
"function": {
"name": tool_call.get("function", {}).get("name", ""),
"arguments": json.loads(
tool_call.get("function", {}).get("arguments", {})
),
},
}
ollama_tool_calls.append(ollama_tool_call)
new_message["tool_calls"] = ollama_tool_calls
# Put the content to empty string (Ollama requires an empty string for tool calls)
new_message["content"] = ""
else:
# Otherwise, assume the content is a list of dicts, e.g., text followed by an image URL
content_text = ""
@ -155,37 +206,38 @@ def convert_payload_openai_to_ollama(openai_payload: dict) -> dict:
)
ollama_payload["stream"] = openai_payload.get("stream", False)
if "tools" in openai_payload:
ollama_payload["tools"] = openai_payload["tools"]
if "format" in openai_payload:
ollama_payload["format"] = openai_payload["format"]
# If there are advanced parameters in the payload, format them in Ollama's options field
ollama_options = {}
if openai_payload.get("options"):
ollama_payload["options"] = openai_payload["options"]
ollama_options = openai_payload["options"]
# Handle parameters which map directly
for param in ["temperature", "top_p", "seed"]:
if param in openai_payload:
ollama_options[param] = openai_payload[param]
# Re-Mapping OpenAI's `max_tokens` -> Ollama's `num_predict`
if "max_tokens" in ollama_options:
ollama_options["num_predict"] = ollama_options["max_tokens"]
del ollama_options[
"max_tokens"
] # To prevent Ollama warning of invalid option provided
# Mapping OpenAI's `max_tokens` -> Ollama's `num_predict`
if "max_completion_tokens" in openai_payload:
ollama_options["num_predict"] = openai_payload["max_completion_tokens"]
elif "max_tokens" in openai_payload:
ollama_options["num_predict"] = openai_payload["max_tokens"]
# Ollama lacks a "system" prompt option. It has to be provided as a direct parameter, so we copy it down.
if "system" in ollama_options:
ollama_payload["system"] = ollama_options["system"]
del ollama_options[
"system"
] # To prevent Ollama warning of invalid option provided
# Handle frequency / presence_penalty, which needs renaming and checking
if "frequency_penalty" in openai_payload:
ollama_options["repeat_penalty"] = openai_payload["frequency_penalty"]
if "presence_penalty" in openai_payload and "penalty" not in ollama_options:
# We are assuming presence penalty uses a similar concept in Ollama, which needs custom handling if exists.
ollama_options["new_topic_penalty"] = openai_payload["presence_penalty"]
# Add options to payload if any have been set
if ollama_options:
# If there is the "stop" parameter in the openai_payload, remap it to the ollama_payload.options
if "stop" in openai_payload:
ollama_options = ollama_payload.get("options", {})
ollama_options["stop"] = openai_payload.get("stop")
ollama_payload["options"] = ollama_options
if "metadata" in openai_payload:
ollama_payload["metadata"] = openai_payload["metadata"]
return ollama_payload

View File

@ -2,6 +2,7 @@ from datetime import datetime
from io import BytesIO
from pathlib import Path
from typing import Dict, Any, List
from html import escape
from markdown import markdown
@ -41,13 +42,13 @@ class PDFGenerator:
def _build_html_message(self, message: Dict[str, Any]) -> str:
"""Build HTML for a single message."""
role = message.get("role", "user")
content = message.get("content", "")
role = escape(message.get("role", "user"))
content = escape(message.get("content", ""))
timestamp = message.get("timestamp")
model = message.get("model") if role == "assistant" else ""
model = escape(message.get("model") if role == "assistant" else "")
date_str = self.format_timestamp(timestamp) if timestamp else ""
date_str = escape(self.format_timestamp(timestamp) if timestamp else "")
# extends pymdownx extension to convert markdown to html.
# - https://facelessuser.github.io/pymdown-extensions/usage_notes/
@ -76,6 +77,7 @@ class PDFGenerator:
def _generate_html_body(self) -> str:
"""Generate the full HTML body for the PDF."""
escaped_title = escape(self.form_data.title)
return f"""
<html>
<head>
@ -84,7 +86,7 @@ class PDFGenerator:
<body>
<div>
<div>
<h2>{self.form_data.title}</h2>
<h2>{escaped_title}</h2>
{self.messages_html}
</div>
</div>
@ -108,7 +110,7 @@ class PDFGenerator:
# When running using `pip install -e .` the static directory is in the site packages.
# This path only works if `open-webui serve` is run from the root of this project.
if not FONTS_DIR.exists():
FONTS_DIR = Path("./backend/static/fonts")
FONTS_DIR = Path(".") / "backend" / "static" / "fonts"
pdf.add_font("NotoSans", "", f"{FONTS_DIR}/NotoSans-Regular.ttf")
pdf.add_font("NotoSans", "b", f"{FONTS_DIR}/NotoSans-Bold.ttf")

View File

@ -45,7 +45,7 @@ def extract_frontmatter(content):
frontmatter[key.strip()] = value.strip()
except Exception as e:
print(f"An error occurred: {e}")
log.exception(f"Failed to extract frontmatter: {e}")
return {}
return frontmatter
@ -167,9 +167,14 @@ def load_function_module_by_id(function_id, content=None):
def install_frontmatter_requirements(requirements):
if requirements:
req_list = [req.strip() for req in requirements.split(",")]
for req in req_list:
log.info(f"Installing requirement: {req}")
subprocess.check_call([sys.executable, "-m", "pip", "install", req])
try:
req_list = [req.strip() for req in requirements.split(",")]
for req in req_list:
log.info(f"Installing requirement: {req}")
subprocess.check_call([sys.executable, "-m", "pip", "install", req])
except Exception as e:
log.error(f"Error installing package: {req}")
raise e
else:
log.info("No requirements found in frontmatter.")

View File

@ -1,15 +1,101 @@
import json
from uuid import uuid4
from open_webui.utils.misc import (
openai_chat_chunk_message_template,
openai_chat_completion_message_template,
)
def convert_ollama_tool_call_to_openai(tool_calls: dict) -> dict:
openai_tool_calls = []
for tool_call in tool_calls:
openai_tool_call = {
"index": tool_call.get("index", 0),
"id": tool_call.get("id", f"call_{str(uuid4())}"),
"type": "function",
"function": {
"name": tool_call.get("function", {}).get("name", ""),
"arguments": json.dumps(
tool_call.get("function", {}).get("arguments", {})
),
},
}
openai_tool_calls.append(openai_tool_call)
return openai_tool_calls
def convert_ollama_usage_to_openai(data: dict) -> dict:
return {
"response_token/s": (
round(
(
(
data.get("eval_count", 0)
/ ((data.get("eval_duration", 0) / 10_000_000))
)
* 100
),
2,
)
if data.get("eval_duration", 0) > 0
else "N/A"
),
"prompt_token/s": (
round(
(
(
data.get("prompt_eval_count", 0)
/ ((data.get("prompt_eval_duration", 0) / 10_000_000))
)
* 100
),
2,
)
if data.get("prompt_eval_duration", 0) > 0
else "N/A"
),
"total_duration": data.get("total_duration", 0),
"load_duration": data.get("load_duration", 0),
"prompt_eval_count": data.get("prompt_eval_count", 0),
"prompt_tokens": int(
data.get("prompt_eval_count", 0)
), # This is the OpenAI compatible key
"prompt_eval_duration": data.get("prompt_eval_duration", 0),
"eval_count": data.get("eval_count", 0),
"completion_tokens": int(
data.get("eval_count", 0)
), # This is the OpenAI compatible key
"eval_duration": data.get("eval_duration", 0),
"approximate_total": (lambda s: f"{s // 3600}h{(s % 3600) // 60}m{s % 60}s")(
(data.get("total_duration", 0) or 0) // 1_000_000_000
),
"total_tokens": int( # This is the OpenAI compatible key
data.get("prompt_eval_count", 0) + data.get("eval_count", 0)
),
"completion_tokens_details": { # This is the OpenAI compatible key
"reasoning_tokens": 0,
"accepted_prediction_tokens": 0,
"rejected_prediction_tokens": 0,
},
}
def convert_response_ollama_to_openai(ollama_response: dict) -> dict:
model = ollama_response.get("model", "ollama")
message_content = ollama_response.get("message", {}).get("content", "")
tool_calls = ollama_response.get("message", {}).get("tool_calls", None)
openai_tool_calls = None
response = openai_chat_completion_message_template(model, message_content)
if tool_calls:
openai_tool_calls = convert_ollama_tool_call_to_openai(tool_calls)
data = ollama_response
usage = convert_ollama_usage_to_openai(data)
response = openai_chat_completion_message_template(
model, message_content, openai_tool_calls, usage
)
return response
@ -18,53 +104,21 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response)
data = json.loads(data)
model = data.get("model", "ollama")
message_content = data.get("message", {}).get("content", "")
message_content = data.get("message", {}).get("content", None)
tool_calls = data.get("message", {}).get("tool_calls", None)
openai_tool_calls = None
if tool_calls:
openai_tool_calls = convert_ollama_tool_call_to_openai(tool_calls)
done = data.get("done", False)
usage = None
if done:
usage = {
"response_token/s": (
round(
(
(
data.get("eval_count", 0)
/ ((data.get("eval_duration", 0) / 10_000_000))
)
* 100
),
2,
)
if data.get("eval_duration", 0) > 0
else "N/A"
),
"prompt_token/s": (
round(
(
(
data.get("prompt_eval_count", 0)
/ ((data.get("prompt_eval_duration", 0) / 10_000_000))
)
* 100
),
2,
)
if data.get("prompt_eval_duration", 0) > 0
else "N/A"
),
"total_duration": data.get("total_duration", 0),
"load_duration": data.get("load_duration", 0),
"prompt_eval_count": data.get("prompt_eval_count", 0),
"prompt_eval_duration": data.get("prompt_eval_duration", 0),
"eval_count": data.get("eval_count", 0),
"eval_duration": data.get("eval_duration", 0),
"approximate_total": (
lambda s: f"{s // 3600}h{(s % 3600) // 60}m{s % 60}s"
)((data.get("total_duration", 0) or 0) // 1_000_000_000),
}
usage = convert_ollama_usage_to_openai(data)
data = openai_chat_chunk_message_template(
model, message_content if not done else None, usage
model, message_content, openai_tool_calls, usage
)
line = f"data: {json.dumps(data)}\n\n"

View File

@ -22,7 +22,7 @@ def get_task_model_id(
# Set the task model
task_model_id = default_model_id
# Check if the user has a custom task model and use that model
if models[task_model_id]["owned_by"] == "ollama":
if models[task_model_id].get("owned_by") == "ollama":
if task_model and task_model in models:
task_model_id = task_model
else:
@ -32,6 +32,12 @@ def get_task_model_id(
return task_model_id
def prompt_variables_template(template: str, variables: dict[str, str]) -> str:
for variable, value in variables.items():
template = template.replace(variable, value)
return template
def prompt_template(
template: str, user_name: Optional[str] = None, user_location: Optional[str] = None
) -> str:
@ -98,7 +104,7 @@ def replace_prompt_variable(template: str, prompt: str) -> str:
def replace_messages_variable(
template: str, messages: Optional[list[str]] = None
template: str, messages: Optional[list[dict]] = None
) -> str:
def replacement_function(match):
full_match = match.group(0)

View File

@ -61,6 +61,12 @@ def get_tools(
)
for spec in tools.specs:
# TODO: Fix hack for OpenAI API
# Some times breaks OpenAI but others don't. Leaving the comment
for val in spec.get("parameters", {}).get("properties", {}).values():
if val["type"] == "str":
val["type"] = "string"
# Remove internal parameters
spec["parameters"]["properties"] = {
key: val
@ -73,6 +79,13 @@ def get_tools(
# convert to function that takes only model params and inserts custom params
original_func = getattr(module, function_name)
callable = apply_extra_params_to_tool_function(original_func, extra_params)
if callable.__doc__ and callable.__doc__.strip() != "":
s = re.split(":(param|return)", callable.__doc__, 1)
spec["description"] = s[0]
else:
spec["description"] = function_name
# TODO: This needs to be a pydantic model
tool_dict = {
"toolkit_id": tool_id,

View File

@ -2,14 +2,14 @@ import json
import logging
import requests
from open_webui.config import WEBUI_FAVICON_URL, WEBUI_NAME
from open_webui.config import WEBUI_FAVICON_URL
from open_webui.env import SRC_LOG_LEVELS, VERSION
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["WEBHOOK"])
def post_webhook(url: str, message: str, event_data: dict) -> bool:
def post_webhook(name: str, url: str, message: str, event_data: dict) -> bool:
try:
log.debug(f"post_webhook: {url}, {message}, {event_data}")
payload = {}
@ -39,7 +39,7 @@ def post_webhook(url: str, message: str, event_data: dict) -> bool:
"sections": [
{
"activityTitle": message,
"activitySubtitle": f"{WEBUI_NAME} ({VERSION}) - {action}",
"activitySubtitle": f"{name} ({VERSION}) - {action}",
"activityImage": WEBUI_FAVICON_URL,
"facts": facts,
"markdown": True,

View File

@ -1,29 +1,26 @@
fastapi==0.111.0
uvicorn[standard]==0.30.6
pydantic==2.9.2
fastapi==0.115.7
uvicorn[standard]==0.34.0
pydantic==2.10.6
python-multipart==0.0.18
Flask==3.1.0
Flask-Cors==5.0.0
python-socketio==5.11.3
python-jose==3.3.0
python-jose==3.4.0
passlib[bcrypt]==1.7.4
requests==2.32.3
aiohttp==3.11.8
aiohttp==3.11.11
async-timeout
aiocache
aiofiles
sqlalchemy==2.0.32
sqlalchemy==2.0.38
alembic==1.14.0
peewee==3.17.8
peewee==3.17.9
peewee-migrate==1.12.2
psycopg2-binary==2.9.9
pgvector==0.3.5
PyMySQL==1.1.1
bcrypt==4.2.0
bcrypt==4.3.0
pymongo
redis
@ -32,20 +29,27 @@ boto3==1.35.53
argon2-cffi==23.1.0
APScheduler==3.10.4
RestrictedPython==8.0
loguru==0.7.2
asgiref==3.8.1
# AI libraries
openai
anthropic
google-generativeai==0.7.2
tiktoken
langchain==0.3.7
langchain-community==0.3.7
langchain==0.3.19
langchain-community==0.3.18
fake-useragent==1.5.1
chromadb==0.6.2
pymilvus==2.5.0
qdrant-client~=1.12.0
opensearch-py==2.7.1
opensearch-py==2.8.0
playwright==1.49.1 # Caution: version must match docker-compose.playwright.yaml
elasticsearch==8.17.1
transformers
@ -57,10 +61,10 @@ einops==0.8.0
ftfy==6.2.3
pypdf==4.3.1
fpdf2==2.8.2
pymdown-extensions==10.11.2
pymdown-extensions==10.14.2
docx2txt==0.8
python-pptx==1.0.0
unstructured==0.15.9
unstructured==0.16.17
nltk==3.9.1
Markdown==3.7
pypandoc==1.13
@ -71,25 +75,26 @@ xlrd==2.0.1
validators==0.34.0
psutil
sentencepiece
soundfile==0.12.1
soundfile==0.13.1
azure-ai-documentintelligence==1.0.0
opencv-python-headless==4.10.0.84
opencv-python-headless==4.11.0.86
rapidocr-onnxruntime==1.3.24
rank-bm25==0.2.2
faster-whisper==1.0.3
faster-whisper==1.1.1
PyJWT[crypto]==2.10.1
authlib==1.3.2
authlib==1.4.1
black==24.8.0
black==25.1.0
langfuse==2.44.0
youtube-transcript-api==0.6.3
pytube==15.0.0
extract_msg
pydub
duckduckgo-search~=7.2.1
duckduckgo-search~=7.3.2
## Google Drive
google-api-python-client
@ -104,5 +109,12 @@ pytest-docker~=3.1.1
googleapis-common-protos==1.63.2
google-cloud-storage==2.19.0
azure-identity==1.20.0
azure-storage-blob==12.24.1
## LDAP
ldap3==2.9.1
## Firecrawl
firecrawl-py==1.12.0

View File

@ -3,6 +3,17 @@
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
cd "$SCRIPT_DIR" || exit
# Add conditional Playwright browser installation
if [[ "${RAG_WEB_LOADER_ENGINE,,}" == "playwright" ]]; then
if [[ -z "${PLAYWRIGHT_WS_URI}" ]]; then
echo "Installing Playwright browsers..."
playwright install chromium
playwright install-deps chromium
fi
python -c "import nltk; nltk.download('punkt_tab')"
fi
KEY_FILE=.webui_secret_key
PORT="${PORT:-8080}"

View File

@ -6,6 +6,17 @@ SETLOCAL ENABLEDELAYEDEXPANSION
SET "SCRIPT_DIR=%~dp0"
cd /d "%SCRIPT_DIR%" || exit /b
:: Add conditional Playwright browser installation
IF /I "%RAG_WEB_LOADER_ENGINE%" == "playwright" (
IF "%PLAYWRIGHT_WS_URI%" == "" (
echo Installing Playwright browsers...
playwright install chromium
playwright install-deps chromium
)
python -c "import nltk; nltk.download('punkt_tab')"
)
SET "KEY_FILE=.webui_secret_key"
IF "%PORT%"=="" SET PORT=8080
IF "%HOST%"=="" SET HOST=0.0.0.0

View File

@ -0,0 +1,10 @@
services:
playwright:
image: mcr.microsoft.com/playwright:v1.49.1-noble # Version must match requirements.txt
container_name: playwright
command: npx -y playwright@1.49.1 run-server --port 3000 --host 0.0.0.0
open-webui:
environment:
- 'RAG_WEB_LOADER_ENGINE=playwright'
- 'PLAYWRIGHT_WS_URI=ws://playwright:3000'

1602
package-lock.json generated

File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More